Add preliminary implementation for custom call support.

This commit is contained in:
Tarry Singh 2024-12-10 09:36:37 +00:00
parent 1d5b79111a
commit 6aa9aa5a7b
7 changed files with 291 additions and 111 deletions

View File

@ -91,27 +91,25 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
pub const Api = opaque { pub const Api = opaque {
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to; pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque { pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{ var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
.ctx = if (context) |ctx| ctx.inner() else null, .ctx = @constCast(context.inner()),
}); });
const result = self.inner().XLA_FFI_Stream_Get.?(&ret); const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
if (result) |ffi_error| { if (result) |ffi_error| {
const err = Error.fromInner(ffi_error); const err = Error.fromInner(ffi_error);
defer err.destroy(self); defer err.destroy(self);
log.err("[Api.getStream] {s}", .{err.getMessage(self)}); log.err("[Api.getStream] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA. @panic("failed to get stream");
return error.Unknown;
} }
return ret.stream.?; return @ptrCast(ret.stream.?);
} }
pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque { pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{ var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
.ctx = if (context) |ctx| ctx.inner() else null, .ctx = @constCast(context.inner()),
.size = size, .size = size,
.alignment = alignment, .alignment = alignment,
}); });
@ -129,9 +127,9 @@ pub const Api = opaque {
return ret.data.?; return ret.data.?;
} }
pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void { pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{ var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
.ctx = if (context) |ctx| ctx.inner() else null, .ctx = @constCast(context.inner()),
.size = size, .size = size,
.data = data, .data = data,
}); });
@ -165,36 +163,13 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
pub const ExecutionContext = opaque { pub const ExecutionContext = opaque {
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to; pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
// pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void { pub fn Context(comptime T: type) type {
// // register type id ==> typeid return struct {
// const typename_ = "zml." ++ @typeName(@TypeOf(value)); pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
const type_id: TypeId = .{ .type_id = T.type_id };
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// if (result) |ffi_error| {
// const err = Error.fromInner(ffi_error);
// defer err.destroy(api);
// log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)});
// // TODO(Corentin): Retrieve error code from Error when implemented in XLA.
// return error.Unknown;
// }
// }
pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{ var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
.ctx = self.inner(), .ctx = @constCast(self.inner()),
.type_id = @ptrCast(@alignCast(type_id)), .type_id = @constCast(&type_id.toCStruct()),
}); });
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
@ -207,12 +182,65 @@ pub const ExecutionContext = opaque {
return error.Unknown; return error.Unknown;
} }
return ret.data.?; if (ret.data == null) return error.NotFound;
return @ptrCast(@alignCast(ret.data.?));
}
};
} }
// TODO getDeviceOrdinal() pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
.ctx = @constCast(self.inner()),
});
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.getDeviceOrdinal] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.device_ordinal;
}
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
.ctx = @constCast(self.inner()),
.task = @ptrCast(@alignCast(task)),
.data = @ptrCast(@alignCast(data)),
});
const result = api.inner().XLA_FFI_ThreadPool_Schedule.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
std.debug.print("error: {any} \n", .{err});
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
fn getTypeId(type_name: []const u8) TypeId {
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
return .{
.type_id = id,
};
}
}; };
const TypeId = c.XLA_FFI_TypeId;
const Task = fn (*anyopaque) void;
const Stream = @import("pjrt.zig").Stream;
const ByteSpan = extern struct { const ByteSpan = extern struct {
ptr: [*]const u8, ptr: [*]const u8,
len: usize, len: usize,
@ -222,17 +250,13 @@ const ByteSpan = extern struct {
} }
}; };
pub const TypeId = extern struct {
type_id: i64,
};
pub const DataType = enum(c.XLA_FFI_DataType) { pub const DataType = enum(c.XLA_FFI_DataType) {
invalid = c.XLA_FFI_DataType_INVALID, invalid = c.XLA_FFI_DataType_INVALID,
pred = c.XLA_FFI_DataType_PRED, pred = c.XLA_FFI_DataType_PRED,
s8 = c.XLA_FFI_DataType_S8, i8 = c.XLA_FFI_DataType_S8,
s16 = c.XLA_FFI_DataType_S16, i16 = c.XLA_FFI_DataType_S16,
s32 = c.XLA_FFI_DataType_S32, i32 = c.XLA_FFI_DataType_S32,
s64 = c.XLA_FFI_DataType_S64, i64 = c.XLA_FFI_DataType_S64,
u8 = c.XLA_FFI_DataType_U8, u8 = c.XLA_FFI_DataType_U8,
u16 = c.XLA_FFI_DataType_U16, u16 = c.XLA_FFI_DataType_U16,
u32 = c.XLA_FFI_DataType_U32, u32 = c.XLA_FFI_DataType_U32,
@ -289,9 +313,8 @@ pub const Args = extern struct {
buffer = c.XLA_FFI_ArgType_BUFFER, buffer = c.XLA_FFI_ArgType_BUFFER,
}; };
pub fn get(self: Args, i: usize) *const Buffer { pub fn buffers(self: Args) []*const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer); return self.ptr[0..self.len];
return self.ptr[0..self.len][i];
} }
}; };
@ -306,9 +329,8 @@ pub const Rets = extern struct {
buffer = c.XLA_FFI_RetType_BUFFER, buffer = c.XLA_FFI_RetType_BUFFER,
}; };
pub fn get(self: Rets, i: usize) *const Buffer { pub fn buffers(self: Rets) []*const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer); return self.ptr[0..self.len];
return self.ptr[0..self.len][i];
} }
}; };
@ -346,8 +368,18 @@ pub const Attrs = extern struct {
dtype: DataType, dtype: DataType,
len: usize, len: usize,
data: [*]const u8, data: [*]const u8,
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
return ptr[0..self.len];
}
}; };
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
return ptr[0..self.len];
}
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) { pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
const attr = self.ptr[0..self.len][index]; const attr = self.ptr[0..self.len][index];
const actual_type = self.types[index]; const actual_type = self.types[index];
@ -370,8 +402,8 @@ pub const Attrs = extern struct {
pub const CallFrame = extern struct { pub const CallFrame = extern struct {
struct_size: usize, struct_size: usize,
extension_start: ?*ExtensionBase, extension_start: ?*ExtensionBase,
api: ?*const Api, api: *const Api,
ctx: ?*const ExecutionContext, ctx: *const ExecutionContext,
stage: ExecutionStage, stage: ExecutionStage,
args: Args, args: Args,
results: Rets, results: Rets,
@ -438,7 +470,7 @@ pub const Error = opaque {
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to; pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from; pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error { pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{ var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
.message = message.ptr, .message = message.ptr,
.errc = @intFromEnum(error_code), .errc = @intFromEnum(error_code),

View File

@ -31,7 +31,7 @@ fn pjrtStructSize(comptime T: type) usize {
return @field(c, typedef_name ++ "_STRUCT_SIZE"); return @field(c, typedef_name ++ "_STRUCT_SIZE");
} }
inline fn pjrtStruct(v: anytype) @TypeOf(v) { pub inline fn pjrtStruct(v: anytype) @TypeOf(v) {
var ret = v; var ret = v;
ret.struct_size = pjrtStructSize(@TypeOf(v)); ret.struct_size = pjrtStructSize(@TypeOf(v));
return ret; return ret;
@ -160,9 +160,14 @@ pub const Api = struct {
return state.str; return state.str;
} }
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { pub fn createExecuteContext(api: *const Api) ApiError!*ExecuteContext {
const ret = try api.call(.PJRT_ExecuteContext_Create, .{});
return @ptrCast(ret.context.?);
}
pub fn ffi(api: *const Api) ?FFI {
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| { if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
return .{ .inner = ext.register_handler.? }; return .{ .inner = ext };
} }
return null; return null;
} }
@ -279,6 +284,8 @@ pub const ShapeSpec = extern struct {
} }
}; };
pub const Stream = opaque {};
pub const Client = opaque { pub const Client = opaque {
const inner = InnerMixin(c.PJRT_Client).inner; const inner = InnerMixin(c.PJRT_Client).inner;
@ -414,7 +421,7 @@ pub const Client = opaque {
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
}.call, }.call,
on_delete_callback_arg: ?*anyopaque = null, on_delete_callback_arg: ?*anyopaque = null,
stream: ?isize = null, stream: ?*const Stream = null,
}; };
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
@ -429,7 +436,7 @@ pub const Client = opaque {
.device = @ptrCast(@constCast(args.device)), .device = @ptrCast(@constCast(args.device)),
.on_delete_callback = args.on_delete_callback, .on_delete_callback = args.on_delete_callback,
.on_delete_callback_arg = args.on_delete_callback_arg, .on_delete_callback_arg = args.on_delete_callback_arg,
.stream = if (args.stream) |stream| stream else 0, .stream = @bitCast(@intFromPtr(args.stream)),
}); });
return @ptrCast(ret.buffer.?); return @ptrCast(ret.buffer.?);
} }
@ -444,20 +451,19 @@ pub const Client = opaque {
return &.{}; return &.{};
} }
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer { pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
const ret = try api.call(.PJRT_Client_DMA_Map, .{ try api.call(.PJRT_Client_DmaMap, .{
.client = self.inner(), .client = self.inner(),
.data = @ptrCast(@constCast(data.ptr)), .data = @ptrCast(@constCast(data.ptr)),
.size = @intCast(data.len), .size = @intCast(data.len),
}); });
return @ptrCast(ret.buffer.?);
} }
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void { pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
_ = api.call(.PJRT_Client_DMA_Unmap, .{ try api.call(.PJRT_Client_DmaUnmap, .{
.client = self.inner(), .client = self.inner(),
.data = @ptrCast(@constCast(data.ptr)), .data = @ptrCast(@constCast(data.ptr)),
}) catch unreachable; });
} }
pub const CreateBuffersForAsyncHostToDeviceArgs = struct { pub const CreateBuffersForAsyncHostToDeviceArgs = struct {
@ -564,6 +570,14 @@ pub const SerializeResult = struct {
} }
}; };
pub const ExecuteContext = opaque {
pub fn deinit(self: *ExecuteContext, api: *const Api) void {
_ = api.call(.PJRT_ExecuteContext_Destroy, .{
.context = @ptrCast(self),
}) catch {};
}
};
pub const Executable = opaque { pub const Executable = opaque {
const inner = InnerMixin(c.PJRT_Executable).inner; const inner = InnerMixin(c.PJRT_Executable).inner;
@ -630,6 +644,7 @@ pub const LoadedExecutable = opaque {
results: []const [*]*Buffer, results: []const [*]*Buffer,
events: []?*Event, events: []?*Event,
non_donatable_input_indices: []const i64 = &.{}, non_donatable_input_indices: []const i64 = &.{},
context: ?*ExecuteContext,
}; };
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void { pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{ var options = pjrtStruct(c.PJRT_ExecuteOptions{
@ -640,6 +655,7 @@ pub const LoadedExecutable = opaque {
.launch_id = 0, .launch_id = 0,
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr), .non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
.num_non_donatable_input_indices = args.non_donatable_input_indices.len, .num_non_donatable_input_indices = args.non_donatable_input_indices.len,
.context = @ptrCast(args.context),
}); });
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{ _ = try api.call(.PJRT_LoadedExecutable_Execute, .{
.executable = self.inner(), .executable = self.inner(),
@ -653,7 +669,7 @@ pub const LoadedExecutable = opaque {
}); });
} }
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable { pub fn getExecutable(self: *const LoadedExecutable, api: *const Api) ApiError!*Executable {
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{ const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
.loaded_executable = self.inner(), .loaded_executable = self.inner(),
}); });
@ -818,7 +834,7 @@ pub const Buffer = opaque {
return ret.on_device_size_in_bytes; return ret.on_device_size_in_bytes;
} }
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer { pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer {
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{ const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
.buffer = self.inner(), .buffer = self.inner(),
.dst_device = device.inner, .dst_device = device.inner,
@ -850,7 +866,7 @@ pub const Buffer = opaque {
return @ptrCast(ret.event); return @ptrCast(ret.event);
} }
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer { pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!*Buffer {
const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{ const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{
.buffer = self.inner(), .buffer = self.inner(),
.dst_memory = @ptrCast(@constCast(dst_memory)), .dst_memory = @ptrCast(@constCast(dst_memory)),
@ -932,8 +948,8 @@ pub const Memory = opaque {
pub fn kind(self: *const Memory, api: *const Api) Kind { pub fn kind(self: *const Memory, api: *const Api) Kind {
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable; const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
const kind_ = ret.kind orelse unreachable[0..ret.kind_size]; const kind_ = ret.kind orelse unreachable;
return std.meta.stringToEnum(Kind, kind_) orelse unreachable; return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
} }
pub fn kindId(self: *const Memory, api: *const Api) u32 { pub fn kindId(self: *const Memory, api: *const Api) u32 {
@ -1044,21 +1060,6 @@ pub const AsyncHostToDeviceTransferManager = opaque {
} }
}; };
pub const ExecutionContext = opaque {
const inner = InnerMixin(c.PJRT_ExecutionContext).inner;
pub fn init(api: *const Api) ApiError!*ExecutionContext {
const ret = try api.call(.PJRT_ExecutionContext_Create, .{});
return @ptrCast(ret.context.?);
}
pub fn deinit(self: *ExecutionContext, api: *const Api) void {
_ = api.call(.PJRT_ExecutionContext_Destroy, .{
.context = self.inner(),
}) catch unreachable;
}
};
pub const NamedValue = extern struct { pub const NamedValue = extern struct {
comptime { comptime {
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue)); std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
@ -1164,17 +1165,34 @@ pub const NamedValue = extern struct {
} }
}; };
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize pub const FFI = extern struct {
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 inner: *const c.PJRT_FFI,
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_FFI_Register_Handler,
pub fn registerFfi( pub const UserData = extern struct {
self: *const CustomCallRegistry, type_id: i64,
user_data: *anyopaque,
fn toCStruct(self: UserData) c.PJRT_FFI_UserData {
return .{
.type_id = self.type_id,
.data = self.user_data,
};
}
};
pub const RegisterFfiOptions = struct {
traits: RegisterHandlerTraits = @enumFromInt(0),
};
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub fn register(
self: *const FFI,
api: *const Api, api: *const Api,
target_name: []const u8, target_name: []const u8,
platform_name: []const u8, platform_name: []const u8,
func: *const ffi.Handler, func: *const ffi.Handler,
options: RegisterFfiOptions,
) ApiError!void { ) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
.api_version = 1, .api_version = 1,
@ -1183,12 +1201,51 @@ pub const CustomCallRegistry = extern struct {
.handler = @ptrCast(@constCast(func)), .handler = @ptrCast(@constCast(func)),
.platform_name = platform_name.ptr, .platform_name = platform_name.ptr,
.platform_name_size = platform_name.len, .platform_name_size = platform_name.len,
.traits = @intFromEnum(options.traits),
}); });
const result = self.inner(&ret); const result = self.inner.register_handler.?(&ret);
if (result) |pjrt_c_error| { if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error); const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)}); log.err("registerFfi error: {s}", .{pjrt_error.getMessage(api)});
return pjrt_error.getCode(api).toApiError();
}
}
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
const type_name = @typeName(T);
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
.type_name = type_name.ptr,
.type_name_size = type_name.len,
.type_id = 0, // let the plugin assign a unique type ID
});
const result = self.inner.type_id_register.?(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
return pjrt_error.getCode(api).toApiError();
}
T.type_id = ret.type_id;
}
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
.context = @ptrCast(context),
.user_data = user_data.toCStruct(),
});
const result = self.inner.user_data_add.?(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("addUserData error: {s}", .{pjrt_error.getMessage(api)});
return pjrt_error.getCode(api).toApiError(); return pjrt_error.getCode(api).toApiError();
} }
} }
}; };
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
_,
};
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_FFI_Register_Handler,
};

