Add stdx utilities and rework async signature inference; tidy executable logging.

This commit is contained in:
Tarry Singh 2023-06-21 14:45:14 +00:00
parent c30aa018dc
commit 9b7eea8ac2
40 changed files with 1490 additions and 1241 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,26 +1,86 @@
const std = @import("std"); const std = @import("std");
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type {
return struct { const params = @typeInfo(funcT).Fn.params;
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); if (params.len == 0) {
pub const ArgsT = blk: { return @TypeOf(.{});
if (@typeInfo(FuncT).Fn.params.len == 0) {
break :blk @TypeOf(.{});
} }
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
if (@typeInfo(funcT).Fn.is_generic == false) {
return std.meta.ArgsTuple(funcT);
}
const args = std.meta.fields(argsT orelse @compileError("generic function requires an explicit ArgsTuple"));
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
inline for (params, args, 0..) |param, arg, i| {
if (param.type == null) {
tuple_fields[i] = arg;
continue;
}
const T = param.type.?;
var num_buf: [32]u8 = undefined;
tuple_fields[i] = .{
.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
},
.type = T,
.default_value = null,
.is_comptime = false,
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
}; };
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &tuple_fields,
},
});
}
pub fn TupleRange(comptime T: type, comptime start: usize, comptime end: usize) type {
const fields = std.meta.fields(T);
var new_fields: [end - start]std.builtin.Type.StructField = undefined;
inline for (start..end, 0..) |i, j| {
var new_field = fields[i];
var num_buf: [32]u8 = undefined;
new_field.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
};
new_fields[j] = new_field;
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &new_fields,
},
});
}
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT));
}
pub fn FnSignatureX(comptime func: anytype, comptime argsT: type) type {
return struct {
pub const FuncT = @TypeOf(func);
pub const ArgsT = argsT;
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = blk: { pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload, .ErrorUnion => |u| u.payload,
else => ReturnT, else => ReturnT,
}; };
}; pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) {
pub const ReturnErrorSet: ?type = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set, .ErrorUnion => |u| u.error_set,
else => null, else => null,
}; };
}; };
};
} }

View File

@ -2,16 +2,22 @@ const std = @import("std");
const xev = @import("xev"); const xev = @import("xev");
const FnSignature = @import("meta.zig").FnSignature; const FnSignature = @import("meta.zig").FnSignature;
const NormalizedTuple = @import("meta.zig").NormalizedTuple;
pub fn Frame(comptime func: anytype) type { pub fn Frame(comptime func: anytype) type {
const Signature = FnSignature(func, null); const Signature = FnSignature(func, null);
return FrameEx(func, Signature.ArgsT); return FrameExx(func, Signature);
} }
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
const Signature = FnSignature(func, argsT);
return FrameExx(func, Signature);
}
pub fn FrameExx(comptime func: anytype, comptime Signature: type) type {
return struct { return struct {
const Self = @This(); const Self = @This();
const Signature = FnSignature(func, argsT); const Signature_ = Signature;
const Task = struct { const Task = struct {
_task: xev.ThreadPool.Task = .{ .callback = &Self.run }, _task: xev.ThreadPool.Task = .{ .callback = &Self.run },
event: std.Thread.ResetEvent = .{}, event: std.Thread.ResetEvent = .{},
@ -27,7 +33,8 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
task.event.set(); task.event.set();
} }
pub fn await_(self: *Self) Signature.ReturnT { pub const await_ = wait;
pub fn wait(self: *Self) Signature.ReturnT {
defer { defer {
AsyncThread.current.mutex.lock(); AsyncThread.current.mutex.lock();
AsyncThread.current.allocator.destroy(self._task); AsyncThread.current.allocator.destroy(self._task);
@ -39,11 +46,7 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
}; };
} }
pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) { pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
return asyncGeneric(func, args);
}
pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
const FrameT = FrameEx(func, @TypeOf(args)); const FrameT = FrameEx(func, @TypeOf(args));
AsyncThread.current.mutex.lock(); AsyncThread.current.mutex.lock();
@ -58,15 +61,11 @@ pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeO
return .{ ._task = task }; return .{ ._task = task };
} }
pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) { pub inline fn callBlocking(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
return callBlockingGeneric(func, args);
}
pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
return @call(.auto, func, args); return @call(.auto, func, args);
} }
pub fn sleep(ms: u64) !void { pub inline fn sleep(ms: u64) !void {
std.time.sleep(ms * std.time.ns_per_ms); std.time.sleep(ms * std.time.ns_per_ms);
} }
@ -77,7 +76,7 @@ pub const AsyncThread = struct {
thread_pool: xev.ThreadPool, thread_pool: xev.ThreadPool,
mutex: std.Thread.Mutex, mutex: std.Thread.Mutex,
pub fn main(allocator_: std.mem.Allocator, comptime func: anytype, args: anytype) !void { pub fn main(allocator_: std.mem.Allocator, comptime mainFunc: anytype) !void {
current = .{ current = .{
.allocator = allocator_, .allocator = allocator_,
.thread_pool = xev.ThreadPool.init(.{}), .thread_pool = xev.ThreadPool.init(.{}),
@ -89,7 +88,7 @@ pub const AsyncThread = struct {
current.thread_pool.deinit(); current.thread_pool.deinit();
} }
return @call(.auto, func, args); return try mainFunc();
} }
}; };
@ -114,15 +113,15 @@ pub const Notification = struct {
} }
}; };
pub fn StdIn() !File { pub fn getStdIn() !File {
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
} }
pub fn StdOut() File { pub fn getStdOut() File {
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
} }
pub fn StdErr() File { pub fn getStdErr() File {
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
} }
@ -217,3 +216,23 @@ pub const File = struct {
}; };
pub const Mutex = std.Thread.Mutex; pub const Mutex = std.Thread.Mutex;
pub fn logFn(
comptime message_level: std.log.Level,
comptime scope: @Type(.EnumLiteral),
comptime format: []const u8,
args: anytype,
) void {
const level_txt = comptime message_level.asText();
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
const stderr = getStdErr().writer();
var bw = std.io.bufferedWriter(stderr);
const writer = bw.writer();
std.debug.lockStdErr();
defer std.debug.unlockStdErr();
nosuspend {
writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return;
bw.flush() catch return;
}
}

View File

@ -504,13 +504,14 @@ pub const LoadedExecutable = opaque {
return @ptrCast(ret.addressable_devices); return @ptrCast(ret.addressable_devices);
} }
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { pub const ExecuteArgs = struct {
num_args: usize, num_args: usize,
arguments: []const [*]const *const Buffer, arguments: []const [*]const *const Buffer,
results: []const [*]*Buffer, results: []const [*]*Buffer,
events: []?*Event, events: []?*Event,
non_donatable_input_indices: []const i64 = &.{}, non_donatable_input_indices: []const i64 = &.{},
}) ApiError!void { };
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{ var options = pjrtStruct(c.PJRT_ExecuteOptions{
.send_callbacks = null, .send_callbacks = null,
.recv_callbacks = null, .recv_callbacks = null,

View File

@ -2,7 +2,7 @@ const std = @import("std");
const c = @import("c"); const c = @import("c");
const tsl_proto = @import("//tsl:profiler_options_proto"); const tsl_proto = @import("//tsl:profiler_options_proto");
const log = std.log.scoped(.zml_profiler); const log = std.log.scoped(.@"zml/profiler");
/// Pjrt Profiler extension /// Pjrt Profiler extension
pub const Profiler = struct { pub const Profiler = struct {

13
stdx/BUILD.bazel Normal file
View File

@ -0,0 +1,13 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "stdx",
srcs = [
"debug.zig",
"math.zig",
"meta.zig",
"signature.zig",
],
main = "stdx.zig",
visibility = ["//visibility:public"],
)

33
stdx/debug.zig Normal file
View File

@ -0,0 +1,33 @@
const std = @import("std");
pub inline fn guard(check: bool, src: std.builtin.SourceLocation) void {
assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line });
}
pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void {
assert(check, "internal error: " ++ msg, args);
}
pub inline fn assert(check: bool, comptime msg: []const u8, args: anytype) void {
if (!check) {
panic(msg, args);
}
}
pub inline fn panic(comptime format: []const u8, args: anytype) noreturn {
std.debug.panic(format, args);
}
pub inline fn compileLog(comptime msg: []const u8, comptime args: anytype) void {
@compileLog(std.fmt.comptimePrint(msg, args));
}
pub inline fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn {
@compileError(std.fmt.comptimePrint(msg, args));
}
pub inline fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void {
if (check == false) {
compileError(msg, args);
}
}

25
stdx/math.zig Normal file
View File

@ -0,0 +1,25 @@
pub inline fn divFloor(comptime T: type, numerator: anytype, denominator: anytype) T {
return @divFloor(floatCast(T, numerator), floatCast(T, denominator));
}
pub inline fn divExact(comptime T: type, numerator: anytype, denominator: anytype) T {
return @divExact(floatCast(T, numerator), floatCast(T, denominator));
}
pub inline fn divTrunc(comptime T: type, numerator: anytype, denominator: anytype) T {
return @divTrunc(floatCast(T, numerator), floatCast(T, denominator));
}
pub inline fn floatCast(comptime T: type, x: anytype) T {
return switch (@typeInfo(@TypeOf(x))) {
.Float => @floatCast(x),
else => @floatFromInt(x),
};
}
pub inline fn intCast(comptime T: type, x: anytype) T {
return switch (@typeInfo(@TypeOf(x))) {
.Int => @intCast(x),
else => @intFromFloat(x),
};
}

158
stdx/meta.zig Normal file
View File

@ -0,0 +1,158 @@
const std = @import("std");
const debug = @import("debug.zig");
const compileError = debug.compileError;
pub const FnSignature = @import("signature.zig").FnSignature;
pub fn isStruct(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => true,
else => false,
};
}
pub fn isTuple(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| info.is_tuple,
else => false,
};
}
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (field.type != Elem) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (f(field.type) == false) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isTupleOf(comptime T: type, comptime Elem: type) bool {
return isTuple(T) and isStructOf(T, Elem);
}
pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return isTuple(T) and isStructOfAny(T, f);
}
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child == Elem,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child == Elem,
else => false,
},
else => false,
},
else => false,
};
}
pub fn isInteger(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Int, .ComptimeInt => true,
else => false,
};
}
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| info.size == .Slice and f(info.child),
else => false,
};
}
pub fn DeclEnum(comptime T: type) type {
const field_infos = std.meta.declarations(T);
if (field_infos.len == 0) {
compileError("Struct {} has no declarations", .{T});
}
return std.meta.DeclEnum(UnwrapPtr(T));
}
pub fn UnwrapPtr(comptime T: type) type {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.One => info.child,
else => T,
},
else => T,
};
}
pub fn asSlice(comptime T: type) type {
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child,
else => compileError(err_msg),
},
else => compileError(err_msg),
},
else => compileError(err_msg),
};
}
pub fn TupleRange(comptime T: type, comptime start: ?usize, comptime end: ?usize) type {
return TupleRangeX(T, start orelse 0, end orelse std.meta.fields(T).len);
}
pub fn TupleRangeX(comptime T: type, comptime start: usize, comptime end: usize) type {
const fields = std.meta.fields(T);
var new_fields: [end - start]std.builtin.Type.StructField = undefined;
inline for (start..end, 0..) |i, j| {
var new_field = fields[i];
var num_buf: [32]u8 = undefined;
new_field.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
};
new_fields[j] = new_field;
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &new_fields,
},
});
}
pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type {
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse compileError("anytype is not supported");
}
pub fn FnArgs(comptime func: anytype) type {
return FnSignature(func, null).ArgsT;
}
pub fn FnResult(comptime func: anytype) type {
return FnSignature(func, null).ReturnT;
}

65
stdx/signature.zig Normal file
View File

@ -0,0 +1,65 @@
const std = @import("std");
const compileError = @import("meta.zig").compileError;
pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type {
const params = @typeInfo(funcT).Fn.params;
if (params.len == 0) {
return @TypeOf(.{});
}
if (@typeInfo(funcT).Fn.is_generic == false) {
return std.meta.ArgsTuple(funcT);
}
const args = std.meta.fields(argsT orelse compileError("generic function requires an explicit ArgsTuple", .{}));
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
inline for (params, args, 0..) |param, arg, i| {
if (param.type == null) {
tuple_fields[i] = arg;
continue;
}
const T = param.type.?;
var num_buf: [32]u8 = undefined;
tuple_fields[i] = .{
.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
},
.type = T,
.default_value = null,
.is_comptime = false,
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
};
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &tuple_fields,
},
});
}
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT));
}
fn FnSignatureX(comptime func: anytype, comptime argsT: type) type {
return struct {
pub const FuncT = @TypeOf(func);
pub const ArgsT = argsT;
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload,
else => ReturnT,
};
pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
};
}

3
stdx/stdx.zig Normal file
View File

@ -0,0 +1,3 @@
pub const math = @import("math.zig");
pub const meta = @import("meta.zig");
pub const debug = @import("debug.zig");

View File

