From cc969bd53296892ede27c8839cdb7d91b49ca0f9 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 20 Aug 2025 10:27:54 +0000 Subject: [PATCH] Add experimental `zml.callback` API (renamed from custom_call) and fix `tensor.print()`; update PJRT bindings, host buffer utilities, and related core ZML modules. --- pjrt/ffi.zig | 148 ++++++++---------- pjrt/pjrt.zig | 32 ++-- runtimes/cuda/BUILD.bazel | 2 +- stdx/fmt.zig | 160 ++++++++----------- zml/BUILD.bazel | 3 +- zml/buffer.zig | 57 +++++-- zml/callback.zig | 313 ++++++++++++++++++++++++++++++++++++++ zml/context.zig | 172 +++------------------ zml/exe.zig | 49 ++++-- zml/hostbuffer.zig | 37 ++--- zml/module.zig | 8 +- zml/ops.zig | 27 ---- zml/pjrtx.zig | 11 +- zml/platform.zig | 16 +- zml/tensor.zig | 18 +-- zml/testing.zig | 7 +- zml/zml.zig | 3 +- 17 files changed, 596 insertions(+), 467 deletions(-) create mode 100644 zml/callback.zig diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig index d2517a0..2ca639a 100644 --- a/pjrt/ffi.zig +++ b/pjrt/ffi.zig @@ -2,12 +2,21 @@ const std = @import("std"); const c = @import("c"); +pub const TypeId = c.XLA_FFI_TypeId; const stdx = @import("stdx"); -const pjrtStruct = @import("pjrt.zig").pjrtStruct; +const pjrt = @import("pjrt.zig"); +const Stream = @import("pjrt.zig").Stream; const log = std.log.scoped(.pjrt); +comptime { + if (@typeInfo(TypeId).@"struct".fields.len != 1) @compileError("TypeId has changed"); +} + +/// The signature of a generic custom call. +pub const Handler = fn (*CallFrame) callconv(.c) ?*Error; + pub const ApiVersion = extern struct { pub const major = c.XLA_FFI_API_MAJOR; pub const minor = c.XLA_FFI_API_MINOR; @@ -29,13 +38,13 @@ pub const ExtensionBase = extern struct { }; // Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449 -pub const HandlerTraits = packed struct(u32) { +pub const HandlerTraits = packed struct(c_uint) { /// Calls to FFI handler are safe to trace into the command buffer. /// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values) /// that can be captured and then replayed. - command_buffer_compatible: u1, + command_buffer_compatible: bool, - __unassigned__: u31, + __unassigned__: u31 = 0, }; pub const Metadata = extern struct { @@ -49,25 +58,6 @@ pub const MetadataExtension = extern struct { metadata: ?*Metadata, }; -pub const ApiError = error{ - Cancelled, - Unknown, - InvalidArgument, - DeadlineExceeded, - NotFound, - AlreadyExists, - PermissionDenied, - ResourceExhausted, - FailedPrecondition, - Aborted, - OutOfRange, - Unimplemented, - Internal, - Unavailable, - DataLoss, - Unauthenticated, -}; - fn TransmuteMixin(comptime T: type, comptime InnerT: type) type { return struct { pub fn to(self: anytype) switch (@TypeOf(self)) { @@ -91,8 +81,8 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type { pub const Api = opaque { pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to; - pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream { - var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{ + pub fn stream(self: *const Api, context: *const ExecutionContext) *pjrt.Stream { + var ret = pjrt.pjrtStruct(c.XLA_FFI_Stream_Get_Args{ .ctx = @constCast(context.inner()), }); const result = self.inner().XLA_FFI_Stream_Get.?(&ret); @@ -107,8 +97,8 @@ pub const Api = opaque { return @ptrCast(ret.stream.?); } - pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque { - var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{ + pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) pjrt.ApiError!*anyopaque { + var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{ .ctx = @constCast(context.inner()), .size = size, .alignment = alignment, @@ -127,8 +117,8 @@ pub const Api = opaque { return ret.data.?; } - pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void { - var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{ + pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) pjrt.ApiError!void { + var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{ .ctx = @constCast(context.inner()), .size = size, .data = data, @@ -163,33 +153,31 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) { pub const ExecutionContext = opaque { pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to; - pub fn Context(comptime T: type) type { - return struct { - pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T { - const type_id: TypeId = .{ .type_id = T.type_id }; - var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{ - .ctx = @constCast(self.inner()), - .type_id = @constCast(&type_id.toCStruct()), - }); - const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); - - if (result) |ffi_error| { - const err = Error.fromInner(ffi_error); - defer err.destroy(api); - log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); - - // TODO(Corentin): Retrieve error code from Error when implemented in XLA. - return error.Unknown; - } - - if (ret.data == null) return error.NotFound; - return @ptrCast(@alignCast(ret.data.?)); - } + pub fn getContext(self: *const ExecutionContext, type_id: TypeId, api: *const Api) pjrt.ApiError!*anyopaque { + var ret: c.XLA_FFI_ExecutionContext_Get_Args = .{ + .struct_size = pjrt.pjrtStructSize(c.XLA_FFI_ExecutionContext_Get_Args), + .extension_start = api.inner().extension_start, + .ctx = @ptrCast(@constCast(self)), + .type_id = @constCast(&type_id), + .data = undefined, // set by XLA_FFI_ExecutionContext_Get. }; + const maybe_err = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); + + if (maybe_err) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + + if (ret.data == null) return error.NotFound; + return ret.data.?; } - pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 { - var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{ + pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) pjrt.ApiError!i32 { + var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{ .ctx = @constCast(self.inner()), }); const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret); @@ -206,8 +194,10 @@ pub const ExecutionContext = opaque { return ret.device_ordinal; } - pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void { - var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{ + const Task = fn (*anyopaque) void; + + pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) pjrt.ApiError!void { + var ret = pjrt.pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{ .ctx = @constCast(self.inner()), .task = @ptrCast(@alignCast(task)), .data = @ptrCast(@alignCast(data)), @@ -225,23 +215,9 @@ pub const ExecutionContext = opaque { return error.Unknown; } } - - fn getTypeId(type_name: []const u8) TypeId { - const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name)); - - return .{ - .type_id = id, - }; - } }; -const TypeId = c.XLA_FFI_TypeId; - -const Task = fn (*anyopaque) void; - -const Stream = @import("pjrt.zig").Stream; - -const ByteSpan = extern struct { +pub const ByteSpan = extern struct { ptr: [*]const u8, len: usize, @@ -252,7 +228,7 @@ const ByteSpan = extern struct { pub const DataType = enum(c.XLA_FFI_DataType) { invalid = c.XLA_FFI_DataType_INVALID, - pred = c.XLA_FFI_DataType_PRED, + bool = c.XLA_FFI_DataType_PRED, i8 = c.XLA_FFI_DataType_S8, i16 = c.XLA_FFI_DataType_S16, i32 = c.XLA_FFI_DataType_S32, @@ -399,6 +375,8 @@ pub const Attrs = extern struct { } }; +/// All informations needed by the user callback, +/// including the list of input/ouput buffers to work on. pub const CallFrame = extern struct { struct_size: usize, extension_start: ?*ExtensionBase, @@ -422,9 +400,11 @@ pub const CallFrame = extern struct { } return false; } -}; -pub const Handler = fn (*CallFrame) callconv(.c) ?*Error; + pub fn stream(call_frame: CallFrame) ?*const pjrt.Stream { + return call_frame.api.stream(call_frame.ctx); + } +}; pub const ErrorCode = enum(c.XLA_FFI_Error_Code) { cancelled = c.XLA_FFI_Error_Code_CANCELLED, @@ -444,7 +424,7 @@ pub const ErrorCode = enum(c.XLA_FFI_Error_Code) { data_loss = c.XLA_FFI_Error_Code_DATA_LOSS, unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED, - pub fn toApiError(code: ErrorCode) ApiError { + pub fn toApiError(code: ErrorCode) pjrt.ApiError { return switch (code) { .cancelled => error.Cancelled, .unknown => error.Unknown, @@ -470,8 +450,10 @@ pub const Error = opaque { pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to; pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from; + pub const ok: ?*Error = null; + pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error { - var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{ + var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Create_Args{ .message = message.ptr, .errc = @intFromEnum(error_code), }); @@ -479,12 +461,12 @@ pub const Error = opaque { } pub fn destroy(err: *Error, api: *const Api) void { - var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() }); + var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() }); api.inner().XLA_FFI_Error_Destroy.?(&ret); } pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 { - var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{ + var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{ .@"error" = err.inner(), }); api.inner().XLA_FFI_Error_GetMessage.?(&ret); @@ -496,8 +478,8 @@ pub const Future = opaque { pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to; pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from; - pub fn create(api: *const Api) ApiError!*Future { - var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{}); + pub fn create(api: *const Api) pjrt.ApiError!*Future { + var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_Create_Args{}); const result = api.inner().XLA_FFI_Future_Create.?(&ret); if (result) |ffi_error| { @@ -512,8 +494,8 @@ pub const Future = opaque { return fromInner(ret.future.?); } - pub fn setAvailable(self: *Future, api: *const Api) ApiError!void { - var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{ + pub fn setAvailable(self: *Future, api: *const Api) pjrt.ApiError!void { + var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{ .future = self.inner(), }); @@ -529,8 +511,8 @@ pub const Future = opaque { } } - pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void { - var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{ + pub fn setError(self: *Future, api: *const Api, err: *Error) pjrt.ApiError!void { + var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetError_Args{ .future = self.inner(), .@"error" = err.inner(), }); diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 6e35856..08ccf21 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -20,7 +20,7 @@ test { // as the way PJRT does it is not very robust. // // 1. https://github.com/openxla/xla/issues/10032 -fn pjrtStructSize(comptime T: type) usize { +pub fn pjrtStructSize(comptime T: type) usize { // unsafe on purpose, we want this to fail if that ever changes const typedef_name = comptime blk: { const needle = ".struct_"; @@ -164,7 +164,7 @@ pub const Api = struct { return @ptrCast(ret.context.?); } - pub fn ffi(api: *const Api) ?FFI { + pub fn ffi(api: *const Api) ?Ffi { if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| { return .{ .inner = ext }; } @@ -1278,7 +1278,7 @@ pub const NamedValue = extern struct { } }; -pub const FFI = extern struct { +pub const Ffi = extern struct { inner: *const c.PJRT_FFI, pub const UserData = extern struct { @@ -1293,19 +1293,15 @@ pub const FFI = extern struct { } }; - pub const RegisterFfiOptions = struct { - traits: RegisterHandlerTraits = @enumFromInt(0), - }; - // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 pub fn register( - self: *const FFI, + self: *const Ffi, api: *const Api, target_name: []const u8, platform_name: []const u8, func: *const ffi.Handler, - options: RegisterFfiOptions, + traits: ffi.HandlerTraits, ) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ .target_name = target_name.ptr, @@ -1313,7 +1309,7 @@ pub const FFI = extern struct { .handler = @ptrCast(@constCast(func)), .platform_name = platform_name.ptr, .platform_name_size = platform_name.len, - .traits = @intFromEnum(options.traits), + .traits = @bitCast(traits), }); const result = self.inner.register_handler.?(&ret); if (result) |pjrt_c_error| { @@ -1323,8 +1319,7 @@ pub const FFI = extern struct { } } - pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void { - const type_name = @typeName(T); + pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId { var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{ .type_name = type_name.ptr, .type_name_size = type_name.len, @@ -1336,10 +1331,10 @@ pub const FFI = extern struct { return pjrt_error.getCode(api).toApiError(); } - T.type_id = ret.type_id; + return .{ .type_id = ret.type_id }; } - pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void { + pub fn addUserData(self: *const Ffi, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{ .context = @ptrCast(context), .user_data = user_data.toCStruct(), @@ -1352,12 +1347,3 @@ pub const FFI = extern struct { } } }; - -pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) { - command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE, - _, -}; - -pub const CustomCallRegistry = extern struct { - inner: *const c.PJRT_FFI_Register_Handler, -}; diff --git a/runtimes/cuda/BUILD.bazel b/runtimes/cuda/BUILD.bazel index 96cc587..83109cf 100644 --- a/runtimes/cuda/BUILD.bazel +++ b/runtimes/cuda/BUILD.bazel @@ -5,7 +5,7 @@ zig_shared_library( name = "zmlxcuda", # Use Clang's compiler-rt, but disable stack checking # to avoid requiring on the _zig_probe_stack symbol. - copts = ["-fno-stack-check"], + copts = ["-fno-stack-check", "-fllvm"], main = "zmlxcuda.zig", shared_lib_name = "libzmlxcuda.so.0", visibility = ["@libpjrt_cuda//:__subpackages__"], diff --git a/stdx/fmt.zig b/stdx/fmt.zig index 259b3cd..145db76 100644 --- a/stdx/fmt.zig +++ b/stdx/fmt.zig @@ -1,145 +1,117 @@ const std = @import("std"); -pub const Fmt = union(enum) { - int: IntFmt, - float: FloatFmt, - generic: void, +pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) { + return .{ .slice = any_slice }; +} - pub fn parse(T: type, comptime fmt_: []const u8) Fmt { - return switch (@typeInfo(T)) { - .float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) }, - .int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) }, - else => .{ .generic = {} }, - }; - } -}; +fn FmtSlice(T: type) type { + return struct { + slice: []const T, -pub const FullFormatOptions = struct { - fmt: Fmt, - options: std.fmt.FormatOptions, -}; - -pub const IntFmt = struct { - base: u8, - case: std.fmt.Case = .lower, - - pub fn parseComptime(comptime fmt_: []const u8) IntFmt { - return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_); - } - - pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt { - return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d")) - .{ .base = 10, .case = .lower } - else if (std.mem.eql(u8, fmt_, "x")) - .{ .base = 16, .case = .lower } - else if (std.mem.eql(u8, fmt_, "X")) - .{ .base = 16, .case = .upper } - else if (std.mem.eql(u8, fmt_, "o")) - .{ .base = 8, .case = .upper } - else - // TODO: unicode/ascii - error.InvalidArgument; - } -}; - -pub const FloatFmt = enum(u8) { - scientific = @intFromEnum(std.fmt.Number.Mode.scientific), - decimal = @intFromEnum(std.fmt.Number.Mode.decimal), - hex, - - pub fn parseComptime(comptime fmt_: []const u8) FloatFmt { - return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_); - } - - pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt { - return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e")) - .scientific - else if (std.mem.eql(u8, fmt_, "d")) - .decimal - else if (std.mem.eql(u8, fmt_, "x")) - .hex - else - error.InvalidArgument; - } -}; - -pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { - return switch (@typeInfo(@TypeOf(value))) { - .comptime_float, .float => try formatFloatValue(value, full, writer), - .comptime_int, .int => try formatIntValue(value, full, writer), - else => try formatAnyValue(value, full, writer), + pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { + return switch (@typeInfo(T)) { + .comptime_float, .float => try formatFloatSlice(f.slice, n, writer), + .comptime_int, .int => try formatIntSlice(f.slice, n, writer), + .bool => try formatBoolSlice(f.slice, n, writer), + .@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) { + try formatComplexSlice(f.slice, n, writer); + } else if (@hasDecl(T, "toF32")) { + try formatFloatSlice(f.slice, n, writer); + } else { + try formatSliceAny(f.slice, n, writer); + }, + else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)), + }; + } }; } -pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { +pub fn formatFloat(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { const x = switch (@typeInfo(@TypeOf(value))) { .@"struct" => value.toF32(), .float => value, - else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))), - }; - try switch (full.fmt.float) { - .scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }), - .decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }), - .hex => writer.printFloatHexOptions(x, .{ .mode = .hex }), + else => @compileError("formatFloat expects a float, got: " ++ @typeName(@TypeOf(value))), }; + return writer.printFloat(x, spec); } -pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { +pub fn formatInt(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { switch (@typeInfo(@TypeOf(value))) { .int => {}, - else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))), + else => @compileError("formatInt expects an int, got: " ++ @typeName(@TypeOf(value))), } - return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options); + return writer.printInt(value, spec.mode.base().?, spec.case, .{ .alignment = spec.alignment, .fill = spec.fill }); } -pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { +pub fn formatComplex(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + try writer.writeAll(".{.re="); + try writer.printFloat(value.re, spec); + try writer.writeAll(", .im="); + try writer.printFloat(value.im, spec); + try writer.writeAll("}"); +} + +pub fn formatBool(value: bool, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + try writer.alignBufferOptions(if (value) "1" else "0", .{ .alignment = spec.alignment, .fill = spec.fill }); +} + +pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { var buf: [48]u8 = undefined; - const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: { + const T = @TypeOf(value); + const fmt = if (@hasDecl(T, "formatNumber")) "{d}" else "{f}"; + + const s = std.fmt.bufPrint(&buf, fmt, .{value}) catch blk: { buf[45..].* = "...".*; break :blk buf[0..]; }; - return try writer.alignBufferOptions(s, full.options); + return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill }); } -pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void { - - // Write first rows - const num_cols: usize = full.options.width orelse 12; +pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + // use the format "width" for the number of columns instead of individual width. + const num_cols: usize = spec.width orelse 12; + var my_options = spec; + my_options.width = null; const n: usize = values.len; + _ = try writer.write("{"); if (n <= num_cols) { for (values, 0..) |v, i| { // Force inlining so that the switch and the buffer can be done once. - try @call(.always_inline, fmt_func, .{ v, full, writer }); + try @call(.always_inline, fmt_func, .{ v, my_options, writer }); if (i < n - 1) _ = try writer.write(","); } } else { const half = @divFloor(num_cols, 2); for (values[0..half]) |v| { - try @call(.always_inline, fmt_func, .{ v, full, writer }); + try @call(.always_inline, fmt_func, .{ v, my_options, writer }); _ = try writer.write(","); } _ = try writer.write(" ..., "); for (values[n - half ..], 0..) |v, i| { - try @call(.always_inline, fmt_func, .{ v, full, writer }); + try @call(.always_inline, fmt_func, .{ v, my_options, writer }); if (i < half - 1) _ = try writer.write(","); } } _ = try writer.write("}"); } -pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void { - return try formatSliceCustom(formatAnyValue, values, full, writer); +pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatAny, values, spec, writer); } -pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { - return try formatSliceCustom(formatFloatValue, values, full, writer); +pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatFloat, values, spec, writer); } -pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { - return try formatSliceCustom(formatIntValue, values, full, writer); +pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatInt, values, spec, writer); } -pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { - return try formatSliceCustom(formatAnyValue, values, full, writer); +pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatComplex, values, spec, writer); +} + +pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { + return try formatSliceCustom(formatBool, values, spec, writer); } diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index aaa44f0..3d2ae08 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -31,6 +31,7 @@ zig_library( "aio/torch/py.zig", "buffer.zig", "context.zig", + "callback.zig", "dtype.zig", "exe.zig", "floats.zig", @@ -53,7 +54,7 @@ zig_library( "torch.zig", "zml.zig", ], - copts = ["-lc"], + copts = ["-lc", "-freference-trace=20"], main = "zml.zig", visibility = ["//visibility:public"], deps = [ diff --git a/zml/buffer.zig b/zml/buffer.zig index 3553bb9..f502d70 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -49,7 +49,7 @@ pub const Buffer = struct { pub const FromOptions = struct { wait: bool = true, - memory: ?pjrt.Memory.Kind = null, + memory: ?Memory = null, }; /// Copies the content of the given buffer from host memory to the accelerator memory. @@ -89,15 +89,20 @@ pub const Buffer = struct { .byte_strides = byte_strides, .host_buffer_semantics = .ImmutableUntilTransferCompletes, }; - if (opts.memory) |memory_kind| { - const memories = try devices[i].addressableMemories(platform.pjrt_api); - const memory = for (memories) |m| { - const kind = m.kind(platform.pjrt_api); - if (kind == memory_kind) break m; - } else return error.NotFound; - args.memory = memory; - } else { + if (platform.target == .cpu or opts.memory == null) { args.device = devices[i]; + } else { + const memory = opts.memory.?; + const device_memories = try devices[i].addressableMemories(platform.pjrt_api); + // TODO measure the cost of this and consider caching on Zig side inside the platform. + const selected_memory = for (device_memories) |m| { + const kind = m.kind(platform.pjrt_api); + if (kind == memory.toPjrtMemory()) break m; + } else { + log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) }); + return error.NotFound; + }; + args.memory = selected_memory; } const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args); @@ -179,10 +184,10 @@ pub const Buffer = struct { return try from(platform, host_buffer, opts); } - pub fn asPinnedHostBuffer(self: Buffer) HostBuffer { - // TODO restore assert + pub fn asHostBuffer(self: Buffer) HostBuffer { + // TODO: skip this check on cpu // const memory = self.getMemory().kind(self._api); - // stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory }); + // stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory }); const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable); return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]); } @@ -299,6 +304,12 @@ pub const Buffer = struct { }; } + pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 { + stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{}); + const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable; + return @ptrCast(opaque_ptr); + } + /// Fetches the content of the given buffer into a stack variable of the given type. pub fn getValue(self: Buffer, T: type) !T { stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); @@ -390,13 +401,31 @@ pub const Buffer = struct { return @reduce(.Or, self._shape._sharding_info); } - pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer { + pub const CopyToMemoryOpts = struct { + wait: bool = true, + }; + + pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer { + const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory()); + if (pjrt_memory == null) { + stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self }); + } + var new_shards: Buffer.Shards = .{}; for (self._shards.slice()) |shard| { - const new_shard = try shard.copyToMemory(self._api, memory); + const new_shard = try shard.copyToMemory(self._api, pjrt_memory.?); new_shards.appendAssumeCapacity(new_shard); } + if (opts.wait) { + for (new_shards.constSlice()) |shard| { + const event = shard.getReadyEvent(self._api); + if (event) |e| { + try e.awaitBlocking(self._api); + } + } + } + return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api }; } diff --git a/zml/callback.zig b/zml/callback.zig new file mode 100644 index 0000000..4d6ab90 --- /dev/null +++ b/zml/callback.zig @@ -0,0 +1,313 @@ +const std = @import("std"); + +const asynk = @import("async"); +const mlir = @import("mlir"); +const pjrt = @import("pjrt"); +const stablehlo = @import("mlir/dialects").stablehlo; +const stdx = @import("stdx"); + +const Buffer = @import("buffer.zig").Buffer; +const CompilationContext = @import("module.zig").CompilationContext; +const DataType = @import("dtype.zig").DataType; +const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const mlirx = @import("mlirx.zig"); +const pjrtx = @import("pjrtx.zig"); +const Platform = @import("platform.zig").Platform; +const Shape = @import("shape.zig").Shape; +const Tensor = @import("tensor.zig").Tensor; + +const log = std.log.scoped(.@"zml/callback"); + +/// Inserts a user-defined callback into the computation graph. +/// The callback is defined with a struct, that store runtime information needed by the callback. +/// +/// ❗Experimental API❗ +/// +/// ```zig +/// pub const MyCallback = struct { +/// // a unique type_id will be set by the PJRT plugin during registration. +/// pub var type_id: pjrt.ffi.TypeId = undefined; +/// +/// pub const callback_config: zml.callback.Config = .{ +/// // assumption this custom call makes about the input / output buffers +/// }; +/// +/// // Required, this will tell the callback in which env it runs. +/// platform: zml.Platform, +/// // data needed by the callback +/// my_data: []const u8, +/// +/// // storage modified by the runtime to tell the callback where it should write its results. +/// // Normally the callback doesn't need to allocate as the input and output buffers are given. +/// results: [1]Buffer = undefined, +/// +/// pub fn init(my_data: []const u8) !MyCallback { +/// return .{ .my_data = my_data }; +/// } +/// +/// pub fn call(callback: *MyCallback, input: Buffer) !void { +/// // Do something with `input` and `callback.my_data`, write the results inside `callback.results[0]` +/// } +/// }; +/// ``` +/// +/// See eg the implementation of the `zml.callback.Print` callback, for a practical example. +/// +/// Note calling this during the compilation of a module, isn't enough: +/// +/// * backend need to be made aware of the callback, see `zml.Platform.registerCallback` +/// * executable need to know the specific data needed by `MyCallback`, see `zml.Exe.bind` +pub fn call( + comptime Callback: type, + inputs: TensorArgs(Callback), + output_shapes: []const Shape, +) []Tensor { + checkIsValidCallback(Callback); + + const ctx = CompilationContext.current(); + const allocator = ctx.allocator(); + const mlir_ctx = ctx.mlirCtx(); + const platform = ctx._platform; + const pjrt_api = platform.pjrt_api; + + if (pjrt_api.ffi() == null) { + stdx.debug.panic("Custom calls are not supported for target {s}", .{@tagName(platform.target)}); + } + + const output_tensors = allocator.alloc(Tensor, output_shapes.len) catch @panic("OOM"); + // Note: we don't always free output_tensor, because it's returned to the caller. + // It's also why we allocate it first so that it doesn't fragment the arena. + errdefer allocator.free(output_tensors); + + const output_types = allocator.alloc(mlir.Type, output_shapes.len) catch @panic("OOM"); + defer allocator.free(output_types); + for (output_types, output_shapes) |*output_type, output_shape| { + output_type.* = mlirx.tensorType(mlir_ctx, output_shape); + } + const input_values = allocator.alloc(mlir.Value, inputs.len) catch @panic("OOM"); + defer allocator.free(input_values); + for (input_values, inputs) |*input_value, input_tensor| { + input_value.* = input_tensor.value(); + } + + const target_name = "zml$" ++ @typeName(Callback); + const op = stablehlo.custom_call( + mlir_ctx, + input_values, + .{ + .call_target_name = target_name, + .api_version = .typed_ffi, + .backend_config = .dict(mlir_ctx, &.{}), + .additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }}, + .has_side_effect = true, + .output_operand_aliases = Callback.callback_config.output_operand_aliases, + }, + output_types, + mlir_ctx.location(@src()), + ); + + for (output_tensors, output_shapes, 0..) |*output_tensor, output_shape, i| { + output_tensor.* = Tensor._result(output_shape, op.result(i)); + } + return output_tensors; +} + +/// Describe properties of a callback +/// +/// * output_operand_aliases: the callback reuse input buffer to write the output +/// * copy_inputs_to_host_pinned: the callback need to work on host visible buffers +/// * traits: PJRT specified properties of the callback +pub const Config = struct { + output_operand_aliases: []const i64 = &.{}, + copy_inputs_to_host_pinned: bool = false, + // TODO: document precisely what `command_buffer_compatible` is doing and its limitations. + traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false }, + // TODO: handle sharded inputs +}; + +/// Compile-time check that a callback has all informations we require. +pub fn checkIsValidCallback(Callback: type) void { + stdx.debug.assertComptime(@hasDecl(Callback, "call"), "Expected callback {} to have a call method", .{Callback}); + const ArgsT = stdx.meta.FnArgs(Callback.call); + inline for (@typeInfo(ArgsT).@"struct".fields[1..]) |field| { + stdx.debug.assertComptime(field.type == Buffer, "Expected callback {}.call arguments to be of type zml.Buffer, got {}", .{ Callback, field.type }); + } + + stdx.debug.assertComptime(@hasDecl(Callback, "type_id") and @TypeOf(Callback.type_id) == pjrt.ffi.TypeId, "Expected callback {} to have a field `pub var type_id: pjrt.ffi.TypeId`", .{Callback}); + stdx.debug.assertComptime(@hasDecl(Callback, "callback_config") and @TypeOf(Callback.callback_config) == Config, "Expected callback {} to have a field `pub const callback_config: zml.CustomCallOptions`", .{Callback}); +} + +pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void { + checkIsValidCallback(Callback); + + const ffi = platform.pjrt_api.ffi() orelse return error.Unavailable; + const target_name = "zml$" ++ @typeName(Callback); + + const proxy_cb = proxy(Callback); + Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback)); + try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits); + log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name }); +} + +fn proxy(Callback: type) pjrt.ffi.Handler { + return struct { + pub fn cb(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error { + return CallbackImpl(Callback, call_frame); + } + }.cb; +} + +fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt.ffi.Error { + if (call_frame.registeringHook()) return null; + + const opts = Callback.callback_config; + + const execution_context = call_frame.ctx; + log.debug("Custom call {s} called !", .{@typeName(Callback)}); + const user_ctx_opaque = execution_context.getContext(Callback.type_id, call_frame.api) catch { + log.err("{} user data was never given for current executable", .{Callback}); + return .create(call_frame.api, .failed_precondition, "failed to fetch user context" ++ @typeName(Callback)); + }; + const user_ctx: *Callback = @ptrCast(@alignCast(user_ctx_opaque)); + // We actually have one more constraint here, we force the Callback to have a platform field, + // and to correctly set it. + // Is this good ? We could also simplify this by registering ourselves the `Platform` type id. + const platform: Platform = user_ctx.platform; + + // Hook to get a cuda stream in the callback. + if (@hasField(Callback, "stream") and platform.target != .cpu) { + const stream = call_frame.api.stream(execution_context); + user_ctx.stream = stream; + } + + var callback_args: std.meta.ArgsTuple(@TypeOf(Callback.call)) = undefined; + callback_args[0] = user_ctx; + + inline for (1..callback_args.len, call_frame.args.buffers()) |i, ffi_buffer| { + const shape = shapeFromFfi(ffi_buffer); + var zml_buffer: Buffer = if (platform.target == .cpu) + .asViewOfHostBuffer(platform, .fromBytes(shape, ffi_buffer.data[0..shape.byteSize()])) + else + .asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data); + if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) { + log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() }); + zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| { + log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err }); + return .create(call_frame.api, .resource_exhausted, "host pinned OOM"); + }; + log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* }); + } + callback_args[i] = zml_buffer; + } + + defer { + if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) { + inline for (1..callback_args.len) |i| callback_args[i].deinit(); + } + } + + for (0..call_frame.results.len) |i| { + const ffi_buffer = call_frame.results.buffers()[i]; + const ffi_buffer_shape = shapeFromFfi(ffi_buffer); + + if (platform.target == .cpu) { + user_ctx.results[i] = Buffer.asViewOfHostBuffer(platform, HostBuffer.fromBytes(ffi_buffer_shape, ffi_buffer.data[0..ffi_buffer_shape.byteSize()])); + } else { + user_ctx.results[i] = Buffer.asViewOfDeviceBuffer(platform, shapeFromFfi(ffi_buffer), null, ffi_buffer.data); + } + } + + @call(.auto, Callback.call, callback_args) catch |err| { + log.err("Callback {} failed with {}", .{ Callback, err }); + return .create(call_frame.api, .internal, "internal callback error"); + }; + + return .ok; +} + +/// Internal custom calls. +/// These are not meant to be used by users, but rather by the library itself. +pub const internal_callbacks = [_]type{ + Print, +}; + +pub fn registerInternalCallbacks(platform: Platform) !void { + inline for (internal_callbacks) |Callback| { + try register(Callback, platform); + // log.debug("Registered internal custom call {s} with type_id {d}", .{ @typeName(Callback), Callback.type_id.type_id }); + } +} + +/// Allocate user data data needed by the ZML provided custom calls. +pub fn bindInternalCallbacks( + arena: std.mem.Allocator, + platform: Platform, + ffi: pjrt.Ffi, + execute_context: *pjrt.ExecuteContext, +) (std.mem.Allocator.Error || pjrt.ApiError)!void { + // Atm we don't have a mechanism to detect which ZML callbacks the executable needs, + // so we always allocate. + { + // Print + const print_ptr = try arena.create(Print); + print_ptr.* = try .init(platform); + try addUserData(Print, platform.pjrt_api, ffi, execute_context, print_ptr); + } +} + +pub fn addUserData( + Callback: type, + api: *const pjrt.Api, + ffi: pjrt.Ffi, + execute_context: *pjrt.ExecuteContext, + user_data: *Callback, +) pjrt.ApiError!void { + try ffi.addUserData( + api, + execute_context, + .{ .type_id = Callback.type_id.type_id, .user_data = @ptrCast(user_data) }, + ); + log.debug("Bound {s}@{x} with type id {d} on {any}", .{ @typeName(Callback), @intFromPtr(user_data), Callback.type_id.type_id, execute_context }); +} + +/// The print callback +pub const Print = struct { + // a unique type_id will be set by the PJRT plugin during registration. + pub var type_id: pjrt.ffi.TypeId = undefined; + + pub const callback_config: Config = .{ + // Print callback pretends to modify the given input buffer, but just returns it unmodified. + .output_operand_aliases = &.{0}, + // It also needs PJRT to copy the data on the host first so it can print it. + .copy_inputs_to_host_pinned = true, + // Print is fairly predictable and can be captured in an execution graph. + .traits = .{ .command_buffer_compatible = false }, + }; + + platform: Platform, + results: [1]Buffer = undefined, + + pub fn init(platform: Platform) !Print { + return .{ .platform = platform }; + } + + pub fn call(_: *Print, input: Buffer) !void { + std.log.defaultLog(.info, .zml, "Device buffer: {f}: {d:20.3}", .{ input, input.asHostBuffer() }); + } +}; + +fn shapeFromFfi(ffi_buffer: *const pjrt.ffi.Buffer) Shape { + const dt: DataType = switch (ffi_buffer.dtype) { + .invalid => stdx.debug.panic("Invalid FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }), + .token, .f8e4m3, .f8e3m4 => stdx.debug.panic("Unsupported FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }), + inline else => |t| @field(DataType, @tagName(t)), + }; + return Shape.init(ffi_buffer.dims(), dt); +} + +fn TensorArgs(Callback: type) type { + const ArgsT = stdx.meta.FnArgs(Callback.call); + + const args = @typeInfo(ArgsT).@"struct".fields; + return [args.len - 1]Tensor; +} diff --git a/zml/context.zig b/zml/context.zig index 15acb93..27d635a 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -7,16 +7,19 @@ const runfiles = @import("runfiles"); const runtimes = @import("runtimes"); const stdx = @import("stdx"); -const DataType = @import("dtype.zig").DataType; -const HostBuffer = @import("hostbuffer.zig").HostBuffer; const pjrt = @import("pjrtx.zig"); -const Platform = @import("platform.zig").Platform; -const Shape = @import("shape.zig").Shape; -const Target = @import("platform.zig").Target; -const zml_platform = @import("platform.zig"); -const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); -const PlatformsMap = std.EnumArray(Target, ?Platform); +const zml = struct { + pub const callback = @import("callback.zig"); + pub const HostBuffer = @import("hostbuffer.zig").HostBuffer; + pub const Platform = @import("platform.zig").Platform; + pub const platform = @import("platform.zig"); + pub const Shape = @import("shape.zig").Shape; + pub const Target = @import("platform.zig").Target; +}; + +const PjrtApiMap = std.EnumArray(zml.Target, ?*const pjrt.Api); +const PlatformsMap = std.EnumArray(zml.Target, ?zml.Platform); const log = std.log.scoped(.@"zml/context"); test { @@ -94,7 +97,7 @@ pub const Context = struct { return .{ .platforms = PlatformsMap.initFill(null) }; } - fn platformToLibrary(comptime target: Target) []const u8 { + fn platformToLibrary(comptime target: zml.Target) []const u8 { const ext = switch (builtin.os.tag) { .windows => ".dll", .macos, .ios, .watchos => ".dylib", @@ -105,7 +108,7 @@ pub const Context = struct { }; } - pub fn pjrtApi(target: Target) *const pjrt.Api { + pub fn pjrtApi(target: zml.Target) *const pjrt.Api { return Context.apis.get(target).?; } @@ -119,12 +122,12 @@ pub const Context = struct { self.* = undefined; } - const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu }; + const prefered_targets = [_]zml.Target{ .tpu, .neuron, .cuda, .rocm, .cpu }; /// Automatically selects the best Platform loaded in the current Context. /// /// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...). - pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform { + pub fn autoPlatform(self: *Context, opts: zml.Platform.CreateOptions) zml.Platform { stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{}); return self.platformByPreferences(opts, &prefered_targets); @@ -133,7 +136,7 @@ pub const Context = struct { /// Given a list of preferred targets to select the best Platform /// /// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...). - pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform { + pub fn platformByPreferences(self: *Context, opts: zml.Platform.CreateOptions, prefered: []const zml.Target) zml.Platform { // Try prefered targets. for (prefered) |target| { if (apis.get(target) == null) continue; @@ -150,7 +153,7 @@ pub const Context = struct { // CPU should only be use as fallback. if (target == .cpu) continue; if (entry.value.* == null) continue; - if (std.mem.indexOfScalar(Target, prefered, target) != null) continue; + if (std.mem.indexOfScalar(zml.Target, prefered, target) != null) continue; return self.platform(target, opts) catch |err| { log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err }); continue; @@ -164,25 +167,25 @@ pub const Context = struct { }; } - pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform { + pub fn platform(self: *Context, target: zml.Target, opts: zml.Platform.CreateOptions) !zml.Platform { if (self.platforms.get(target)) |p| { return p; } const api = Context.apis.get(target); if (api == null) return error.PlatformNotCompiled; - const p = try Platform.init(target, api.?, opts); + const p = try zml.Platform.init(target, api.?, opts); if (p.getDevices().len == 0) { log.err("No device found for platform {} !", .{target}); return error.NoDevicesFound; } - try CustomCall.registerZmlCustomCalls(p); - self.platforms.set(target, p); + try zml.callback.registerInternalCallbacks(p); + return p; } - pub fn printAvailablePlatforms(self: Context, selected: Platform) void { + pub fn printAvailablePlatforms(self: Context, selected: zml.Platform) void { // List available targets log.info("Available Platforms:", .{}); const selected_prefix = "✅"; @@ -190,7 +193,7 @@ pub const Context = struct { const selected_postfix = "(AUTO-SELECTED)"; const not_selected_postfix = ""; - for (zml_platform.available_targets) |target| { + for (zml.platform.available_targets) |target| { log.info(" {s} {s} {s}", .{ if (target == selected.target) selected_prefix else not_selected_prefix, @tagName(target), @@ -211,133 +214,4 @@ pub const Context = struct { } } } - - pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void; -}; - -const CustomCall = struct { - pub fn registerZmlCustomCalls(platform: Platform) !void { - const ffi = platform.pjrt_api.ffi() orelse { - log.warn("Registering custom calls failed: No FFI Extension found in {s} PJRT Plugin.", .{@tagName(platform.target)}); - return; - }; - try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{}); - } - - fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error { - if (call_frame.registeringHook()) return null; - - const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable; - std.debug.assert(callback_attr.dtype == .u64); - const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize)); - - const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable; - std.debug.assert(user_ctx_ptr.dtype == .u64); - const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize)); - - const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len); - for (input_buffers, 0..) |*b, i| { - b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]); - } - - const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len); - for (output_buffers, 0..) |*b, i| { - b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]); - } - - callback(user_ctx, input_buffers, output_buffers); - return null; - } -}; - -fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape { - // log.warn("received buffer {}", .{buffer_desc}); - const dt: DataType = switch (buffer_desc.dtype) { - .invalid => @panic("invalid ffi"), - .pred => .bool, - .i8 => .i8, - .i16 => .i16, - .i32 => .i32, - .i64 => .i64, - .token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"), - inline else => |t| @field(DataType, @tagName(t)), - }; - return Shape.init(buffer_desc.dims(), dt); -} - -/// Create a HostBuffer from a ffi description of a buffer. -/// Normally the ffi describe device buffer but we assume they are located in pinned memory, -/// and therefore the data pointer is readable both from host and from device. -fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { - const buffer_shape = getShape(buffer_desc); - return HostBuffer.fromBytes( - buffer_shape, - buffer_desc.data[0..buffer_shape.byteSize()], - ); -} - -pub const cuda = struct { - pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00); - pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00); - var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00); - var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00); - - pub const MemcpyKind = enum(c_int) { - host_to_host = 0, - host_to_device = 1, - device_to_host = 2, - device_to_device = 3, - inferred = 4, - }; - - const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int; - const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int; - const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int; - const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int; - - pub fn init() void { - var cudart = std.DynLib.open("libcudart.so.12") catch { - log.err("cudart not found, callback will segfault", .{}); - return; - }; - defer cudart.close(); - - _memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse { - @panic("cudaMemcpyAsync not found"); - }; - _memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse { - @panic("cudaMemcpy not found"); - }; - streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse { - @panic("cudaStreamSynchronize not found"); - }; - cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse { - @panic("cudaLaunchHostFunc not found"); - }; - } - - pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void { - const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host); - check(err); - } - - pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void { - const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device); - check(err); - } - - pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream); - check(err); - } - - pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void { - const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream); - check(err); - } - - pub fn check(err: c_int) void { - if (err == 0) return; - stdx.debug.panic("CUDA error: {d}", .{err}); - } }; diff --git a/zml/exe.zig b/zml/exe.zig index 625fecf..d4c426b 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -5,6 +5,7 @@ const stdx = @import("stdx"); const aio = @import("aio.zig"); const Buffer = @import("buffer.zig").Buffer; const Bufferized = @import("tensor.zig").Bufferized; +const callback = @import("callback.zig"); const CompilationContext = @import("module.zig").CompilationContext; const meta = @import("meta.zig"); const pjrt = @import("pjrtx.zig"); @@ -154,7 +155,7 @@ pub const BaseExe = struct { exe: *pjrt.LoadedExecutable, /// The execution context for this executable. - context: ?*pjrt.ExecuteContext = null, + execute_context: ?*pjrt.ExecuteContext, /// Pre-allocated slice of buffers to use as inputs when the module is called. input_per_device: []const [*]*pjrt.Buffer, @@ -205,9 +206,18 @@ pub const BaseExe = struct { const all_shapes = try allocator.alloc(Shape, n_in + n_out); @memcpy(all_shapes[0..n_in], args.input_shapes); @memcpy(all_shapes[n_in..], args.result_shapes); + + var execute_context: ?*pjrt.ExecuteContext = null; + if (platform.pjrt_api.ffi()) |ffi| { + log.info("Created context execution {*} for {*}", .{ execute_context, exe }); + execute_context = try platform.pjrt_api.createExecuteContext(); + try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?); + } + return .{ .platform = platform, .exe = exe, + .execute_context = execute_context, .ready_buffer_count = 0, .input_buffer_count = @intCast(n_in), .num_devices = args.n_devices, @@ -220,7 +230,7 @@ pub const BaseExe = struct { } pub fn deinit(self: BaseExe) void { - if (self.context) |ctx| { + if (self.execute_context) |ctx| { ctx.deinit(self.platform.pjrt_api); } self._arena.deinit(); @@ -244,16 +254,16 @@ pub const BaseExe = struct { // even if it has been marked as "can be donated" during compilation. // TODO: expose it ? .non_donatable_input_indices = &.{}, - .context = self.context, + .context = self.execute_context, }) catch |err| { std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err}); }; - for (events[0..sharding.num_partitions]) |e| { - if (e) |ev| { - ev.await_(self.platform.pjrt_api) catch unreachable; - } - } + // for (events[0..sharding.num_partitions]) |e| { + // if (e) |ev| { + // ev.await_(self.platform.pjrt_api) catch unreachable; + // } + // } } pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void { @@ -285,6 +295,17 @@ pub const BaseExe = struct { stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index }); } + pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void { + stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{}); + const pjrt_api = exe.platform.pjrt_api; + + if (pjrt_api.ffi()) |ffi| { + try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op); + } else { + stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)}); + } + } + pub fn serialize(self: BaseExe, writer: anytype) !void { var executable = try self.exe.getExecutable(self.platform.pjrt_api); var serialize_result = try executable.serialize(self.platform.pjrt_api); @@ -314,11 +335,11 @@ pub const BaseExe = struct { pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe { var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{ - .n_in = self.input_buffer_count, + .input_shapes = self.input_shapes, .result_shapes = self.result_shapes, .n_devices = self.num_devices, }); - exe.context = self.context; + exe.execute_context = self.execute_context; return exe; } }; @@ -348,6 +369,14 @@ pub fn Exe(ArgsT: type, ReturnT: type) type { return new; } + /// For a given customCall inside this executable, + /// provide a pointer to runtime data. + /// The caller keeps memory ownership and need to ensure that the value + /// stays alive as long as the executable. + pub fn bind(self: Self, comptime T: type, value: *T) !void { + try self.inner.bind(T, value); + } + pub fn serialize(self: Self, writer: anytype) !void { return try self.inner.serialize(writer); } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index c2099b0..848fee8 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -325,37 +325,18 @@ pub const HostBuffer = struct { self: HostBuffer, writer: anytype, ) !void { - // TODO debug option - // try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) }); try writer.print("HostBuffer(.{f})", .{self._shape}); } - /// Formatter for a HostBuffer that also print the values not just the shape. - /// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});` - pub fn pretty(self: HostBuffer) PrettyPrinter { - return .{ .x = self }; + pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { + return self.prettyPrintIndented(writer, 4, 0, n); } - pub const PrettyPrinter = struct { - x: HostBuffer, - - // TODO(0.15.0) revisit pretty printer - pub fn format(self: PrettyPrinter, writer: anytype) !void { - const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) { - .integer => .parse(i32, "d"), - .float => .parse(f32, "d"), - else => .parse(void, ""), - }; - const options: std.fmt.FormatOptions = .{}; - try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options }); - } - }; - - pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void { + pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: std.fmt.Number) !void { return self.prettyPrintIndented(writer, 4, 0, options); } - fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void { + fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void { if (self.rank() == 0) { // Special case input tensor is a scalar return switch (self.dtype()) { @@ -363,9 +344,10 @@ pub const HostBuffer = struct { const val: dt.toZigType() = self.items(dt.toZigType())[0]; return switch (comptime dt.class()) { // Since we have custom floats, we need to explicitly convert to float32 ourselves. - .float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer), - .integer => stdx.fmt.formatIntValue(val, options, writer), - .bool, .complex => stdx.fmt.formatAnyValue(val, options, writer), + .float => stdx.fmt.formatFloat(floats.floatCast(f32, val), options, writer), + .integer => stdx.fmt.formatInt(val, options, writer), + .bool => stdx.fmt.formatBool(val, options, writer), + .complex => stdx.fmt.formatComplex(val, options, writer), }; }, }; @@ -380,7 +362,8 @@ pub const HostBuffer = struct { switch (comptime dt.class()) { .float => try stdx.fmt.formatFloatSlice(values, options, writer), .integer => try stdx.fmt.formatIntSlice(values, options, writer), - .bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer), + .complex => try stdx.fmt.formatComplexSlice(values, options, writer), + .bool => try stdx.fmt.formatBoolSlice(values, options, writer), } }, } diff --git a/zml/module.zig b/zml/module.zig index 33d22f7..6263d3f 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1178,9 +1178,9 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str .@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat), .pointer => |info| switch (info.size) { .one => switch (strat) { - .shallow => hash(hasher, @intFromPtr(key), .Shallow), - .deep => hash(hasher, key.*, .Shallow), - .deeprecursive => switch (@typeInfo(info.child)) { + .Shallow => hash(hasher, @intFromPtr(key), .Shallow), + .Deep => hash(hasher, key.*, .Shallow), + .DeepRecursive => switch (@typeInfo(info.child)) { .@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow), else => hash(hasher, key.*, .DeepRecursive), }, @@ -1196,7 +1196,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str .many, .c, => switch (strat) { - .shallow => hash(hasher, @intFromPtr(key), .Shallow), + .Shallow => hash(hasher, @intFromPtr(key), .Shallow), else => @compileError( \\ unknown-length pointers and C pointers cannot be hashed deeply. \\ Consider providing your own hash function. diff --git a/zml/ops.zig b/zml/ops.zig index c26eac3..ee12508 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -764,33 +764,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base return res; } -pub const HostCallbackOpt = struct { - has_side_effect: bool = false, - output_operand_aliases: []const i64 = &.{}, -}; - -pub fn addHostCallback( - callback: *const Context.HostCallback, - blkctx: ?*anyopaque, - inputs: []const Tensor, - output_shapes: []const Shape, - opts: HostCallbackOpt, -) []Tensor { - return customCall( - "zmlHostBufferCallback", - inputs, - output_shapes, - .{ - .callback = @intFromPtr(callback), - .user_context = @intFromPtr(blkctx), - }, - .{ - .has_side_effect = opts.has_side_effect, - .output_operand_aliases = opts.output_operand_aliases, - }, - ); -} - pub const TritonOps = struct { debug: bool = false, name: [:0]const u8, diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 84343a6..8658aeb 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -207,6 +207,13 @@ pub const Event = opaque { return self.inner().getEventError(api); } + pub fn awaitBlocking(self: *Event, api: *const Api) ApiError!void { + if (self.isReady(api)) { + return; + } + try self.inner().await_(api); + } + pub fn await_(self: *Event, api: *const Api) ApiError!void { defer self.deinit(api); @@ -264,14 +271,14 @@ pub const LoadedExecutable = opaque { }; pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void { - try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{ + try self.inner().execute(api, pjrt.LoadedExecutable.ExecuteArgs{ .num_args = args.num_args, .arguments = @ptrCast(args.arguments), .results = @ptrCast(args.results), .events = @ptrCast(args.events), .non_donatable_input_indices = args.non_donatable_input_indices, .context = args.context, - } }); + }); } pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable { diff --git a/zml/platform.zig b/zml/platform.zig index 838cac4..0bc792e 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -22,6 +22,11 @@ pub const Platform = struct { target: Target, pjrt_api: *const pjrt.Api, pjrt_client: *pjrt.Client, + + // This make the pjrt struct quite fat, but is only used during compilation. + // TODO: Reconsider having it here, and maybe pass explicitly to compile, + // or create an intermediary struct: + // `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);` compilation_options: CompilationOptions = .{}, pub const MAX_NUM_DEVICES: u8 = 32; @@ -71,17 +76,6 @@ pub const Platform = struct { return res; } - pub fn registerFFIType(self: Platform, comptime T: type) !void { - if (self.pjrt_api.ffi()) |ffi| { - if (!@hasDecl(T, "type_id")) { - stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)}); - } - try ffi.registerTypeId(self.pjrt_api, T); - } else { - stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)}); - } - } - pub fn deinit(self: *Platform) void { self.pjrt_client.deinit(self.pjrt_api); } diff --git a/zml/tensor.zig b/zml/tensor.zig index 26f670d..63ab44d 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -5,6 +5,7 @@ const mlir = @import("mlir"); const stdx = @import("stdx"); const Buffer = @import("buffer.zig").Buffer; +const callback = @import("callback.zig"); const CompilationContext = @import("module.zig").CompilationContext; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; @@ -3824,22 +3825,7 @@ pub const Tensor = struct { /// Only for debug purpose, it inserts device to host synchronization /// so it will slow down the program execution. pub fn print(input: Tensor) Tensor { - // TODO: find a way of doing print that doesn't involve a H2D copy. - return ops.addHostCallback( - &printCallback, - null, - &.{input}, - &.{input.shape()}, - .{ .output_operand_aliases = &.{0} }, - )[0]; - } - - fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { - const host_buffer = inputs[0]; - std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() }); - // This is true because of the operand aliases. - // Since the result is already pointing to the input we don't need to modify the buffer. - std.debug.assert(host_buffer._data == outputs[0]._data); + return callback.call(callback.Print, .{input}, &.{input.shape()})[0]; } }; diff --git a/zml/testing.zig b/zml/testing.zig index b236400..5fec596 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -51,14 +51,13 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { if (should_free_left) left.deinit(allocator); if (should_free_right) right.deinit(allocator); } - errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() }); - + errdefer log.err("\n--> Left: {0f}{0d:24.3}\n--> Right: {1f}{1d:24.3}", .{ left, right }); if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) { log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() }); return error.TestUnexpectedResult; } if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) { - log.err("left.dtype ({}) != right.dtype ({})", .{ left.dtype(), right.dtype() }); + log.err("left.dtype ({f}) != right.dtype ({f})", .{ left.shape(), right.shape() }); return error.TestUnexpectedResult; } switch (left.dtype()) { @@ -89,7 +88,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { const right_data = right.items(R); for (left_data, right_data, 0..) |l, r, i| { if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) { - log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); + log.err("left.data != right_data.\n < {d:40.3} \n > {d:40.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ stdx.fmt.slice(center(left_data, i)), stdx.fmt.slice(center(right_data, i)), i, left_data[i], right_data[i] }); return error.TestUnexpectedResult; } } diff --git a/zml/zml.zig b/zml/zml.zig index b5cd803..55559f5 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -6,11 +6,13 @@ // Namespaces const std = @import("std"); +pub const platform_specific = @import("c"); pub const tokenizer = @import("zml/tokenizer"); pub const aio = @import("aio.zig"); pub const Buffer = @import("buffer.zig").Buffer; pub const Bufferized = @import("tensor.zig").Bufferized; +pub const callback = @import("callback.zig"); pub const CompilationOptions = @import("platform.zig").CompilationOptions; pub const context = @import("context.zig"); pub const Context = @import("context.zig").Context; @@ -43,7 +45,6 @@ pub const Tensor = @import("tensor.zig").Tensor; pub const testing = @import("testing.zig"); pub const torch = @import("torch.zig"); -// pub const tokenizer = @import("tokenizer.zig"); pub const tools = struct { pub const Tracer = @import("tools/tracer.zig").Tracer; };