View File

@ -224,7 +224,7 @@ pub const Buffer = struct {
/// Creates a Buffer from a pointer into device memory. /// Creates a Buffer from a pointer into device memory.
/// This allows to interface with other libraries producing buffers. /// This allows to interface with other libraries producing buffers.
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer { pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
var res: [Shape.MAX_RANK]i64 = undefined; var res: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| { for (0..Shape.MAX_RANK) |i| {

View File

@ -37,6 +37,7 @@ pub const Context = struct {
inline for (comptime std.enums.values(runtimes.Platform)) |t| { inline for (comptime std.enums.values(runtimes.Platform)) |t| {
if (runtimes.load(t)) |api| { if (runtimes.load(t)) |api| {
Context.apis.set(t, api); Context.apis.set(t, api);
if (t == .cuda) cuda.init();
} else |_| {} } else |_| {}
} }
} }
@ -218,10 +219,10 @@ pub const Context = struct {
const CustomCall = struct { const CustomCall = struct {
pub fn registerZmlCustomCalls(platform: Platform) !void { pub fn registerZmlCustomCalls(platform: Platform) !void {
const registry = platform.pjrt_api.customCallRegistry(); const maybe_ffi = platform.pjrt_api.ffi();
if (registry) |reg| { if (maybe_ffi) |ffi| {
try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback); try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
} else { } else {
stdx.debug.panic("Registering custom calls failed", .{}); stdx.debug.panic("Registering custom calls failed", .{});
} }
@ -240,12 +241,12 @@ const CustomCall = struct {
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len); const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
for (input_buffers, 0..) |*b, i| { for (input_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i)); b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
} }
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len); const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
for (output_buffers, 0..) |*b, i| { for (output_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i)); b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
} }
callback(user_ctx, input_buffers, output_buffers); callback(user_ctx, input_buffers, output_buffers);
@ -258,10 +259,10 @@ fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
const dt: DataType = switch (buffer_desc.dtype) { const dt: DataType = switch (buffer_desc.dtype) {
.invalid => @panic("invalid ffi"), .invalid => @panic("invalid ffi"),
.pred => .bool, .pred => .bool,
.s8 => .i8, .i8 => .i8,
.s16 => .i16, .i16 => .i16,
.s32 => .i32, .i32 => .i32,
.s64 => .i64, .i64 => .i64,
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"), .token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
inline else => |t| @field(DataType, @tagName(t)), inline else => |t| @field(DataType, @tagName(t)),
}; };
@ -278,3 +279,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
buffer_desc.data[0..buffer_shape.byteSize()], buffer_desc.data[0..buffer_shape.byteSize()],
); );
} }
pub const cuda = struct {
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
inferred = 4,
};
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int;
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int;
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int;
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
pub fn init() void {
var cudart = std.DynLib.open("libcudart.so.12") catch {
log.err("cudart not found, callback will segfault", .{});
return;
};
defer cudart.close();
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
@panic("cudaMemcpyAsync not found");
};
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
@panic("cudaMemcpy not found");
};
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
@panic("cudaStreamSynchronize not found");
};
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
@panic("cudaLaunchHostFunc not found");
};
}
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
check(err);
}
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
check(err);
}
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
check(err);
}
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
check(err);
}
pub fn check(err: c_int) void {
if (err == 0) return;
stdx.debug.panic("CUDA error: {d}", .{err});
}
};