@ -32,6 +32,7 @@ zig_library(
"//mlir/dialects", "//mlir/dialects",
"//pjrt", "//pjrt",
"//runtimes", "//runtimes",
"//stdx",
"//zml/tools", "//zml/tools",
"@rules_zig//zig/lib:libc", "@rules_zig//zig/lib:libc",
"@rules_zig//zig/runfiles", "@rules_zig//zig/runfiles",

View File

@ -1,8 +1,10 @@
const builtin = @import("builtin");
const asynk = @import("async"); const asynk = @import("async");
const std = @import("std"); const builtin = @import("builtin");
const zml = @import("zml.zig");
const c = @import("c"); const c = @import("c");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml.zig");
const posix = @import("posix.zig"); const posix = @import("posix.zig");
pub const gguf = @import("aio/gguf.zig"); pub const gguf = @import("aio/gguf.zig");
@ -13,7 +15,7 @@ pub const tinyllama = @import("aio/tinyllama.zig");
pub const torch = @import("aio/torch.zig"); pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig"); pub const yaml = @import("aio/yaml.zig");
pub const log = std.log.scoped(.zml_aio); pub const log = std.log.scoped(.@"zml/aio");
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test { test {
@ -256,7 +258,11 @@ pub const MemoryMappedFile = struct {
0, 0,
}); });
try asynk.callBlocking(posix.madvise, .{ data_.ptr, @intCast(data_.len), @intCast(c.MADV_SEQUENTIAL) }); try asynk.callBlocking(posix.madvise, .{
data_.ptr,
@as(usize, @intCast(data_.len)),
@as(u32, @intCast(c.MADV_SEQUENTIAL)),
});
return .{ return .{
.file = file, .file = file,
@ -600,7 +606,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer; var buf_with_metadata = host_buffer;
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
buf_with_metadata._shape = obj._shape; buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata); obj.* = try zml.Buffer.from(platform, buf_with_metadata);
} else { } else {

View File

@ -8,7 +8,7 @@ const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const assert = std.debug.assert; const assert = std.debug.assert;
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
var file = try core.GgufFile.open(path); var file = try core.GgufFile.open(path);

View File

@ -3,7 +3,7 @@ const std = @import("std");
const zml = @import("../../zml.zig"); const zml = @import("../../zml.zig");
const assert = std.debug.assert; const assert = std.debug.assert;
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.@"zml/io");
pub const GgufErrors = error{ pub const GgufErrors = error{
ValueTypeMismatch, ValueTypeMismatch,

View File

@ -1,5 +1,5 @@
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.zml_aio); const log = std.log.scoped(.@"zml/aio");
const asynk = @import("async"); const asynk = @import("async");
const yaml = @import("zig-yaml"); const yaml = @import("zig-yaml");

View File

@ -7,7 +7,7 @@ const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{ var res: zml.aio.BufferStore = .{

View File

@ -1,8 +1,8 @@
/// Tools to load models from https://huggingface.co/karpathy/tinyllamas/ /// Tools to load models from https://huggingface.co/karpathy/tinyllamas/
/// Originally made to be run with https://github.com/karpathy/llama2.c /// Originally made to be run with https://github.com/karpathy/llama2.c
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("../zml.zig"); const zml = @import("../zml.zig");
@ -86,7 +86,7 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe
const weights_size = off; const weights_size = off;
std.log.info("Loaded a tinyllama file of {} bytes.\nThis is the parsed configuration of this llama model: {}", .{ weights_size, c }); std.log.info("Loaded a tinyllama file of {} bytes.\nThis is the parsed configuration of this llama model: {}", .{ weights_size, c });
if (file.stat() catch null) |stat| { if (file.stat() catch null) |stat| {
zml.meta.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c }); stdx.debug.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c });
} }
{ {

View File

@ -7,7 +7,7 @@ const py = @import("torch/py.zig");
const File = @import("torch/file.zig").File; const File = @import("torch/file.zig").File;
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.zml_aio); const log = std.log.scoped(.@"zml/aio");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());

View File

@ -1,6 +1,5 @@
const std = @import("std"); const std = @import("std");
const zml = @import("../../zml.zig"); const stdx = @import("stdx");
const meta = zml.meta;
const py = @import("py.zig"); const py = @import("py.zig");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
@ -228,7 +227,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
}, },
} }
}, },
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.append(blk: { .tuple1 => try stack.append(blk: {
const tup_values = try arena.alloc(py.Any, 1); const tup_values = try arena.alloc(py.Any, 1);
tup_values[0] = try pop(&stack); tup_values[0] = try pop(&stack);

View File

@ -1,8 +1,6 @@
const std = @import("std");
const testing = std.testing;
const log = std.log.scoped(.zml_aio);
const asynk = @import("async"); const asynk = @import("async");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("../../zml.zig"); const zml = @import("../../zml.zig");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
@ -10,6 +8,9 @@ const py = @import("py.zig");
const eval = @import("eval.zig"); const eval = @import("eval.zig");
const HostBuffer = zml.HostBuffer; const HostBuffer = zml.HostBuffer;
const testing = std.testing;
const log = std.log.scoped(.@"zml/aio");
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead // TODO(cryptodeal): use zml.aio.PrefixBuilder instead
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
@ -329,7 +330,7 @@ pub const File = struct {
}, },
.dict => { .dict => {
const n = @divExact(seq.values.len, 2); const n = @divExact(seq.values.len, 2);
log.info("found dict with {} entries", .{n}); log.debug("found dict with {} entries", .{n});
for (0..n) |i| { for (0..n) |i| {
const key, const val = seq.values[2 * i ..][0..2].*; const key, const val = seq.values[2 * i ..][0..2].*;
switch (key) { switch (key) {
@ -534,7 +535,7 @@ pub const File = struct {
} }
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray { fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); stdx.debug.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
var result: zml.Shape.DimsArray = .{}; var result: zml.Shape.DimsArray = .{};
for (values) |val| { for (values) |val| {
switch (val) { switch (val) {

View File

@ -1,6 +1,6 @@
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.zml_aio); const log = std.log.scoped(.@"zml/aio");
/// All possible pickle operators. /// All possible pickle operators.
/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py /// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py

View File

@ -1,6 +1,6 @@
const std = @import("std"); const std = @import("std");
const math = std.math; const math = std.math;
const log = std.log.scoped(.zml_aio); const log = std.log.scoped(.@"zml/aio");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");

View File

@ -1,9 +1,11 @@
const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const testing = std.testing; const stdx = @import("stdx");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const asynk = @import("async");
const testing = std.testing;
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
@ -42,12 +44,12 @@ pub const Buffer = struct {
// We shard only on the first axis so that the chunks are still contiguous. // We shard only on the first axis so that the chunks are still contiguous.
// TODO: support more advanced sharding specs // TODO: support more advanced sharding specs
meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info); const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info);
const n_partitions = platform.sharding().num_partitions; const n_partitions = platform.sharding().num_partitions;
const chunk_size = if (sharding_ax) |ax| cs: { const chunk_size = if (sharding_ax) |ax| cs: {
// This kind of sharding error should be detected earlier on. // This kind of sharding error should be detected earlier on.
meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
break :cs @divExact(host_buffer.dim(ax), n_partitions); break :cs @divExact(host_buffer.dim(ax), n_partitions);
} else 0; } else 0;
@ -88,8 +90,8 @@ pub const Buffer = struct {
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer { pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); stdx.debug.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
var shards: Shards = .{}; var shards: Shards = .{};
shards.appendSliceAssumeCapacity(pjrt_buffers); shards.appendSliceAssumeCapacity(pjrt_buffers);
return .{ return .{
@ -190,9 +192,9 @@ pub const Buffer = struct {
/// Fetches the content of the given buffer into a stack variable of the given type. /// Fetches the content of the given buffer into a stack variable of the given type.
pub fn getValue(self: Buffer, T: type) !T { pub fn getValue(self: Buffer, T: type) !T {
meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
var res: T = undefined; var res: T = undefined;
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
if (maybe_event) |event| { if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
@ -204,7 +206,7 @@ pub const Buffer = struct {
/// and return a new `HostBuffer` object with the same shape. /// and return a new `HostBuffer` object with the same shape.
/// The returned `HostBuffer` doesn't own the memory. /// The returned `HostBuffer` doesn't own the memory.
pub fn toHost(self: Buffer, output: []u8) !HostBuffer { pub fn toHost(self: Buffer, output: []u8) !HostBuffer {
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output);
if (maybe_event) |event| { if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
@ -216,7 +218,7 @@ pub const Buffer = struct {
/// The returned `HostBuffer` does own the memory. /// The returned `HostBuffer` does own the memory.
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
const output = try HostBuffer.empty(allocator, self.shape()); const output = try HostBuffer.empty(allocator, self.shape());
meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
if (maybe_event) |event| { if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);

View File

@ -1,22 +1,22 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std");
const mlir = @import("mlir");
const c = @import("c"); const c = @import("c");
const mlir = @import("mlir");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const runtimes = @import("runtimes"); const runtimes = @import("runtimes");
const std = @import("std");
const stdx = @import("stdx");
const platform = @import("platform.zig"); const platform = @import("platform.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const available_targets = @import("platform.zig").available_targets;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Target = @import("platform.zig").Target;
const Platform = @import("platform.zig").Platform;
const log = std.log.scoped(.zml);
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
const Platform = @import("platform.zig").Platform;
const PlatformsMap = std.EnumArray(Target, ?Platform); const PlatformsMap = std.EnumArray(Target, ?Platform);
const Target = @import("platform.zig").Target;
const available_targets = @import("platform.zig").available_targets;
const log = std.log.scoped(.@"zml/context");
/// Every program using ZML must start with a `zml.Context.init(.{});` /// Every program using ZML must start with a `zml.Context.init(.{});`
/// The ZML context contains global state to interact with the different /// The ZML context contains global state to interact with the different
@ -145,6 +145,36 @@ pub const Context = struct {
return platform_ orelse @panic("No platform found !"); return platform_ orelse @panic("No platform found !");
} }
pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void {
// List available targets
log.info("Available Platforms:", .{});
const selected_prefix = "";
const not_selected_prefix = "";
const selected_postfix = "(AUTO-SELECTED)";
const not_selected_postfix = "";
for (platform.available_targets) |target| {
log.info(" {s} {s} {s}", .{
if (target == selected.target) selected_prefix else not_selected_prefix,
@tagName(target),
if (target == selected.target) selected_postfix else not_selected_postfix,
});
// now the platform's devices
if (self.platforms.get(target)) |pfm| {
for (pfm.getDevices(), 0..) |device, index| {
const deviceKind = device.getDescription(pfm.pjrt_api).getKind(pfm.pjrt_api);
log.info(" ◦ #{d}: {s}", .{
index,
deviceKind,
});
// we only list 1 CPU device
if (target == .cpu) break;
}
}
}
}
pub const HostCallbackCtx = struct { pub const HostCallbackCtx = struct {
host: HostBuffer, host: HostBuffer,
mutex: std.Thread.Mutex = std.Thread.Mutex{}, mutex: std.Thread.Mutex = std.Thread.Mutex{},

View File

@ -6,7 +6,7 @@ const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const EnumLiteral = @TypeOf(.enum_literal); const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.zml_tensor); const log = std.log.scoped(.@"zml/tensor");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());

View File

@ -1,6 +1,6 @@
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const meta = @import("meta.zig");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
@ -108,13 +108,13 @@ pub const HostBuffer = struct {
/// The memory is initialized with increasing numbers. /// The memory is initialized with increasing numbers.
/// The caller owns the memory, and need to call `deinit()`. /// The caller owns the memory, and need to call `deinit()`.
pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer { pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
const b = dt.sizeOf(); const b = dt.sizeOf();
const res = try empty(allocator, Shape.init(.{n_steps}, dt)); const res = try empty(allocator, Shape.init(.{n_steps}, dt));
meta.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt}); stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
var data_ = @constCast(res.data); var data_ = @constCast(res.data);
switch (dt) { switch (dt) {
inline else => { inline else => {
@ -201,7 +201,7 @@ pub const HostBuffer = struct {
} }
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
var res = self; var res = self;
res._shape = self._shape.reshape(shape_); res._shape = self._shape.reshape(shape_);
return res; return res;
@ -219,9 +219,9 @@ pub const HostBuffer = struct {
const start: i64 = if (s.start < 0) s.start + d else s.start; const start: i64 = if (s.start < 0) s.start + d else s.start;
var end = s.end orelse d; var end = s.end orelse d;
if (end < 0) end += d; if (end < 0) end += d;
meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
// If strides weren't set it means original buffer is contiguous. // If strides weren't set it means original buffer is contiguous.
// But it won't be anymore after slicing. The strides don't change though. // But it won't be anymore after slicing. The strides don't change though.

View File

@ -1,4 +1,8 @@
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const FnParam = stdx.meta.FnParam;
const asSlice = stdx.meta.asSlice;
const testing = std.testing; const testing = std.testing;
@ -6,215 +10,6 @@ test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
/// Computes floating point value division between two integers.
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
return toFloat(T, numerator) / toFloat(T, denominator);
}
fn toFloat(T: type, x: anytype) T {
return switch (@typeInfo(@TypeOf(x))) {
.Float => @floatCast(x),
else => @floatFromInt(x),
};
}
pub fn guard(check: bool, src: std.builtin.SourceLocation) void {
assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line });
}
pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void {
assert(check, "ZML internal error: " ++ msg, args);
}
pub fn assert(check: bool, comptime msg: []const u8, args: anytype) void {
if (!check) panic(msg, args);
}
pub fn panic(comptime msg: []const u8, args: anytype) noreturn {
std.log.err(msg, args);
@panic(msg);
}
pub fn compileLog(comptime msg: []const u8, comptime args: anytype) void {
@compileLog(std.fmt.comptimePrint(msg, args));
}
pub fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn {
@compileError(std.fmt.comptimePrint(msg, args));
}
pub fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void {
if (check == false) {
compileError(msg, args);
}
}
pub fn isStruct(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => true,
else => false,
};
}
pub fn isTuple(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| info.is_tuple,
else => false,
};
}
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (field.type != Elem) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (f(field.type) == false) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isTupleOf(comptime T: type, comptime Elem: type) bool {
return isTuple(T) and isStructOf(T, Elem);
}
pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return isTuple(T) and isStructOfAny(T, f);
}
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child == Elem,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child == Elem,
else => false,
},
else => false,
},
else => false,
};
}
pub fn asSlice(comptime T: type) type {
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child,
else => @compileError(err_msg),
},
else => @compileError(err_msg),
},
else => @compileError(err_msg),
};
}
pub fn isInteger(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Int, .ComptimeInt => true,
else => false,
};
}
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| info.size == .Slice and f(info.child),
else => false,
};
}
pub fn DeclEnum(comptime T: type) type {
const field_infos = std.meta.declarations(T);
if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T});
return std.meta.DeclEnum(UnwrapPtr(T));
}
pub fn UnwrapPtr(comptime T: type) type {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.One => info.child,
else => T,
},
else => T,
};
}
pub fn FnParam(func: anytype, n: comptime_int) type {
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse @compileError("anytype not supported in callbacks");
}
pub fn FnParams(func: anytype) type {
return std.meta.ArgsTuple(@TypeOf(func));
}
pub fn FnResult(func: anytype) type {
return @typeInfo(@TypeOf(func)).Fn.return_type.?;
}
pub fn FnResultPayload(func: anytype) type {
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
const payload_type = switch (@typeInfo(return_type)) {
.ErrorUnion => |u| u.payload,
else => return_type,
};
return payload_type;
}
pub fn FnResultErrorSet(func: anytype) ?type {
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
const error_set = switch (@typeInfo(return_type)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
return error_set;
}
pub fn Signature(comptime func: anytype, comptime argsT: ?type) type {
return struct {
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
pub const ArgsT = blk: {
if (@typeInfo(FuncT).Fn.params.len == 0) {
break :blk @TypeOf(.{});
}
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
};
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload,
else => ReturnT,
};
};
pub const ReturnErrorSet: ?type = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
};
};
}
pub fn MapType(From: type, To: type) type { pub fn MapType(From: type, To: type) type {
return struct { return struct {
pub fn map(T: type) type { pub fn map(T: type) type {
@ -299,7 +94,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
const type_info_to_ptr = @typeInfo(@TypeOf(to)); const type_info_to_ptr = @typeInfo(@TypeOf(to));
if (type_info_to_ptr != .Pointer) { if (type_info_to_ptr != .Pointer) {
@compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to))); stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to)));
} }
const ToStruct = type_info_to_ptr.Pointer.child; const ToStruct = type_info_to_ptr.Pointer.child;
const type_info_to = @typeInfo(ToStruct); const type_info_to = @typeInfo(ToStruct);
@ -348,7 +143,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
} else if (field.default_value) |_| { } else if (field.default_value) |_| {
@field(to, field.name) = null; @field(to, field.name) = null;
} else { } else {
compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name }); stdx.meta.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
}, },
else => @field(to, field.name) = @field(from, field.name), else => @field(to, field.name) = @field(from, field.name),
} }
@ -374,7 +169,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
} }
to.* = items; to.* = items;
}, },
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
}, },
.Optional => if (from) |f| { .Optional => if (from) |f| {
to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined); to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined);
@ -383,7 +178,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
to.* = null; to.* = null;
}, },
.Int, .Float => to.* = from, .Int, .Float => to.* = from,
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
} }
} }
@ -444,12 +239,12 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
const type_info_v = @typeInfo(T); const type_info_v = @typeInfo(T);
const K = switch (@typeInfo(FnParam(cb, 1))) { const K = switch (@typeInfo(FnParam(cb, 1))) {
.Pointer => |info| info.child, .Pointer => |info| info.child,
else => @compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))), else => stdx.meta.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))),
}; };
if (type_info_v != .Pointer) { if (type_info_v != .Pointer) {
const Callback = @TypeOf(cb); const Callback = @TypeOf(cb);
@compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T)); stdx.meta.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
} }
const ptr_info = type_info_v.Pointer; const ptr_info = type_info_v.Pointer;
if (@typeInfo(ptr_info.child) == .Fn) return; if (@typeInfo(ptr_info.child) == .Fn) return;
@ -512,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
} }
} }
}, },
else => @compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)), else => stdx.meta.compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)),
} }
} }
@ -601,10 +396,8 @@ test visit {
/// Only T elements of values will be looked at. /// Only T elements of values will be looked at.
/// This only works for simple types, in particular `zip` doesn't follow pointers. /// This only works for simple types, in particular `zip` doesn't follow pointers.
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call. /// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) { pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
const sliceT = @typeInfo(FnParam(func, 0)); const sliceT = @typeInfo(FnParam(func, 0));
assertComptime(sliceT == .Pointer and sliceT.Pointer.size == .Slice and sliceT.Pointer.child == FnResult(func), "zip requires a `fn([]const T, Args) T`, received: {}", .{@TypeOf(func)});
const T = sliceT.Pointer.child; const T = sliceT.Pointer.child;
const V = asSlice(@TypeOf(values)); const V = asSlice(@TypeOf(values));
if (V == T) { if (V == T) {
@ -613,13 +406,13 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
// const fn_args // const fn_args
return switch (@typeInfo(V)) { return switch (@typeInfo(V)) {
.Pointer => @compileError("zip only accept by value arguments. Received: " ++ @typeName(V)), .Pointer => stdx.meta.compileError("zip only accept by value arguments. Received: " ++ @typeName(V)),
.Struct => |struct_info| { .Struct => |struct_info| {
var out: V = values[0]; var out: V = values[0];
inline for (struct_info.fields) |f| { inline for (struct_info.fields) |f| {
if (f.is_comptime) continue; if (f.is_comptime) continue;
if (@typeInfo(f.type) == .Pointer) { if (@typeInfo(f.type) == .Pointer) {
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
} }
var fields = try allocator.alloc(f.type, values.len); var fields = try allocator.alloc(f.type, values.len);
defer allocator.free(fields); defer allocator.free(fields);
@ -632,7 +425,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
}, },
.Array => |arr_info| { .Array => |arr_info| {
if (@typeInfo(arr_info.child) == .Pointer) { if (@typeInfo(arr_info.child) == .Pointer) {
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
} }
var out: V = undefined; var out: V = undefined;
var slice = try allocator.alloc(arr_info.child, values.len); var slice = try allocator.alloc(arr_info.child, values.len);
@ -645,7 +438,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a
} }
return out; return out;
}, },
.Union, .Optional => @compileError("zip doesn't yet support " ++ @typeName(V)), .Union, .Optional => stdx.meta.compileError("zip doesn't yet support " ++ @typeName(V)),
else => values[0], else => values[0],
}; };
} }
@ -668,11 +461,11 @@ test zip {
/// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist. /// finds all X in the given object, and write the result of func(X) into an arraylist.
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void { pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
const LocalContext = struct { const LocalContext = struct {
func_ctx: _CollectCtx(func), func_ctx: _CollectCtx(func),
out: *std.ArrayList(FnResult(func)), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT),
oom: bool = false, oom: bool = false,
}; };
var context = LocalContext{ .func_ctx = func_ctx, .out = out }; var context = LocalContext{ .func_ctx = func_ctx, .out = out };
@ -691,10 +484,10 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(F
fn _CollectCtx(func: anytype) type { fn _CollectCtx(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).Fn.params; const params = @typeInfo(@TypeOf(func)).Fn.params;
if (params.len == 1) return void; if (params.len == 1) return void;
return params[0].type orelse @compileError("anytype not supported in collect"); return params[0].type orelse stdx.meta.compileError("anytype not supported in collect");
} }
fn _CollectArg(func: anytype) type { fn _CollectArg(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).Fn.params; const params = @typeInfo(@TypeOf(func)).Fn.params;
return params[params.len - 1].type orelse @compileError("anytype not supported in collect"); return params[params.len - 1].type orelse stdx.meta.compileError("anytype not supported in collect");
} }

View File

@ -2,14 +2,14 @@ const mlir = @This();
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const dtype = @import("dtype.zig"); const dtype = @import("dtype.zig");
const meta = @import("meta.zig");
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.zml_mlir); const log = std.log.scoped(.@"zml/mlir");
pub usingnamespace @import("mlir"); pub usingnamespace @import("mlir");
@ -128,7 +128,7 @@ pub const ext = struct {
} }
} }
meta.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
} }
}; };
@ -148,7 +148,7 @@ pub const ext = struct {
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val)); const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
return int_attr.as(mlir.Attribute).?; return int_attr.as(mlir.Attribute).?;
}, },
inline else => |_, tag| meta.panic("Unsupported data type: {any}", .{tag}), inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
} }
} }
}; };
@ -169,7 +169,7 @@ pub const ext = struct {
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
inline else => |tag| meta.panic("Unsupported data type: {any}", .{tag}), inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
}; };
} }
}; };

