Add stdx utilities and rework async signature inference; tidy executable logging.
This commit is contained in:
parent
c30aa018dc
commit
9b7eea8ac2
1172
MODULE.bazel.lock
1172
MODULE.bazel.lock
File diff suppressed because it is too large
Load Diff
@ -1,26 +1,86 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type {
|
||||||
|
const params = @typeInfo(funcT).Fn.params;
|
||||||
|
if (params.len == 0) {
|
||||||
|
return @TypeOf(.{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (@typeInfo(funcT).Fn.is_generic == false) {
|
||||||
|
return std.meta.ArgsTuple(funcT);
|
||||||
|
}
|
||||||
|
|
||||||
|
const args = std.meta.fields(argsT orelse @compileError("generic function requires an explicit ArgsTuple"));
|
||||||
|
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
|
||||||
|
inline for (params, args, 0..) |param, arg, i| {
|
||||||
|
if (param.type == null) {
|
||||||
|
tuple_fields[i] = arg;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const T = param.type.?;
|
||||||
|
var num_buf: [32]u8 = undefined;
|
||||||
|
tuple_fields[i] = .{
|
||||||
|
.name = blk: {
|
||||||
|
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
|
||||||
|
num_buf[s] = 0;
|
||||||
|
break :blk num_buf[0..s :0];
|
||||||
|
},
|
||||||
|
.type = T,
|
||||||
|
.default_value = null,
|
||||||
|
.is_comptime = false,
|
||||||
|
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return @Type(.{
|
||||||
|
.Struct = .{
|
||||||
|
.is_tuple = true,
|
||||||
|
.layout = .auto,
|
||||||
|
.decls = &.{},
|
||||||
|
.fields = &tuple_fields,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn TupleRange(comptime T: type, comptime start: usize, comptime end: usize) type {
|
||||||
|
const fields = std.meta.fields(T);
|
||||||
|
var new_fields: [end - start]std.builtin.Type.StructField = undefined;
|
||||||
|
inline for (start..end, 0..) |i, j| {
|
||||||
|
var new_field = fields[i];
|
||||||
|
var num_buf: [32]u8 = undefined;
|
||||||
|
new_field.name = blk: {
|
||||||
|
const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{});
|
||||||
|
num_buf[s] = 0;
|
||||||
|
break :blk num_buf[0..s :0];
|
||||||
|
};
|
||||||
|
new_fields[j] = new_field;
|
||||||
|
}
|
||||||
|
return @Type(.{
|
||||||
|
.Struct = .{
|
||||||
|
.is_tuple = true,
|
||||||
|
.layout = .auto,
|
||||||
|
.decls = &.{},
|
||||||
|
.fields = &new_fields,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
|
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
|
||||||
|
return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FnSignatureX(comptime func: anytype, comptime argsT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
|
pub const FuncT = @TypeOf(func);
|
||||||
pub const ArgsT = blk: {
|
pub const ArgsT = argsT;
|
||||||
if (@typeInfo(FuncT).Fn.params.len == 0) {
|
|
||||||
break :blk @TypeOf(.{});
|
|
||||||
}
|
|
||||||
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
|
|
||||||
};
|
|
||||||
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
|
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
|
||||||
pub const ReturnPayloadT = blk: {
|
pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) {
|
||||||
break :blk switch (@typeInfo(ReturnT)) {
|
.ErrorUnion => |u| u.payload,
|
||||||
.ErrorUnion => |u| u.payload,
|
else => ReturnT,
|
||||||
else => ReturnT,
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
pub const ReturnErrorSet: ?type = blk: {
|
pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) {
|
||||||
break :blk switch (@typeInfo(ReturnT)) {
|
.ErrorUnion => |u| u.error_set,
|
||||||
.ErrorUnion => |u| u.error_set,
|
else => null,
|
||||||
else => null,
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,16 +2,22 @@ const std = @import("std");
|
|||||||
const xev = @import("xev");
|
const xev = @import("xev");
|
||||||
|
|
||||||
const FnSignature = @import("meta.zig").FnSignature;
|
const FnSignature = @import("meta.zig").FnSignature;
|
||||||
|
const NormalizedTuple = @import("meta.zig").NormalizedTuple;
|
||||||
|
|
||||||
pub fn Frame(comptime func: anytype) type {
|
pub fn Frame(comptime func: anytype) type {
|
||||||
const Signature = FnSignature(func, null);
|
const Signature = FnSignature(func, null);
|
||||||
return FrameEx(func, Signature.ArgsT);
|
return FrameExx(func, Signature);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
|
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
|
||||||
|
const Signature = FnSignature(func, argsT);
|
||||||
|
return FrameExx(func, Signature);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FrameExx(comptime func: anytype, comptime Signature: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
const Self = @This();
|
const Self = @This();
|
||||||
const Signature = FnSignature(func, argsT);
|
const Signature_ = Signature;
|
||||||
const Task = struct {
|
const Task = struct {
|
||||||
_task: xev.ThreadPool.Task = .{ .callback = &Self.run },
|
_task: xev.ThreadPool.Task = .{ .callback = &Self.run },
|
||||||
event: std.Thread.ResetEvent = .{},
|
event: std.Thread.ResetEvent = .{},
|
||||||
@ -27,7 +33,8 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
|
|||||||
task.event.set();
|
task.event.set();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn await_(self: *Self) Signature.ReturnT {
|
pub const await_ = wait;
|
||||||
|
pub fn wait(self: *Self) Signature.ReturnT {
|
||||||
defer {
|
defer {
|
||||||
AsyncThread.current.mutex.lock();
|
AsyncThread.current.mutex.lock();
|
||||||
AsyncThread.current.allocator.destroy(self._task);
|
AsyncThread.current.allocator.destroy(self._task);
|
||||||
@ -39,11 +46,7 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) {
|
pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
|
||||||
return asyncGeneric(func, args);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
|
|
||||||
const FrameT = FrameEx(func, @TypeOf(args));
|
const FrameT = FrameEx(func, @TypeOf(args));
|
||||||
|
|
||||||
AsyncThread.current.mutex.lock();
|
AsyncThread.current.mutex.lock();
|
||||||
@ -58,15 +61,11 @@ pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeO
|
|||||||
return .{ ._task = task };
|
return .{ ._task = task };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) {
|
pub inline fn callBlocking(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
|
||||||
return callBlockingGeneric(func, args);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
|
|
||||||
return @call(.auto, func, args);
|
return @call(.auto, func, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sleep(ms: u64) !void {
|
pub inline fn sleep(ms: u64) !void {
|
||||||
std.time.sleep(ms * std.time.ns_per_ms);
|
std.time.sleep(ms * std.time.ns_per_ms);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ pub const AsyncThread = struct {
|
|||||||
thread_pool: xev.ThreadPool,
|
thread_pool: xev.ThreadPool,
|
||||||
mutex: std.Thread.Mutex,
|
mutex: std.Thread.Mutex,
|
||||||
|
|
||||||
pub fn main(allocator_: std.mem.Allocator, comptime func: anytype, args: anytype) !void {
|
pub fn main(allocator_: std.mem.Allocator, comptime mainFunc: anytype) !void {
|
||||||
current = .{
|
current = .{
|
||||||
.allocator = allocator_,
|
.allocator = allocator_,
|
||||||
.thread_pool = xev.ThreadPool.init(.{}),
|
.thread_pool = xev.ThreadPool.init(.{}),
|
||||||
@ -89,7 +88,7 @@ pub const AsyncThread = struct {
|
|||||||
current.thread_pool.deinit();
|
current.thread_pool.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
return @call(.auto, func, args);
|
return try mainFunc();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -114,15 +113,15 @@ pub const Notification = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn StdIn() !File {
|
pub fn getStdIn() !File {
|
||||||
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn StdOut() File {
|
pub fn getStdOut() File {
|
||||||
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
|
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn StdErr() File {
|
pub fn getStdErr() File {
|
||||||
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
|
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,3 +216,23 @@ pub const File = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const Mutex = std.Thread.Mutex;
|
pub const Mutex = std.Thread.Mutex;
|
||||||
|
|
||||||
|
pub fn logFn(
|
||||||
|
comptime message_level: std.log.Level,
|
||||||
|
comptime scope: @Type(.EnumLiteral),
|
||||||
|
comptime format: []const u8,
|
||||||
|
args: anytype,
|
||||||
|
) void {
|
||||||
|
const level_txt = comptime message_level.asText();
|
||||||
|
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
|
||||||
|
const stderr = getStdErr().writer();
|
||||||
|
var bw = std.io.bufferedWriter(stderr);
|
||||||
|
const writer = bw.writer();
|
||||||
|
|
||||||
|
std.debug.lockStdErr();
|
||||||
|
defer std.debug.unlockStdErr();
|
||||||
|
nosuspend {
|
||||||
|
writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return;
|
||||||
|
bw.flush() catch return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -504,13 +504,14 @@ pub const LoadedExecutable = opaque {
|
|||||||
return @ptrCast(ret.addressable_devices);
|
return @ptrCast(ret.addressable_devices);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
pub const ExecuteArgs = struct {
|
||||||
num_args: usize,
|
num_args: usize,
|
||||||
arguments: []const [*]const *const Buffer,
|
arguments: []const [*]const *const Buffer,
|
||||||
results: []const [*]*Buffer,
|
results: []const [*]*Buffer,
|
||||||
events: []?*Event,
|
events: []?*Event,
|
||||||
non_donatable_input_indices: []const i64 = &.{},
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
}) ApiError!void {
|
};
|
||||||
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
||||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||||
.send_callbacks = null,
|
.send_callbacks = null,
|
||||||
.recv_callbacks = null,
|
.recv_callbacks = null,
|
||||||
|
|||||||
@ -2,7 +2,7 @@ const std = @import("std");
|
|||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const tsl_proto = @import("//tsl:profiler_options_proto");
|
const tsl_proto = @import("//tsl:profiler_options_proto");
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_profiler);
|
const log = std.log.scoped(.@"zml/profiler");
|
||||||
|
|
||||||
/// Pjrt Profiler extension
|
/// Pjrt Profiler extension
|
||||||
pub const Profiler = struct {
|
pub const Profiler = struct {
|
||||||
|
|||||||
13
stdx/BUILD.bazel
Normal file
13
stdx/BUILD.bazel
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "stdx",
|
||||||
|
srcs = [
|
||||||
|
"debug.zig",
|
||||||
|
"math.zig",
|
||||||
|
"meta.zig",
|
||||||
|
"signature.zig",
|
||||||
|
],
|
||||||
|
main = "stdx.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
33
stdx/debug.zig
Normal file
33
stdx/debug.zig
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub inline fn guard(check: bool, src: std.builtin.SourceLocation) void {
|
||||||
|
assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void {
|
||||||
|
assert(check, "internal error: " ++ msg, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn assert(check: bool, comptime msg: []const u8, args: anytype) void {
|
||||||
|
if (!check) {
|
||||||
|
panic(msg, args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn panic(comptime format: []const u8, args: anytype) noreturn {
|
||||||
|
std.debug.panic(format, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn compileLog(comptime msg: []const u8, comptime args: anytype) void {
|
||||||
|
@compileLog(std.fmt.comptimePrint(msg, args));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn {
|
||||||
|
@compileError(std.fmt.comptimePrint(msg, args));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void {
|
||||||
|
if (check == false) {
|
||||||
|
compileError(msg, args);
|
||||||
|
}
|
||||||
|
}
|
||||||
25
stdx/math.zig
Normal file
25
stdx/math.zig
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
pub inline fn divFloor(comptime T: type, numerator: anytype, denominator: anytype) T {
|
||||||
|
return @divFloor(floatCast(T, numerator), floatCast(T, denominator));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn divExact(comptime T: type, numerator: anytype, denominator: anytype) T {
|
||||||
|
return @divExact(floatCast(T, numerator), floatCast(T, denominator));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn divTrunc(comptime T: type, numerator: anytype, denominator: anytype) T {
|
||||||
|
return @divTrunc(floatCast(T, numerator), floatCast(T, denominator));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn floatCast(comptime T: type, x: anytype) T {
|
||||||
|
return switch (@typeInfo(@TypeOf(x))) {
|
||||||
|
.Float => @floatCast(x),
|
||||||
|
else => @floatFromInt(x),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn intCast(comptime T: type, x: anytype) T {
|
||||||
|
return switch (@typeInfo(@TypeOf(x))) {
|
||||||
|
.Int => @intCast(x),
|
||||||
|
else => @intFromFloat(x),
|
||||||
|
};
|
||||||
|
}
|
||||||
158
stdx/meta.zig
Normal file
158
stdx/meta.zig
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const debug = @import("debug.zig");
|
||||||
|
|
||||||
|
const compileError = debug.compileError;
|
||||||
|
|
||||||
|
pub const FnSignature = @import("signature.zig").FnSignature;
|
||||||
|
|
||||||
|
pub fn isStruct(comptime T: type) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Struct => true,
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isTuple(comptime T: type) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Struct => |info| info.is_tuple,
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Struct => |info| blk: {
|
||||||
|
inline for (info.fields) |field| {
|
||||||
|
if (field.type != Elem) {
|
||||||
|
break :blk false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break :blk true;
|
||||||
|
},
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Struct => |info| blk: {
|
||||||
|
inline for (info.fields) |field| {
|
||||||
|
if (f(field.type) == false) {
|
||||||
|
break :blk false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break :blk true;
|
||||||
|
},
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isTupleOf(comptime T: type, comptime Elem: type) bool {
|
||||||
|
return isTuple(T) and isStructOf(T, Elem);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
||||||
|
return isTuple(T) and isStructOfAny(T, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Pointer => |info| switch (info.size) {
|
||||||
|
.Slice => info.child == Elem,
|
||||||
|
.One => switch (@typeInfo(info.child)) {
|
||||||
|
// As Zig, convert pointer to Array as a slice.
|
||||||
|
.Array => |arr_info| arr_info.child == Elem,
|
||||||
|
else => false,
|
||||||
|
},
|
||||||
|
else => false,
|
||||||
|
},
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isInteger(comptime T: type) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Int, .ComptimeInt => true,
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Pointer => |info| info.size == .Slice and f(info.child),
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn DeclEnum(comptime T: type) type {
|
||||||
|
const field_infos = std.meta.declarations(T);
|
||||||
|
if (field_infos.len == 0) {
|
||||||
|
compileError("Struct {} has no declarations", .{T});
|
||||||
|
}
|
||||||
|
return std.meta.DeclEnum(UnwrapPtr(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn UnwrapPtr(comptime T: type) type {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Pointer => |info| switch (info.size) {
|
||||||
|
.One => info.child,
|
||||||
|
else => T,
|
||||||
|
},
|
||||||
|
else => T,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn asSlice(comptime T: type) type {
|
||||||
|
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.Pointer => |info| switch (info.size) {
|
||||||
|
.Slice => info.child,
|
||||||
|
.One => switch (@typeInfo(info.child)) {
|
||||||
|
// As Zig, convert pointer to Array as a slice.
|
||||||
|
.Array => |arr_info| arr_info.child,
|
||||||
|
else => compileError(err_msg),
|
||||||
|
},
|
||||||
|
else => compileError(err_msg),
|
||||||
|
},
|
||||||
|
else => compileError(err_msg),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn TupleRange(comptime T: type, comptime start: ?usize, comptime end: ?usize) type {
|
||||||
|
return TupleRangeX(T, start orelse 0, end orelse std.meta.fields(T).len);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn TupleRangeX(comptime T: type, comptime start: usize, comptime end: usize) type {
|
||||||
|
const fields = std.meta.fields(T);
|
||||||
|
var new_fields: [end - start]std.builtin.Type.StructField = undefined;
|
||||||
|
inline for (start..end, 0..) |i, j| {
|
||||||
|
var new_field = fields[i];
|
||||||
|
var num_buf: [32]u8 = undefined;
|
||||||
|
new_field.name = blk: {
|
||||||
|
const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{});
|
||||||
|
num_buf[s] = 0;
|
||||||
|
break :blk num_buf[0..s :0];
|
||||||
|
};
|
||||||
|
new_fields[j] = new_field;
|
||||||
|
}
|
||||||
|
return @Type(.{
|
||||||
|
.Struct = .{
|
||||||
|
.is_tuple = true,
|
||||||
|
.layout = .auto,
|
||||||
|
.decls = &.{},
|
||||||
|
.fields = &new_fields,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type {
|
||||||
|
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse compileError("anytype is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FnArgs(comptime func: anytype) type {
|
||||||
|
return FnSignature(func, null).ArgsT;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FnResult(comptime func: anytype) type {
|
||||||
|
return FnSignature(func, null).ReturnT;
|
||||||
|
}
|
||||||
65
stdx/signature.zig
Normal file
65
stdx/signature.zig
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
const compileError = @import("meta.zig").compileError;
|
||||||
|
|
||||||
|
pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type {
|
||||||
|
const params = @typeInfo(funcT).Fn.params;
|
||||||
|
if (params.len == 0) {
|
||||||
|
return @TypeOf(.{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (@typeInfo(funcT).Fn.is_generic == false) {
|
||||||
|
return std.meta.ArgsTuple(funcT);
|
||||||
|
}
|
||||||
|
|
||||||
|
const args = std.meta.fields(argsT orelse compileError("generic function requires an explicit ArgsTuple", .{}));
|
||||||
|
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
|
||||||
|
inline for (params, args, 0..) |param, arg, i| {
|
||||||
|
if (param.type == null) {
|
||||||
|
tuple_fields[i] = arg;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const T = param.type.?;
|
||||||
|
var num_buf: [32]u8 = undefined;
|
||||||
|
tuple_fields[i] = .{
|
||||||
|
.name = blk: {
|
||||||
|
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
|
||||||
|
num_buf[s] = 0;
|
||||||
|
break :blk num_buf[0..s :0];
|
||||||
|
},
|
||||||
|
.type = T,
|
||||||
|
.default_value = null,
|
||||||
|
.is_comptime = false,
|
||||||
|
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return @Type(.{
|
||||||
|
.Struct = .{
|
||||||
|
.is_tuple = true,
|
||||||
|
.layout = .auto,
|
||||||
|
.decls = &.{},
|
||||||
|
.fields = &tuple_fields,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
|
||||||
|
return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn FnSignatureX(comptime func: anytype, comptime argsT: type) type {
|
||||||
|
return struct {
|
||||||
|
pub const FuncT = @TypeOf(func);
|
||||||
|
pub const ArgsT = argsT;
|
||||||
|
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
|
||||||
|
pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) {
|
||||||
|
.ErrorUnion => |u| u.payload,
|
||||||
|
else => ReturnT,
|
||||||
|
};
|
||||||
|
pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) {
|
||||||
|
.ErrorUnion => |u| u.error_set,
|
||||||
|
else => null,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
3
stdx/stdx.zig
Normal file
3
stdx/stdx.zig
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub const math = @import("math.zig");
|
||||||
|
pub const meta = @import("meta.zig");
|
||||||
|
pub const debug = @import("debug.zig");
|
||||||
@ -32,6 +32,7 @@ zig_library(
|
|||||||
"//mlir/dialects",
|
"//mlir/dialects",
|
||||||
"//pjrt",
|
"//pjrt",
|
||||||
"//runtimes",
|
"//runtimes",
|
||||||
|
"//stdx",
|
||||||
"//zml/tools",
|
"//zml/tools",
|
||||||
"@rules_zig//zig/lib:libc",
|
"@rules_zig//zig/lib:libc",
|
||||||
"@rules_zig//zig/runfiles",
|
"@rules_zig//zig/runfiles",
|
||||||
|
|||||||
18
zml/aio.zig
18
zml/aio.zig
@ -1,8 +1,10 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const std = @import("std");
|
const builtin = @import("builtin");
|
||||||
const zml = @import("zml.zig");
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const zml = @import("zml.zig");
|
||||||
const posix = @import("posix.zig");
|
const posix = @import("posix.zig");
|
||||||
|
|
||||||
pub const gguf = @import("aio/gguf.zig");
|
pub const gguf = @import("aio/gguf.zig");
|
||||||
@ -13,7 +15,7 @@ pub const tinyllama = @import("aio/tinyllama.zig");
|
|||||||
pub const torch = @import("aio/torch.zig");
|
pub const torch = @import("aio/torch.zig");
|
||||||
pub const yaml = @import("aio/yaml.zig");
|
pub const yaml = @import("aio/yaml.zig");
|
||||||
|
|
||||||
pub const log = std.log.scoped(.zml_aio);
|
pub const log = std.log.scoped(.@"zml/aio");
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -256,7 +258,11 @@ pub const MemoryMappedFile = struct {
|
|||||||
0,
|
0,
|
||||||
});
|
});
|
||||||
|
|
||||||
try asynk.callBlocking(posix.madvise, .{ data_.ptr, @intCast(data_.len), @intCast(c.MADV_SEQUENTIAL) });
|
try asynk.callBlocking(posix.madvise, .{
|
||||||
|
data_.ptr,
|
||||||
|
@as(usize, @intCast(data_.len)),
|
||||||
|
@as(u32, @intCast(c.MADV_SEQUENTIAL)),
|
||||||
|
});
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
.file = file,
|
.file = file,
|
||||||
@ -600,7 +606,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
||||||
var buf_with_metadata = host_buffer;
|
var buf_with_metadata = host_buffer;
|
||||||
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
||||||
zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
||||||
buf_with_metadata._shape = obj._shape;
|
buf_with_metadata._shape = obj._shape;
|
||||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -8,7 +8,7 @@ const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
|
|||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_io);
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
var file = try core.GgufFile.open(path);
|
var file = try core.GgufFile.open(path);
|
||||||
|
|||||||
@ -3,7 +3,7 @@ const std = @import("std");
|
|||||||
const zml = @import("../../zml.zig");
|
const zml = @import("../../zml.zig");
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const log = std.log.scoped(.zml_io);
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub const GgufErrors = error{
|
pub const GgufErrors = error{
|
||||||
ValueTypeMismatch,
|
ValueTypeMismatch,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const log = std.log.scoped(.zml_aio);
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const yaml = @import("zig-yaml");
|
const yaml = @import("zig-yaml");
|
||||||
|
|||||||
@ -7,7 +7,7 @@ const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
|||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
const log = std.log.scoped(.zml_io);
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
var res: zml.aio.BufferStore = .{
|
var res: zml.aio.BufferStore = .{
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
/// Tools to load models from https://huggingface.co/karpathy/tinyllamas/
|
/// Tools to load models from https://huggingface.co/karpathy/tinyllamas/
|
||||||
/// Originally made to be run with https://github.com/karpathy/llama2.c
|
/// Originally made to be run with https://github.com/karpathy/llama2.c
|
||||||
const std = @import("std");
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("../zml.zig");
|
const zml = @import("../zml.zig");
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe
|
|||||||
const weights_size = off;
|
const weights_size = off;
|
||||||
std.log.info("Loaded a tinyllama file of {} bytes.\nThis is the parsed configuration of this llama model: {}", .{ weights_size, c });
|
std.log.info("Loaded a tinyllama file of {} bytes.\nThis is the parsed configuration of this llama model: {}", .{ weights_size, c });
|
||||||
if (file.stat() catch null) |stat| {
|
if (file.stat() catch null) |stat| {
|
||||||
zml.meta.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c });
|
stdx.debug.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c });
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@ -7,7 +7,7 @@ const py = @import("torch/py.zig");
|
|||||||
const File = @import("torch/file.zig").File;
|
const File = @import("torch/file.zig").File;
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||||
const log = std.log.scoped(.zml_aio);
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const zml = @import("../../zml.zig");
|
const stdx = @import("stdx");
|
||||||
const meta = zml.meta;
|
|
||||||
|
|
||||||
const py = @import("py.zig");
|
const py = @import("py.zig");
|
||||||
const pickle = @import("pickle.zig");
|
const pickle = @import("pickle.zig");
|
||||||
@ -228,7 +227,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
|
.proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
|
||||||
.tuple1 => try stack.append(blk: {
|
.tuple1 => try stack.append(blk: {
|
||||||
const tup_values = try arena.alloc(py.Any, 1);
|
const tup_values = try arena.alloc(py.Any, 1);
|
||||||
tup_values[0] = try pop(&stack);
|
tup_values[0] = try pop(&stack);
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
const std = @import("std");
|
|
||||||
const testing = std.testing;
|
|
||||||
const log = std.log.scoped(.zml_aio);
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("../../zml.zig");
|
const zml = @import("../../zml.zig");
|
||||||
const pickle = @import("pickle.zig");
|
const pickle = @import("pickle.zig");
|
||||||
@ -10,6 +8,9 @@ const py = @import("py.zig");
|
|||||||
const eval = @import("eval.zig");
|
const eval = @import("eval.zig");
|
||||||
const HostBuffer = zml.HostBuffer;
|
const HostBuffer = zml.HostBuffer;
|
||||||
|
|
||||||
|
const testing = std.testing;
|
||||||
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||||
|
|
||||||
@ -329,7 +330,7 @@ pub const File = struct {
|
|||||||
},
|
},
|
||||||
.dict => {
|
.dict => {
|
||||||
const n = @divExact(seq.values.len, 2);
|
const n = @divExact(seq.values.len, 2);
|
||||||
log.info("found dict with {} entries", .{n});
|
log.debug("found dict with {} entries", .{n});
|
||||||
for (0..n) |i| {
|
for (0..n) |i| {
|
||||||
const key, const val = seq.values[2 * i ..][0..2].*;
|
const key, const val = seq.values[2 * i ..][0..2].*;
|
||||||
switch (key) {
|
switch (key) {
|
||||||
@ -534,7 +535,7 @@ pub const File = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
|
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
|
||||||
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
|
stdx.debug.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
|
||||||
var result: zml.Shape.DimsArray = .{};
|
var result: zml.Shape.DimsArray = .{};
|
||||||
for (values) |val| {
|
for (values) |val| {
|
||||||
switch (val) {
|
switch (val) {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_aio);
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
/// All possible pickle operators.
|
/// All possible pickle operators.
|
||||||
/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py
|
/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const math = std.math;
|
const math = std.math;
|
||||||
const log = std.log.scoped(.zml_aio);
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
const pickle = @import("pickle.zig");
|
const pickle = @import("pickle.zig");
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
const asynk = @import("async");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const testing = std.testing;
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
const asynk = @import("async");
|
|
||||||
|
const testing = std.testing;
|
||||||
|
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
@ -42,12 +44,12 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
// We shard only on the first axis so that the chunks are still contiguous.
|
// We shard only on the first axis so that the chunks are still contiguous.
|
||||||
// TODO: support more advanced sharding specs
|
// TODO: support more advanced sharding specs
|
||||||
meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
||||||
const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info);
|
const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info);
|
||||||
const n_partitions = platform.sharding().num_partitions;
|
const n_partitions = platform.sharding().num_partitions;
|
||||||
const chunk_size = if (sharding_ax) |ax| cs: {
|
const chunk_size = if (sharding_ax) |ax| cs: {
|
||||||
// This kind of sharding error should be detected earlier on.
|
// This kind of sharding error should be detected earlier on.
|
||||||
meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
|
stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
|
||||||
break :cs @divExact(host_buffer.dim(ax), n_partitions);
|
break :cs @divExact(host_buffer.dim(ax), n_partitions);
|
||||||
} else 0;
|
} else 0;
|
||||||
|
|
||||||
@ -88,8 +90,8 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||||
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||||
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||||
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
|
stdx.debug.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
|
||||||
var shards: Shards = .{};
|
var shards: Shards = .{};
|
||||||
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
||||||
return .{
|
return .{
|
||||||
@ -190,9 +192,9 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||||
pub fn getValue(self: Buffer, T: type) !T {
|
pub fn getValue(self: Buffer, T: type) !T {
|
||||||
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||||
var res: T = undefined;
|
var res: T = undefined;
|
||||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
||||||
if (maybe_event) |event| {
|
if (maybe_event) |event| {
|
||||||
try event.await_(self._api);
|
try event.await_(self._api);
|
||||||
@ -204,7 +206,7 @@ pub const Buffer = struct {
|
|||||||
/// and return a new `HostBuffer` object with the same shape.
|
/// and return a new `HostBuffer` object with the same shape.
|
||||||
/// The returned `HostBuffer` doesn't own the memory.
|
/// The returned `HostBuffer` doesn't own the memory.
|
||||||
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
|
pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
|
||||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
|
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
|
||||||
if (maybe_event) |event| {
|
if (maybe_event) |event| {
|
||||||
try event.await_(self._api);
|
try event.await_(self._api);
|
||||||
@ -216,7 +218,7 @@ pub const Buffer = struct {
|
|||||||
/// The returned `HostBuffer` does own the memory.
|
/// The returned `HostBuffer` does own the memory.
|
||||||
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
||||||
const output = try HostBuffer.empty(allocator, self.shape());
|
const output = try HostBuffer.empty(allocator, self.shape());
|
||||||
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
||||||
if (maybe_event) |event| {
|
if (maybe_event) |event| {
|
||||||
try event.await_(self._api);
|
try event.await_(self._api);
|
||||||
|
|||||||
@ -1,22 +1,22 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
|
||||||
const mlir = @import("mlir");
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
const mlir = @import("mlir");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const platform = @import("platform.zig");
|
const platform = @import("platform.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
|
|
||||||
const available_targets = @import("platform.zig").available_targets;
|
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
const Target = @import("platform.zig").Target;
|
|
||||||
const Platform = @import("platform.zig").Platform;
|
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
|
||||||
|
|
||||||
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
||||||
|
const Platform = @import("platform.zig").Platform;
|
||||||
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
||||||
|
const Target = @import("platform.zig").Target;
|
||||||
|
|
||||||
|
const available_targets = @import("platform.zig").available_targets;
|
||||||
|
const log = std.log.scoped(.@"zml/context");
|
||||||
|
|
||||||
/// Every program using ZML must start with a `zml.Context.init(.{});`
|
/// Every program using ZML must start with a `zml.Context.init(.{});`
|
||||||
/// The ZML context contains global state to interact with the different
|
/// The ZML context contains global state to interact with the different
|
||||||
@ -145,6 +145,36 @@ pub const Context = struct {
|
|||||||
return platform_ orelse @panic("No platform found !");
|
return platform_ orelse @panic("No platform found !");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void {
|
||||||
|
// List available targets
|
||||||
|
log.info("Available Platforms:", .{});
|
||||||
|
const selected_prefix = "✅";
|
||||||
|
const not_selected_prefix = "• ";
|
||||||
|
const selected_postfix = "(AUTO-SELECTED)";
|
||||||
|
const not_selected_postfix = "";
|
||||||
|
|
||||||
|
for (platform.available_targets) |target| {
|
||||||
|
log.info(" {s} {s} {s}", .{
|
||||||
|
if (target == selected.target) selected_prefix else not_selected_prefix,
|
||||||
|
@tagName(target),
|
||||||
|
if (target == selected.target) selected_postfix else not_selected_postfix,
|
||||||
|
});
|
||||||
|
|
||||||
|
// now the platform's devices
|
||||||
|
if (self.platforms.get(target)) |pfm| {
|
||||||
|
for (pfm.getDevices(), 0..) |device, index| {
|
||||||
|
const deviceKind = device.getDescription(pfm.pjrt_api).getKind(pfm.pjrt_api);
|
||||||
|
log.info(" ◦ #{d}: {s}", .{
|
||||||
|
index,
|
||||||
|
deviceKind,
|
||||||
|
});
|
||||||
|
// we only list 1 CPU device
|
||||||
|
if (target == .cpu) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub const HostCallbackCtx = struct {
|
pub const HostCallbackCtx = struct {
|
||||||
host: HostBuffer,
|
host: HostBuffer,
|
||||||
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
||||||
|
|||||||
@ -6,7 +6,7 @@ const Shape = @import("shape.zig").Shape;
|
|||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
|
||||||
const EnumLiteral = @TypeOf(.enum_literal);
|
const EnumLiteral = @TypeOf(.enum_literal);
|
||||||
const log = std.log.scoped(.zml_tensor);
|
const log = std.log.scoped(.@"zml/tensor");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
@ -108,13 +108,13 @@ pub const HostBuffer = struct {
|
|||||||
/// The memory is initialized with increasing numbers.
|
/// The memory is initialized with increasing numbers.
|
||||||
/// The caller owns the memory, and need to call `deinit()`.
|
/// The caller owns the memory, and need to call `deinit()`.
|
||||||
pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
|
pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
|
||||||
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||||
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||||
|
|
||||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||||
const b = dt.sizeOf();
|
const b = dt.sizeOf();
|
||||||
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
||||||
meta.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
||||||
var data_ = @constCast(res.data);
|
var data_ = @constCast(res.data);
|
||||||
switch (dt) {
|
switch (dt) {
|
||||||
inline else => {
|
inline else => {
|
||||||
@ -201,7 +201,7 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
||||||
meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.reshape(shape_);
|
res._shape = self._shape.reshape(shape_);
|
||||||
return res;
|
return res;
|
||||||
@ -219,9 +219,9 @@ pub const HostBuffer = struct {
|
|||||||
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
||||||
var end = s.end orelse d;
|
var end = s.end orelse d;
|
||||||
if (end < 0) end += d;
|
if (end < 0) end += d;
|
||||||
meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
|
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
|
||||||
meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
|
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
|
||||||
meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
|
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
|
||||||
|
|
||||||
// If strides weren't set it means original buffer is contiguous.
|
// If strides weren't set it means original buffer is contiguous.
|
||||||
// But it won't be anymore after slicing. The strides don't change though.
|
// But it won't be anymore after slicing. The strides don't change though.
|
||||||
|
|||||||
249
zml/meta.zig
249
zml/meta.zig
@ -1,4 +1,8 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const FnParam = stdx.meta.FnParam;
|
||||||
|
const asSlice = stdx.meta.asSlice;
|
||||||
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
@ -6,215 +10,6 @@ test {
|
|||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes floating point value division between two integers.
|
|
||||||
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
|
|
||||||
return toFloat(T, numerator) / toFloat(T, denominator);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn toFloat(T: type, x: anytype) T {
|
|
||||||
return switch (@typeInfo(@TypeOf(x))) {
|
|
||||||
.Float => @floatCast(x),
|
|
||||||
else => @floatFromInt(x),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn guard(check: bool, src: std.builtin.SourceLocation) void {
|
|
||||||
assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line });
|
|
||||||
}
|
|
||||||
|
|
||||||
pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void {
|
|
||||||
assert(check, "ZML internal error: " ++ msg, args);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn assert(check: bool, comptime msg: []const u8, args: anytype) void {
|
|
||||||
if (!check) panic(msg, args);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn panic(comptime msg: []const u8, args: anytype) noreturn {
|
|
||||||
std.log.err(msg, args);
|
|
||||||
@panic(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compileLog(comptime msg: []const u8, comptime args: anytype) void {
|
|
||||||
@compileLog(std.fmt.comptimePrint(msg, args));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn {
|
|
||||||
@compileError(std.fmt.comptimePrint(msg, args));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void {
|
|
||||||
if (check == false) {
|
|
||||||
compileError(msg, args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isStruct(comptime T: type) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Struct => true,
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isTuple(comptime T: type) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Struct => |info| info.is_tuple,
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Struct => |info| blk: {
|
|
||||||
inline for (info.fields) |field| {
|
|
||||||
if (field.type != Elem) {
|
|
||||||
break :blk false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break :blk true;
|
|
||||||
},
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Struct => |info| blk: {
|
|
||||||
inline for (info.fields) |field| {
|
|
||||||
if (f(field.type) == false) {
|
|
||||||
break :blk false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break :blk true;
|
|
||||||
},
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isTupleOf(comptime T: type, comptime Elem: type) bool {
|
|
||||||
return isTuple(T) and isStructOf(T, Elem);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
|
||||||
return isTuple(T) and isStructOfAny(T, f);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Pointer => |info| switch (info.size) {
|
|
||||||
.Slice => info.child == Elem,
|
|
||||||
.One => switch (@typeInfo(info.child)) {
|
|
||||||
// As Zig, convert pointer to Array as a slice.
|
|
||||||
.Array => |arr_info| arr_info.child == Elem,
|
|
||||||
else => false,
|
|
||||||
},
|
|
||||||
else => false,
|
|
||||||
},
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn asSlice(comptime T: type) type {
|
|
||||||
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Pointer => |info| switch (info.size) {
|
|
||||||
.Slice => info.child,
|
|
||||||
.One => switch (@typeInfo(info.child)) {
|
|
||||||
// As Zig, convert pointer to Array as a slice.
|
|
||||||
.Array => |arr_info| arr_info.child,
|
|
||||||
else => @compileError(err_msg),
|
|
||||||
},
|
|
||||||
else => @compileError(err_msg),
|
|
||||||
},
|
|
||||||
else => @compileError(err_msg),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isInteger(comptime T: type) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Int, .ComptimeInt => true,
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Pointer => |info| info.size == .Slice and f(info.child),
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn DeclEnum(comptime T: type) type {
|
|
||||||
const field_infos = std.meta.declarations(T);
|
|
||||||
if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T});
|
|
||||||
return std.meta.DeclEnum(UnwrapPtr(T));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn UnwrapPtr(comptime T: type) type {
|
|
||||||
return switch (@typeInfo(T)) {
|
|
||||||
.Pointer => |info| switch (info.size) {
|
|
||||||
.One => info.child,
|
|
||||||
else => T,
|
|
||||||
},
|
|
||||||
else => T,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn FnParam(func: anytype, n: comptime_int) type {
|
|
||||||
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse @compileError("anytype not supported in callbacks");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn FnParams(func: anytype) type {
|
|
||||||
return std.meta.ArgsTuple(@TypeOf(func));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn FnResult(func: anytype) type {
|
|
||||||
return @typeInfo(@TypeOf(func)).Fn.return_type.?;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn FnResultPayload(func: anytype) type {
|
|
||||||
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
|
|
||||||
const payload_type = switch (@typeInfo(return_type)) {
|
|
||||||
.ErrorUnion => |u| u.payload,
|
|
||||||
else => return_type,
|
|
||||||
};
|
|
||||||
return payload_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn FnResultErrorSet(func: anytype) ?type {
|
|
||||||
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
|
|
||||||
const error_set = switch (@typeInfo(return_type)) {
|
|
||||||
.ErrorUnion => |u| u.error_set,
|
|
||||||
else => null,
|
|
||||||
};
|
|
||||||
return error_set;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn Signature(comptime func: anytype, comptime argsT: ?type) type {
|
|
||||||
return struct {
|
|
||||||
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
|
|
||||||
pub const ArgsT = blk: {
|
|
||||||
if (@typeInfo(FuncT).Fn.params.len == 0) {
|
|
||||||
break :blk @TypeOf(.{});
|
|
||||||
}
|
|
||||||
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
|
|
||||||
};
|
|
||||||
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
|
|
||||||
pub const ReturnPayloadT = blk: {
|
|
||||||
break :blk switch (@typeInfo(ReturnT)) {
|
|
||||||
.ErrorUnion => |u| u.payload,
|
|
||||||
else => ReturnT,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
pub const ReturnErrorSet: ?type = blk: {
|
|
||||||
break :blk switch (@typeInfo(ReturnT)) {
|
|
||||||
.ErrorUnion => |u| u.error_set,
|
|
||||||
else => null,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn MapType(From: type, To: type) type {
|
pub fn MapType(From: type, To: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn map(T: type) type {
|
pub fn map(T: type) type {
|
||||||
@ -299,7 +94,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
|
|
||||||
const type_info_to_ptr = @typeInfo(@TypeOf(to));
|
const type_info_to_ptr = @typeInfo(@TypeOf(to));
|
||||||
if (type_info_to_ptr != .Pointer) {
|
if (type_info_to_ptr != .Pointer) {
|
||||||
@compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to)));
|
stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to)));
|
||||||
}
|
}
|
||||||
const ToStruct = type_info_to_ptr.Pointer.child;
|
const ToStruct = type_info_to_ptr.Pointer.child;
|
||||||
const type_info_to = @typeInfo(ToStruct);
|
const type_info_to = @typeInfo(ToStruct);
|
||||||
@ -348,7 +143,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
} else if (field.default_value) |_| {
|
} else if (field.default_value) |_| {
|
||||||
@field(to, field.name) = null;
|
@field(to, field.name) = null;
|
||||||
} else {
|
} else {
|
||||||
compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
|
stdx.meta.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
|
||||||
},
|
},
|
||||||
else => @field(to, field.name) = @field(from, field.name),
|
else => @field(to, field.name) = @field(from, field.name),
|
||||||
}
|
}
|
||||||
@ -374,7 +169,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
}
|
}
|
||||||
to.* = items;
|
to.* = items;
|
||||||
},
|
},
|
||||||
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
|
else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
|
||||||
},
|
},
|
||||||
.Optional => if (from) |f| {
|
.Optional => if (from) |f| {
|
||||||
to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined);
|
to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined);
|
||||||
@ -383,7 +178,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
to.* = null;
|
to.* = null;
|
||||||
},
|
},
|
||||||
.Int, .Float => to.* = from,
|
.Int, .Float => to.* = from,
|
||||||
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
|
else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,12 +239,12 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
const type_info_v = @typeInfo(T);
|
const type_info_v = @typeInfo(T);
|
||||||
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
||||||
.Pointer => |info| info.child,
|
.Pointer => |info| info.child,
|
||||||
else => @compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))),
|
else => stdx.meta.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (type_info_v != .Pointer) {
|
if (type_info_v != .Pointer) {
|
||||||
const Callback = @TypeOf(cb);
|
const Callback = @TypeOf(cb);
|
||||||
@compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
|
stdx.meta.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
|
||||||
}
|
}
|
||||||
const ptr_info = type_info_v.Pointer;
|
const ptr_info = type_info_v.Pointer;
|
||||||
if (@typeInfo(ptr_info.child) == .Fn) return;
|
if (@typeInfo(ptr_info.child) == .Fn) return;
|
||||||
@ -512,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
else => @compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)),
|
else => stdx.meta.compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,10 +396,8 @@ test visit {
|
|||||||
/// Only T elements of values will be looked at.
|
/// Only T elements of values will be looked at.
|
||||||
/// This only works for simple types, in particular `zip` doesn't follow pointers.
|
/// This only works for simple types, in particular `zip` doesn't follow pointers.
|
||||||
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
|
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
|
||||||
pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
|
pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
|
||||||
const sliceT = @typeInfo(FnParam(func, 0));
|
const sliceT = @typeInfo(FnParam(func, 0));
|
||||||
assertComptime(sliceT == .Pointer and sliceT.Pointer.size == .Slice and sliceT.Pointer.child == FnResult(func), "zip requires a `fn([]const T, Args) T`, received: {}", .{@TypeOf(func)});
|
|
||||||
|
|
||||||
const T = sliceT.Pointer.child;
|
const T = sliceT.Pointer.child;
|
||||||
const V = asSlice(@TypeOf(values));
|
const V = asSlice(@TypeOf(values));
|
||||||
if (V == T) {
|
if (V == T) {
|
||||||
@ -613,13 +406,13 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
|
|||||||
// const fn_args
|
// const fn_args
|
||||||
|
|
||||||
return switch (@typeInfo(V)) {
|
return switch (@typeInfo(V)) {
|
||||||
.Pointer => @compileError("zip only accept by value arguments. Received: " ++ @typeName(V)),
|
.Pointer => stdx.meta.compileError("zip only accept by value arguments. Received: " ++ @typeName(V)),
|
||||||
.Struct => |struct_info| {
|
.Struct => |struct_info| {
|
||||||
var out: V = values[0];
|
var out: V = values[0];
|
||||||
inline for (struct_info.fields) |f| {
|
inline for (struct_info.fields) |f| {
|
||||||
if (f.is_comptime) continue;
|
if (f.is_comptime) continue;
|
||||||
if (@typeInfo(f.type) == .Pointer) {
|
if (@typeInfo(f.type) == .Pointer) {
|
||||||
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
|
stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
|
||||||
}
|
}
|
||||||
var fields = try allocator.alloc(f.type, values.len);
|
var fields = try allocator.alloc(f.type, values.len);
|
||||||
defer allocator.free(fields);
|
defer allocator.free(fields);
|
||||||
@ -632,7 +425,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
|
|||||||
},
|
},
|
||||||
.Array => |arr_info| {
|
.Array => |arr_info| {
|
||||||
if (@typeInfo(arr_info.child) == .Pointer) {
|
if (@typeInfo(arr_info.child) == .Pointer) {
|
||||||
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
|
stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
|
||||||
}
|
}
|
||||||
var out: V = undefined;
|
var out: V = undefined;
|
||||||
var slice = try allocator.alloc(arr_info.child, values.len);
|
var slice = try allocator.alloc(arr_info.child, values.len);
|
||||||
@ -645,7 +438,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
|
|||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
},
|
},
|
||||||
.Union, .Optional => @compileError("zip doesn't yet support " ++ @typeName(V)),
|
.Union, .Optional => stdx.meta.compileError("zip doesn't yet support " ++ @typeName(V)),
|
||||||
else => values[0],
|
else => values[0],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -668,11 +461,11 @@ test zip {
|
|||||||
|
|
||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void {
|
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
||||||
assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
func_ctx: _CollectCtx(func),
|
func_ctx: _CollectCtx(func),
|
||||||
out: *std.ArrayList(FnResult(func)),
|
out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT),
|
||||||
oom: bool = false,
|
oom: bool = false,
|
||||||
};
|
};
|
||||||
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
||||||
@ -691,10 +484,10 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(F
|
|||||||
fn _CollectCtx(func: anytype) type {
|
fn _CollectCtx(func: anytype) type {
|
||||||
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
||||||
if (params.len == 1) return void;
|
if (params.len == 1) return void;
|
||||||
return params[0].type orelse @compileError("anytype not supported in collect");
|
return params[0].type orelse stdx.meta.compileError("anytype not supported in collect");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _CollectArg(func: anytype) type {
|
fn _CollectArg(func: anytype) type {
|
||||||
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
||||||
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
return params[params.len - 1].type orelse stdx.meta.compileError("anytype not supported in collect");
|
||||||
}
|
}
|
||||||
|
|||||||
10
zml/mlir.zig
10
zml/mlir.zig
@ -2,14 +2,14 @@ const mlir = @This();
|
|||||||
|
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const dtype = @import("dtype.zig");
|
const dtype = @import("dtype.zig");
|
||||||
const meta = @import("meta.zig");
|
|
||||||
|
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_mlir);
|
const log = std.log.scoped(.@"zml/mlir");
|
||||||
|
|
||||||
pub usingnamespace @import("mlir");
|
pub usingnamespace @import("mlir");
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ pub const ext = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
|
stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ pub const ext = struct {
|
|||||||
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
|
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
|
||||||
return int_attr.as(mlir.Attribute).?;
|
return int_attr.as(mlir.Attribute).?;
|
||||||
},
|
},
|
||||||
inline else => |_, tag| meta.panic("Unsupported data type: {any}", .{tag}),
|
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -169,7 +169,7 @@ pub const ext = struct {
|
|||||||
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||||
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||||
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
||||||
inline else => |tag| meta.panic("Unsupported data type: {any}", .{tag}),
|
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
132
zml/module.zig
132
zml/module.zig
@ -1,32 +1,31 @@
|
|||||||
|
const asynk = @import("async");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const dialect = @import("mlir/dialects");
|
||||||
|
const protobuf = @import("io/protobuf");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
const xla_pb = @import("//xla:xla_proto");
|
const xla_pb = @import("//xla:xla_proto");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const ops = @import("ops.zig");
|
const ops = @import("ops.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
const protobuf = @import("io/protobuf");
|
|
||||||
const asynk = @import("async");
|
|
||||||
const aio = @import("aio.zig");
|
const aio = @import("aio.zig");
|
||||||
|
|
||||||
const dialect = @import("mlir/dialects");
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
|
const Bufferized = @import("tensor.zig").Bufferized;
|
||||||
const assert = std.debug.assert;
|
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Location = mlir.Location;
|
const Location = mlir.Location;
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
|
||||||
const Shape = @import("shape.zig").Shape;
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
|
||||||
const Bufferized = @import("tensor.zig").Bufferized;
|
|
||||||
const Tracer = @import("tools/tracer.zig").Tracer;
|
const Tracer = @import("tools/tracer.zig").Tracer;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_module);
|
const assert = std.debug.assert;
|
||||||
|
const log = std.log.scoped(.@"zml/module");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -101,7 +100,7 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn deactivate(self: *CompilationContext) void {
|
pub fn deactivate(self: *CompilationContext) void {
|
||||||
std.debug.assert(_current != null and _current.? == self);
|
assert(_current != null and _current.? == self);
|
||||||
_current = self._previous;
|
_current = self._previous;
|
||||||
self._previous = null;
|
self._previous = null;
|
||||||
}
|
}
|
||||||
@ -163,7 +162,7 @@ pub const CompilationContext = struct {
|
|||||||
// So we create a copy of the arguments, and replace values
|
// So we create a copy of the arguments, and replace values
|
||||||
// by the block arguments.
|
// by the block arguments.
|
||||||
var blk_args = args;
|
var blk_args = args;
|
||||||
assert(assignBlockArguments(&blk_args, block, 0) == N);
|
std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N);
|
||||||
|
|
||||||
const loc = self.mlirCtx().location(@src());
|
const loc = self.mlirCtx().location(@src());
|
||||||
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
|
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
|
||||||
@ -209,9 +208,9 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
|
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
|
||||||
meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable;
|
meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable;
|
||||||
meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{});
|
stdx.debug.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{});
|
||||||
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
||||||
meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||||
|
|
||||||
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
||||||
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
||||||
@ -311,7 +310,7 @@ pub const CompilationContext = struct {
|
|||||||
// This will break the day we writer another attribute before donation.
|
// This will break the day we writer another attribute before donation.
|
||||||
// When the time come, do a more fancy lookup here to check if an argument
|
// When the time come, do a more fancy lookup here to check if an argument
|
||||||
// is donated twice.
|
// is donated twice.
|
||||||
meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
|
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
|
||||||
attributes[a].appendAssumeCapacity(
|
attributes[a].appendAssumeCapacity(
|
||||||
mlir.NamedAttribute.init(
|
mlir.NamedAttribute.init(
|
||||||
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
||||||
@ -507,7 +506,7 @@ pub const CompilationContext = struct {
|
|||||||
extractValues(args, values[function.n_model..]);
|
extractValues(args, values[function.n_model..]);
|
||||||
|
|
||||||
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
||||||
var res: meta.FnResult(func) = undefined;
|
var res: stdx.meta.FnResult(func) = undefined;
|
||||||
assignResults(&res, function.res_shapes, op);
|
assignResults(&res, function.res_shapes, op);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -531,7 +530,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
|
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
|
||||||
if (res.found_existing) {
|
if (res.found_existing) {
|
||||||
std.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id });
|
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id });
|
||||||
} else {
|
} else {
|
||||||
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
|
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
|
||||||
}
|
}
|
||||||
@ -677,9 +676,9 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3
|
|||||||
};
|
};
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
||||||
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
||||||
const model_sharding = ctx.buffers.len;
|
const model_sharding = ctx.buffers.len;
|
||||||
meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
||||||
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
||||||
ctx.buffers[d][ctx.index] = shard;
|
ctx.buffers[d][ctx.index] = shard;
|
||||||
}
|
}
|
||||||
@ -718,7 +717,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
|
|||||||
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
|
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
|
||||||
}
|
}
|
||||||
}).cb, &local_ctx, v);
|
}).cb, &local_ctx, v);
|
||||||
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
stdx.debug.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and assign op results to each tensor found.
|
/// Visit the given struct and assign op results to each tensor found.
|
||||||
@ -761,6 +760,13 @@ const BaseExe = struct {
|
|||||||
num_devices: u8,
|
num_devices: u8,
|
||||||
/// Allocator backing result_buffer_shapes and deinit by ExeWithWeights
|
/// Allocator backing result_buffer_shapes and deinit by ExeWithWeights
|
||||||
_allocator: std.heap.ArenaAllocator,
|
_allocator: std.heap.ArenaAllocator,
|
||||||
|
|
||||||
|
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
||||||
|
var executable = try self.exe.getExecutable(self.pjrt_api);
|
||||||
|
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||||
|
defer serialize_result.deinit();
|
||||||
|
try writer.writeAll(serialize_result.bytes);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents a ZML model, compiled into a PJRT executable.
|
/// Represents a ZML model, compiled into a PJRT executable.
|
||||||
@ -779,6 +785,16 @@ pub fn Exe(comptime func: anytype) type {
|
|||||||
pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) {
|
pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) {
|
||||||
return ExeWithWeights(func).initFromModel(allocator, self.inner, model);
|
return ExeWithWeights(func).initFromModel(allocator, self.inner, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn serialize(self: Self, writer: anytype) !void {
|
||||||
|
return try self.inner.serialize(writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self {
|
||||||
|
// const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size);
|
||||||
|
// defer allocator.free(bytes);
|
||||||
|
// return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes);
|
||||||
|
// }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -906,7 +922,7 @@ fn compileInternal(
|
|||||||
var timer = std.time.Timer.start() catch null;
|
var timer = std.time.Timer.start() catch null;
|
||||||
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
||||||
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||||
const f = try asynk.callBlockingGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args });
|
const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args });
|
||||||
context._module.getBody().appendOperation(f.mlir_fn);
|
context._module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
const sharding = context._platform.sharding();
|
const sharding = context._platform.sharding();
|
||||||
@ -927,7 +943,7 @@ fn compileInternal(
|
|||||||
|
|
||||||
if (timer) |*t| {
|
if (timer) |*t| {
|
||||||
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
|
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
|
||||||
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)});
|
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloor(f32, time_ms, 1000)});
|
||||||
}
|
}
|
||||||
|
|
||||||
var arena_state_exe = std.heap.ArenaAllocator.init(allocator);
|
var arena_state_exe = std.heap.ArenaAllocator.init(allocator);
|
||||||
@ -945,12 +961,7 @@ fn compileInternal(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
pub fn load(
|
||||||
/// The steps are:
|
|
||||||
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
|
||||||
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
|
||||||
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
|
||||||
pub fn compile(
|
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
init_args: anytype,
|
init_args: anytype,
|
||||||
@ -973,33 +984,62 @@ pub fn compile(
|
|||||||
return compileModel(allocator, model, func, args_shapes, platform);
|
return compileModel(allocator, model, func, args_shapes, platform);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
||||||
|
/// The steps are:
|
||||||
|
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
||||||
|
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
||||||
|
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
||||||
|
pub fn compile(
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
comptime func: anytype,
|
||||||
|
init_args: anytype,
|
||||||
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||||
|
buffer_store: aio.BufferStore,
|
||||||
|
platform: Platform,
|
||||||
|
) !Exe(func) {
|
||||||
|
const ModelT = ModuleSignature(func).ModelT;
|
||||||
|
|
||||||
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
defer arena_state.deinit();
|
||||||
|
const arena = arena_state.allocator();
|
||||||
|
var model = try aio.populateModel(ModelT, arena, buffer_store);
|
||||||
|
|
||||||
|
// If the Model has a "init" function, call it with the given parameters.
|
||||||
|
if (@hasDecl(ModelT, "init")) {
|
||||||
|
// TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void.
|
||||||
|
@call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
return compileModel(allocator, func, model, args_shapes, platform);
|
||||||
|
}
|
||||||
|
|
||||||
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
||||||
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
||||||
pub fn compileModel(
|
pub fn compileModel(
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
model: anytype,
|
comptime func: anytype,
|
||||||
comptime func: @TypeOf(.literal),
|
model: ModuleSignature(func).ModelT,
|
||||||
args_shapes: ShapeOf(ModuleSignature(@field(@TypeOf(model), @tagName(func))).ArgsT),
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
) !Exe(@field(@TypeOf(model), @tagName(func))) {
|
) !Exe(func) {
|
||||||
const Model = @TypeOf(model);
|
const ModelT = ModuleSignature(func).ModelT;
|
||||||
const name = @typeName(Model) ++ "." ++ @tagName(func);
|
const name = @typeName(ModelT) ++ ".forward";
|
||||||
log.info("Compiling {s} with {}", .{ name, args_shapes });
|
log.info("Compiling {s} with {}", .{ name, args_shapes });
|
||||||
|
|
||||||
var context = try CompilationContext.init(allocator, name, platform);
|
var context = try CompilationContext.init(allocator, name, platform);
|
||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
|
|
||||||
const raw_module = try compileInternal(allocator, &context, @field(Model, @tagName(func)), model, args_shapes);
|
const raw_module = try compileInternal(allocator, &context, func, model, args_shapes);
|
||||||
|
|
||||||
return Exe(@field(Model, @tagName(func))){ .inner = raw_module };
|
return .{ .inner = raw_module };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compiles a function with the given configuration and shapes, for the given platform.
|
/// Compiles a function with the given configuration and shapes, for the given platform.
|
||||||
/// Generate MLIR by calling the given function with tensor of the given shapes.
|
/// Generate MLIR by calling the given function with tensor of the given shapes.
|
||||||
pub fn compileFn(
|
pub fn compileFn(
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
func: anytype,
|
comptime func: anytype,
|
||||||
args: ShapeOf(meta.FnParams(func)),
|
args: ShapeOf(stdx.meta.FnArgs(func)),
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
) !ExeWithWeights(FnWithVoidArg(func)) {
|
) !ExeWithWeights(FnWithVoidArg(func)) {
|
||||||
const name = @typeName(@TypeOf(func));
|
const name = @typeName(@TypeOf(func));
|
||||||
@ -1008,7 +1048,7 @@ pub fn compileFn(
|
|||||||
|
|
||||||
const Local = struct {
|
const Local = struct {
|
||||||
// This is the function we will actually compile.
|
// This is the function we will actually compile.
|
||||||
pub fn forward(_: void, inner_args: meta.FnParams(func)) meta.FnResult(func) {
|
pub fn forward(_: void, inner_args: stdx.meta.FnArgs(func)) stdx.meta.FnResult(func) {
|
||||||
return @call(.auto, func, inner_args);
|
return @call(.auto, func, inner_args);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1019,10 +1059,10 @@ pub fn compileFn(
|
|||||||
return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model);
|
return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn FnWithVoidArg(func: anytype) type {
|
fn FnWithVoidArg(comptime func: anytype) type {
|
||||||
const fn_info = @typeInfo(@TypeOf(func)).Fn;
|
const fn_info = @typeInfo(@TypeOf(func)).Fn;
|
||||||
const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void };
|
const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void };
|
||||||
meta.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)});
|
||||||
return @Type(.{ .Fn = .{
|
return @Type(.{ .Fn = .{
|
||||||
.calling_convention = fn_info.calling_convention,
|
.calling_convention = fn_info.calling_convention,
|
||||||
.is_generic = false,
|
.is_generic = false,
|
||||||
@ -1268,10 +1308,10 @@ pub fn ModuleSignature(comptime func: anytype) Sign {
|
|||||||
const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
|
const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
|
||||||
return .{
|
return .{
|
||||||
.FuncT = FuncT,
|
.FuncT = FuncT,
|
||||||
.ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create,ModuleSignature for function with an 'anytype' parameter"),
|
.ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create ModuleSignature for function with an 'anytype' parameter"),
|
||||||
.ArgsT = blk: {
|
.ArgsT = blk: {
|
||||||
const function_info = @typeInfo(FuncT);
|
const function_info = @typeInfo(FuncT);
|
||||||
if (function_info.Fn.params[1..].len == 0) {
|
if (function_info.Fn.params.len < 2) {
|
||||||
break :blk @TypeOf(.{});
|
break :blk @TypeOf(.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
50
zml/nn.zig
50
zml/nn.zig
@ -1,20 +1,20 @@
|
|||||||
//! Common layer definition and functions for Neural Networks (NN)
|
//! Common layer definition and functions for Neural Networks (NN)
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const assert = std.debug.assert;
|
const stdx = @import("stdx");
|
||||||
const testing = std.testing;
|
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
const cuda = @import("nn/cuda.zig");
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const helpers = @import("helpers.zig");
|
const helpers = @import("helpers.zig");
|
||||||
|
const meta = @import("meta.zig");
|
||||||
const ops = @import("ops.zig");
|
const ops = @import("ops.zig");
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_tensor);
|
const assert = std.debug.assert;
|
||||||
|
const log = std.log.scoped(.@"zml/tensor");
|
||||||
const cuda = @import("nn/cuda.zig");
|
const testing = std.testing;
|
||||||
|
|
||||||
test {
|
test {
|
||||||
_ = cuda;
|
_ = cuda;
|
||||||
@ -41,8 +41,8 @@ pub const TokenEmbedding = struct {
|
|||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
|
|
||||||
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
|
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
|
||||||
meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
|
stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
|
||||||
meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
|
stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
|
||||||
return self.weight.gatherValues(0, idx, .{});
|
return self.weight.gatherValues(0, idx, .{});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -159,7 +159,7 @@ pub const CosSin = [2]Tensor;
|
|||||||
/// See: https://paperswithcode.com/method/rope
|
/// See: https://paperswithcode.com/method/rope
|
||||||
pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
|
pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
|
||||||
const cos, const sin = cos_sin_cache;
|
const cos, const sin = cos_sin_cache;
|
||||||
meta.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin });
|
stdx.debug.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin });
|
||||||
// broadcast cos / sin to .{ batch, .seq, .half_dim }
|
// broadcast cos / sin to .{ batch, .seq, .half_dim }
|
||||||
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
||||||
const has_tags = cos.shape().tag(0) != Shape.TagUnknown;
|
const has_tags = cos.shape().tag(0) != Shape.TagUnknown;
|
||||||
@ -178,9 +178,9 @@ pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
|
|||||||
|
|
||||||
pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin {
|
pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin {
|
||||||
const shape = Shape.init(sh, dtype);
|
const shape = Shape.init(sh, dtype);
|
||||||
meta.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape});
|
stdx.debug.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape});
|
||||||
const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) };
|
const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) };
|
||||||
meta.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim});
|
stdx.debug.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim});
|
||||||
|
|
||||||
// compute sin and cos in f32 before downcasting to x type.
|
// compute sin and cos in f32 before downcasting to x type.
|
||||||
const inv_freq = invFreq(head_dim, opts.freq_base, .f32);
|
const inv_freq = invFreq(head_dim, opts.freq_base, .f32);
|
||||||
@ -364,8 +364,8 @@ pub fn upsample(
|
|||||||
) Tensor {
|
) Tensor {
|
||||||
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
|
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
|
||||||
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
|
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
|
||||||
meta.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input});
|
stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input});
|
||||||
meta.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
|
stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
|
||||||
return switch (opts.mode) {
|
return switch (opts.mode) {
|
||||||
.nearest => {
|
.nearest => {
|
||||||
var scale_factors: [3]f64 = undefined;
|
var scale_factors: [3]f64 = undefined;
|
||||||
@ -398,7 +398,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor {
|
|||||||
var res = input;
|
var res = input;
|
||||||
for (spatial_dims) |d| {
|
for (spatial_dims) |d| {
|
||||||
const n = out_shape.dim(d);
|
const n = out_shape.dim(d);
|
||||||
const ratio = meta.divFloat(f32, input.dim(d), n);
|
const ratio = stdx.math.divFloor(f32, input.dim(d), n);
|
||||||
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
|
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
|
||||||
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
|
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
|
||||||
}
|
}
|
||||||
@ -576,7 +576,7 @@ pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Te
|
|||||||
|
|
||||||
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
||||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
||||||
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
|
const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len));
|
||||||
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
||||||
const left = scaled.floor();
|
const left = scaled.floor();
|
||||||
const right = left.addConstant(1);
|
const right = left.addConstant(1);
|
||||||
@ -638,7 +638,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
|
|||||||
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
||||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
||||||
|
|
||||||
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
|
const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len));
|
||||||
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
||||||
const t = scaled.sub(scaled.floor());
|
const t = scaled.sub(scaled.floor());
|
||||||
const pos = Tensor.stack(&.{
|
const pos = Tensor.stack(&.{
|
||||||
@ -693,11 +693,11 @@ pub fn causalAttnMask(
|
|||||||
attn_window_len: ?u32,
|
attn_window_len: ?u32,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
const attn_shape = Shape.init(attn_shape_, dtype);
|
const attn_shape = Shape.init(attn_shape_, dtype);
|
||||||
meta.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape});
|
stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape});
|
||||||
const qlen = attn_shape.dim(-2);
|
const qlen = attn_shape.dim(-2);
|
||||||
const q_idx = Tensor.iota(attn_shape, .i32, -2);
|
const q_idx = Tensor.iota(attn_shape, -2);
|
||||||
const klen = attn_shape.dim(-1);
|
const klen = attn_shape.dim(-1);
|
||||||
const k_idx = Tensor.iota(attn_shape, .i32, -1);
|
const k_idx = Tensor.iota(attn_shape, -1);
|
||||||
|
|
||||||
// all elements > main diagonal must be 0
|
// all elements > main diagonal must be 0
|
||||||
// (q_idx - window_len < k_idx <= q_idx)
|
// (q_idx - window_len < k_idx <= q_idx)
|
||||||
@ -748,16 +748,16 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
|
|
||||||
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
|
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
|
||||||
const err_args = .{ q, k, v, opts.attn_mask };
|
const err_args = .{ q, k, v, opts.attn_mask };
|
||||||
meta.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
||||||
meta.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
meta.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
|
|
||||||
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) {
|
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) {
|
||||||
return cuda.sdpa(q, k, v, opts);
|
return cuda.sdpa(q, k, v, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (q.dim(.h) != k.dim(.h)) {
|
if (q.dim(.h) != k.dim(.h)) {
|
||||||
meta.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
|
stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
|
||||||
// Note: we don't try to repeat queries.
|
// Note: we don't try to repeat queries.
|
||||||
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
|
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
|
||||||
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
|
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
|
||||||
@ -766,7 +766,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
const attn_mask = if (opts.attn_mask) |m| m else null;
|
const attn_mask = if (opts.attn_mask) |m| m else null;
|
||||||
|
|
||||||
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
|
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
|
||||||
meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
|
stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
|
||||||
};
|
};
|
||||||
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
|
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
|
||||||
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
|
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
|
||||||
|
|||||||
28
zml/ops.zig
28
zml/ops.zig
@ -1,11 +1,13 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const mlir = @import("mlir.zig");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const buffer = @import("buffer.zig");
|
||||||
const helpers = @import("helpers.zig");
|
const helpers = @import("helpers.zig");
|
||||||
const module = @import("module.zig");
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
|
const mlir = @import("mlir.zig");
|
||||||
|
const module = @import("module.zig");
|
||||||
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = buffer.Buffer;
|
||||||
const CompilationContext = module.CompilationContext;
|
const CompilationContext = module.CompilationContext;
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
@ -20,14 +22,14 @@ const dialect = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.@"zml/tensor");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate an MLIR call to the given member function with the given tensors.
|
/// Generate an MLIR call to the given member function with the given tensors.
|
||||||
pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
|
pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
|
||||||
// TODO: this should use `self.getContext().callFunc(self, args)`
|
// TODO: this should use `self.getContext().callFunc(self, args)`
|
||||||
|
|
||||||
return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args);
|
return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args);
|
||||||
@ -121,8 +123,8 @@ test "simple while" {
|
|||||||
|
|
||||||
pub fn reduce(
|
pub fn reduce(
|
||||||
comptime body_fn: anytype,
|
comptime body_fn: anytype,
|
||||||
inputs: meta.FnParam(body_fn, 0),
|
inputs: stdx.meta.FnParam(body_fn, 0),
|
||||||
inits: meta.FnParam(body_fn, 0),
|
inits: stdx.meta.FnParam(body_fn, 0),
|
||||||
axes: []const i64,
|
axes: []const i64,
|
||||||
) BlockSignNoCtx(body_fn).Return {
|
) BlockSignNoCtx(body_fn).Return {
|
||||||
// TODO: actualAxes
|
// TODO: actualAxes
|
||||||
@ -155,7 +157,7 @@ pub fn reduce(
|
|||||||
|
|
||||||
// `stablehlo.reduce` drops axes. We want to avoid that to propagate tags.
|
// `stablehlo.reduce` drops axes. We want to avoid that to propagate tags.
|
||||||
// So we need to broadcast the output of `stablehlo.reduce` to the input shapes.
|
// So we need to broadcast the output of `stablehlo.reduce` to the input shapes.
|
||||||
// To that order, we initialize `result` to `inputs`, then we use meta.visit,
|
// To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
|
||||||
// to find the correct mlir.Value, but we first broadcast before creating the final
|
// to find the correct mlir.Value, but we first broadcast before creating the final
|
||||||
// Tensor struct.
|
// Tensor struct.
|
||||||
var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
||||||
@ -217,10 +219,10 @@ pub const ReduceWindowOpts = struct {
|
|||||||
|
|
||||||
pub fn reduceWindow(
|
pub fn reduceWindow(
|
||||||
comptime body_fn: anytype,
|
comptime body_fn: anytype,
|
||||||
inputs: meta.FnParam(body_fn, 0),
|
inputs: stdx.meta.FnParam(body_fn, 0),
|
||||||
inits: meta.FnParam(body_fn, 0),
|
inits: stdx.meta.FnParam(body_fn, 0),
|
||||||
opts: ReduceWindowOpts,
|
opts: ReduceWindowOpts,
|
||||||
) meta.FnResult(body_fn) {
|
) stdx.meta.FnResult(body_fn) {
|
||||||
const BodyS = comptime BlockSignNoCtx(body_fn);
|
const BodyS = comptime BlockSignNoCtx(body_fn);
|
||||||
comptime {
|
comptime {
|
||||||
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
|
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
|
||||||
@ -263,7 +265,7 @@ pub fn reduceWindow(
|
|||||||
pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: anytype) BlockSign(func).Return {
|
pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: anytype) BlockSign(func).Return {
|
||||||
const num_steps: u32, const step_tag = blk: {
|
const num_steps: u32, const step_tag = blk: {
|
||||||
const dims, const tags = Shape.parseDimensions(num_steps_);
|
const dims, const tags = Shape.parseDimensions(num_steps_);
|
||||||
meta.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_});
|
stdx.debug.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_});
|
||||||
break :blk .{ @intCast(dims.get(0)), tags.get(0) };
|
break :blk .{ @intCast(dims.get(0)), tags.get(0) };
|
||||||
};
|
};
|
||||||
const S = comptime BlockSign(func);
|
const S = comptime BlockSign(func);
|
||||||
@ -290,7 +292,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn updateResBuffer(inputs: []const Tensor, idx: Tensor) Tensor {
|
fn updateResBuffer(inputs: []const Tensor, idx: Tensor) Tensor {
|
||||||
meta.internalAssert(inputs.len == 2, "too many tensors", .{});
|
stdx.debug.internalAssert(inputs.len == 2, "too many tensors", .{});
|
||||||
const res, const step_res = inputs[0..2].*;
|
const res, const step_res = inputs[0..2].*;
|
||||||
return res.dynamicUpdateSlice1d(step_res.insertAxes(0, .{._}), 0, idx);
|
return res.dynamicUpdateSlice1d(step_res.insertAxes(0, .{._}), 0, idx);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,22 +1,21 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
const dialects = @import("mlir/dialects");
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
|
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const dtype = @import("dtype.zig");
|
const dtype = @import("dtype.zig");
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const dialects = @import("mlir/dialects");
|
|
||||||
|
|
||||||
pub const Profiler = pjrt.Profiler;
|
|
||||||
pub const ApiError = pjrt.ApiError;
|
|
||||||
pub const ErrorCode = pjrt.ErrorCode;
|
|
||||||
|
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
|
pub const Profiler = pjrt.Profiler;
|
||||||
|
pub const ApiError = pjrt.ApiError;
|
||||||
|
pub const ErrorCode = pjrt.ErrorCode;
|
||||||
pub const Buffer = pjrt.Buffer;
|
pub const Buffer = pjrt.Buffer;
|
||||||
pub const BufferType = pjrt.BufferType;
|
pub const BufferType = pjrt.BufferType;
|
||||||
pub const Device = pjrt.Device;
|
pub const Device = pjrt.Device;
|
||||||
@ -181,14 +180,16 @@ pub const LoadedExecutable = opaque {
|
|||||||
return self.inner().getAddressableDevices(api);
|
return self.inner().getAddressableDevices(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
pub const ExecuteArgs = struct {
|
||||||
arguments: []const [*]const *const Buffer,
|
arguments: []const [*]const *const Buffer,
|
||||||
num_args: usize,
|
num_args: usize,
|
||||||
results: []const [*]*Buffer,
|
results: []const [*]*Buffer,
|
||||||
events: []?*Event,
|
events: []?*Event,
|
||||||
non_donatable_input_indices: []const i64 = &.{},
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
}) ExecuteError!void {
|
};
|
||||||
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, .{
|
|
||||||
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
||||||
|
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
|
||||||
.num_args = args.num_args,
|
.num_args = args.num_args,
|
||||||
.arguments = @ptrCast(args.arguments),
|
.arguments = @ptrCast(args.arguments),
|
||||||
.results = @ptrCast(args.results),
|
.results = @ptrCast(args.results),
|
||||||
|
|||||||
@ -1,28 +1,18 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const builtin = @import("builtin");
|
||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
pub const Target = runtimes.Platform;
|
pub const Target = runtimes.Platform;
|
||||||
|
|
||||||
pub const available_targets = switch (builtin.os.tag) {
|
pub const available_targets = std.enums.values(Target);
|
||||||
.macos => [_]Target{
|
|
||||||
.cpu,
|
|
||||||
},
|
|
||||||
.linux => [_]Target{
|
|
||||||
.cpu,
|
|
||||||
.cuda,
|
|
||||||
.rocm,
|
|
||||||
.tpu,
|
|
||||||
},
|
|
||||||
else => [_]Target{},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const CompilationOptions = struct {
|
pub const CompilationOptions = struct {
|
||||||
xla_dump_to: ?[]const u8 = null,
|
xla_dump_to: ?[]const u8 = null,
|
||||||
|
|||||||
113
zml/shape.zig
113
zml/shape.zig
@ -1,11 +1,12 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
|
|
||||||
const EnumLiteral = @TypeOf(.enum_literal);
|
const EnumLiteral = @TypeOf(.enum_literal);
|
||||||
|
|
||||||
const log = std.log.scoped(.shape);
|
const log = std.log.scoped(.shape);
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -39,7 +40,7 @@ pub const Shape = struct {
|
|||||||
return .{ v._dims, v._tags };
|
return .{ v._dims, v._tags };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isSliceOfAny(T, meta.isInteger)) {
|
if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger)) {
|
||||||
var dims_ = DimsArray.init(0) catch unreachable;
|
var dims_ = DimsArray.init(0) catch unreachable;
|
||||||
var tags_ = TagsArray.init(0) catch unreachable;
|
var tags_ = TagsArray.init(0) catch unreachable;
|
||||||
for (v) |d| {
|
for (v) |d| {
|
||||||
@ -49,19 +50,19 @@ pub const Shape = struct {
|
|||||||
return .{ dims_, tags_ };
|
return .{ dims_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isStruct(T)) {
|
if (comptime stdx.meta.isStruct(T)) {
|
||||||
var dims_: DimsArray = .{};
|
var dims_: DimsArray = .{};
|
||||||
var tags_: TagsArray = .{};
|
var tags_: TagsArray = .{};
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
const fv = @field(v, field.name);
|
const fv = @field(v, field.name);
|
||||||
if (comptime meta.isInteger(field.type)) {
|
if (comptime stdx.meta.isInteger(field.type)) {
|
||||||
dims_.appendAssumeCapacity(@intCast(fv));
|
dims_.appendAssumeCapacity(@intCast(fv));
|
||||||
} else if (comptime isAutoDim(fv)) {
|
} else if (comptime isAutoDim(fv)) {
|
||||||
dims_.appendAssumeCapacity(-1);
|
dims_.appendAssumeCapacity(-1);
|
||||||
} else {
|
} else {
|
||||||
meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
|
stdx.meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
|
||||||
}
|
}
|
||||||
if (comptime meta.isTuple(T)) {
|
if (comptime stdx.meta.isTuple(T)) {
|
||||||
tags_.appendAssumeCapacity(TagUnknown);
|
tags_.appendAssumeCapacity(TagUnknown);
|
||||||
} else {
|
} else {
|
||||||
tags_.appendAssumeCapacity(toTag(field));
|
tags_.appendAssumeCapacity(toTag(field));
|
||||||
@ -71,7 +72,7 @@ pub const Shape = struct {
|
|||||||
return .{ dims_, tags_ };
|
return .{ dims_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
|
stdx.meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
test parseDimensions {
|
test parseDimensions {
|
||||||
@ -92,7 +93,7 @@ pub const Shape = struct {
|
|||||||
var axes_ = AxesArray.init(0) catch unreachable;
|
var axes_ = AxesArray.init(0) catch unreachable;
|
||||||
var tags_ = TagsArray.init(0) catch unreachable;
|
var tags_ = TagsArray.init(0) catch unreachable;
|
||||||
|
|
||||||
if (comptime meta.isSliceOfAny(T, isAxisConvertible)) {
|
if (comptime stdx.meta.isSliceOfAny(T, isAxisConvertible)) {
|
||||||
for (v) |d| {
|
for (v) |d| {
|
||||||
axes_.appendAssumeCapacity(self.axis(d));
|
axes_.appendAssumeCapacity(self.axis(d));
|
||||||
tags_.appendAssumeCapacity(self.tag(d));
|
tags_.appendAssumeCapacity(self.tag(d));
|
||||||
@ -100,7 +101,7 @@ pub const Shape = struct {
|
|||||||
return .{ axes_, tags_ };
|
return .{ axes_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isTupleOfAny(T, isAxisConvertible)) {
|
if (comptime stdx.meta.isTupleOfAny(T, isAxisConvertible)) {
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
axes_.appendAssumeCapacity(self.axis(@field(v, field.name)));
|
axes_.appendAssumeCapacity(self.axis(@field(v, field.name)));
|
||||||
tags_.appendAssumeCapacity(self.tag(@field(v, field.name)));
|
tags_.appendAssumeCapacity(self.tag(@field(v, field.name)));
|
||||||
@ -108,12 +109,12 @@ pub const Shape = struct {
|
|||||||
return .{ axes_, tags_ };
|
return .{ axes_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T});
|
stdx.meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parseTags(v: anytype) TagsArray {
|
pub fn parseTags(v: anytype) TagsArray {
|
||||||
const T = @TypeOf(v);
|
const T = @TypeOf(v);
|
||||||
meta.assertComptime(meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T});
|
stdx.debug.assertComptime(stdx.meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T});
|
||||||
var tags_ = TagsArray.init(0) catch unreachable;
|
var tags_ = TagsArray.init(0) catch unreachable;
|
||||||
inline for (v) |field| {
|
inline for (v) |field| {
|
||||||
tags_.appendAssumeCapacity(toTag(field));
|
tags_.appendAssumeCapacity(toTag(field));
|
||||||
@ -135,7 +136,7 @@ pub const Shape = struct {
|
|||||||
var res: Shape = .{ ._dtype = dt };
|
var res: Shape = .{ ._dtype = dt };
|
||||||
for (0..rank_) |i| {
|
for (0..rank_) |i| {
|
||||||
res._dims.append(@intCast(i)) catch {
|
res._dims.append(@intCast(i)) catch {
|
||||||
meta.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ });
|
stdx.debug.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ });
|
||||||
};
|
};
|
||||||
res._tags.append(TagUnknown) catch unreachable;
|
res._tags.append(TagUnknown) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -162,7 +163,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn isAxisConvertible(comptime T: type) bool {
|
fn isAxisConvertible(comptime T: type) bool {
|
||||||
return meta.isInteger(T) or isTagConvertible(T);
|
return stdx.meta.isInteger(T) or isTagConvertible(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn isTagConvertible(comptime T: type) bool {
|
fn isTagConvertible(comptime T: type) bool {
|
||||||
@ -180,12 +181,12 @@ pub const Shape = struct {
|
|||||||
EnumLiteral => @tagName(v).ptr,
|
EnumLiteral => @tagName(v).ptr,
|
||||||
std.builtin.Type.StructField => v.name.ptr,
|
std.builtin.Type.StructField => v.name.ptr,
|
||||||
Tag => v,
|
Tag => v,
|
||||||
else => meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
else => stdx.meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn ensureDimsAndTagsAreSync(self: Shape) void {
|
inline fn ensureDimsAndTagsAreSync(self: Shape) void {
|
||||||
meta.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len });
|
stdx.debug.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tag(self: Shape, ax: anytype) Tag {
|
pub fn tag(self: Shape, ax: anytype) Tag {
|
||||||
@ -220,7 +221,7 @@ pub const Shape = struct {
|
|||||||
pub fn hasTags(self: Shape, tagz: anytype) bool {
|
pub fn hasTags(self: Shape, tagz: anytype) bool {
|
||||||
const T = @TypeOf(tagz);
|
const T = @TypeOf(tagz);
|
||||||
|
|
||||||
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
||||||
for (tagz) |t| {
|
for (tagz) |t| {
|
||||||
if (self.hasTag(t) == null) {
|
if (self.hasTag(t) == null) {
|
||||||
return false;
|
return false;
|
||||||
@ -229,7 +230,7 @@ pub const Shape = struct {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
||||||
inline for (tagz) |t| {
|
inline for (tagz) |t| {
|
||||||
if (self.hasTag(t) == null) {
|
if (self.hasTag(t) == null) {
|
||||||
return false;
|
return false;
|
||||||
@ -238,7 +239,7 @@ pub const Shape = struct {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Expected tuple of tags, got {any}", .{T});
|
stdx.meta.compileError("Expected tuple of tags, got {any}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isFullyTagged(self: Shape) bool {
|
pub fn isFullyTagged(self: Shape) bool {
|
||||||
@ -252,7 +253,7 @@ pub const Shape = struct {
|
|||||||
self.ensureDimsAndTagsAreSync();
|
self.ensureDimsAndTagsAreSync();
|
||||||
|
|
||||||
const T = @TypeOf(axis_);
|
const T = @TypeOf(axis_);
|
||||||
if (comptime meta.isInteger(T)) {
|
if (comptime stdx.meta.isInteger(T)) {
|
||||||
return self.axisFromInt(@intCast(axis_));
|
return self.axisFromInt(@intCast(axis_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,7 +261,7 @@ pub const Shape = struct {
|
|||||||
return self.axisFromTag(toTag(axis_));
|
return self.axisFromTag(toTag(axis_));
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T});
|
stdx.meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn axes(self: Shape, axes_: anytype) AxesArray {
|
pub fn axes(self: Shape, axes_: anytype) AxesArray {
|
||||||
@ -274,27 +275,27 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
var res = AxesArray.init(0) catch unreachable;
|
var res = AxesArray.init(0) catch unreachable;
|
||||||
|
|
||||||
if (comptime meta.isSliceOfAny(T, meta.isInteger) or meta.isSliceOf(T, Tag)) {
|
if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger) or stdx.meta.isSliceOf(T, Tag)) {
|
||||||
for (axes_) |ax| {
|
for (axes_) |ax| {
|
||||||
res.appendAssumeCapacity(self.axis(ax));
|
res.appendAssumeCapacity(self.axis(ax));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isStruct(T)) {
|
if (comptime stdx.meta.isStruct(T)) {
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
res.appendAssumeCapacity(self.axis(@field(axes_, field.name)));
|
res.appendAssumeCapacity(self.axis(@field(axes_, field.name)));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
|
stdx.meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn axisFromInt(self: Shape, d: isize) u3 {
|
fn axisFromInt(self: Shape, d: isize) u3 {
|
||||||
const rk: i8 = self.rank();
|
const rk: i8 = self.rank();
|
||||||
if (d < -rk or d > rk) {
|
if (d < -rk or d > rk) {
|
||||||
meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
|
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
|
||||||
}
|
}
|
||||||
return if (d < 0)
|
return if (d < 0)
|
||||||
@intCast(d + rk)
|
@intCast(d + rk)
|
||||||
@ -323,9 +324,9 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn axisFromTag(self: Shape, d: Tag) u3 {
|
fn axisFromTag(self: Shape, d: Tag) u3 {
|
||||||
meta.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self });
|
stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self });
|
||||||
return self.axisFromTagMaybe(d) orelse {
|
return self.axisFromTagMaybe(d) orelse {
|
||||||
meta.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d });
|
stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d });
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +340,7 @@ pub const Shape = struct {
|
|||||||
pub fn count(self: Shape) usize {
|
pub fn count(self: Shape) usize {
|
||||||
var res: i64 = 1;
|
var res: i64 = 1;
|
||||||
for (self.dims()) |d| {
|
for (self.dims()) |d| {
|
||||||
meta.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self});
|
stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self});
|
||||||
res *= d;
|
res *= d;
|
||||||
}
|
}
|
||||||
return @intCast(res);
|
return @intCast(res);
|
||||||
@ -398,12 +399,12 @@ pub const Shape = struct {
|
|||||||
var new_shape: Shape = .{ ._dtype = self.dtype() };
|
var new_shape: Shape = .{ ._dtype = self.dtype() };
|
||||||
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
|
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
|
||||||
new_shape.inferMissingAxis(self.count());
|
new_shape.inferMissingAxis(self.count());
|
||||||
meta.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() });
|
stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() });
|
||||||
return new_shape;
|
return new_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inferMissingAxis(self: *Shape, n_: usize) void {
|
fn inferMissingAxis(self: *Shape, n_: usize) void {
|
||||||
meta.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*});
|
stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*});
|
||||||
|
|
||||||
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
|
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
|
||||||
// We can't use `self.count()` yet cause we have negative dims.
|
// We can't use `self.count()` yet cause we have negative dims.
|
||||||
@ -481,7 +482,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
|
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
|
||||||
meta.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self});
|
stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self});
|
||||||
|
|
||||||
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
||||||
self.rank()
|
self.rank()
|
||||||
@ -573,23 +574,23 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
|
|
||||||
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
||||||
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz });
|
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz });
|
||||||
for (tagz, 0..) |tag_, i| {
|
for (tagz, 0..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
||||||
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz });
|
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz });
|
||||||
inline for (tagz, 0..) |tag_, i| {
|
inline for (tagz, 0..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
|
stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
|
||||||
}
|
}
|
||||||
|
|
||||||
test withTags {
|
test withTags {
|
||||||
@ -620,23 +621,23 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
|
|
||||||
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
||||||
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz });
|
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz });
|
||||||
for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
||||||
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz });
|
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz });
|
||||||
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
|
stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
|
||||||
}
|
}
|
||||||
|
|
||||||
test withPartialTags {
|
test withPartialTags {
|
||||||
@ -683,7 +684,7 @@ pub const Shape = struct {
|
|||||||
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
|
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
|
||||||
pub fn rename(self: Shape, renames: anytype) Shape {
|
pub fn rename(self: Shape, renames: anytype) Shape {
|
||||||
const T = @TypeOf(renames);
|
const T = @TypeOf(renames);
|
||||||
meta.assertComptime(meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
|
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
|
||||||
var res = self;
|
var res = self;
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
|
res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
|
||||||
@ -789,7 +790,7 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
pub fn splitAxes(self: Shape, axes_: anytype) Shape {
|
pub fn splitAxes(self: Shape, axes_: anytype) Shape {
|
||||||
const T = @TypeOf(axes_);
|
const T = @TypeOf(axes_);
|
||||||
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
|
stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
|
||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
@ -822,7 +823,7 @@ pub const Shape = struct {
|
|||||||
var new_dim: i64 = 1;
|
var new_dim: i64 = 1;
|
||||||
for (axes__.constSlice(), first_axis..) |ax, counter| {
|
for (axes__.constSlice(), first_axis..) |ax, counter| {
|
||||||
new_dim *= self.dim(ax);
|
new_dim *= self.dim(ax);
|
||||||
meta.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ });
|
stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ });
|
||||||
}
|
}
|
||||||
|
|
||||||
var new_shape = self;
|
var new_shape = self;
|
||||||
@ -863,11 +864,11 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
pub fn mergeAxes(self: Shape, axes_: anytype) Shape {
|
pub fn mergeAxes(self: Shape, axes_: anytype) Shape {
|
||||||
const T = @TypeOf(axes_);
|
const T = @TypeOf(axes_);
|
||||||
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
|
stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
|
||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
meta.assertComptime(meta.isTupleOfAny(field.type, isAxisConvertible) or meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type});
|
stdx.debug.assertComptime(stdx.meta.isTupleOfAny(field.type, isAxisConvertible) or stdx.meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type});
|
||||||
res = res.mergeAxis(field, @field(axes_, field.name));
|
res = res.mergeAxis(field, @field(axes_, field.name));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
@ -912,28 +913,28 @@ pub const Shape = struct {
|
|||||||
var vals_: std.BoundedArray(T, MAX_RANK) = .{};
|
var vals_: std.BoundedArray(T, MAX_RANK) = .{};
|
||||||
var tags_: TagsArray = .{};
|
var tags_: TagsArray = .{};
|
||||||
|
|
||||||
if (comptime meta.isSliceOf(V, T)) {
|
if (comptime stdx.meta.isSliceOf(V, T)) {
|
||||||
for (v) |d| {
|
for (v) |d| {
|
||||||
vals_.appendAssumeCapacity(d);
|
vals_.appendAssumeCapacity(d);
|
||||||
}
|
}
|
||||||
return .{ vals_, tags_ };
|
return .{ vals_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isStruct(V)) {
|
if (comptime stdx.meta.isStruct(V)) {
|
||||||
const fields = std.meta.fields(V);
|
const fields = std.meta.fields(V);
|
||||||
meta.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK });
|
stdx.debug.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK });
|
||||||
inline for (fields) |field| {
|
inline for (fields) |field| {
|
||||||
const fv = @field(v, field.name);
|
const fv = @field(v, field.name);
|
||||||
vals_.appendAssumeCapacity(fv);
|
vals_.appendAssumeCapacity(fv);
|
||||||
|
|
||||||
if (!comptime meta.isTuple(V)) {
|
if (!comptime stdx.meta.isTuple(V)) {
|
||||||
tags_.appendAssumeCapacity(toTag(field));
|
tags_.appendAssumeCapacity(toTag(field));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return .{ vals_, tags_ };
|
return .{ vals_, tags_ };
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
|
stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
|
||||||
}
|
}
|
||||||
|
|
||||||
test parseStruct {
|
test parseStruct {
|
||||||
@ -948,17 +949,17 @@ pub const Shape = struct {
|
|||||||
const V = @TypeOf(options);
|
const V = @TypeOf(options);
|
||||||
|
|
||||||
var res: std.BoundedArray(T, MAX_RANK) = .{};
|
var res: std.BoundedArray(T, MAX_RANK) = .{};
|
||||||
if (comptime meta.isSliceOf(V, T)) {
|
if (comptime stdx.meta.isSliceOf(V, T)) {
|
||||||
meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
|
stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
|
||||||
for (options) |d| {
|
for (options) |d| {
|
||||||
res.appendAssumeCapacity(d);
|
res.appendAssumeCapacity(d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (comptime meta.isStruct(V)) {
|
if (comptime stdx.meta.isStruct(V)) {
|
||||||
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
|
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
|
||||||
const fields = std.meta.fields(V);
|
const fields = std.meta.fields(V);
|
||||||
meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
|
stdx.debug.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
|
||||||
inline for (fields) |field| {
|
inline for (fields) |field| {
|
||||||
const a = self.axis(field);
|
const a = self.axis(field);
|
||||||
res.buffer[a] = @field(options, field.name);
|
res.buffer[a] = @field(options, field.name);
|
||||||
@ -966,7 +967,7 @@ pub const Shape = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
|
stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
|
||||||
}
|
}
|
||||||
|
|
||||||
test parseAxesOptions {
|
test parseAxesOptions {
|
||||||
|
|||||||
281
zml/tensor.zig
281
zml/tensor.zig
@ -1,7 +1,6 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const assert = std.debug.assert;
|
const stdx = @import("stdx");
|
||||||
const testing = std.testing;
|
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
@ -22,7 +21,9 @@ const dialect = struct {
|
|||||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||||
};
|
};
|
||||||
|
|
||||||
const scoped_log = std.log.scoped(.zml_tensor);
|
const assert = std.debug.assert;
|
||||||
|
const testing = std.testing;
|
||||||
|
const scoped_log = std.log.scoped(.@"zml/tensor");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(Tensor);
|
std.testing.refAllDecls(Tensor);
|
||||||
@ -99,7 +100,7 @@ pub const Tensor = struct {
|
|||||||
if (builtin.mode == .Debug) {
|
if (builtin.mode == .Debug) {
|
||||||
// Check that the MLIR value actually have the same shape.
|
// Check that the MLIR value actually have the same shape.
|
||||||
const other = fromMlirValue(val);
|
const other = fromMlirValue(val);
|
||||||
meta.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape });
|
stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
@ -112,7 +113,7 @@ pub const Tensor = struct {
|
|||||||
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
|
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
|
||||||
const n = ranked_tensor.getRank();
|
const n = ranked_tensor.getRank();
|
||||||
|
|
||||||
meta.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
||||||
|
|
||||||
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
|
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
|
||||||
for (0..n) |i| {
|
for (0..n) |i| {
|
||||||
@ -213,7 +214,7 @@ pub const Tensor = struct {
|
|||||||
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
|
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
|
||||||
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
|
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
|
||||||
// Note: check donation docs, this may be too permissive.
|
// Note: check donation docs, this may be too permissive.
|
||||||
meta.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin });
|
stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin });
|
||||||
|
|
||||||
// TODO: should we store all donations inside the context ?
|
// TODO: should we store all donations inside the context ?
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -262,7 +263,7 @@ pub const Tensor = struct {
|
|||||||
break :gt res;
|
break :gt res;
|
||||||
} else lt: {
|
} else lt: {
|
||||||
// several contiguous elements of self maps to one element of the result
|
// several contiguous elements of self maps to one element of the result
|
||||||
meta.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size });
|
stdx.debug.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size });
|
||||||
break :lt self._shape.remove(-1);
|
break :lt self._shape.remove(-1);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -295,7 +296,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
||||||
pub fn popcnt(self: Tensor) Tensor {
|
pub fn popcnt(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc);
|
const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc);
|
||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
@ -358,7 +359,7 @@ pub const Tensor = struct {
|
|||||||
/// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true
|
/// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true
|
||||||
/// and upper-triangular otherwise.
|
/// and upper-triangular otherwise.
|
||||||
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
||||||
meta.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
|
||||||
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
||||||
@ -367,8 +368,8 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Solves the system of linear equations formed by the input tensors.
|
/// Solves the system of linear equations formed by the input tensors.
|
||||||
pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor {
|
pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor {
|
||||||
meta.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||||||
meta.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
||||||
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
||||||
@ -377,7 +378,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor.
|
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor.
|
||||||
pub fn roundNearestAfz(self: Tensor) Tensor {
|
pub fn roundNearestAfz(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc);
|
const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc);
|
||||||
@ -386,7 +387,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor.
|
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor.
|
||||||
pub fn roundNearestEven(self: Tensor) Tensor {
|
pub fn roundNearestEven(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc);
|
const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc);
|
||||||
@ -395,8 +396,8 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
|
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
|
||||||
pub fn complex(re: Tensor, im: Tensor) Tensor {
|
pub fn complex(re: Tensor, im: Tensor) Tensor {
|
||||||
meta.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape });
|
stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape });
|
||||||
meta.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
|
stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
|
||||||
|
|
||||||
const loc = re.getContext().mlirCtx().location(@src());
|
const loc = re.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc);
|
const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc);
|
||||||
@ -408,7 +409,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// Tensor type can float or complex.
|
/// Tensor type can float or complex.
|
||||||
pub fn real(self: Tensor) Tensor {
|
pub fn real(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
||||||
|
|
||||||
if (self.dtype().isFloat()) {
|
if (self.dtype().isFloat()) {
|
||||||
return self;
|
return self;
|
||||||
@ -428,7 +429,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// Tensor type can float or complex.
|
/// Tensor type can float or complex.
|
||||||
pub fn imag(self: Tensor) Tensor {
|
pub fn imag(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
||||||
|
|
||||||
// Real tensors don't have imaginary part.
|
// Real tensors don't have imaginary part.
|
||||||
if (self.dtype().isFloat()) {
|
if (self.dtype().isFloat()) {
|
||||||
@ -450,18 +451,18 @@ pub const Tensor = struct {
|
|||||||
pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor {
|
pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor {
|
||||||
// TODO: support tagged API.
|
// TODO: support tagged API.
|
||||||
|
|
||||||
meta.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len});
|
stdx.debug.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len});
|
||||||
meta.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() });
|
stdx.debug.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() });
|
||||||
|
|
||||||
const sh = switch (opts.kind) {
|
const sh = switch (opts.kind) {
|
||||||
.FFT, .IFFT => blk: {
|
.FFT, .IFFT => blk: {
|
||||||
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
||||||
|
|
||||||
break :blk self._shape;
|
break :blk self._shape;
|
||||||
},
|
},
|
||||||
.RFFT => blk: {
|
.RFFT => blk: {
|
||||||
meta.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() });
|
stdx.debug.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() });
|
||||||
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
||||||
|
|
||||||
const dt: DataType = switch (self.dtype()) {
|
const dt: DataType = switch (self.dtype()) {
|
||||||
.f32 => .c64,
|
.f32 => .c64,
|
||||||
@ -471,8 +472,8 @@ pub const Tensor = struct {
|
|||||||
break :blk shape_.withDtype(dt);
|
break :blk shape_.withDtype(dt);
|
||||||
},
|
},
|
||||||
.IRFFT => blk: {
|
.IRFFT => blk: {
|
||||||
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
||||||
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
||||||
|
|
||||||
const dt: DataType = switch (self.dtype()) {
|
const dt: DataType = switch (self.dtype()) {
|
||||||
.c64 => .f32,
|
.c64 => .f32,
|
||||||
@ -551,7 +552,7 @@ pub const Tensor = struct {
|
|||||||
16 => .u16,
|
16 => .u16,
|
||||||
32 => .u32,
|
32 => .u32,
|
||||||
64 => .u64,
|
64 => .u64,
|
||||||
else => meta.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}),
|
else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}),
|
||||||
};
|
};
|
||||||
|
|
||||||
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
|
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
|
||||||
@ -635,7 +636,7 @@ pub const Tensor = struct {
|
|||||||
/// Note: this uses stablehlo.rng which is deprecated.
|
/// Note: this uses stablehlo.rng which is deprecated.
|
||||||
/// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md
|
/// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md
|
||||||
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
||||||
meta.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
||||||
|
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
|
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
|
||||||
@ -731,9 +732,9 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise conversion to another floating point type.
|
/// Returns a Tensor containing the element-wise conversion to another floating point type.
|
||||||
pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor {
|
pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor {
|
||||||
meta.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()});
|
stdx.debug.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()});
|
||||||
meta.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
||||||
meta.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
||||||
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
||||||
@ -741,56 +742,56 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
|
inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
|
||||||
meta.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() });
|
stdx.debug.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() });
|
||||||
const N = self.rank();
|
const N = self.rank();
|
||||||
meta.guard(opts.window_strides.len == N - 2, @src());
|
stdx.debug.guard(opts.window_strides.len == N - 2, @src());
|
||||||
for (opts.window_strides) |s| meta.guard(0 < s, @src());
|
for (opts.window_strides) |s| stdx.debug.guard(0 < s, @src());
|
||||||
meta.guard(opts.lhs_dilation.len == N - 2, @src());
|
stdx.debug.guard(opts.lhs_dilation.len == N - 2, @src());
|
||||||
for (opts.lhs_dilation) |d| meta.guard(0 < d, @src());
|
for (opts.lhs_dilation) |d| stdx.debug.guard(0 < d, @src());
|
||||||
meta.guard(opts.rhs_dilation.len == N - 2, @src());
|
stdx.debug.guard(opts.rhs_dilation.len == N - 2, @src());
|
||||||
for (opts.rhs_dilation) |d| meta.guard(0 < d, @src());
|
for (opts.rhs_dilation) |d| stdx.debug.guard(0 < d, @src());
|
||||||
meta.guard(opts.window_reversal.len == N - 2, @src());
|
stdx.debug.guard(opts.window_reversal.len == N - 2, @src());
|
||||||
meta.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src());
|
stdx.debug.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src());
|
||||||
meta.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src());
|
stdx.debug.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src());
|
||||||
meta.guard(opts.input_spatial_dimensions.len == N - 2, @src());
|
stdx.debug.guard(opts.input_spatial_dimensions.len == N - 2, @src());
|
||||||
meta.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src());
|
stdx.debug.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src());
|
||||||
meta.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src());
|
||||||
meta.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src());
|
||||||
for (opts.input_spatial_dimensions, 0..) |d, i| {
|
for (opts.input_spatial_dimensions, 0..) |d, i| {
|
||||||
meta.guard(d != opts.input_batch_dimension, @src());
|
stdx.debug.guard(d != opts.input_batch_dimension, @src());
|
||||||
meta.guard(d != opts.input_feature_dimension, @src());
|
stdx.debug.guard(d != opts.input_feature_dimension, @src());
|
||||||
meta.guard(0 <= d and d < N, @src());
|
stdx.debug.guard(0 <= d and d < N, @src());
|
||||||
if (i < opts.input_spatial_dimensions.len - 1) continue;
|
if (i < opts.input_spatial_dimensions.len - 1) continue;
|
||||||
meta.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src());
|
stdx.debug.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||||||
}
|
}
|
||||||
meta.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src());
|
stdx.debug.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src());
|
||||||
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src());
|
stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src());
|
||||||
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src());
|
stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src());
|
||||||
meta.guard(opts.kernel_spatial_dimensions.len == N - 2, @src());
|
stdx.debug.guard(opts.kernel_spatial_dimensions.len == N - 2, @src());
|
||||||
meta.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src());
|
stdx.debug.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src());
|
||||||
meta.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src());
|
||||||
meta.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src());
|
||||||
for (opts.kernel_spatial_dimensions, 0..) |d, i| {
|
for (opts.kernel_spatial_dimensions, 0..) |d, i| {
|
||||||
meta.guard(d != opts.kernel_input_feature_dimension, @src());
|
stdx.debug.guard(d != opts.kernel_input_feature_dimension, @src());
|
||||||
meta.guard(d != opts.kernel_output_feature_dimension, @src());
|
stdx.debug.guard(d != opts.kernel_output_feature_dimension, @src());
|
||||||
meta.guard(0 <= d and d < N, @src());
|
stdx.debug.guard(0 <= d and d < N, @src());
|
||||||
if (i < opts.kernel_spatial_dimensions.len - 1) continue;
|
if (i < opts.kernel_spatial_dimensions.len - 1) continue;
|
||||||
meta.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src());
|
stdx.debug.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||||||
}
|
}
|
||||||
meta.guard(opts.output_spatial_dimensions.len == N - 2, @src());
|
stdx.debug.guard(opts.output_spatial_dimensions.len == N - 2, @src());
|
||||||
meta.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src());
|
stdx.debug.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src());
|
||||||
meta.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src());
|
||||||
meta.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src());
|
stdx.debug.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src());
|
||||||
for (opts.output_spatial_dimensions, 0..) |d, i| {
|
for (opts.output_spatial_dimensions, 0..) |d, i| {
|
||||||
meta.guard(d != opts.output_batch_dimension, @src());
|
stdx.debug.guard(d != opts.output_batch_dimension, @src());
|
||||||
meta.guard(d != opts.output_feature_dimension, @src());
|
stdx.debug.guard(d != opts.output_feature_dimension, @src());
|
||||||
meta.guard(0 <= d and d < N, @src());
|
stdx.debug.guard(0 <= d and d < N, @src());
|
||||||
if (i < opts.output_spatial_dimensions.len - 1) continue;
|
if (i < opts.output_spatial_dimensions.len - 1) continue;
|
||||||
meta.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src());
|
stdx.debug.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||||||
}
|
}
|
||||||
meta.guard(0 < opts.feature_group_count, @src());
|
stdx.debug.guard(0 < opts.feature_group_count, @src());
|
||||||
meta.guard(0 < opts.batch_group_count, @src());
|
stdx.debug.guard(0 < opts.batch_group_count, @src());
|
||||||
meta.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src());
|
stdx.debug.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src());
|
||||||
var used_opts = opts;
|
var used_opts = opts;
|
||||||
used_opts.pad_shape = &.{ @intCast(N - 2), 2 };
|
used_opts.pad_shape = &.{ @intCast(N - 2), 2 };
|
||||||
used_opts.precision_config = &.{ .DEFAULT, .DEFAULT };
|
used_opts.precision_config = &.{ .DEFAULT, .DEFAULT };
|
||||||
@ -1042,10 +1043,10 @@ pub const Tensor = struct {
|
|||||||
var batching_axes: [MAX_RANK][2]i8 = undefined;
|
var batching_axes: [MAX_RANK][2]i8 = undefined;
|
||||||
var n_batching: u8 = 0;
|
var n_batching: u8 = 0;
|
||||||
for (lhs._shape.tags(), 0..) |l, li| {
|
for (lhs._shape.tags(), 0..) |l, li| {
|
||||||
meta.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
|
stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
|
||||||
|
|
||||||
for (rhs._shape.tags(), 0..) |r, ri| {
|
for (rhs._shape.tags(), 0..) |r, ri| {
|
||||||
meta.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
|
stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
|
||||||
|
|
||||||
if (l == r) {
|
if (l == r) {
|
||||||
for (contracting_axes) |ct| {
|
for (contracting_axes) |ct| {
|
||||||
@ -1114,7 +1115,7 @@ pub const Tensor = struct {
|
|||||||
contracting_axes: []const [2]i8,
|
contracting_axes: []const [2]i8,
|
||||||
batching_axes: []const [2]i8,
|
batching_axes: []const [2]i8,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
||||||
|
|
||||||
const Axes = std.BoundedArray(i64, MAX_RANK);
|
const Axes = std.BoundedArray(i64, MAX_RANK);
|
||||||
|
|
||||||
@ -1124,7 +1125,7 @@ pub const Tensor = struct {
|
|||||||
var rhs_batching_axes: Axes = .{};
|
var rhs_batching_axes: Axes = .{};
|
||||||
for (batching_axes) |b_axes| {
|
for (batching_axes) |b_axes| {
|
||||||
const l, const r = b_axes;
|
const l, const r = b_axes;
|
||||||
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
||||||
var t = lhs._shape.tag(l);
|
var t = lhs._shape.tag(l);
|
||||||
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
||||||
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
||||||
@ -1137,7 +1138,7 @@ pub const Tensor = struct {
|
|||||||
var rhs_contracting_axes: Axes = .{};
|
var rhs_contracting_axes: Axes = .{};
|
||||||
for (contracting_axes) |c_axes| {
|
for (contracting_axes) |c_axes| {
|
||||||
const l, const r = c_axes;
|
const l, const r = c_axes;
|
||||||
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
||||||
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
||||||
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
||||||
}
|
}
|
||||||
@ -1353,7 +1354,7 @@ pub const Tensor = struct {
|
|||||||
else
|
else
|
||||||
toI64(axes__);
|
toI64(axes__);
|
||||||
|
|
||||||
meta.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
|
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
|
||||||
|
|
||||||
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
||||||
return self;
|
return self;
|
||||||
@ -1386,7 +1387,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
||||||
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
|
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
|
||||||
meta.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
|
||||||
|
|
||||||
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
|
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
|
||||||
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
||||||
@ -1443,7 +1444,7 @@ pub const Tensor = struct {
|
|||||||
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
|
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
|
||||||
const old_shape = self._shape;
|
const old_shape = self._shape;
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
// meta.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||||||
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
|
||||||
@ -1483,7 +1484,7 @@ pub const Tensor = struct {
|
|||||||
var res_shape: Shape = self._shape;
|
var res_shape: Shape = self._shape;
|
||||||
|
|
||||||
for (slices, 0..) |s, a| {
|
for (slices, 0..) |s, a| {
|
||||||
meta.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
|
stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
|
||||||
|
|
||||||
const args: Slice = .{
|
const args: Slice = .{
|
||||||
.start = self.wrapIndex(a, s.start),
|
.start = self.wrapIndex(a, s.start),
|
||||||
@ -1549,7 +1550,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Concatenates the input Tensors along the given axis.
|
/// Concatenates the input Tensors along the given axis.
|
||||||
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
|
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
|
||||||
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
|
stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
|
||||||
var buffer: [32]mlir.Value = undefined;
|
var buffer: [32]mlir.Value = undefined;
|
||||||
std.debug.assert(tensors.len <= buffer.len);
|
std.debug.assert(tensors.len <= buffer.len);
|
||||||
std.debug.assert(tensors.len > 0);
|
std.debug.assert(tensors.len > 0);
|
||||||
@ -1576,13 +1577,13 @@ pub const Tensor = struct {
|
|||||||
/// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers }
|
/// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers }
|
||||||
pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor {
|
pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor {
|
||||||
// Note: we could ask the compilation context for some memory instead of stack allocating
|
// Note: we could ask the compilation context for some memory instead of stack allocating
|
||||||
meta.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len});
|
stdx.debug.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len});
|
||||||
|
|
||||||
const shape0 = tensors[0]._shape;
|
const shape0 = tensors[0]._shape;
|
||||||
const res_shape = shape0.insertTag(axis_, 1, tag);
|
const res_shape = shape0.insertTag(axis_, 1, tag);
|
||||||
|
|
||||||
for (tensors[1..]) |tensor| {
|
for (tensors[1..]) |tensor| {
|
||||||
meta.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
|
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
|
||||||
}
|
}
|
||||||
|
|
||||||
var reshaped: [32]Tensor = undefined;
|
var reshaped: [32]Tensor = undefined;
|
||||||
@ -1623,7 +1624,7 @@ pub const Tensor = struct {
|
|||||||
/// Repeats a Tensor several times along the given axes.
|
/// Repeats a Tensor several times along the given axes.
|
||||||
pub fn repeat(self: Tensor, n_reps: []const u63) Tensor {
|
pub fn repeat(self: Tensor, n_reps: []const u63) Tensor {
|
||||||
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
|
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
|
||||||
meta.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
stdx.debug.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
for (n_reps, 0..) |n_rep, a| {
|
for (n_reps, 0..) |n_rep, a| {
|
||||||
@ -1647,7 +1648,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Repeats in line each value along the given axes.
|
/// Repeats in line each value along the given axes.
|
||||||
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
|
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
|
||||||
meta.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
for (n_reps, 0..) |n_rep, a| {
|
for (n_reps, 0..) |n_rep, a| {
|
||||||
@ -1724,8 +1725,8 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing evenly spaced values within a given interval.
|
/// Returns a Tensor containing evenly spaced values within a given interval.
|
||||||
pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
|
pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
|
||||||
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||||
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
|
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
|
||||||
@ -1770,9 +1771,9 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive.
|
/// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive.
|
||||||
pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor {
|
pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor {
|
||||||
meta.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
stdx.debug.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||||
meta.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps});
|
stdx.debug.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps});
|
||||||
meta.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
|
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
|
||||||
@ -1824,7 +1825,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the result of the outer product between the input Tensors.
|
/// Returns a Tensor containing the result of the outer product between the input Tensors.
|
||||||
pub fn outer(self: Tensor, other: Tensor) Tensor {
|
pub fn outer(self: Tensor, other: Tensor) Tensor {
|
||||||
meta.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() });
|
stdx.debug.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() });
|
||||||
|
|
||||||
if (self.rank() + other.rank() == 1) {
|
if (self.rank() + other.rank() == 1) {
|
||||||
return self.mul(other);
|
return self.mul(other);
|
||||||
@ -1856,7 +1857,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
|
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
|
||||||
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
|
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
|
||||||
meta.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
||||||
|
|
||||||
const a = output_shape.rank() - self.rank();
|
const a = output_shape.rank() - self.rank();
|
||||||
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
|
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
|
||||||
@ -1868,7 +1869,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Broadcasts a Tensor to the given shape, adding axes at the end.
|
/// Broadcasts a Tensor to the given shape, adding axes at the end.
|
||||||
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
|
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
|
||||||
meta.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
||||||
|
|
||||||
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
|
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
|
||||||
return self;
|
return self;
|
||||||
@ -1967,7 +1968,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Appends a 1-dim axis, with the given tag.
|
/// Appends a 1-dim axis, with the given tag.
|
||||||
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
|
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
|
||||||
meta.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
|
||||||
|
|
||||||
return self.insertAxes(.last, t);
|
return self.insertAxes(.last, t);
|
||||||
}
|
}
|
||||||
@ -1975,7 +1976,7 @@ pub const Tensor = struct {
|
|||||||
/// Drops a 1-dim axis at the given index
|
/// Drops a 1-dim axis at the given index
|
||||||
pub fn squeeze(self: Tensor, axis_: anytype) Tensor {
|
pub fn squeeze(self: Tensor, axis_: anytype) Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
meta.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)});
|
stdx.debug.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)});
|
||||||
|
|
||||||
const new_shape = self._shape.remove(a);
|
const new_shape = self._shape.remove(a);
|
||||||
// log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape });
|
// log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape });
|
||||||
@ -2023,10 +2024,10 @@ pub const Tensor = struct {
|
|||||||
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
|
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
|
||||||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||||||
|
|
||||||
meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||||||
for (coord_axes_.constSlice(), 0..) |a, i| {
|
for (coord_axes_.constSlice(), 0..) |a, i| {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
|
stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2040,7 +2041,7 @@ pub const Tensor = struct {
|
|||||||
// Note: tags are required for batching.
|
// Note: tags are required for batching.
|
||||||
self_kind.appendAssumeCapacity(.batching);
|
self_kind.appendAssumeCapacity(.batching);
|
||||||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||||||
meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
||||||
} else if (maybe_coord_ax) |_| {
|
} else if (maybe_coord_ax) |_| {
|
||||||
// for gatherValues we collapsed all gathered axes
|
// for gatherValues we collapsed all gathered axes
|
||||||
// (contrary to gatherSlices where we collapse none)
|
// (contrary to gatherSlices where we collapse none)
|
||||||
@ -2057,7 +2058,7 @@ pub const Tensor = struct {
|
|||||||
indices.rank()
|
indices.rank()
|
||||||
else blk: {
|
else blk: {
|
||||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||||
meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
|
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
|
||||||
break :blk ax;
|
break :blk ax;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2124,7 +2125,7 @@ pub const Tensor = struct {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
||||||
meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
|
stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
|
||||||
return _result(res_shape, gather_op.result(0));
|
return _result(res_shape, gather_op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2201,16 +2202,16 @@ pub const Tensor = struct {
|
|||||||
const tagged_api = slice_shape.isFullyTagged();
|
const tagged_api = slice_shape.isFullyTagged();
|
||||||
if (tagged_api) {
|
if (tagged_api) {
|
||||||
for (slice_shape.tags()) |t| {
|
for (slice_shape.tags()) |t| {
|
||||||
meta.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self });
|
stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self });
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For untagged api, we require all slices to be specified.
|
// For untagged api, we require all slices to be specified.
|
||||||
// Note: we could relax this and right align the slice.
|
// Note: we could relax this and right align the slice.
|
||||||
meta.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
|
stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||||
meta.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices });
|
stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices });
|
||||||
|
|
||||||
// Compute result shape
|
// Compute result shape
|
||||||
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
|
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
|
||||||
@ -2228,12 +2229,12 @@ pub const Tensor = struct {
|
|||||||
self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
|
self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
|
||||||
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
|
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
|
||||||
slice_dims.set(self_ax, 1);
|
slice_dims.set(self_ax, 1);
|
||||||
meta.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices });
|
stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices });
|
||||||
} else if (maybe_slice_ax) |slice_ax| {
|
} else if (maybe_slice_ax) |slice_ax| {
|
||||||
// Specified axes contains the start offset of the slices,
|
// Specified axes contains the start offset of the slices,
|
||||||
// and are collected in `start_index_map`.
|
// and are collected in `start_index_map`.
|
||||||
const slice_dim = slice_shape.dim(slice_ax);
|
const slice_dim = slice_shape.dim(slice_ax);
|
||||||
meta.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape });
|
stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape });
|
||||||
slice_dims.set(self_ax, slice_dim);
|
slice_dims.set(self_ax, slice_dim);
|
||||||
res_shape = res_shape.appendDim(slice_dim, t);
|
res_shape = res_shape.appendDim(slice_dim, t);
|
||||||
start_index_map.appendAssumeCapacity(@intCast(self_ax));
|
start_index_map.appendAssumeCapacity(@intCast(self_ax));
|
||||||
@ -2395,7 +2396,7 @@ pub const Tensor = struct {
|
|||||||
const loc = @src();
|
const loc = @src();
|
||||||
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
|
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
|
||||||
|
|
||||||
meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
|
stdx.debug.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
|
||||||
|
|
||||||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||||||
const AxisKind = enum { batching, update_window, inserted_window, window_id };
|
const AxisKind = enum { batching, update_window, inserted_window, window_id };
|
||||||
@ -2420,7 +2421,7 @@ pub const Tensor = struct {
|
|||||||
indices.rank()
|
indices.rank()
|
||||||
else blk: {
|
else blk: {
|
||||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||||
meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
|
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
|
||||||
|
|
||||||
break :blk ax;
|
break :blk ax;
|
||||||
};
|
};
|
||||||
@ -2435,7 +2436,7 @@ pub const Tensor = struct {
|
|||||||
if (self_kind.get(self_ax) == .batching) {
|
if (self_kind.get(self_ax) == .batching) {
|
||||||
up_kind.appendAssumeCapacity(.batching);
|
up_kind.appendAssumeCapacity(.batching);
|
||||||
} else {
|
} else {
|
||||||
meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
|
stdx.debug.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
|
||||||
up_kind.appendAssumeCapacity(.update_window);
|
up_kind.appendAssumeCapacity(.update_window);
|
||||||
}
|
}
|
||||||
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
|
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
|
||||||
@ -2446,9 +2447,9 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
||||||
if (single_coord) {
|
if (single_coord) {
|
||||||
meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
stdx.debug.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||||||
} else {
|
} else {
|
||||||
meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
stdx.debug.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||||||
}
|
}
|
||||||
|
|
||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
@ -2671,7 +2672,7 @@ pub const Tensor = struct {
|
|||||||
/// * bubbles up Nan
|
/// * bubbles up Nan
|
||||||
/// * in case of equality the smallest index matching the maximum
|
/// * in case of equality the smallest index matching the maximum
|
||||||
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
|
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
|
||||||
meta.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
|
stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
|
||||||
|
|
||||||
const a = x.axis(axis_);
|
const a = x.axis(axis_);
|
||||||
|
|
||||||
@ -2870,7 +2871,7 @@ pub const Tensor = struct {
|
|||||||
padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } },
|
padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } },
|
||||||
}) MaxPoolRes {
|
}) MaxPoolRes {
|
||||||
// TODO: rewrite using modern ZML
|
// TODO: rewrite using modern ZML
|
||||||
meta.guard(self.rank() == 3 or self.rank() == 4, @src());
|
stdx.debug.guard(self.rank() == 3 or self.rank() == 4, @src());
|
||||||
|
|
||||||
// TODO: support maxPool on non last axis
|
// TODO: support maxPool on non last axis
|
||||||
// Note: the problem is initPoolArg assuming last axis
|
// Note: the problem is initPoolArg assuming last axis
|
||||||
@ -3004,14 +3005,14 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor {
|
pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor {
|
||||||
meta.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len});
|
stdx.debug.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len});
|
||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const length = self.dim(a);
|
const length = self.dim(a);
|
||||||
if (split_size_or_sections.len != 1) {
|
if (split_size_or_sections.len != 1) {
|
||||||
var split_sum: i64 = 0;
|
var split_sum: i64 = 0;
|
||||||
for (split_size_or_sections) |n| split_sum += n;
|
for (split_size_or_sections) |n| split_sum += n;
|
||||||
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
stdx.debug.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
||||||
@ -3029,7 +3030,7 @@ pub const Tensor = struct {
|
|||||||
/// Note: this doesn't support tagging, if you have tags,
|
/// Note: this doesn't support tagging, if you have tags,
|
||||||
/// you should use `dynamicSlice` directly.
|
/// you should use `dynamicSlice` directly.
|
||||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
||||||
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const new_shape = self._shape.set(a, len);
|
const new_shape = self._shape.set(a, len);
|
||||||
@ -3087,17 +3088,17 @@ pub const Tensor = struct {
|
|||||||
const offset = slice_.start;
|
const offset = slice_.start;
|
||||||
const len = slice_.len;
|
const len = slice_.len;
|
||||||
if (slices_tags.len == 0) {
|
if (slices_tags.len == 0) {
|
||||||
meta.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len });
|
stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len });
|
||||||
|
|
||||||
offset_values[i] = offset.value();
|
offset_values[i] = offset.value();
|
||||||
res_shape._dims.set(i, len);
|
res_shape._dims.set(i, len);
|
||||||
|
|
||||||
meta.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i });
|
stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i });
|
||||||
} else {
|
} else {
|
||||||
const t = slices_tags.get(i);
|
const t = slices_tags.get(i);
|
||||||
const a = res_shape.hasTag(t) orelse meta.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
||||||
|
|
||||||
meta.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t });
|
stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t });
|
||||||
|
|
||||||
offset_values[a] = offset.value();
|
offset_values[a] = offset.value();
|
||||||
res_shape._dims.set(a, len);
|
res_shape._dims.set(a, len);
|
||||||
@ -3149,12 +3150,12 @@ pub const Tensor = struct {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
|
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
|
||||||
// TODO: add updateSlice for when the offset isn't dynamic
|
// TODO: add updateSlice for when the offset isn't dynamic
|
||||||
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
|
stdx.debug.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
|
||||||
|
|
||||||
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
|
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
|
||||||
// log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags });
|
// log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags });
|
||||||
for (offset.constSlice(), 0..) |start_idx, i| {
|
for (offset.constSlice(), 0..) |start_idx, i| {
|
||||||
meta.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i });
|
stdx.debug.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i });
|
||||||
}
|
}
|
||||||
|
|
||||||
const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0;
|
const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0;
|
||||||
@ -3164,14 +3165,14 @@ pub const Tensor = struct {
|
|||||||
if (tagged_api) {
|
if (tagged_api) {
|
||||||
// Check that all update tags are known.
|
// Check that all update tags are known.
|
||||||
for (update._shape._tags.constSlice()) |t| {
|
for (update._shape._tags.constSlice()) |t| {
|
||||||
meta.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape });
|
stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
var update_shape = self._shape;
|
var update_shape = self._shape;
|
||||||
var prev_ax: i8 = -1;
|
var prev_ax: i8 = -1;
|
||||||
for (self._shape.tags(), 0..) |t, self_ax| {
|
for (self._shape.tags(), 0..) |t, self_ax| {
|
||||||
if (update._shape.hasTag(t)) |up_ax| {
|
if (update._shape.hasTag(t)) |up_ax| {
|
||||||
meta.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape });
|
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape });
|
||||||
|
|
||||||
update_shape._dims.set(self_ax, update.dim(up_ax));
|
update_shape._dims.set(self_ax, update.dim(up_ax));
|
||||||
prev_ax = up_ax;
|
prev_ax = up_ax;
|
||||||
@ -3182,16 +3183,16 @@ pub const Tensor = struct {
|
|||||||
update = update.reshape(update_shape);
|
update = update.reshape(update_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() });
|
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() });
|
||||||
|
|
||||||
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
|
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
|
||||||
const t = self._shape.debugTag(ax);
|
const t = self._shape.debugTag(ax);
|
||||||
meta.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t });
|
stdx.debug.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t });
|
||||||
|
|
||||||
if (tagged_api and up_d < self_d) {
|
if (tagged_api and up_d < self_d) {
|
||||||
const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null;
|
const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null;
|
||||||
|
|
||||||
meta.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t });
|
stdx.debug.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3200,7 +3201,7 @@ pub const Tensor = struct {
|
|||||||
var offset_values: [MAX_RANK]mlir.Value = undefined;
|
var offset_values: [MAX_RANK]mlir.Value = undefined;
|
||||||
if (offset_tags.len == 0) {
|
if (offset_tags.len == 0) {
|
||||||
// Without offset tags we need the same number of offset than rank.
|
// Without offset tags we need the same number of offset than rank.
|
||||||
meta.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len });
|
stdx.debug.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len });
|
||||||
|
|
||||||
for (offset.constSlice(), 0..) |idx, i| {
|
for (offset.constSlice(), 0..) |idx, i| {
|
||||||
offset_values[i] = idx.value();
|
offset_values[i] = idx.value();
|
||||||
@ -3210,7 +3211,7 @@ pub const Tensor = struct {
|
|||||||
// This is only allowed when using tagged sliced.
|
// This is only allowed when using tagged sliced.
|
||||||
offset_values = .{zero} ** MAX_RANK;
|
offset_values = .{zero} ** MAX_RANK;
|
||||||
for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
|
for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
|
||||||
const a = self._shape.hasTag(t) orelse meta.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
||||||
offset_values[a] = start.value();
|
offset_values[a] = start.value();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3329,12 +3330,12 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
|
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
|
||||||
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
|
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
|
||||||
meta.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||||||
|
|
||||||
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
|
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
|
||||||
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
|
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
|
||||||
|
|
||||||
meta.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
|
||||||
const op = dialect.stablehlo.compare(
|
const op = dialect.stablehlo.compare(
|
||||||
@ -3352,7 +3353,7 @@ pub const Tensor = struct {
|
|||||||
/// For each vector in the input tensor,
|
/// For each vector in the input tensor,
|
||||||
/// creates a diagonal-matrix where diagonal values are set to the vector values.
|
/// creates a diagonal-matrix where diagonal values are set to the vector values.
|
||||||
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
|
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
|
||||||
meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
|
stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const d = self.dim(a);
|
const d = self.dim(a);
|
||||||
var res_shape = self._shape;
|
var res_shape = self._shape;
|
||||||
@ -3409,7 +3410,7 @@ pub const Tensor = struct {
|
|||||||
/// To get the upper triangular part, swap the order of axes:
|
/// To get the upper triangular part, swap the order of axes:
|
||||||
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
|
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
|
||||||
pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor {
|
pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor {
|
||||||
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
|
stdx.debug.assertComptime(stdx.meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
|
||||||
const _axes = self.axes(axes_);
|
const _axes = self.axes(axes_);
|
||||||
|
|
||||||
const x = Tensor.iota(self.shape(), _axes.get(0));
|
const x = Tensor.iota(self.shape(), _axes.get(0));
|
||||||
@ -3472,8 +3473,8 @@ pub const Tensor = struct {
|
|||||||
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
|
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
|
||||||
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
|
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
|
||||||
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
||||||
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
|
stdx.debug.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
|
||||||
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
|
stdx.debug.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
|
||||||
|
|
||||||
if (bool_tensor.rank() != 0 and on_true.rank() == 0) {
|
if (bool_tensor.rank() != 0 and on_true.rank() == 0) {
|
||||||
return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false);
|
return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false);
|
||||||
@ -3482,8 +3483,8 @@ pub const Tensor = struct {
|
|||||||
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
|
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
||||||
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
||||||
|
|
||||||
const loc = bool_tensor.getContext().mlirCtx().location(@src());
|
const loc = bool_tensor.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.select(
|
const op = dialect.stablehlo.select(
|
||||||
@ -3538,11 +3539,11 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void {
|
fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void {
|
||||||
meta.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{});
|
stdx.debug.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{});
|
||||||
meta.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len });
|
stdx.debug.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len });
|
||||||
for (vectors) |x| {
|
for (vectors) |x| {
|
||||||
meta.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors});
|
stdx.debug.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors});
|
||||||
meta.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors});
|
stdx.debug.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors});
|
||||||
}
|
}
|
||||||
|
|
||||||
var res_shape = Shape.init(.{}, vectors[0].dtype());
|
var res_shape = Shape.init(.{}, vectors[0].dtype());
|
||||||
@ -3645,7 +3646,7 @@ pub const Tensor = struct {
|
|||||||
) fn (Tensor, Tensor) Tensor {
|
) fn (Tensor, Tensor) Tensor {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
||||||
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
|
stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
|
||||||
|
|
||||||
if (self.rank() == 0 and other.rank() != 0) {
|
if (self.rank() == 0 and other.rank() != 0) {
|
||||||
return binaryOpHelper(self.broad(other._shape), other);
|
return binaryOpHelper(self.broad(other._shape), other);
|
||||||
@ -3655,7 +3656,7 @@ pub const Tensor = struct {
|
|||||||
return binaryOpHelper(self, other.broad(self._shape));
|
return binaryOpHelper(self, other.broad(self._shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
||||||
|
|
||||||
const mlirCtx = self.getContext().mlirCtx();
|
const mlirCtx = self.getContext().mlirCtx();
|
||||||
const location = mlirCtx.location(@src());
|
const location = mlirCtx.location(@src());
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
//! Test runner for unit test based on https://github.com/ziglang/zig/blob/master/lib/compiler/test_runner.zig with async
|
//! Test runner for unit test based on https://github.com/ziglang/zig/blob/master/lib/compiler/test_runner.zig with async
|
||||||
const builtin = @import("builtin");
|
|
||||||
|
|
||||||
const std = @import("std");
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
const io = std.io;
|
const io = std.io;
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
@ -21,10 +21,10 @@ var fba = std.heap.FixedBufferAllocator.init(&fba_buffer);
|
|||||||
|
|
||||||
pub fn main() anyerror!void {
|
pub fn main() anyerror!void {
|
||||||
testing.log_level = log_level;
|
testing.log_level = log_level;
|
||||||
try asynk.AsyncThread.main(testing.allocator, asyncMain, .{});
|
try asynk.AsyncThread.main(testing.allocator, asyncMain);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asyncMain() void {
|
pub fn asyncMain() !void {
|
||||||
const test_fn_list: []const std.builtin.TestFn = builtin.test_functions;
|
const test_fn_list: []const std.builtin.TestFn = builtin.test_functions;
|
||||||
var ok_count: usize = 0;
|
var ok_count: usize = 0;
|
||||||
var skip_count: usize = 0;
|
var skip_count: usize = 0;
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
const std = @import("std");
|
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const shapesOf = @import("tensor.zig").shapesOf;
|
const shapesOf = @import("tensor.zig").shapesOf;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_testing);
|
const log = std.log.scoped(.@"zml/testing");
|
||||||
|
|
||||||
var _ctx: ?zml.Context = null;
|
var _ctx: ?zml.Context = null;
|
||||||
|
|
||||||
@ -128,7 +129,7 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
|
|||||||
/// Compile a function and immediatly call it with the given buffers.
|
/// Compile a function and immediatly call it with the given buffers.
|
||||||
/// The compiled module is discarded after the call.
|
/// The compiled module is discarded after the call.
|
||||||
/// Useful during testing when a module is typically called only once.
|
/// Useful during testing when a module is typically called only once.
|
||||||
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) {
|
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
@ -139,7 +140,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
|
|||||||
return x.shape();
|
return x.shape();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
var shape_args: zml.ShapeOf(meta.FnParams(func)) = undefined;
|
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
|
||||||
try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args);
|
try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args);
|
||||||
|
|
||||||
const mod = try zml.compileFn(allocator, func, shape_args, platform);
|
const mod = try zml.compileFn(allocator, func, shape_args, platform);
|
||||||
@ -151,7 +152,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
|
|||||||
/// Compile a function and immediatly call it with the given buffers.
|
/// Compile a function and immediatly call it with the given buffers.
|
||||||
/// The compiled module is discarded after the call.
|
/// The compiled module is discarded after the call.
|
||||||
/// Useful during testing when a module is typically called only once.
|
/// Useful during testing when a module is typically called only once.
|
||||||
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(meta.FnParams(func)), buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) {
|
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
//! Text tokenizer implementations
|
//! Text tokenizer implementations
|
||||||
const std = @import("std");
|
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const testing = std.testing;
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_tokenizer);
|
const testing = std.testing;
|
||||||
|
|
||||||
const helpers = @import("helpers.zig");
|
const helpers = @import("helpers.zig");
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/tokenizer");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(Normalizer);
|
std.testing.refAllDecls(Normalizer);
|
||||||
@ -202,7 +204,7 @@ pub const Tokenizer = struct {
|
|||||||
// Detects memory corruption of tokens.
|
// Detects memory corruption of tokens.
|
||||||
if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !");
|
if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !");
|
||||||
|
|
||||||
meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
|
stdx.debug.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
|
||||||
}
|
}
|
||||||
const next_tok = self.tokens[tok_buff[i + 1]];
|
const next_tok = self.tokens[tok_buff[i + 1]];
|
||||||
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.
|
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const log = std.log.scoped(.zml_torch);
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
const Tensor = zml.Tensor;
|
const Tensor = zml.Tensor;
|
||||||
const meta = zml.meta;
|
|
||||||
|
const log = std.log.scoped(.zml_torch);
|
||||||
|
|
||||||
/// Multiplies a matrix or a vector with a tensor,
|
/// Multiplies a matrix or a vector with a tensor,
|
||||||
/// following the semantic of pytorch `@` operator.
|
/// following the semantic of pytorch `@` operator.
|
||||||
@ -14,7 +16,7 @@ const meta = zml.meta;
|
|||||||
/// * `matmul(.{10}, .{10}) -> .{}`
|
/// * `matmul(.{10}, .{10}) -> .{}`
|
||||||
/// * `matmul(.{10}, .{10}) -> .{}`
|
/// * `matmul(.{10}, .{10}) -> .{}`
|
||||||
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
||||||
meta.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
|
stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
|
||||||
|
|
||||||
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
|
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
|
||||||
if (lhs.rank() == 1 or rhs.rank() <= 2) {
|
if (lhs.rank() == 1 or rhs.rank() <= 2) {
|
||||||
@ -22,7 +24,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
|||||||
return lhs.dotGeneral(rhs, &contracting, &.{});
|
return lhs.dotGeneral(rhs, &contracting, &.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
meta.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
|
stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
|
||||||
|
|
||||||
// Pytorch treats the extra dimensions of rhs has batching dimensions,
|
// Pytorch treats the extra dimensions of rhs has batching dimensions,
|
||||||
// and implicitly broadcast lhs along those.
|
// and implicitly broadcast lhs along those.
|
||||||
@ -91,7 +93,7 @@ pub fn unsqueeze(
|
|||||||
self: Tensor,
|
self: Tensor,
|
||||||
axis_: anytype,
|
axis_: anytype,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
meta.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
|
||||||
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
||||||
.Int, .ComptimeInt => if (axis_ < 0)
|
.Int, .ComptimeInt => if (axis_ < 0)
|
||||||
@as(i8, self.rank()) + 1 + axis_
|
@as(i8, self.rank()) + 1 + axis_
|
||||||
@ -125,9 +127,9 @@ test unsqueeze {
|
|||||||
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
|
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
|
||||||
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
|
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
|
||||||
const shape = tensor.shape();
|
const shape = tensor.shape();
|
||||||
meta.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
|
stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
|
||||||
|
|
||||||
meta.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
|
stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
|
||||||
|
|
||||||
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
|
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
|
||||||
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });
|
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });
|
||||||
@ -173,7 +175,7 @@ test pixelShuffle {
|
|||||||
/// ref: https://pytorch.org/docs/stable/generated/torch.roll.html
|
/// ref: https://pytorch.org/docs/stable/generated/torch.roll.html
|
||||||
pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
|
pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
|
||||||
// TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 })
|
// TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 })
|
||||||
meta.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len });
|
stdx.debug.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len });
|
||||||
|
|
||||||
if (shifts.len != 1 or axes_.len != 1) {
|
if (shifts.len != 1 or axes_.len != 1) {
|
||||||
const tail_shifts = shifts[1..shifts.len];
|
const tail_shifts = shifts[1..shifts.len];
|
||||||
@ -219,8 +221,8 @@ pub const MeshgridIndexing = enum { xy, ij };
|
|||||||
/// * for ‘ij’ indexing, outputs are of shape (M, N, P)
|
/// * for ‘ij’ indexing, outputs are of shape (M, N, P)
|
||||||
/// * for ‘xy’ indexing, outputs are of shape (N, M, P)
|
/// * for ‘xy’ indexing, outputs are of shape (N, M, P)
|
||||||
pub fn meshgrid(comptime N: u3, vectors: [N]Tensor, indexing: MeshgridIndexing) [N]Tensor {
|
pub fn meshgrid(comptime N: u3, vectors: [N]Tensor, indexing: MeshgridIndexing) [N]Tensor {
|
||||||
meta.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{});
|
stdx.debug.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{});
|
||||||
meta.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len});
|
stdx.debug.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len});
|
||||||
|
|
||||||
if (vectors.len == 1) return vectors;
|
if (vectors.len == 1) return vectors;
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user