View File

@ -12,7 +12,7 @@ const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf; const ShapeOf = @import("tensor.zig").ShapeOf;
const log = std.log.scoped(.zml); const log = std.log.scoped(.@"zml/exe");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
@ -135,6 +135,9 @@ pub const BaseExe = struct {
/// The PJRT executable representing the compiled module. /// The PJRT executable representing the compiled module.
exe: *pjrt.LoadedExecutable, exe: *pjrt.LoadedExecutable,
/// The execution context for this executable.
context: ?*pjrt.ExecuteContext = null,
/// Pre-allocated slice of buffers to use as inputs when the module is called. /// Pre-allocated slice of buffers to use as inputs when the module is called.
input_per_device: []const [*]*pjrt.Buffer, input_per_device: []const [*]*pjrt.Buffer,
@ -199,6 +202,9 @@ pub const BaseExe = struct {
} }
pub fn deinit(self: BaseExe) void { pub fn deinit(self: BaseExe) void {
if (self.context) |ctx| {
ctx.deinit(self.platform.pjrt_api);
}
self._arena.deinit(); self._arena.deinit();
} }
@ -220,6 +226,7 @@ pub const BaseExe = struct {
// even if it has been marked as "can be donated" during compilation. // even if it has been marked as "can be donated" during compilation.
// TODO: expose it ? // TODO: expose it ?
.non_donatable_input_indices = &.{}, .non_donatable_input_indices = &.{},
.context = self.context,
}) catch |err| { }) catch |err| {
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err}); std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
}; };
@ -288,11 +295,13 @@ pub const BaseExe = struct {
} }
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe { pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
return .init(parent_allocator, self.platform, self.exe, .{ var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
.input_shapes = self.input_shapes, .n_in = self.input_buffer_count,
.result_shapes = self.result_shapes, .result_shapes = self.result_shapes,
.n_devices = self.num_devices, .n_devices = self.num_devices,
}); });
exe.context = self.context;
return exe;
} }
}; };