View File

@ -1,32 +1,31 @@
const asynk = @import("async");
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const dialect = @import("mlir/dialects");
const protobuf = @import("io/protobuf");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const std = @import("std");
const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto"); const xla_pb = @import("//xla:xla_proto");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlir = @import("mlir.zig");
const ops = @import("ops.zig"); const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const protobuf = @import("io/protobuf");
const asynk = @import("async");
const aio = @import("aio.zig"); const aio = @import("aio.zig");
const dialect = @import("mlir/dialects"); const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const assert = std.debug.assert;
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Location = mlir.Location; const Location = mlir.Location;
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Target = @import("platform.zig").Target; const Target = @import("platform.zig").Target;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Shape = @import("shape.zig").Shape;
const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const Tracer = @import("tools/tracer.zig").Tracer; const Tracer = @import("tools/tracer.zig").Tracer;
const log = std.log.scoped(.zml_module); const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/module");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
@ -101,7 +100,7 @@ pub const CompilationContext = struct {
} }
pub fn deactivate(self: *CompilationContext) void { pub fn deactivate(self: *CompilationContext) void {
std.debug.assert(_current != null and _current.? == self); assert(_current != null and _current.? == self);
_current = self._previous; _current = self._previous;
self._previous = null; self._previous = null;
} }
@ -163,7 +162,7 @@ pub const CompilationContext = struct {
// So we create a copy of the arguments, and replace values // So we create a copy of the arguments, and replace values
// by the block arguments. // by the block arguments.
var blk_args = args; var blk_args = args;
assert(assignBlockArguments(&blk_args, block, 0) == N); std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N);
const loc = self.mlirCtx().location(@src()); const loc = self.mlirCtx().location(@src());
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args)); const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
@ -209,9 +208,9 @@ pub const CompilationContext = struct {
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable; meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable;
meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{}); stdx.debug.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{});
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count); const input_types = try arena.alloc(mlir.Type, tensor_count);
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
@ -311,7 +310,7 @@ pub const CompilationContext = struct {
// This will break the day we writer another attribute before donation. // This will break the day we writer another attribute before donation.
// When the time come, do a more fancy lookup here to check if an argument // When the time come, do a more fancy lookup here to check if an argument
// is donated twice. // is donated twice.
meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
attributes[a].appendAssumeCapacity( attributes[a].appendAssumeCapacity(
mlir.NamedAttribute.init( mlir.NamedAttribute.init(
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
@ -507,7 +506,7 @@ pub const CompilationContext = struct {
extractValues(args, values[function.n_model..]); extractValues(args, values[function.n_model..]);
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
var res: meta.FnResult(func) = undefined; var res: stdx.meta.FnResult(func) = undefined;
assignResults(&res, function.res_shapes, op); assignResults(&res, function.res_shapes, op);
return res; return res;
} }
@ -531,7 +530,7 @@ pub const CompilationContext = struct {
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id); const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
if (res.found_existing) { if (res.found_existing) {
std.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id }); stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id });
} else { } else {
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
} }
@ -677,9 +676,9 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3
}; };
meta.visit((struct { meta.visit((struct {
fn cb(ctx: *LocalContext, buffer: *const Buffer) void { fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
// meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
const model_sharding = ctx.buffers.len; const model_sharding = ctx.buffers.len;
meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
for (buffer._shards.constSlice(), 0..) |shard, d| { for (buffer._shards.constSlice(), 0..) |shard, d| {
ctx.buffers[d][ctx.index] = shard; ctx.buffers[d][ctx.index] = shard;
} }
@ -718,7 +717,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
} }
}).cb, &local_ctx, v); }).cb, &local_ctx, v);
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); stdx.debug.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
} }
/// Visit the given struct and assign op results to each tensor found. /// Visit the given struct and assign op results to each tensor found.
@ -761,6 +760,13 @@ const BaseExe = struct {
num_devices: u8, num_devices: u8,
/// Allocator backing result_buffer_shapes and deinit by ExeWithWeights /// Allocator backing result_buffer_shapes and deinit by ExeWithWeights
_allocator: std.heap.ArenaAllocator, _allocator: std.heap.ArenaAllocator,
pub fn serialize(self: BaseExe, writer: anytype) !void {
var executable = try self.exe.getExecutable(self.pjrt_api);
var serialize_result = try executable.serialize(self.platform.pjrt_api);
defer serialize_result.deinit();
try writer.writeAll(serialize_result.bytes);
}
}; };
/// Represents a ZML model, compiled into a PJRT executable. /// Represents a ZML model, compiled into a PJRT executable.
@ -779,6 +785,16 @@ pub fn Exe(comptime func: anytype) type {
pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) { pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) {
return ExeWithWeights(func).initFromModel(allocator, self.inner, model); return ExeWithWeights(func).initFromModel(allocator, self.inner, model);
} }
pub fn serialize(self: Self, writer: anytype) !void {
return try self.inner.serialize(writer);
}
// pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self {
// const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size);
// defer allocator.free(bytes);
// return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes);
// }
}; };
} }
@ -906,7 +922,7 @@ fn compileInternal(
var timer = std.time.Timer.start() catch null; var timer = std.time.Timer.start() catch null;
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
// Run in a dedicated thread because compilation relies on `threadlocal`. // Run in a dedicated thread because compilation relies on `threadlocal`.
const f = try asynk.callBlockingGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args });
context._module.getBody().appendOperation(f.mlir_fn); context._module.getBody().appendOperation(f.mlir_fn);
const sharding = context._platform.sharding(); const sharding = context._platform.sharding();
@ -927,7 +943,7 @@ fn compileInternal(
if (timer) |*t| { if (timer) |*t| {
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)}); if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloor(f32, time_ms, 1000)});
} }
var arena_state_exe = std.heap.ArenaAllocator.init(allocator); var arena_state_exe = std.heap.ArenaAllocator.init(allocator);
@ -945,12 +961,7 @@ fn compileInternal(
}; };
} }
/// Compiles a Model struct with the given configuration and shapes, for the given platform. pub fn load(
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compile(
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
comptime Model: type, comptime Model: type,
init_args: anytype, init_args: anytype,
@ -973,33 +984,62 @@ pub fn compile(
return compileModel(allocator, model, func, args_shapes, platform); return compileModel(allocator, model, func, args_shapes, platform);
} }
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compile(
allocator: std.mem.Allocator,
comptime func: anytype,
init_args: anytype,
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
buffer_store: aio.BufferStore,
platform: Platform,
) !Exe(func) {
const ModelT = ModuleSignature(func).ModelT;
var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit();
const arena = arena_state.allocator();
var model = try aio.populateModel(ModelT, arena, buffer_store);
// If the Model has a "init" function, call it with the given parameters.
if (@hasDecl(ModelT, "init")) {
// TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void.
@call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args);
}
return compileModel(allocator, func, model, args_shapes, platform);
}
/// Compiles a Model struct with the given configuration and shapes, for the given platform. /// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments /// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileModel( pub fn compileModel(
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
model: anytype, comptime func: anytype,
comptime func: @TypeOf(.literal), model: ModuleSignature(func).ModelT,
args_shapes: ShapeOf(ModuleSignature(@field(@TypeOf(model), @tagName(func))).ArgsT), args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
platform: Platform, platform: Platform,
) !Exe(@field(@TypeOf(model), @tagName(func))) { ) !Exe(func) {
const Model = @TypeOf(model); const ModelT = ModuleSignature(func).ModelT;
const name = @typeName(Model) ++ "." ++ @tagName(func); const name = @typeName(ModelT) ++ ".forward";
log.info("Compiling {s} with {}", .{ name, args_shapes }); log.info("Compiling {s} with {}", .{ name, args_shapes });
var context = try CompilationContext.init(allocator, name, platform); var context = try CompilationContext.init(allocator, name, platform);
defer context.deinit(); defer context.deinit();
const raw_module = try compileInternal(allocator, &context, @field(Model, @tagName(func)), model, args_shapes); const raw_module = try compileInternal(allocator, &context, func, model, args_shapes);
return Exe(@field(Model, @tagName(func))){ .inner = raw_module }; return .{ .inner = raw_module };
} }
/// Compiles a function with the given configuration and shapes, for the given platform. /// Compiles a function with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling the given function with tensor of the given shapes. /// Generate MLIR by calling the given function with tensor of the given shapes.
pub fn compileFn( pub fn compileFn(
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
func: anytype, comptime func: anytype,
args: ShapeOf(meta.FnParams(func)), args: ShapeOf(stdx.meta.FnArgs(func)),
platform: Platform, platform: Platform,
) !ExeWithWeights(FnWithVoidArg(func)) { ) !ExeWithWeights(FnWithVoidArg(func)) {
const name = @typeName(@TypeOf(func)); const name = @typeName(@TypeOf(func));
@ -1008,7 +1048,7 @@ pub fn compileFn(
const Local = struct { const Local = struct {
// This is the function we will actually compile. // This is the function we will actually compile.
pub fn forward(_: void, inner_args: meta.FnParams(func)) meta.FnResult(func) { pub fn forward(_: void, inner_args: stdx.meta.FnArgs(func)) stdx.meta.FnResult(func) {
return @call(.auto, func, inner_args); return @call(.auto, func, inner_args);
} }
}; };
@ -1019,10 +1059,10 @@ pub fn compileFn(
return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model); return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model);
} }
fn FnWithVoidArg(func: anytype) type { fn FnWithVoidArg(comptime func: anytype) type {
const fn_info = @typeInfo(@TypeOf(func)).Fn; const fn_info = @typeInfo(@TypeOf(func)).Fn;
const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void }; const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void };
meta.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)}); stdx.debug.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)});
return @Type(.{ .Fn = .{ return @Type(.{ .Fn = .{
.calling_convention = fn_info.calling_convention, .calling_convention = fn_info.calling_convention,
.is_generic = false, .is_generic = false,
@ -1268,10 +1308,10 @@ pub fn ModuleSignature(comptime func: anytype) Sign {
const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
return .{ return .{
.FuncT = FuncT, .FuncT = FuncT,
.ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create,ModuleSignature for function with an 'anytype' parameter"), .ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create ModuleSignature for function with an 'anytype' parameter"),
.ArgsT = blk: { .ArgsT = blk: {
const function_info = @typeInfo(FuncT); const function_info = @typeInfo(FuncT);
if (function_info.Fn.params[1..].len == 0) { if (function_info.Fn.params.len < 2) {
break :blk @TypeOf(.{}); break :blk @TypeOf(.{});
} }

View File

@ -1,20 +1,20 @@
//! Common layer definition and functions for Neural Networks (NN) //! Common layer definition and functions for Neural Networks (NN)
const std = @import("std"); const std = @import("std");
const assert = std.debug.assert; const stdx = @import("stdx");
const testing = std.testing;
const zml = @import("zml.zig"); const cuda = @import("nn/cuda.zig");
const meta = @import("meta.zig");
const helpers = @import("helpers.zig"); const helpers = @import("helpers.zig");
const meta = @import("meta.zig");
const ops = @import("ops.zig"); const ops = @import("ops.zig");
const zml = @import("zml.zig");
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.zml_tensor); const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/tensor");
const cuda = @import("nn/cuda.zig"); const testing = std.testing;
test { test {
_ = cuda; _ = cuda;
@ -41,8 +41,8 @@ pub const TokenEmbedding = struct {
weight: Tensor, weight: Tensor,
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
return self.weight.gatherValues(0, idx, .{}); return self.weight.gatherValues(0, idx, .{});
} }
}; };
@ -159,7 +159,7 @@ pub const CosSin = [2]Tensor;
/// See: https://paperswithcode.com/method/rope /// See: https://paperswithcode.com/method/rope
pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor { pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
const cos, const sin = cos_sin_cache; const cos, const sin = cos_sin_cache;
meta.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin }); stdx.debug.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin });
// broadcast cos / sin to .{ batch, .seq, .half_dim } // broadcast cos / sin to .{ batch, .seq, .half_dim }
const x_real, const x_imag = splitRealImg(x, opts.impl); const x_real, const x_imag = splitRealImg(x, opts.impl);
const has_tags = cos.shape().tag(0) != Shape.TagUnknown; const has_tags = cos.shape().tag(0) != Shape.TagUnknown;
@ -178,9 +178,9 @@ pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor {
pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin { pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin {
const shape = Shape.init(sh, dtype); const shape = Shape.init(sh, dtype);
meta.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape}); stdx.debug.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape});
const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) }; const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) };
meta.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim}); stdx.debug.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim});
// compute sin and cos in f32 before downcasting to x type. // compute sin and cos in f32 before downcasting to x type.
const inv_freq = invFreq(head_dim, opts.freq_base, .f32); const inv_freq = invFreq(head_dim, opts.freq_base, .f32);
@ -364,8 +364,8 @@ pub fn upsample(
) Tensor { ) Tensor {
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here. // TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes. // resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
meta.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input}); stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input});
meta.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{}); stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
return switch (opts.mode) { return switch (opts.mode) {
.nearest => { .nearest => {
var scale_factors: [3]f64 = undefined; var scale_factors: [3]f64 = undefined;
@ -398,7 +398,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor {
var res = input; var res = input;
for (spatial_dims) |d| { for (spatial_dims) |d| {
const n = out_shape.dim(d); const n = out_shape.dim(d);
const ratio = meta.divFloat(f32, input.dim(d), n); const ratio = stdx.math.divFloor(f32, input.dim(d), n);
const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32); const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32);
res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true }); res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true });
} }
@ -576,7 +576,7 @@ pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Te
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
const left = scaled.floor(); const left = scaled.floor();
const right = left.addConstant(1); const right = left.addConstant(1);
@ -638,7 +638,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
const t = scaled.sub(scaled.floor()); const t = scaled.sub(scaled.floor());
const pos = Tensor.stack(&.{ const pos = Tensor.stack(&.{
@ -693,11 +693,11 @@ pub fn causalAttnMask(
attn_window_len: ?u32, attn_window_len: ?u32,
) Tensor { ) Tensor {
const attn_shape = Shape.init(attn_shape_, dtype); const attn_shape = Shape.init(attn_shape_, dtype);
meta.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape}); stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape});
const qlen = attn_shape.dim(-2); const qlen = attn_shape.dim(-2);
const q_idx = Tensor.iota(attn_shape, .i32, -2); const q_idx = Tensor.iota(attn_shape, -2);
const klen = attn_shape.dim(-1); const klen = attn_shape.dim(-1);
const k_idx = Tensor.iota(attn_shape, .i32, -1); const k_idx = Tensor.iota(attn_shape, -1);
// all elements > main diagonal must be 0 // all elements > main diagonal must be 0
// (q_idx - window_len < k_idx <= q_idx) // (q_idx - window_len < k_idx <= q_idx)
@ -748,16 +748,16 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
const err_args = .{ q, k, v, opts.attn_mask }; const err_args = .{ q, k, v, opts.attn_mask };
meta.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
meta.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
meta.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) { if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) {
return cuda.sdpa(q, k, v, opts); return cuda.sdpa(q, k, v, opts);
} }
if (q.dim(.h) != k.dim(.h)) { if (q.dim(.h) != k.dim(.h)) {
meta.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
// Note: we don't try to repeat queries. // Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage. // Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h))); const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
@ -766,7 +766,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
const attn_mask = if (opts.attn_mask) |m| m else null; const attn_mask = if (opts.attn_mask) |m| m else null;
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args); stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
}; };
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());

