From 6aa9aa5a7b84aa8aca9ffa6c2bae4e8bc12cc1a1 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 10 Dec 2024 09:36:37 +0000 Subject: [PATCH] Add preliminary implementation for custom call support. --- pjrt/ffi.zig | 146 +++++++++++++++++++++++++++++------------------ pjrt/pjrt.zig | 135 ++++++++++++++++++++++++++++++------------- zml/buffer.zig | 2 +- zml/context.zig | 85 ++++++++++++++++++++++++--- zml/exe.zig | 15 ++++- zml/pjrtx.zig | 8 ++- zml/platform.zig | 11 ++++ 7 files changed, 291 insertions(+), 111 deletions(-) diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig index 39e7c24..d4fb6ee 100644 --- a/pjrt/ffi.zig +++ b/pjrt/ffi.zig @@ -91,27 +91,25 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type { pub const Api = opaque { pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to; - pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque { + pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream { var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{ - .ctx = if (context) |ctx| ctx.inner() else null, + .ctx = @constCast(context.inner()), }); const result = self.inner().XLA_FFI_Stream_Get.?(&ret); - if (result) |ffi_error| { const err = Error.fromInner(ffi_error); defer err.destroy(self); log.err("[Api.getStream] {s}", .{err.getMessage(self)}); - // TODO(Corentin): Retrieve error code from Error when implemented in XLA. - return error.Unknown; + @panic("failed to get stream"); } - return ret.stream.?; + return @ptrCast(ret.stream.?); } - pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque { + pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque { var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{ - .ctx = if (context) |ctx| ctx.inner() else null, + .ctx = @constCast(context.inner()), .size = size, .alignment = alignment, }); @@ -129,9 +127,9 @@ pub const Api = opaque { return ret.data.?; } - pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void { + pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void { var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{ - .ctx = if (context) |ctx| ctx.inner() else null, + .ctx = @constCast(context.inner()), .size = size, .data = data, }); @@ -165,54 +163,84 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) { pub const ExecutionContext = opaque { pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to; - // pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void { - // // register type id ==> typeid - // const typename_ = "zml." ++ @typeName(@TypeOf(value)); + pub fn Context(comptime T: type) type { + return struct { + pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T { + const type_id: TypeId = .{ .type_id = T.type_id }; + var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{ + .ctx = @constCast(self.inner()), + .type_id = @constCast(&type_id.toCStruct()), + }); + const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); - // var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{ - // .ctx = self.inner(), - // .handler = @ptrCast(@alignCast(handler)), - // }); - // const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret); + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); - // var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{ - // .ctx = self.inner(), - // .handler = @ptrCast(@alignCast(handler)), - // }); - // const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret); + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } - // if (result) |ffi_error| { - // const err = Error.fromInner(ffi_error); - // defer err.destroy(api); - // log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)}); + if (ret.data == null) return error.NotFound; + return @ptrCast(@alignCast(ret.data.?)); + } + }; + } - // // TODO(Corentin): Retrieve error code from Error when implemented in XLA. - // return error.Unknown; - // } - // } - - pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque { - var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{ - .ctx = self.inner(), - .type_id = @ptrCast(@alignCast(type_id)), + pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 { + var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{ + .ctx = @constCast(self.inner()), }); - const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); + const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret); if (result) |ffi_error| { const err = Error.fromInner(ffi_error); defer err.destroy(api); - log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); + log.err("[ExecutionContext.getDeviceOrdinal] {s}", .{err.getMessage(api)}); // TODO(Corentin): Retrieve error code from Error when implemented in XLA. return error.Unknown; } - return ret.data.?; + return ret.device_ordinal; } - // TODO getDeviceOrdinal() + pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void { + var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{ + .ctx = @constCast(self.inner()), + .task = @ptrCast(@alignCast(task)), + .data = @ptrCast(@alignCast(data)), + }); + + const result = api.inner().XLA_FFI_ThreadPool_Schedule.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + std.debug.print("error: {any} \n", .{err}); + log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + } + + fn getTypeId(type_name: []const u8) TypeId { + const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name)); + + return .{ + .type_id = id, + }; + } }; +const TypeId = c.XLA_FFI_TypeId; + +const Task = fn (*anyopaque) void; + +const Stream = @import("pjrt.zig").Stream; + const ByteSpan = extern struct { ptr: [*]const u8, len: usize, @@ -222,17 +250,13 @@ const ByteSpan = extern struct { } }; -pub const TypeId = extern struct { - type_id: i64, -}; - pub const DataType = enum(c.XLA_FFI_DataType) { invalid = c.XLA_FFI_DataType_INVALID, pred = c.XLA_FFI_DataType_PRED, - s8 = c.XLA_FFI_DataType_S8, - s16 = c.XLA_FFI_DataType_S16, - s32 = c.XLA_FFI_DataType_S32, - s64 = c.XLA_FFI_DataType_S64, + i8 = c.XLA_FFI_DataType_S8, + i16 = c.XLA_FFI_DataType_S16, + i32 = c.XLA_FFI_DataType_S32, + i64 = c.XLA_FFI_DataType_S64, u8 = c.XLA_FFI_DataType_U8, u16 = c.XLA_FFI_DataType_U16, u32 = c.XLA_FFI_DataType_U32, @@ -289,9 +313,8 @@ pub const Args = extern struct { buffer = c.XLA_FFI_ArgType_BUFFER, }; - pub fn get(self: Args, i: usize) *const Buffer { - std.debug.assert(self.types[0..self.len][i] == .buffer); - return self.ptr[0..self.len][i]; + pub fn buffers(self: Args) []*const Buffer { + return self.ptr[0..self.len]; } }; @@ -306,9 +329,8 @@ pub const Rets = extern struct { buffer = c.XLA_FFI_RetType_BUFFER, }; - pub fn get(self: Rets, i: usize) *const Buffer { - std.debug.assert(self.types[0..self.len][i] == .buffer); - return self.ptr[0..self.len][i]; + pub fn buffers(self: Rets) []*const Buffer { + return self.ptr[0..self.len]; } }; @@ -346,8 +368,18 @@ pub const Attrs = extern struct { dtype: DataType, len: usize, data: [*]const u8, + + pub fn slice(self: Array, T: type) []const T { + const ptr: [*]const T = @alignCast(@ptrCast(self.data)); + return ptr[0..self.len]; + } }; + pub fn slice(self: Array, T: type) []const T { + const ptr: [*]const T = @alignCast(@ptrCast(self.data)); + return ptr[0..self.len]; + } + pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) { const attr = self.ptr[0..self.len][index]; const actual_type = self.types[index]; @@ -370,8 +402,8 @@ pub const Attrs = extern struct { pub const CallFrame = extern struct { struct_size: usize, extension_start: ?*ExtensionBase, - api: ?*const Api, - ctx: ?*const ExecutionContext, + api: *const Api, + ctx: *const ExecutionContext, stage: ExecutionStage, args: Args, results: Rets, @@ -438,7 +470,7 @@ pub const Error = opaque { pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to; pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from; - pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error { + pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error { var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{ .message = message.ptr, .errc = @intFromEnum(error_code), diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index df1a7d4..c834878 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -31,7 +31,7 @@ fn pjrtStructSize(comptime T: type) usize { return @field(c, typedef_name ++ "_STRUCT_SIZE"); } -inline fn pjrtStruct(v: anytype) @TypeOf(v) { +pub inline fn pjrtStruct(v: anytype) @TypeOf(v) { var ret = v; ret.struct_size = pjrtStructSize(@TypeOf(v)); return ret; @@ -160,9 +160,14 @@ pub const Api = struct { return state.str; } - pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { + pub fn createExecuteContext(api: *const Api) ApiError!*ExecuteContext { + const ret = try api.call(.PJRT_ExecuteContext_Create, .{}); + return @ptrCast(ret.context.?); + } + + pub fn ffi(api: *const Api) ?FFI { if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| { - return .{ .inner = ext.register_handler.? }; + return .{ .inner = ext }; } return null; } @@ -279,6 +284,8 @@ pub const ShapeSpec = extern struct { } }; +pub const Stream = opaque {}; + pub const Client = opaque { const inner = InnerMixin(c.PJRT_Client).inner; @@ -414,7 +421,7 @@ pub const Client = opaque { fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} }.call, on_delete_callback_arg: ?*anyopaque = null, - stream: ?isize = null, + stream: ?*const Stream = null, }; pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { @@ -429,7 +436,7 @@ pub const Client = opaque { .device = @ptrCast(@constCast(args.device)), .on_delete_callback = args.on_delete_callback, .on_delete_callback_arg = args.on_delete_callback_arg, - .stream = if (args.stream) |stream| stream else 0, + .stream = @bitCast(@intFromPtr(args.stream)), }); return @ptrCast(ret.buffer.?); } @@ -444,20 +451,19 @@ pub const Client = opaque { return &.{}; } - pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer { - const ret = try api.call(.PJRT_Client_DMA_Map, .{ + pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!void { + try api.call(.PJRT_Client_DmaMap, .{ .client = self.inner(), .data = @ptrCast(@constCast(data.ptr)), .size = @intCast(data.len), }); - return @ptrCast(ret.buffer.?); } - pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void { - _ = api.call(.PJRT_Client_DMA_Unmap, .{ + pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) ApiError!void { + try api.call(.PJRT_Client_DmaUnmap, .{ .client = self.inner(), .data = @ptrCast(@constCast(data.ptr)), - }) catch unreachable; + }); } pub const CreateBuffersForAsyncHostToDeviceArgs = struct { @@ -564,6 +570,14 @@ pub const SerializeResult = struct { } }; +pub const ExecuteContext = opaque { + pub fn deinit(self: *ExecuteContext, api: *const Api) void { + _ = api.call(.PJRT_ExecuteContext_Destroy, .{ + .context = @ptrCast(self), + }) catch {}; + } +}; + pub const Executable = opaque { const inner = InnerMixin(c.PJRT_Executable).inner; @@ -630,6 +644,7 @@ pub const LoadedExecutable = opaque { results: []const [*]*Buffer, events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, + context: ?*ExecuteContext, }; pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void { var options = pjrtStruct(c.PJRT_ExecuteOptions{ @@ -640,6 +655,7 @@ pub const LoadedExecutable = opaque { .launch_id = 0, .non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr), .num_non_donatable_input_indices = args.non_donatable_input_indices.len, + .context = @ptrCast(args.context), }); _ = try api.call(.PJRT_LoadedExecutable_Execute, .{ .executable = self.inner(), @@ -653,7 +669,7 @@ pub const LoadedExecutable = opaque { }); } - pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable { + pub fn getExecutable(self: *const LoadedExecutable, api: *const Api) ApiError!*Executable { const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{ .loaded_executable = self.inner(), }); @@ -818,7 +834,7 @@ pub const Buffer = opaque { return ret.on_device_size_in_bytes; } - pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer { + pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer { const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{ .buffer = self.inner(), .dst_device = device.inner, @@ -850,7 +866,7 @@ pub const Buffer = opaque { return @ptrCast(ret.event); } - pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer { + pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!*Buffer { const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{ .buffer = self.inner(), .dst_memory = @ptrCast(@constCast(dst_memory)), @@ -932,8 +948,8 @@ pub const Memory = opaque { pub fn kind(self: *const Memory, api: *const Api) Kind { const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable; - const kind_ = ret.kind orelse unreachable[0..ret.kind_size]; - return std.meta.stringToEnum(Kind, kind_) orelse unreachable; + const kind_ = ret.kind orelse unreachable; + return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable; } pub fn kindId(self: *const Memory, api: *const Api) u32 { @@ -1044,21 +1060,6 @@ pub const AsyncHostToDeviceTransferManager = opaque { } }; -pub const ExecutionContext = opaque { - const inner = InnerMixin(c.PJRT_ExecutionContext).inner; - - pub fn init(api: *const Api) ApiError!*ExecutionContext { - const ret = try api.call(.PJRT_ExecutionContext_Create, .{}); - return @ptrCast(ret.context.?); - } - - pub fn deinit(self: *ExecutionContext, api: *const Api) void { - _ = api.call(.PJRT_ExecutionContext_Destroy, .{ - .context = self.inner(), - }) catch unreachable; - } -}; - pub const NamedValue = extern struct { comptime { std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue)); @@ -1164,17 +1165,34 @@ pub const NamedValue = extern struct { } }; -// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize -// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 -pub const CustomCallRegistry = extern struct { - inner: *const c.PJRT_FFI_Register_Handler, +pub const FFI = extern struct { + inner: *const c.PJRT_FFI, - pub fn registerFfi( - self: *const CustomCallRegistry, + pub const UserData = extern struct { + type_id: i64, + user_data: *anyopaque, + + fn toCStruct(self: UserData) c.PJRT_FFI_UserData { + return .{ + .type_id = self.type_id, + .data = self.user_data, + }; + } + }; + + pub const RegisterFfiOptions = struct { + traits: RegisterHandlerTraits = @enumFromInt(0), + }; + + // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize + // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 + pub fn register( + self: *const FFI, api: *const Api, target_name: []const u8, platform_name: []const u8, func: *const ffi.Handler, + options: RegisterFfiOptions, ) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ .api_version = 1, @@ -1183,12 +1201,51 @@ pub const CustomCallRegistry = extern struct { .handler = @ptrCast(@constCast(func)), .platform_name = platform_name.ptr, .platform_name_size = platform_name.len, + .traits = @intFromEnum(options.traits), }); - const result = self.inner(&ret); + const result = self.inner.register_handler.?(&ret); if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); - log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)}); + log.err("registerFfi error: {s}", .{pjrt_error.getMessage(api)}); + return pjrt_error.getCode(api).toApiError(); + } + } + + pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void { + const type_name = @typeName(T); + var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{ + .type_name = type_name.ptr, + .type_name_size = type_name.len, + .type_id = 0, // let the plugin assign a unique type ID + }); + const result = self.inner.type_id_register.?(&ret); + if (result) |pjrt_c_error| { + const pjrt_error: *Error = @ptrCast(pjrt_c_error); + return pjrt_error.getCode(api).toApiError(); + } + + T.type_id = ret.type_id; + } + + pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void { + var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{ + .context = @ptrCast(context), + .user_data = user_data.toCStruct(), + }); + const result = self.inner.user_data_add.?(&ret); + if (result) |pjrt_c_error| { + const pjrt_error: *Error = @ptrCast(pjrt_c_error); + log.err("addUserData error: {s}", .{pjrt_error.getMessage(api)}); return pjrt_error.getCode(api).toApiError(); } } }; + +pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) { + command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE, + _, +}; + +pub const CustomCallRegistry = extern struct { + inner: *const c.PJRT_FFI_Register_Handler, +}; diff --git a/zml/buffer.zig b/zml/buffer.zig index 1fc57ef..4828844 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -224,7 +224,7 @@ pub const Buffer = struct { /// Creates a Buffer from a pointer into device memory. /// This allows to interface with other libraries producing buffers. - pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer { + pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { var res: [Shape.MAX_RANK]i64 = undefined; for (0..Shape.MAX_RANK) |i| { diff --git a/zml/context.zig b/zml/context.zig index 03de62a..33c4255 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -37,6 +37,7 @@ pub const Context = struct { inline for (comptime std.enums.values(runtimes.Platform)) |t| { if (runtimes.load(t)) |api| { Context.apis.set(t, api); + if (t == .cuda) cuda.init(); } else |_| {} } } @@ -218,10 +219,10 @@ pub const Context = struct { const CustomCall = struct { pub fn registerZmlCustomCalls(platform: Platform) !void { - const registry = platform.pjrt_api.customCallRegistry(); + const maybe_ffi = platform.pjrt_api.ffi(); - if (registry) |reg| { - try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback); + if (maybe_ffi) |ffi| { + try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{}); } else { stdx.debug.panic("Registering custom calls failed", .{}); } @@ -240,12 +241,12 @@ const CustomCall = struct { const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len); for (input_buffers, 0..) |*b, i| { - b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i)); + b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]); } const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len); for (output_buffers, 0..) |*b, i| { - b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i)); + b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]); } callback(user_ctx, input_buffers, output_buffers); @@ -258,10 +259,10 @@ fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape { const dt: DataType = switch (buffer_desc.dtype) { .invalid => @panic("invalid ffi"), .pred => .bool, - .s8 => .i8, - .s16 => .i16, - .s32 => .i32, - .s64 => .i64, + .i8 => .i8, + .i16 => .i16, + .i32 => .i32, + .i64 => .i64, .token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"), inline else => |t| @field(DataType, @tagName(t)), }; @@ -278,3 +279,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { buffer_desc.data[0..buffer_shape.byteSize()], ); } + +pub const cuda = struct { + pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00); + pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00); + var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00); + var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00); + + pub const MemcpyKind = enum(c_int) { + host_to_host = 0, + host_to_device = 1, + device_to_host = 2, + device_to_device = 3, + inferred = 4, + }; + + const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int; + const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int; + const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int; + const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int; + + pub fn init() void { + var cudart = std.DynLib.open("libcudart.so.12") catch { + log.err("cudart not found, callback will segfault", .{}); + return; + }; + defer cudart.close(); + + _memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse { + @panic("cudaMemcpyAsync not found"); + }; + _memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse { + @panic("cudaMemcpy not found"); + }; + streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse { + @panic("cudaStreamSynchronize not found"); + }; + cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse { + @panic("cudaLaunchHostFunc not found"); + }; + } + + pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void { + const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host); + check(err); + } + + pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void { + const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device); + check(err); + } + + pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void { + const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream); + check(err); + } + + pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void { + const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream); + check(err); + } + + pub fn check(err: c_int) void { + if (err == 0) return; + stdx.debug.panic("CUDA error: {d}", .{err}); + } +}; diff --git a/zml/exe.zig b/zml/exe.zig index 3f4b6aa..2ab8bcb 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -12,7 +12,7 @@ const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; const ShapeOf = @import("tensor.zig").ShapeOf; -const log = std.log.scoped(.zml); +const log = std.log.scoped(.@"zml/exe"); test { std.testing.refAllDecls(@This()); @@ -135,6 +135,9 @@ pub const BaseExe = struct { /// The PJRT executable representing the compiled module. exe: *pjrt.LoadedExecutable, + /// The execution context for this executable. + context: ?*pjrt.ExecuteContext = null, + /// Pre-allocated slice of buffers to use as inputs when the module is called. input_per_device: []const [*]*pjrt.Buffer, @@ -199,6 +202,9 @@ pub const BaseExe = struct { } pub fn deinit(self: BaseExe) void { + if (self.context) |ctx| { + ctx.deinit(self.platform.pjrt_api); + } self._arena.deinit(); } @@ -220,6 +226,7 @@ pub const BaseExe = struct { // even if it has been marked as "can be donated" during compilation. // TODO: expose it ? .non_donatable_input_indices = &.{}, + .context = self.context, }) catch |err| { std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err}); }; @@ -288,11 +295,13 @@ pub const BaseExe = struct { } pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe { - return .init(parent_allocator, self.platform, self.exe, .{ - .input_shapes = self.input_shapes, + var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{ + .n_in = self.input_buffer_count, .result_shapes = self.result_shapes, .n_devices = self.num_devices, }); + exe.context = self.context; + return exe; } }; diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index bd094ea..dd1a563 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -8,6 +8,7 @@ pub const ffi = pjrt.ffi; pub const Profiler = pjrt.Profiler; pub const ApiError = pjrt.ApiError; pub const ErrorCode = pjrt.ErrorCode; +pub const ExecuteContext = pjrt.ExecuteContext; pub const BufferType = pjrt.BufferType; pub const Device = pjrt.Device; pub const DeviceDescription = pjrt.DeviceDescription; @@ -20,6 +21,7 @@ pub const SerializeResult = pjrt.SerializeResult; pub const Executable = pjrt.Executable; pub const ExecuteError = ApiError; pub const Memory = pjrt.Memory; +pub const Stream = pjrt.Stream; const log = std.log.scoped(.zml); @@ -120,7 +122,7 @@ pub const Client = opaque { return self.inner().addressableMemories(api); } - pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory { + pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*const Memory { for (self.addressableMemories(api)) |mem| { if (mem.kind(api) == kind) { return mem; @@ -182,7 +184,7 @@ pub const Buffer = opaque { } pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer { - return @ptrCast(self.inner().copyToMemory(api, memory_)); + return @ptrCast(try self.inner().copyToMemory(api, memory_)); } pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event { @@ -262,6 +264,7 @@ pub const LoadedExecutable = opaque { results: []const [*]*Buffer, events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, + context: ?*ExecuteContext, }; pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void { @@ -271,6 +274,7 @@ pub const LoadedExecutable = opaque { .results = @ptrCast(args.results), .events = @ptrCast(args.events), .non_donatable_input_indices = args.non_donatable_input_indices, + .context = args.context, } }); } diff --git a/zml/platform.zig b/zml/platform.zig index ec4ae4c..f0d59ae 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -75,6 +75,17 @@ pub const Platform = struct { return res; } + pub fn registerFFIType(self: Platform, comptime T: type) !void { + if (self.pjrt_api.ffi()) |ffi| { + if (!@hasDecl(T, "type_id")) { + stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)}); + } + try ffi.registerTypeId(self.pjrt_api, T); + } else { + stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)}); + } + } + pub fn deinit(self: *Platform) void { self.pjrt_client.deinit(self.pjrt_api); }