View File

@ -8,6 +8,7 @@ pub const ffi = pjrt.ffi;
pub const Profiler = pjrt.Profiler; pub const Profiler = pjrt.Profiler;
pub const ApiError = pjrt.ApiError; pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode; pub const ErrorCode = pjrt.ErrorCode;
pub const ExecuteContext = pjrt.ExecuteContext;
pub const BufferType = pjrt.BufferType; pub const BufferType = pjrt.BufferType;
pub const Device = pjrt.Device; pub const Device = pjrt.Device;
pub const DeviceDescription = pjrt.DeviceDescription; pub const DeviceDescription = pjrt.DeviceDescription;
@ -20,6 +21,7 @@ pub const SerializeResult = pjrt.SerializeResult;
pub const Executable = pjrt.Executable; pub const Executable = pjrt.Executable;
pub const ExecuteError = ApiError; pub const ExecuteError = ApiError;
pub const Memory = pjrt.Memory; pub const Memory = pjrt.Memory;
pub const Stream = pjrt.Stream;
const log = std.log.scoped(.zml); const log = std.log.scoped(.zml);
@ -120,7 +122,7 @@ pub const Client = opaque {
return self.inner().addressableMemories(api); return self.inner().addressableMemories(api);
} }
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory { pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*const Memory {
for (self.addressableMemories(api)) |mem| { for (self.addressableMemories(api)) |mem| {
if (mem.kind(api) == kind) { if (mem.kind(api) == kind) {
return mem; return mem;
@ -182,7 +184,7 @@ pub const Buffer = opaque {
} }
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer { pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
return @ptrCast(self.inner().copyToMemory(api, memory_)); return @ptrCast(try self.inner().copyToMemory(api, memory_));
} }
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event { pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
@ -262,6 +264,7 @@ pub const LoadedExecutable = opaque {
results: []const [*]*Buffer, results: []const [*]*Buffer,
events: []?*Event, events: []?*Event,
non_donatable_input_indices: []const i64 = &.{}, non_donatable_input_indices: []const i64 = &.{},
context: ?*ExecuteContext,
}; };
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void { pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
@ -271,6 +274,7 @@ pub const LoadedExecutable = opaque {
.results = @ptrCast(args.results), .results = @ptrCast(args.results),
.events = @ptrCast(args.events), .events = @ptrCast(args.events),
.non_donatable_input_indices = args.non_donatable_input_indices, .non_donatable_input_indices = args.non_donatable_input_indices,
.context = args.context,
} }); } });
} }

View File

@ -75,6 +75,17 @@ pub const Platform = struct {
return res; return res;
} }
pub fn registerFFIType(self: Platform, comptime T: type) !void {
if (self.pjrt_api.ffi()) |ffi| {
if (!@hasDecl(T, "type_id")) {
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
}
try ffi.registerTypeId(self.pjrt_api, T);
} else {
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
}
}
pub fn deinit(self: *Platform) void { pub fn deinit(self: *Platform) void {
self.pjrt_client.deinit(self.pjrt_api); self.pjrt_client.deinit(self.pjrt_api);
} }