View File

@ -1,11 +1,13 @@
const std = @import("std"); const std = @import("std");
const mlir = @import("mlir.zig"); const stdx = @import("stdx");
const buffer = @import("buffer.zig");
const helpers = @import("helpers.zig"); const helpers = @import("helpers.zig");
const module = @import("module.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const module = @import("module.zig");
const Buffer = @import("buffer.zig").Buffer; const Buffer = buffer.Buffer;
const CompilationContext = module.CompilationContext; const CompilationContext = module.CompilationContext;
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
@ -20,14 +22,14 @@ const dialect = struct {
}; };
const assert = std.debug.assert; const assert = std.debug.assert;
const log = std.log.scoped(.zml); const log = std.log.scoped(.@"zml/tensor");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
/// Generate an MLIR call to the given member function with the given tensors. /// Generate an MLIR call to the given member function with the given tensors.
pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
// TODO: this should use `self.getContext().callFunc(self, args)` // TODO: this should use `self.getContext().callFunc(self, args)`
return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args); return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args);
@ -121,8 +123,8 @@ test "simple while" {
pub fn reduce( pub fn reduce(
comptime body_fn: anytype, comptime body_fn: anytype,
inputs: meta.FnParam(body_fn, 0), inputs: stdx.meta.FnParam(body_fn, 0),
inits: meta.FnParam(body_fn, 0), inits: stdx.meta.FnParam(body_fn, 0),
axes: []const i64, axes: []const i64,
) BlockSignNoCtx(body_fn).Return { ) BlockSignNoCtx(body_fn).Return {
// TODO: actualAxes // TODO: actualAxes
@ -155,7 +157,7 @@ pub fn reduce(
// `stablehlo.reduce` drops axes. We want to avoid that to propagate tags. // `stablehlo.reduce` drops axes. We want to avoid that to propagate tags.
// So we need to broadcast the output of `stablehlo.reduce` to the input shapes. // So we need to broadcast the output of `stablehlo.reduce` to the input shapes.
// To that order, we initialize `result` to `inputs`, then we use meta.visit, // To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
// to find the correct mlir.Value, but we first broadcast before creating the final // to find the correct mlir.Value, but we first broadcast before creating the final
// Tensor struct. // Tensor struct.
var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
@ -217,10 +219,10 @@ pub const ReduceWindowOpts = struct {
pub fn reduceWindow( pub fn reduceWindow(
comptime body_fn: anytype, comptime body_fn: anytype,
inputs: meta.FnParam(body_fn, 0), inputs: stdx.meta.FnParam(body_fn, 0),
inits: meta.FnParam(body_fn, 0), inits: stdx.meta.FnParam(body_fn, 0),
opts: ReduceWindowOpts, opts: ReduceWindowOpts,
) meta.FnResult(body_fn) { ) stdx.meta.FnResult(body_fn) {
const BodyS = comptime BlockSignNoCtx(body_fn); const BodyS = comptime BlockSignNoCtx(body_fn);
comptime { comptime {
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
@ -263,7 +265,7 @@ pub fn reduceWindow(
pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: anytype) BlockSign(func).Return { pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: anytype) BlockSign(func).Return {
const num_steps: u32, const step_tag = blk: { const num_steps: u32, const step_tag = blk: {
const dims, const tags = Shape.parseDimensions(num_steps_); const dims, const tags = Shape.parseDimensions(num_steps_);
meta.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_}); stdx.debug.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_});
break :blk .{ @intCast(dims.get(0)), tags.get(0) }; break :blk .{ @intCast(dims.get(0)), tags.get(0) };
}; };
const S = comptime BlockSign(func); const S = comptime BlockSign(func);
@ -290,7 +292,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
} }
fn updateResBuffer(inputs: []const Tensor, idx: Tensor) Tensor { fn updateResBuffer(inputs: []const Tensor, idx: Tensor) Tensor {
meta.internalAssert(inputs.len == 2, "too many tensors", .{}); stdx.debug.internalAssert(inputs.len == 2, "too many tensors", .{});
const res, const step_res = inputs[0..2].*; const res, const step_res = inputs[0..2].*;
return res.dynamicUpdateSlice1d(step_res.insertAxes(0, .{._}), 0, idx); return res.dynamicUpdateSlice1d(step_res.insertAxes(0, .{._}), 0, idx);
} }

View File

@ -1,22 +1,21 @@
const builtin = @import("builtin");
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const builtin = @import("builtin");
const dialects = @import("mlir/dialects");
const mlir = @import("mlir"); const mlir = @import("mlir");
const pjrt = @import("pjrt"); const pjrt = @import("pjrt");
const std = @import("std");
const stdx = @import("stdx");
const dtype = @import("dtype.zig"); const dtype = @import("dtype.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const dialects = @import("mlir/dialects");
pub const Profiler = pjrt.Profiler;
pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode;
const Target = @import("platform.zig").Target; const Target = @import("platform.zig").Target;
const log = std.log.scoped(.zml); const log = std.log.scoped(.zml);
pub const Profiler = pjrt.Profiler;
pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode;
pub const Buffer = pjrt.Buffer; pub const Buffer = pjrt.Buffer;
pub const BufferType = pjrt.BufferType; pub const BufferType = pjrt.BufferType;
pub const Device = pjrt.Device; pub const Device = pjrt.Device;
@ -181,14 +180,16 @@ pub const LoadedExecutable = opaque {
return self.inner().getAddressableDevices(api); return self.inner().getAddressableDevices(api);
} }
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { pub const ExecuteArgs = struct {
arguments: []const [*]const *const Buffer, arguments: []const [*]const *const Buffer,
num_args: usize, num_args: usize,
results: []const [*]*Buffer, results: []const [*]*Buffer,
events: []?*Event, events: []?*Event,
non_donatable_input_indices: []const i64 = &.{}, non_donatable_input_indices: []const i64 = &.{},
}) ExecuteError!void { };
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, .{
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
.num_args = args.num_args, .num_args = args.num_args,
.arguments = @ptrCast(args.arguments), .arguments = @ptrCast(args.arguments),
.results = @ptrCast(args.results), .results = @ptrCast(args.results),

View File

@ -1,28 +1,18 @@
const builtin = @import("builtin");
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const builtin = @import("builtin");
const runtimes = @import("runtimes"); const runtimes = @import("runtimes");
const std = @import("std");
const stdx = @import("stdx");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const module = @import("module.zig"); const module = @import("module.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const log = std.log.scoped(.zml); const log = std.log.scoped(.zml);
pub const Target = runtimes.Platform; pub const Target = runtimes.Platform;
pub const available_targets = switch (builtin.os.tag) { pub const available_targets = std.enums.values(Target);
.macos => [_]Target{
.cpu,
},
.linux => [_]Target{
.cpu,
.cuda,
.rocm,
.tpu,
},
else => [_]Target{},
};
pub const CompilationOptions = struct { pub const CompilationOptions = struct {
xla_dump_to: ?[]const u8 = null, xla_dump_to: ?[]const u8 = null,

View File

@ -1,11 +1,12 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const testing = std.testing; const testing = std.testing;
const meta = @import("meta.zig");
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const EnumLiteral = @TypeOf(.enum_literal); const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.shape); const log = std.log.scoped(.shape);
test { test {
@ -39,7 +40,7 @@ pub const Shape = struct {
return .{ v._dims, v._tags }; return .{ v._dims, v._tags };
} }
if (comptime meta.isSliceOfAny(T, meta.isInteger)) { if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger)) {
var dims_ = DimsArray.init(0) catch unreachable; var dims_ = DimsArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable; var tags_ = TagsArray.init(0) catch unreachable;
for (v) |d| { for (v) |d| {
@ -49,19 +50,19 @@ pub const Shape = struct {
return .{ dims_, tags_ }; return .{ dims_, tags_ };
} }
if (comptime meta.isStruct(T)) { if (comptime stdx.meta.isStruct(T)) {
var dims_: DimsArray = .{}; var dims_: DimsArray = .{};
var tags_: TagsArray = .{}; var tags_: TagsArray = .{};
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
const fv = @field(v, field.name); const fv = @field(v, field.name);
if (comptime meta.isInteger(field.type)) { if (comptime stdx.meta.isInteger(field.type)) {
dims_.appendAssumeCapacity(@intCast(fv)); dims_.appendAssumeCapacity(@intCast(fv));
} else if (comptime isAutoDim(fv)) { } else if (comptime isAutoDim(fv)) {
dims_.appendAssumeCapacity(-1); dims_.appendAssumeCapacity(-1);
} else { } else {
meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name}); stdx.meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
} }
if (comptime meta.isTuple(T)) { if (comptime stdx.meta.isTuple(T)) {
tags_.appendAssumeCapacity(TagUnknown); tags_.appendAssumeCapacity(TagUnknown);
} else { } else {
tags_.appendAssumeCapacity(toTag(field)); tags_.appendAssumeCapacity(toTag(field));
@ -71,7 +72,7 @@ pub const Shape = struct {
return .{ dims_, tags_ }; return .{ dims_, tags_ };
} }
meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T}); stdx.meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
} }
test parseDimensions { test parseDimensions {
@ -92,7 +93,7 @@ pub const Shape = struct {
var axes_ = AxesArray.init(0) catch unreachable; var axes_ = AxesArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable; var tags_ = TagsArray.init(0) catch unreachable;
if (comptime meta.isSliceOfAny(T, isAxisConvertible)) { if (comptime stdx.meta.isSliceOfAny(T, isAxisConvertible)) {
for (v) |d| { for (v) |d| {
axes_.appendAssumeCapacity(self.axis(d)); axes_.appendAssumeCapacity(self.axis(d));
tags_.appendAssumeCapacity(self.tag(d)); tags_.appendAssumeCapacity(self.tag(d));
@ -100,7 +101,7 @@ pub const Shape = struct {
return .{ axes_, tags_ }; return .{ axes_, tags_ };
} }
if (comptime meta.isTupleOfAny(T, isAxisConvertible)) { if (comptime stdx.meta.isTupleOfAny(T, isAxisConvertible)) {
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
axes_.appendAssumeCapacity(self.axis(@field(v, field.name))); axes_.appendAssumeCapacity(self.axis(@field(v, field.name)));
tags_.appendAssumeCapacity(self.tag(@field(v, field.name))); tags_.appendAssumeCapacity(self.tag(@field(v, field.name)));
@ -108,12 +109,12 @@ pub const Shape = struct {
return .{ axes_, tags_ }; return .{ axes_, tags_ };
} }
meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T}); stdx.meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T});
} }
pub fn parseTags(v: anytype) TagsArray { pub fn parseTags(v: anytype) TagsArray {
const T = @TypeOf(v); const T = @TypeOf(v);
meta.assertComptime(meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T}); stdx.debug.assertComptime(stdx.meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T});
var tags_ = TagsArray.init(0) catch unreachable; var tags_ = TagsArray.init(0) catch unreachable;
inline for (v) |field| { inline for (v) |field| {
tags_.appendAssumeCapacity(toTag(field)); tags_.appendAssumeCapacity(toTag(field));
@ -135,7 +136,7 @@ pub const Shape = struct {
var res: Shape = .{ ._dtype = dt }; var res: Shape = .{ ._dtype = dt };
for (0..rank_) |i| { for (0..rank_) |i| {
res._dims.append(@intCast(i)) catch { res._dims.append(@intCast(i)) catch {
meta.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ }); stdx.debug.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ });
}; };
res._tags.append(TagUnknown) catch unreachable; res._tags.append(TagUnknown) catch unreachable;
} }
@ -162,7 +163,7 @@ pub const Shape = struct {
} }
fn isAxisConvertible(comptime T: type) bool { fn isAxisConvertible(comptime T: type) bool {
return meta.isInteger(T) or isTagConvertible(T); return stdx.meta.isInteger(T) or isTagConvertible(T);
} }
fn isTagConvertible(comptime T: type) bool { fn isTagConvertible(comptime T: type) bool {
@ -180,12 +181,12 @@ pub const Shape = struct {
EnumLiteral => @tagName(v).ptr, EnumLiteral => @tagName(v).ptr,
std.builtin.Type.StructField => v.name.ptr, std.builtin.Type.StructField => v.name.ptr,
Tag => v, Tag => v,
else => meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), else => stdx.meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
}; };
} }
inline fn ensureDimsAndTagsAreSync(self: Shape) void { inline fn ensureDimsAndTagsAreSync(self: Shape) void {
meta.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len }); stdx.debug.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len });
} }
pub fn tag(self: Shape, ax: anytype) Tag { pub fn tag(self: Shape, ax: anytype) Tag {
@ -220,7 +221,7 @@ pub const Shape = struct {
pub fn hasTags(self: Shape, tagz: anytype) bool { pub fn hasTags(self: Shape, tagz: anytype) bool {
const T = @TypeOf(tagz); const T = @TypeOf(tagz);
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
for (tagz) |t| { for (tagz) |t| {
if (self.hasTag(t) == null) { if (self.hasTag(t) == null) {
return false; return false;
@ -229,7 +230,7 @@ pub const Shape = struct {
return true; return true;
} }
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
inline for (tagz) |t| { inline for (tagz) |t| {
if (self.hasTag(t) == null) { if (self.hasTag(t) == null) {
return false; return false;
@ -238,7 +239,7 @@ pub const Shape = struct {
return true; return true;
} }
meta.compileError("Expected tuple of tags, got {any}", .{T}); stdx.meta.compileError("Expected tuple of tags, got {any}", .{T});
} }
pub fn isFullyTagged(self: Shape) bool { pub fn isFullyTagged(self: Shape) bool {
@ -252,7 +253,7 @@ pub const Shape = struct {
self.ensureDimsAndTagsAreSync(); self.ensureDimsAndTagsAreSync();
const T = @TypeOf(axis_); const T = @TypeOf(axis_);
if (comptime meta.isInteger(T)) { if (comptime stdx.meta.isInteger(T)) {
return self.axisFromInt(@intCast(axis_)); return self.axisFromInt(@intCast(axis_));
} }
@ -260,7 +261,7 @@ pub const Shape = struct {
return self.axisFromTag(toTag(axis_)); return self.axisFromTag(toTag(axis_));
} }
meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T}); stdx.meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T});
} }
pub fn axes(self: Shape, axes_: anytype) AxesArray { pub fn axes(self: Shape, axes_: anytype) AxesArray {
@ -274,27 +275,27 @@ pub const Shape = struct {
var res = AxesArray.init(0) catch unreachable; var res = AxesArray.init(0) catch unreachable;
if (comptime meta.isSliceOfAny(T, meta.isInteger) or meta.isSliceOf(T, Tag)) { if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger) or stdx.meta.isSliceOf(T, Tag)) {
for (axes_) |ax| { for (axes_) |ax| {
res.appendAssumeCapacity(self.axis(ax)); res.appendAssumeCapacity(self.axis(ax));
} }
return res; return res;
} }
if (comptime meta.isStruct(T)) { if (comptime stdx.meta.isStruct(T)) {
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
res.appendAssumeCapacity(self.axis(@field(axes_, field.name))); res.appendAssumeCapacity(self.axis(@field(axes_, field.name)));
} }
return res; return res;
} }
meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); stdx.meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
} }
fn axisFromInt(self: Shape, d: isize) u3 { fn axisFromInt(self: Shape, d: isize) u3 {
const rk: i8 = self.rank(); const rk: i8 = self.rank();
if (d < -rk or d > rk) { if (d < -rk or d > rk) {
meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
} }
return if (d < 0) return if (d < 0)
@intCast(d + rk) @intCast(d + rk)
@ -323,9 +324,9 @@ pub const Shape = struct {
} }
fn axisFromTag(self: Shape, d: Tag) u3 { fn axisFromTag(self: Shape, d: Tag) u3 {
meta.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self }); stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self });
return self.axisFromTagMaybe(d) orelse { return self.axisFromTagMaybe(d) orelse {
meta.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d }); stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d });
}; };
} }
@ -339,7 +340,7 @@ pub const Shape = struct {
pub fn count(self: Shape) usize { pub fn count(self: Shape) usize {
var res: i64 = 1; var res: i64 = 1;
for (self.dims()) |d| { for (self.dims()) |d| {
meta.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self}); stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self});
res *= d; res *= d;
} }
return @intCast(res); return @intCast(res);
@ -398,12 +399,12 @@ pub const Shape = struct {
var new_shape: Shape = .{ ._dtype = self.dtype() }; var new_shape: Shape = .{ ._dtype = self.dtype() };
new_shape._dims, new_shape._tags = parseDimensions(new_shape_); new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
new_shape.inferMissingAxis(self.count()); new_shape.inferMissingAxis(self.count());
meta.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() }); stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() });
return new_shape; return new_shape;
} }
fn inferMissingAxis(self: *Shape, n_: usize) void { fn inferMissingAxis(self: *Shape, n_: usize) void {
meta.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*}); stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*});
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return; const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
// We can't use `self.count()` yet cause we have negative dims. // We can't use `self.count()` yet cause we have negative dims.
@ -481,7 +482,7 @@ pub const Shape = struct {
} }
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape { pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
meta.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self}); stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self});
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last) const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank() self.rank()
@ -573,23 +574,23 @@ pub const Shape = struct {
var res = self; var res = self;
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz }); stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz });
for (tagz, 0..) |tag_, i| { for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
return res; return res;
} }
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz }); stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz });
inline for (tagz, 0..) |tag_, i| { inline for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
return res; return res;
} }
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
} }
test withTags { test withTags {
@ -620,23 +621,23 @@ pub const Shape = struct {
var res = self; var res = self;
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz }); stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz });
for (tagz, self.rank() - tagz.len..) |tag_, i| { for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
return res; return res;
} }
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz }); stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz });
inline for (tagz, self.rank() - tagz.len..) |tag_, i| { inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
return res; return res;
} }
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
} }
test withPartialTags { test withPartialTags {
@ -683,7 +684,7 @@ pub const Shape = struct {
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 }; /// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
pub fn rename(self: Shape, renames: anytype) Shape { pub fn rename(self: Shape, renames: anytype) Shape {
const T = @TypeOf(renames); const T = @TypeOf(renames);
meta.assertComptime(meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T}); stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
var res = self; var res = self;
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
res._tags.set(self.axis(field), toTag(@field(renames, field.name))); res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
@ -789,7 +790,7 @@ pub const Shape = struct {
pub fn splitAxes(self: Shape, axes_: anytype) Shape { pub fn splitAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_); const T = @TypeOf(axes_);
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self; var res = self;
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
@ -822,7 +823,7 @@ pub const Shape = struct {
var new_dim: i64 = 1; var new_dim: i64 = 1;
for (axes__.constSlice(), first_axis..) |ax, counter| { for (axes__.constSlice(), first_axis..) |ax, counter| {
new_dim *= self.dim(ax); new_dim *= self.dim(ax);
meta.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ }); stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ });
} }
var new_shape = self; var new_shape = self;
@ -863,11 +864,11 @@ pub const Shape = struct {
pub fn mergeAxes(self: Shape, axes_: anytype) Shape { pub fn mergeAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_); const T = @TypeOf(axes_);
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self; var res = self;
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
meta.assertComptime(meta.isTupleOfAny(field.type, isAxisConvertible) or meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type}); stdx.debug.assertComptime(stdx.meta.isTupleOfAny(field.type, isAxisConvertible) or stdx.meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type});
res = res.mergeAxis(field, @field(axes_, field.name)); res = res.mergeAxis(field, @field(axes_, field.name));
} }
return res; return res;
@ -912,28 +913,28 @@ pub const Shape = struct {
var vals_: std.BoundedArray(T, MAX_RANK) = .{}; var vals_: std.BoundedArray(T, MAX_RANK) = .{};
var tags_: TagsArray = .{}; var tags_: TagsArray = .{};
if (comptime meta.isSliceOf(V, T)) { if (comptime stdx.meta.isSliceOf(V, T)) {
for (v) |d| { for (v) |d| {
vals_.appendAssumeCapacity(d); vals_.appendAssumeCapacity(d);
} }
return .{ vals_, tags_ }; return .{ vals_, tags_ };
} }
if (comptime meta.isStruct(V)) { if (comptime stdx.meta.isStruct(V)) {
const fields = std.meta.fields(V); const fields = std.meta.fields(V);
meta.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK }); stdx.debug.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK });
inline for (fields) |field| { inline for (fields) |field| {
const fv = @field(v, field.name); const fv = @field(v, field.name);
vals_.appendAssumeCapacity(fv); vals_.appendAssumeCapacity(fv);
if (!comptime meta.isTuple(V)) { if (!comptime stdx.meta.isTuple(V)) {
tags_.appendAssumeCapacity(toTag(field)); tags_.appendAssumeCapacity(toTag(field));
} }
} }
return .{ vals_, tags_ }; return .{ vals_, tags_ };
} }
meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
} }
test parseStruct { test parseStruct {
@ -948,17 +949,17 @@ pub const Shape = struct {
const V = @TypeOf(options); const V = @TypeOf(options);
var res: std.BoundedArray(T, MAX_RANK) = .{}; var res: std.BoundedArray(T, MAX_RANK) = .{};
if (comptime meta.isSliceOf(V, T)) { if (comptime stdx.meta.isSliceOf(V, T)) {
meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
for (options) |d| { for (options) |d| {
res.appendAssumeCapacity(d); res.appendAssumeCapacity(d);
} }
} }
if (comptime meta.isStruct(V)) { if (comptime stdx.meta.isStruct(V)) {
for (0..self.rank()) |_| res.appendAssumeCapacity(default); for (0..self.rank()) |_| res.appendAssumeCapacity(default);
const fields = std.meta.fields(V); const fields = std.meta.fields(V);
meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len }); stdx.debug.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
inline for (fields) |field| { inline for (fields) |field| {
const a = self.axis(field); const a = self.axis(field);
res.buffer[a] = @field(options, field.name); res.buffer[a] = @field(options, field.name);
@ -966,7 +967,7 @@ pub const Shape = struct {
return res; return res;
} }
meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
} }
test parseAxesOptions { test parseAxesOptions {

View File

@ -1,7 +1,6 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const assert = std.debug.assert; const stdx = @import("stdx");
const testing = std.testing;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlir = @import("mlir.zig");
@ -22,7 +21,9 @@ const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo; const stablehlo = @import("mlir/dialects").stablehlo;
}; };
const scoped_log = std.log.scoped(.zml_tensor); const assert = std.debug.assert;
const testing = std.testing;
const scoped_log = std.log.scoped(.@"zml/tensor");
test { test {
std.testing.refAllDecls(Tensor); std.testing.refAllDecls(Tensor);
@ -99,7 +100,7 @@ pub const Tensor = struct {
if (builtin.mode == .Debug) { if (builtin.mode == .Debug) {
// Check that the MLIR value actually have the same shape. // Check that the MLIR value actually have the same shape.
const other = fromMlirValue(val); const other = fromMlirValue(val);
meta.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape }); stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape });
} }
return res; return res;
@ -112,7 +113,7 @@ pub const Tensor = struct {
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?; const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
const n = ranked_tensor.getRank(); const n = ranked_tensor.getRank();
meta.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) }; var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
for (0..n) |i| { for (0..n) |i| {
@ -213,7 +214,7 @@ pub const Tensor = struct {
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output. /// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor { pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
// Note: check donation docs, this may be too permissive. // Note: check donation docs, this may be too permissive.
meta.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin }); stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin });
// TODO: should we store all donations inside the context ? // TODO: should we store all donations inside the context ?
var res = self; var res = self;
@ -262,7 +263,7 @@ pub const Tensor = struct {
break :gt res; break :gt res;
} else lt: { } else lt: {
// several contiguous elements of self maps to one element of the result // several contiguous elements of self maps to one element of the result
meta.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size }); stdx.debug.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size });
break :lt self._shape.remove(-1); break :lt self._shape.remove(-1);
}; };
@ -295,7 +296,7 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor. /// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
pub fn popcnt(self: Tensor) Tensor { pub fn popcnt(self: Tensor) Tensor {
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
const loc = self.getContext().mlirCtx().location(@src()); const loc = self.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc); const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc);
return _result(self._shape, op.result(0)); return _result(self._shape, op.result(0));
@ -358,7 +359,7 @@ pub const Tensor = struct {
/// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true /// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true
/// and upper-triangular otherwise. /// and upper-triangular otherwise.
pub fn cholesky(self: Tensor, lower: bool) Tensor { pub fn cholesky(self: Tensor, lower: bool) Tensor {
meta.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()}); stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc); const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
@ -367,8 +368,8 @@ pub const Tensor = struct {
/// Solves the system of linear equations formed by the input tensors. /// Solves the system of linear equations formed by the input tensors.
pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor { pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor {
meta.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
meta.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts); const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
@ -377,7 +378,7 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor. /// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor.
pub fn roundNearestAfz(self: Tensor) Tensor { pub fn roundNearestAfz(self: Tensor) Tensor {
meta.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()});
const loc = self.getContext().mlirCtx().location(@src()); const loc = self.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc); const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc);
@ -386,7 +387,7 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor. /// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor.
pub fn roundNearestEven(self: Tensor) Tensor { pub fn roundNearestEven(self: Tensor) Tensor {
meta.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()});
const loc = self.getContext().mlirCtx().location(@src()); const loc = self.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc); const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc);
@ -395,8 +396,8 @@ pub const Tensor = struct {
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors. /// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
pub fn complex(re: Tensor, im: Tensor) Tensor { pub fn complex(re: Tensor, im: Tensor) Tensor {
meta.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape }); stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape });
meta.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()}); stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
const loc = re.getContext().mlirCtx().location(@src()); const loc = re.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc); const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc);
@ -408,7 +409,7 @@ pub const Tensor = struct {
/// ///
/// Tensor type can float or complex. /// Tensor type can float or complex.
pub fn real(self: Tensor) Tensor { pub fn real(self: Tensor) Tensor {
meta.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()});
if (self.dtype().isFloat()) { if (self.dtype().isFloat()) {
return self; return self;
@ -428,7 +429,7 @@ pub const Tensor = struct {
/// ///
/// Tensor type can float or complex. /// Tensor type can float or complex.
pub fn imag(self: Tensor) Tensor { pub fn imag(self: Tensor) Tensor {
meta.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()});
// Real tensors don't have imaginary part. // Real tensors don't have imaginary part.
if (self.dtype().isFloat()) { if (self.dtype().isFloat()) {
@ -450,18 +451,18 @@ pub const Tensor = struct {
pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor { pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor {
// TODO: support tagged API. // TODO: support tagged API.
meta.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len}); stdx.debug.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len});
meta.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() }); stdx.debug.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() });
const sh = switch (opts.kind) { const sh = switch (opts.kind) {
.FFT, .IFFT => blk: { .FFT, .IFFT => blk: {
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
break :blk self._shape; break :blk self._shape;
}, },
.RFFT => blk: { .RFFT => blk: {
meta.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() }); stdx.debug.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() });
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
const dt: DataType = switch (self.dtype()) { const dt: DataType = switch (self.dtype()) {
.f32 => .c64, .f32 => .c64,
@ -471,8 +472,8 @@ pub const Tensor = struct {
break :blk shape_.withDtype(dt); break :blk shape_.withDtype(dt);
}, },
.IRFFT => blk: { .IRFFT => blk: {
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
const dt: DataType = switch (self.dtype()) { const dt: DataType = switch (self.dtype()) {
.c64 => .f32, .c64 => .f32,
@ -551,7 +552,7 @@ pub const Tensor = struct {
16 => .u16, 16 => .u16,
32 => .u32, 32 => .u32,
64 => .u64, 64 => .u64,
else => meta.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}), else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}),
}; };
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype)); const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
@ -635,7 +636,7 @@ pub const Tensor = struct {
/// Note: this uses stablehlo.rng which is deprecated. /// Note: this uses stablehlo.rng which is deprecated.
/// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md /// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor { pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
meta.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts }); const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
@ -731,9 +732,9 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise conversion to another floating point type. /// Returns a Tensor containing the element-wise conversion to another floating point type.
pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor { pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor {
meta.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()}); stdx.debug.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()});
meta.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits}); stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
meta.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits}); stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits }); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc); const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
@ -741,56 +742,56 @@ pub const Tensor = struct {
} }
inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor { inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
meta.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() }); stdx.debug.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() });
const N = self.rank(); const N = self.rank();
meta.guard(opts.window_strides.len == N - 2, @src()); stdx.debug.guard(opts.window_strides.len == N - 2, @src());
for (opts.window_strides) |s| meta.guard(0 < s, @src()); for (opts.window_strides) |s| stdx.debug.guard(0 < s, @src());
meta.guard(opts.lhs_dilation.len == N - 2, @src()); stdx.debug.guard(opts.lhs_dilation.len == N - 2, @src());
for (opts.lhs_dilation) |d| meta.guard(0 < d, @src()); for (opts.lhs_dilation) |d| stdx.debug.guard(0 < d, @src());
meta.guard(opts.rhs_dilation.len == N - 2, @src()); stdx.debug.guard(opts.rhs_dilation.len == N - 2, @src());
for (opts.rhs_dilation) |d| meta.guard(0 < d, @src()); for (opts.rhs_dilation) |d| stdx.debug.guard(0 < d, @src());
meta.guard(opts.window_reversal.len == N - 2, @src()); stdx.debug.guard(opts.window_reversal.len == N - 2, @src());
meta.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src()); stdx.debug.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src());
meta.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src()); stdx.debug.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src());
meta.guard(opts.input_spatial_dimensions.len == N - 2, @src()); stdx.debug.guard(opts.input_spatial_dimensions.len == N - 2, @src());
meta.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src()); stdx.debug.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src());
meta.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src()); stdx.debug.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src());
meta.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src()); stdx.debug.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src());
for (opts.input_spatial_dimensions, 0..) |d, i| { for (opts.input_spatial_dimensions, 0..) |d, i| {
meta.guard(d != opts.input_batch_dimension, @src()); stdx.debug.guard(d != opts.input_batch_dimension, @src());
meta.guard(d != opts.input_feature_dimension, @src()); stdx.debug.guard(d != opts.input_feature_dimension, @src());
meta.guard(0 <= d and d < N, @src()); stdx.debug.guard(0 <= d and d < N, @src());
if (i < opts.input_spatial_dimensions.len - 1) continue; if (i < opts.input_spatial_dimensions.len - 1) continue;
meta.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src()); stdx.debug.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src());
} }
meta.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src()); stdx.debug.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src());
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src()); stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src());
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src()); stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src());
meta.guard(opts.kernel_spatial_dimensions.len == N - 2, @src()); stdx.debug.guard(opts.kernel_spatial_dimensions.len == N - 2, @src());
meta.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src()); stdx.debug.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src());
meta.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src()); stdx.debug.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src());
meta.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src()); stdx.debug.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src());
for (opts.kernel_spatial_dimensions, 0..) |d, i| { for (opts.kernel_spatial_dimensions, 0..) |d, i| {
meta.guard(d != opts.kernel_input_feature_dimension, @src()); stdx.debug.guard(d != opts.kernel_input_feature_dimension, @src());
meta.guard(d != opts.kernel_output_feature_dimension, @src()); stdx.debug.guard(d != opts.kernel_output_feature_dimension, @src());
meta.guard(0 <= d and d < N, @src()); stdx.debug.guard(0 <= d and d < N, @src());
if (i < opts.kernel_spatial_dimensions.len - 1) continue; if (i < opts.kernel_spatial_dimensions.len - 1) continue;
meta.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src()); stdx.debug.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src());
} }
meta.guard(opts.output_spatial_dimensions.len == N - 2, @src()); stdx.debug.guard(opts.output_spatial_dimensions.len == N - 2, @src());
meta.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src()); stdx.debug.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src());
meta.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src()); stdx.debug.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src());
meta.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src()); stdx.debug.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src());
for (opts.output_spatial_dimensions, 0..) |d, i| { for (opts.output_spatial_dimensions, 0..) |d, i| {
meta.guard(d != opts.output_batch_dimension, @src()); stdx.debug.guard(d != opts.output_batch_dimension, @src());
meta.guard(d != opts.output_feature_dimension, @src()); stdx.debug.guard(d != opts.output_feature_dimension, @src());
meta.guard(0 <= d and d < N, @src()); stdx.debug.guard(0 <= d and d < N, @src());
if (i < opts.output_spatial_dimensions.len - 1) continue; if (i < opts.output_spatial_dimensions.len - 1) continue;
meta.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src()); stdx.debug.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src());
} }
meta.guard(0 < opts.feature_group_count, @src()); stdx.debug.guard(0 < opts.feature_group_count, @src());
meta.guard(0 < opts.batch_group_count, @src()); stdx.debug.guard(0 < opts.batch_group_count, @src());
meta.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src()); stdx.debug.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src());
var used_opts = opts; var used_opts = opts;
used_opts.pad_shape = &.{ @intCast(N - 2), 2 }; used_opts.pad_shape = &.{ @intCast(N - 2), 2 };
used_opts.precision_config = &.{ .DEFAULT, .DEFAULT }; used_opts.precision_config = &.{ .DEFAULT, .DEFAULT };
@ -1042,10 +1043,10 @@ pub const Tensor = struct {
var batching_axes: [MAX_RANK][2]i8 = undefined; var batching_axes: [MAX_RANK][2]i8 = undefined;
var n_batching: u8 = 0; var n_batching: u8 = 0;
for (lhs._shape.tags(), 0..) |l, li| { for (lhs._shape.tags(), 0..) |l, li| {
meta.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs }); stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
for (rhs._shape.tags(), 0..) |r, ri| { for (rhs._shape.tags(), 0..) |r, ri| {
meta.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs }); stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
if (l == r) { if (l == r) {
for (contracting_axes) |ct| { for (contracting_axes) |ct| {
@ -1114,7 +1115,7 @@ pub const Tensor = struct {
contracting_axes: []const [2]i8, contracting_axes: []const [2]i8,
batching_axes: []const [2]i8, batching_axes: []const [2]i8,
) Tensor { ) Tensor {
meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
const Axes = std.BoundedArray(i64, MAX_RANK); const Axes = std.BoundedArray(i64, MAX_RANK);
@ -1124,7 +1125,7 @@ pub const Tensor = struct {
var rhs_batching_axes: Axes = .{}; var rhs_batching_axes: Axes = .{};
for (batching_axes) |b_axes| { for (batching_axes) |b_axes| {
const l, const r = b_axes; const l, const r = b_axes;
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
var t = lhs._shape.tag(l); var t = lhs._shape.tag(l);
if (t == Shape.TagUnknown) t = rhs._shape.tag(r); if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
res_shape = res_shape.appendDim(lhs._shape.dim(l), t); res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
@ -1137,7 +1138,7 @@ pub const Tensor = struct {
var rhs_contracting_axes: Axes = .{}; var rhs_contracting_axes: Axes = .{};
for (contracting_axes) |c_axes| { for (contracting_axes) |c_axes| {
const l, const r = c_axes; const l, const r = c_axes;
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l)); lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r)); rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
} }
@ -1353,7 +1354,7 @@ pub const Tensor = struct {
else else
toI64(axes__); toI64(axes__);
meta.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len }); stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) { if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
return self; return self;
@ -1386,7 +1387,7 @@ pub const Tensor = struct {
/// ///
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3) /// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor { pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
meta.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() }); stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1; const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n }); const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
@ -1443,7 +1444,7 @@ pub const Tensor = struct {
pub fn flatten(self: Tensor, axis_: anytype) Tensor { pub fn flatten(self: Tensor, axis_: anytype) Tensor {
const old_shape = self._shape; const old_shape = self._shape;
const a = self.axis(axis_); const a = self.axis(axis_);
// meta.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
@ -1483,7 +1484,7 @@ pub const Tensor = struct {
var res_shape: Shape = self._shape; var res_shape: Shape = self._shape;
for (slices, 0..) |s, a| { for (slices, 0..) |s, a| {
meta.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a }); stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
const args: Slice = .{ const args: Slice = .{
.start = self.wrapIndex(a, s.start), .start = self.wrapIndex(a, s.start),
@ -1549,7 +1550,7 @@ pub const Tensor = struct {
/// Concatenates the input Tensors along the given axis. /// Concatenates the input Tensors along the given axis.
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor { pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
var buffer: [32]mlir.Value = undefined; var buffer: [32]mlir.Value = undefined;
std.debug.assert(tensors.len <= buffer.len); std.debug.assert(tensors.len <= buffer.len);
std.debug.assert(tensors.len > 0); std.debug.assert(tensors.len > 0);
@ -1576,13 +1577,13 @@ pub const Tensor = struct {
/// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers } /// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers }
pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor { pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor {
// Note: we could ask the compilation context for some memory instead of stack allocating // Note: we could ask the compilation context for some memory instead of stack allocating
meta.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len}); stdx.debug.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len});
const shape0 = tensors[0]._shape; const shape0 = tensors[0]._shape;
const res_shape = shape0.insertTag(axis_, 1, tag); const res_shape = shape0.insertTag(axis_, 1, tag);
for (tensors[1..]) |tensor| { for (tensors[1..]) |tensor| {
meta.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 }); stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
} }
var reshaped: [32]Tensor = undefined; var reshaped: [32]Tensor = undefined;
@ -1623,7 +1624,7 @@ pub const Tensor = struct {
/// Repeats a Tensor several times along the given axes. /// Repeats a Tensor several times along the given axes.
pub fn repeat(self: Tensor, n_reps: []const u63) Tensor { pub fn repeat(self: Tensor, n_reps: []const u63) Tensor {
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2}); // TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
meta.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); stdx.debug.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
var res = self; var res = self;
for (n_reps, 0..) |n_rep, a| { for (n_reps, 0..) |n_rep, a| {
@ -1647,7 +1648,7 @@ pub const Tensor = struct {
/// Repeats in line each value along the given axes. /// Repeats in line each value along the given axes.
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor { pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
meta.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
var res = self; var res = self;
for (n_reps, 0..) |n_rep, a| { for (n_reps, 0..) |n_rep, a| {
@ -1724,8 +1725,8 @@ pub const Tensor = struct {
/// Returns a Tensor containing evenly spaced values within a given interval. /// Returns a Tensor containing evenly spaced values within a given interval.
pub fn arange(args: ArangeArgs, dt: DataType) Tensor { pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt }); const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
@ -1770,9 +1771,9 @@ pub const Tensor = struct {
/// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive. /// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive.
pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor { pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor {
meta.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); stdx.debug.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
meta.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps}); stdx.debug.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps});
meta.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt}); stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt }); const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
@ -1824,7 +1825,7 @@ pub const Tensor = struct {
/// Returns a Tensor containing the result of the outer product between the input Tensors. /// Returns a Tensor containing the result of the outer product between the input Tensors.
pub fn outer(self: Tensor, other: Tensor) Tensor { pub fn outer(self: Tensor, other: Tensor) Tensor {
meta.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() }); stdx.debug.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() });
if (self.rank() + other.rank() == 1) { if (self.rank() + other.rank() == 1) {
return self.mul(other); return self.mul(other);
@ -1856,7 +1857,7 @@ pub const Tensor = struct {
/// Broadcasts a Tensor to the given shape, adding axes at the beginning. /// Broadcasts a Tensor to the given shape, adding axes at the beginning.
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor { pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
meta.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
const a = output_shape.rank() - self.rank(); const a = output_shape.rank() - self.rank();
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) { if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
@ -1868,7 +1869,7 @@ pub const Tensor = struct {
/// Broadcasts a Tensor to the given shape, adding axes at the end. /// Broadcasts a Tensor to the given shape, adding axes at the end.
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor { pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
meta.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) { if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
return self; return self;
@ -1967,7 +1968,7 @@ pub const Tensor = struct {
/// Appends a 1-dim axis, with the given tag. /// Appends a 1-dim axis, with the given tag.
pub fn appendAxes(self: Tensor, t: anytype) Tensor { pub fn appendAxes(self: Tensor, t: anytype) Tensor {
meta.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK }); stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
return self.insertAxes(.last, t); return self.insertAxes(.last, t);
} }
@ -1975,7 +1976,7 @@ pub const Tensor = struct {
/// Drops a 1-dim axis at the given index /// Drops a 1-dim axis at the given index
pub fn squeeze(self: Tensor, axis_: anytype) Tensor { pub fn squeeze(self: Tensor, axis_: anytype) Tensor {
const a = self.axis(axis_); const a = self.axis(axis_);
meta.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)}); stdx.debug.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)});
const new_shape = self._shape.remove(a); const new_shape = self._shape.remove(a);
// log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape }); // log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape });
@ -2023,10 +2024,10 @@ pub const Tensor = struct {
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices }); // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
for (coord_axes_.constSlice(), 0..) |a, i| { for (coord_axes_.constSlice(), 0..) |a, i| {
if (i > 0) { if (i > 0) {
meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
} }
} }
@ -2040,7 +2041,7 @@ pub const Tensor = struct {
// Note: tags are required for batching. // Note: tags are required for batching.
self_kind.appendAssumeCapacity(.batching); self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax); indices_batch_axes.appendAssumeCapacity(id_ax);
meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
} else if (maybe_coord_ax) |_| { } else if (maybe_coord_ax) |_| {
// for gatherValues we collapsed all gathered axes // for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none) // (contrary to gatherSlices where we collapse none)
@ -2057,7 +2058,7 @@ pub const Tensor = struct {
indices.rank() indices.rank()
else blk: { else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
break :blk ax; break :blk ax;
}; };
@ -2124,7 +2125,7 @@ pub const Tensor = struct {
); );
const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
return _result(res_shape, gather_op.result(0)); return _result(res_shape, gather_op.result(0));
} }
@ -2201,16 +2202,16 @@ pub const Tensor = struct {
const tagged_api = slice_shape.isFullyTagged(); const tagged_api = slice_shape.isFullyTagged();
if (tagged_api) { if (tagged_api) {
for (slice_shape.tags()) |t| { for (slice_shape.tags()) |t| {
meta.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self }); stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self });
} }
} else { } else {
// For untagged api, we require all slices to be specified. // For untagged api, we require all slices to be specified.
// Note: we could relax this and right align the slice. // Note: we could relax this and right align the slice.
meta.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
} }
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices }); stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices });
// Compute result shape // Compute result shape
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype()); var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
@ -2228,12 +2229,12 @@ pub const Tensor = struct {
self_batch_axes.appendAssumeCapacity(@intCast(self_ax)); self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t)); indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
slice_dims.set(self_ax, 1); slice_dims.set(self_ax, 1);
meta.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices }); stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices });
} else if (maybe_slice_ax) |slice_ax| { } else if (maybe_slice_ax) |slice_ax| {
// Specified axes contains the start offset of the slices, // Specified axes contains the start offset of the slices,
// and are collected in `start_index_map`. // and are collected in `start_index_map`.
const slice_dim = slice_shape.dim(slice_ax); const slice_dim = slice_shape.dim(slice_ax);
meta.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape }); stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape });
slice_dims.set(self_ax, slice_dim); slice_dims.set(self_ax, slice_dim);
res_shape = res_shape.appendDim(slice_dim, t); res_shape = res_shape.appendDim(slice_dim, t);
start_index_map.appendAssumeCapacity(@intCast(self_ax)); start_index_map.appendAssumeCapacity(@intCast(self_ax));
@ -2395,7 +2396,7 @@ pub const Tensor = struct {
const loc = @src(); const loc = @src();
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates }); // scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); stdx.debug.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
const AxisKind = enum { batching, update_window, inserted_window, window_id }; const AxisKind = enum { batching, update_window, inserted_window, window_id };
@ -2420,7 +2421,7 @@ pub const Tensor = struct {
indices.rank() indices.rank()
else blk: { else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices }); stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
break :blk ax; break :blk ax;
}; };
@ -2435,7 +2436,7 @@ pub const Tensor = struct {
if (self_kind.get(self_ax) == .batching) { if (self_kind.get(self_ax) == .batching) {
up_kind.appendAssumeCapacity(.batching); up_kind.appendAssumeCapacity(.batching);
} else { } else {
meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates }); stdx.debug.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
up_kind.appendAssumeCapacity(.update_window); up_kind.appendAssumeCapacity(.update_window);
} }
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) { } else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
@ -2446,9 +2447,9 @@ pub const Tensor = struct {
} }
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
if (single_coord) { if (single_coord) {
meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); stdx.debug.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
} else { } else {
meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); stdx.debug.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
} }
const ctx = self.getContext(); const ctx = self.getContext();
@ -2671,7 +2672,7 @@ pub const Tensor = struct {
/// * bubbles up Nan /// * bubbles up Nan
/// * in case of equality the smallest index matching the maximum /// * in case of equality the smallest index matching the maximum
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes { pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
meta.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype}); stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
const a = x.axis(axis_); const a = x.axis(axis_);
@ -2870,7 +2871,7 @@ pub const Tensor = struct {
padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } }, padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } },
}) MaxPoolRes { }) MaxPoolRes {
// TODO: rewrite using modern ZML // TODO: rewrite using modern ZML
meta.guard(self.rank() == 3 or self.rank() == 4, @src()); stdx.debug.guard(self.rank() == 3 or self.rank() == 4, @src());
// TODO: support maxPool on non last axis // TODO: support maxPool on non last axis
// Note: the problem is initPoolArg assuming last axis // Note: the problem is initPoolArg assuming last axis
@ -3004,14 +3005,14 @@ pub const Tensor = struct {
} }
pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor { pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor {
meta.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len}); stdx.debug.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len});
const a = self.axis(axis_); const a = self.axis(axis_);
const length = self.dim(a); const length = self.dim(a);
if (split_size_or_sections.len != 1) { if (split_size_or_sections.len != 1) {
var split_sum: i64 = 0; var split_sum: i64 = 0;
for (split_size_or_sections) |n| split_sum += n; for (split_size_or_sections) |n| split_sum += n;
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); stdx.debug.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
} }
const res = try allocator.alloc(Tensor, split_size_or_sections.len); const res = try allocator.alloc(Tensor, split_size_or_sections.len);
@ -3029,7 +3030,7 @@ pub const Tensor = struct {
/// Note: this doesn't support tagging, if you have tags, /// Note: this doesn't support tagging, if you have tags,
/// you should use `dynamicSlice` directly. /// you should use `dynamicSlice` directly.
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
const a = self.axis(axis_); const a = self.axis(axis_);
const new_shape = self._shape.set(a, len); const new_shape = self._shape.set(a, len);
@ -3087,17 +3088,17 @@ pub const Tensor = struct {
const offset = slice_.start; const offset = slice_.start;
const len = slice_.len; const len = slice_.len;
if (slices_tags.len == 0) { if (slices_tags.len == 0) {
meta.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len }); stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len });
offset_values[i] = offset.value(); offset_values[i] = offset.value();
res_shape._dims.set(i, len); res_shape._dims.set(i, len);
meta.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i }); stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i });
} else { } else {
const t = slices_tags.get(i); const t = slices_tags.get(i);
const a = res_shape.hasTag(t) orelse meta.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape }); const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape });
meta.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t }); stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t });
offset_values[a] = offset.value(); offset_values[a] = offset.value();
res_shape._dims.set(a, len); res_shape._dims.set(a, len);
@ -3149,12 +3150,12 @@ pub const Tensor = struct {
/// ``` /// ```
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor { pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
// TODO: add updateSlice for when the offset isn't dynamic // TODO: add updateSlice for when the offset isn't dynamic
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() }); stdx.debug.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_); const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
// log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags }); // log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags });
for (offset.constSlice(), 0..) |start_idx, i| { for (offset.constSlice(), 0..) |start_idx, i| {
meta.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i }); stdx.debug.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i });
} }
const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0; const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0;
@ -3164,14 +3165,14 @@ pub const Tensor = struct {
if (tagged_api) { if (tagged_api) {
// Check that all update tags are known. // Check that all update tags are known.
for (update._shape._tags.constSlice()) |t| { for (update._shape._tags.constSlice()) |t| {
meta.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape }); stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape });
} }
var update_shape = self._shape; var update_shape = self._shape;
var prev_ax: i8 = -1; var prev_ax: i8 = -1;
for (self._shape.tags(), 0..) |t, self_ax| { for (self._shape.tags(), 0..) |t, self_ax| {
if (update._shape.hasTag(t)) |up_ax| { if (update._shape.hasTag(t)) |up_ax| {
meta.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape }); stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape });
update_shape._dims.set(self_ax, update.dim(up_ax)); update_shape._dims.set(self_ax, update.dim(up_ax));
prev_ax = up_ax; prev_ax = up_ax;
@ -3182,16 +3183,16 @@ pub const Tensor = struct {
update = update.reshape(update_shape); update = update.reshape(update_shape);
} }
meta.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() }); stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() });
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| { for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
const t = self._shape.debugTag(ax); const t = self._shape.debugTag(ax);
meta.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t }); stdx.debug.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t });
if (tagged_api and up_d < self_d) { if (tagged_api and up_d < self_d) {
const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null; const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null;
meta.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t }); stdx.debug.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t });
} }
} }
@ -3200,7 +3201,7 @@ pub const Tensor = struct {
var offset_values: [MAX_RANK]mlir.Value = undefined; var offset_values: [MAX_RANK]mlir.Value = undefined;
if (offset_tags.len == 0) { if (offset_tags.len == 0) {
// Without offset tags we need the same number of offset than rank. // Without offset tags we need the same number of offset than rank.
meta.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len }); stdx.debug.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len });
for (offset.constSlice(), 0..) |idx, i| { for (offset.constSlice(), 0..) |idx, i| {
offset_values[i] = idx.value(); offset_values[i] = idx.value();
@ -3210,7 +3211,7 @@ pub const Tensor = struct {
// This is only allowed when using tagged sliced. // This is only allowed when using tagged sliced.
offset_values = .{zero} ** MAX_RANK; offset_values = .{zero} ** MAX_RANK;
for (offset.constSlice(), offset_tags.constSlice()) |start, t| { for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
const a = self._shape.hasTag(t) orelse meta.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape }); const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape });
offset_values[a] = start.value(); offset_values[a] = start.value();
} }
} }
@ -3329,12 +3330,12 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors. /// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor { pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
meta.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other); if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{})); if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
meta.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
const op = dialect.stablehlo.compare( const op = dialect.stablehlo.compare(
@ -3352,7 +3353,7 @@ pub const Tensor = struct {
/// For each vector in the input tensor, /// For each vector in the input tensor,
/// creates a diagonal-matrix where diagonal values are set to the vector values. /// creates a diagonal-matrix where diagonal values are set to the vector values.
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor { pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
const a = self.axis(axis_); const a = self.axis(axis_);
const d = self.dim(a); const d = self.dim(a);
var res_shape = self._shape; var res_shape = self._shape;
@ -3409,7 +3410,7 @@ pub const Tensor = struct {
/// To get the upper triangular part, swap the order of axes: /// To get the upper triangular part, swap the order of axes:
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);` /// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor { pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor {
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); stdx.debug.assertComptime(stdx.meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
const _axes = self.axes(axes_); const _axes = self.axes(axes_);
const x = Tensor.iota(self.shape(), _axes.get(0)); const x = Tensor.iota(self.shape(), _axes.get(0));
@ -3472,8 +3473,8 @@ pub const Tensor = struct {
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]` /// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]` /// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor { pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()}); stdx.debug.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() }); stdx.debug.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
if (bool_tensor.rank() != 0 and on_true.rank() == 0) { if (bool_tensor.rank() != 0 and on_true.rank() == 0) {
return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false); return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false);
@ -3482,8 +3483,8 @@ pub const Tensor = struct {
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape())); return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
} }
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
const loc = bool_tensor.getContext().mlirCtx().location(@src()); const loc = bool_tensor.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.select( const op = dialect.stablehlo.select(
@ -3538,11 +3539,11 @@ pub const Tensor = struct {
} }
fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void { fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void {
meta.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{}); stdx.debug.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{});
meta.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len }); stdx.debug.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len });
for (vectors) |x| { for (vectors) |x| {
meta.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors}); stdx.debug.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors});
meta.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors}); stdx.debug.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors});
} }
var res_shape = Shape.init(.{}, vectors[0].dtype()); var res_shape = Shape.init(.{}, vectors[0].dtype());
@ -3645,7 +3646,7 @@ pub const Tensor = struct {
) fn (Tensor, Tensor) Tensor { ) fn (Tensor, Tensor) Tensor {
return struct { return struct {
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor { pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
if (self.rank() == 0 and other.rank() != 0) { if (self.rank() == 0 and other.rank() != 0) {
return binaryOpHelper(self.broad(other._shape), other); return binaryOpHelper(self.broad(other._shape), other);
@ -3655,7 +3656,7 @@ pub const Tensor = struct {
return binaryOpHelper(self, other.broad(self._shape)); return binaryOpHelper(self, other.broad(self._shape));
} }
meta.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
const mlirCtx = self.getContext().mlirCtx(); const mlirCtx = self.getContext().mlirCtx();
const location = mlirCtx.location(@src()); const location = mlirCtx.location(@src());

View File

@ -1,8 +1,8 @@
//! Test runner for unit test based on https://github.com/ziglang/zig/blob/master/lib/compiler/test_runner.zig with async //! Test runner for unit test based on https://github.com/ziglang/zig/blob/master/lib/compiler/test_runner.zig with async
const builtin = @import("builtin");
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const builtin = @import("builtin");
const std = @import("std");
const io = std.io; const io = std.io;
const testing = std.testing; const testing = std.testing;
const assert = std.debug.assert; const assert = std.debug.assert;
@ -21,10 +21,10 @@ var fba = std.heap.FixedBufferAllocator.init(&fba_buffer);
pub fn main() anyerror!void { pub fn main() anyerror!void {
testing.log_level = log_level; testing.log_level = log_level;
try asynk.AsyncThread.main(testing.allocator, asyncMain, .{}); try asynk.AsyncThread.main(testing.allocator, asyncMain);
} }
pub fn asyncMain() void { pub fn asyncMain() !void {
const test_fn_list: []const std.builtin.TestFn = builtin.test_functions; const test_fn_list: []const std.builtin.TestFn = builtin.test_functions;
var ok_count: usize = 0; var ok_count: usize = 0;
var skip_count: usize = 0; var skip_count: usize = 0;

View File

@ -1,11 +1,12 @@
const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml.zig"); const zml = @import("zml.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const shapesOf = @import("tensor.zig").shapesOf; const shapesOf = @import("tensor.zig").shapesOf;
const log = std.log.scoped(.zml_testing); const log = std.log.scoped(.@"zml/testing");
var _ctx: ?zml.Context = null; var _ctx: ?zml.Context = null;
@ -128,7 +129,7 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
/// Compile a function and immediatly call it with the given buffers. /// Compile a function and immediatly call it with the given buffers.
/// The compiled module is discarded after the call. /// The compiled module is discarded after the call.
/// Useful during testing when a module is typically called only once. /// Useful during testing when a module is typically called only once.
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) { pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
// This simplify test API and also ensure this fn isn't used outside of tests. // This simplify test API and also ensure this fn isn't used outside of tests.
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
var arena = std.heap.ArenaAllocator.init(allocator); var arena = std.heap.ArenaAllocator.init(allocator);
@ -139,7 +140,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
return x.shape(); return x.shape();
} }
}; };
var shape_args: zml.ShapeOf(meta.FnParams(func)) = undefined; var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args); try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args);
const mod = try zml.compileFn(allocator, func, shape_args, platform); const mod = try zml.compileFn(allocator, func, shape_args, platform);
@ -151,7 +152,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
/// Compile a function and immediatly call it with the given buffers. /// Compile a function and immediatly call it with the given buffers.
/// The compiled module is discarded after the call. /// The compiled module is discarded after the call.
/// Useful during testing when a module is typically called only once. /// Useful during testing when a module is typically called only once.
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(meta.FnParams(func)), buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) { pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
// This simplify test API and also ensure this fn isn't used outside of tests. // This simplify test API and also ensure this fn isn't used outside of tests.
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
var arena = std.heap.ArenaAllocator.init(allocator); var arena = std.heap.ArenaAllocator.init(allocator);

View File

@ -1,13 +1,15 @@
//! Text tokenizer implementations //! Text tokenizer implementations
const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const testing = std.testing; const std = @import("std");
const stdx = @import("stdx");
const log = std.log.scoped(.zml_tokenizer); const testing = std.testing;
const helpers = @import("helpers.zig"); const helpers = @import("helpers.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const log = std.log.scoped(.@"zml/tokenizer");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
std.testing.refAllDecls(Normalizer); std.testing.refAllDecls(Normalizer);
@ -202,7 +204,7 @@ pub const Tokenizer = struct {
// Detects memory corruption of tokens. // Detects memory corruption of tokens.
if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !"); if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !");
meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] }); stdx.debug.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
} }
const next_tok = self.tokens[tok_buff[i + 1]]; const next_tok = self.tokens[tok_buff[i + 1]];
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token. // if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.

View File

@ -1,9 +1,11 @@
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.zml_torch); const stdx = @import("stdx");
const zml = @import("zml.zig"); const zml = @import("zml.zig");
const Tensor = zml.Tensor; const Tensor = zml.Tensor;
const meta = zml.meta;
const log = std.log.scoped(.zml_torch);
/// Multiplies a matrix or a vector with a tensor, /// Multiplies a matrix or a vector with a tensor,
/// following the semantic of pytorch `@` operator. /// following the semantic of pytorch `@` operator.
@ -14,7 +16,7 @@ const meta = zml.meta;
/// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}`
/// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}`
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
meta.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }}; const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
if (lhs.rank() == 1 or rhs.rank() <= 2) { if (lhs.rank() == 1 or rhs.rank() <= 2) {
@ -22,7 +24,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
return lhs.dotGeneral(rhs, &contracting, &.{}); return lhs.dotGeneral(rhs, &contracting, &.{});
} }
meta.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
// Pytorch treats the extra dimensions of rhs has batching dimensions, // Pytorch treats the extra dimensions of rhs has batching dimensions,
// and implicitly broadcast lhs along those. // and implicitly broadcast lhs along those.
@ -91,7 +93,7 @@ pub fn unsqueeze(
self: Tensor, self: Tensor,
axis_: anytype, axis_: anytype,
) Tensor { ) Tensor {
meta.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self}); stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
const a = switch (@typeInfo(@TypeOf(axis_))) { const a = switch (@typeInfo(@TypeOf(axis_))) {
.Int, .ComptimeInt => if (axis_ < 0) .Int, .ComptimeInt => if (axis_ < 0)
@as(i8, self.rank()) + 1 + axis_ @as(i8, self.rank()) + 1 + axis_
@ -125,9 +127,9 @@ test unsqueeze {
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle /// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor { pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
const shape = tensor.shape(); const shape = tensor.shape();
meta.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
meta.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor }); const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w }); const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });
@ -173,7 +175,7 @@ test pixelShuffle {
/// ref: https://pytorch.org/docs/stable/generated/torch.roll.html /// ref: https://pytorch.org/docs/stable/generated/torch.roll.html
pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor { pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor {
// TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 }) // TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 })
meta.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len }); stdx.debug.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len });
if (shifts.len != 1 or axes_.len != 1) { if (shifts.len != 1 or axes_.len != 1) {
const tail_shifts = shifts[1..shifts.len]; const tail_shifts = shifts[1..shifts.len];
@ -219,8 +221,8 @@ pub const MeshgridIndexing = enum { xy, ij };
/// * for ij indexing, outputs are of shape (M, N, P) /// * for ij indexing, outputs are of shape (M, N, P)
/// * for xy indexing, outputs are of shape (N, M, P) /// * for xy indexing, outputs are of shape (N, M, P)
pub fn meshgrid(comptime N: u3, vectors: [N]Tensor, indexing: MeshgridIndexing) [N]Tensor { pub fn meshgrid(comptime N: u3, vectors: [N]Tensor, indexing: MeshgridIndexing) [N]Tensor {
meta.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{}); stdx.debug.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{});
meta.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len}); stdx.debug.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len});
if (vectors.len == 1) return vectors; if (vectors.len == 1) return vectors;