Add preliminary implementation for custom call support.

This commit is contained in:
Tarry Singh 2024-12-10 09:36:37 +00:00
parent 1d5b79111a
commit 6aa9aa5a7b
7 changed files with 291 additions and 111 deletions

View File

@ -91,27 +91,25 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
pub const Api = opaque {
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque {
pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
.ctx = @constCast(context.inner()),
});
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
@panic("failed to get stream");
}
return ret.stream.?;
return @ptrCast(ret.stream.?);
}
pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
.ctx = @constCast(context.inner()),
.size = size,
.alignment = alignment,
});
@ -129,9 +127,9 @@ pub const Api = opaque {
return ret.data.?;
}
pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
.ctx = @constCast(context.inner()),
.size = size,
.data = data,
});
@ -165,36 +163,13 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
pub const ExecutionContext = opaque {
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
// pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void {
// // register type id ==> typeid
// const typename_ = "zml." ++ @typeName(@TypeOf(value));
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// if (result) |ffi_error| {
// const err = Error.fromInner(ffi_error);
// defer err.destroy(api);
// log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)});
// // TODO(Corentin): Retrieve error code from Error when implemented in XLA.
// return error.Unknown;
// }
// }
pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque {
pub fn Context(comptime T: type) type {
return struct {
pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
const type_id: TypeId = .{ .type_id = T.type_id };
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
.ctx = self.inner(),
.type_id = @ptrCast(@alignCast(type_id)),
.ctx = @constCast(self.inner()),
.type_id = @constCast(&type_id.toCStruct()),
});
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
@ -207,12 +182,65 @@ pub const ExecutionContext = opaque {
return error.Unknown;
}
return ret.data.?;
if (ret.data == null) return error.NotFound;
return @ptrCast(@alignCast(ret.data.?));
}
};
}
// TODO getDeviceOrdinal()
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
.ctx = @constCast(self.inner()),
});
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.getDeviceOrdinal] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.device_ordinal;
}
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
.ctx = @constCast(self.inner()),
.task = @ptrCast(@alignCast(task)),
.data = @ptrCast(@alignCast(data)),
});
const result = api.inner().XLA_FFI_ThreadPool_Schedule.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
std.debug.print("error: {any} \n", .{err});
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
fn getTypeId(type_name: []const u8) TypeId {
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
return .{
.type_id = id,
};
}
};
const TypeId = c.XLA_FFI_TypeId;
const Task = fn (*anyopaque) void;
const Stream = @import("pjrt.zig").Stream;
const ByteSpan = extern struct {
ptr: [*]const u8,
len: usize,
@ -222,17 +250,13 @@ const ByteSpan = extern struct {
}
};
pub const TypeId = extern struct {
type_id: i64,
};
pub const DataType = enum(c.XLA_FFI_DataType) {
invalid = c.XLA_FFI_DataType_INVALID,
pred = c.XLA_FFI_DataType_PRED,
s8 = c.XLA_FFI_DataType_S8,
s16 = c.XLA_FFI_DataType_S16,
s32 = c.XLA_FFI_DataType_S32,
s64 = c.XLA_FFI_DataType_S64,
i8 = c.XLA_FFI_DataType_S8,
i16 = c.XLA_FFI_DataType_S16,
i32 = c.XLA_FFI_DataType_S32,
i64 = c.XLA_FFI_DataType_S64,
u8 = c.XLA_FFI_DataType_U8,
u16 = c.XLA_FFI_DataType_U16,
u32 = c.XLA_FFI_DataType_U32,
@ -289,9 +313,8 @@ pub const Args = extern struct {
buffer = c.XLA_FFI_ArgType_BUFFER,
};
pub fn get(self: Args, i: usize) *const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer);
return self.ptr[0..self.len][i];
pub fn buffers(self: Args) []*const Buffer {
return self.ptr[0..self.len];
}
};
@ -306,9 +329,8 @@ pub const Rets = extern struct {
buffer = c.XLA_FFI_RetType_BUFFER,
};
pub fn get(self: Rets, i: usize) *const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer);
return self.ptr[0..self.len][i];
pub fn buffers(self: Rets) []*const Buffer {
return self.ptr[0..self.len];
}
};
@ -346,8 +368,18 @@ pub const Attrs = extern struct {
dtype: DataType,
len: usize,
data: [*]const u8,
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
return ptr[0..self.len];
}
};
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
return ptr[0..self.len];
}
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
const attr = self.ptr[0..self.len][index];
const actual_type = self.types[index];
@ -370,8 +402,8 @@ pub const Attrs = extern struct {
pub const CallFrame = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
api: ?*const Api,
ctx: ?*const ExecutionContext,
api: *const Api,
ctx: *const ExecutionContext,
stage: ExecutionStage,
args: Args,
results: Rets,
@ -438,7 +470,7 @@ pub const Error = opaque {
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error {
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
.message = message.ptr,
.errc = @intFromEnum(error_code),

View File

@ -31,7 +31,7 @@ fn pjrtStructSize(comptime T: type) usize {
return @field(c, typedef_name ++ "_STRUCT_SIZE");
}
inline fn pjrtStruct(v: anytype) @TypeOf(v) {
pub inline fn pjrtStruct(v: anytype) @TypeOf(v) {
var ret = v;
ret.struct_size = pjrtStructSize(@TypeOf(v));
return ret;
@ -160,9 +160,14 @@ pub const Api = struct {
return state.str;
}
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
pub fn createExecuteContext(api: *const Api) ApiError!*ExecuteContext {
const ret = try api.call(.PJRT_ExecuteContext_Create, .{});
return @ptrCast(ret.context.?);
}
pub fn ffi(api: *const Api) ?FFI {
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
return .{ .inner = ext.register_handler.? };
return .{ .inner = ext };
}
return null;
}
@ -279,6 +284,8 @@ pub const ShapeSpec = extern struct {
}
};
pub const Stream = opaque {};
pub const Client = opaque {
const inner = InnerMixin(c.PJRT_Client).inner;
@ -414,7 +421,7 @@ pub const Client = opaque {
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
}.call,
on_delete_callback_arg: ?*anyopaque = null,
stream: ?isize = null,
stream: ?*const Stream = null,
};
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
@ -429,7 +436,7 @@ pub const Client = opaque {
.device = @ptrCast(@constCast(args.device)),
.on_delete_callback = args.on_delete_callback,
.on_delete_callback_arg = args.on_delete_callback_arg,
.stream = if (args.stream) |stream| stream else 0,
.stream = @bitCast(@intFromPtr(args.stream)),
});
return @ptrCast(ret.buffer.?);
}
@ -444,20 +451,19 @@ pub const Client = opaque {
return &.{};
}
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer {
const ret = try api.call(.PJRT_Client_DMA_Map, .{
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
try api.call(.PJRT_Client_DmaMap, .{
.client = self.inner(),
.data = @ptrCast(@constCast(data.ptr)),
.size = @intCast(data.len),
});
return @ptrCast(ret.buffer.?);
}
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void {
_ = api.call(.PJRT_Client_DMA_Unmap, .{
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
try api.call(.PJRT_Client_DmaUnmap, .{
.client = self.inner(),
.data = @ptrCast(@constCast(data.ptr)),
}) catch unreachable;
});
}
pub const CreateBuffersForAsyncHostToDeviceArgs = struct {
@ -564,6 +570,14 @@ pub const SerializeResult = struct {
}
};
pub const ExecuteContext = opaque {
pub fn deinit(self: *ExecuteContext, api: *const Api) void {
_ = api.call(.PJRT_ExecuteContext_Destroy, .{
.context = @ptrCast(self),
}) catch {};
}
};
pub const Executable = opaque {
const inner = InnerMixin(c.PJRT_Executable).inner;
@ -630,6 +644,7 @@ pub const LoadedExecutable = opaque {
results: []const [*]*Buffer,
events: []?*Event,
non_donatable_input_indices: []const i64 = &.{},
context: ?*ExecuteContext,
};
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{
@ -640,6 +655,7 @@ pub const LoadedExecutable = opaque {
.launch_id = 0,
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
.num_non_donatable_input_indices = args.non_donatable_input_indices.len,
.context = @ptrCast(args.context),
});
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{
.executable = self.inner(),
@ -653,7 +669,7 @@ pub const LoadedExecutable = opaque {
});
}
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
pub fn getExecutable(self: *const LoadedExecutable, api: *const Api) ApiError!*Executable {
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
.loaded_executable = self.inner(),
});
@ -818,7 +834,7 @@ pub const Buffer = opaque {
return ret.on_device_size_in_bytes;
}
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer {
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer {
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
.buffer = self.inner(),
.dst_device = device.inner,
@ -850,7 +866,7 @@ pub const Buffer = opaque {
return @ptrCast(ret.event);
}
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer {
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!*Buffer {
const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{
.buffer = self.inner(),
.dst_memory = @ptrCast(@constCast(dst_memory)),
@ -932,8 +948,8 @@ pub const Memory = opaque {
pub fn kind(self: *const Memory, api: *const Api) Kind {
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
const kind_ = ret.kind orelse unreachable[0..ret.kind_size];
return std.meta.stringToEnum(Kind, kind_) orelse unreachable;
const kind_ = ret.kind orelse unreachable;
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
}
pub fn kindId(self: *const Memory, api: *const Api) u32 {
@ -1044,21 +1060,6 @@ pub const AsyncHostToDeviceTransferManager = opaque {
}
};
pub const ExecutionContext = opaque {
const inner = InnerMixin(c.PJRT_ExecutionContext).inner;
pub fn init(api: *const Api) ApiError!*ExecutionContext {
const ret = try api.call(.PJRT_ExecutionContext_Create, .{});
return @ptrCast(ret.context.?);
}
pub fn deinit(self: *ExecutionContext, api: *const Api) void {
_ = api.call(.PJRT_ExecutionContext_Destroy, .{
.context = self.inner(),
}) catch unreachable;
}
};
pub const NamedValue = extern struct {
comptime {
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
@ -1164,17 +1165,34 @@ pub const NamedValue = extern struct {
}
};
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_FFI_Register_Handler,
pub const FFI = extern struct {
inner: *const c.PJRT_FFI,
pub fn registerFfi(
self: *const CustomCallRegistry,
pub const UserData = extern struct {
type_id: i64,
user_data: *anyopaque,
fn toCStruct(self: UserData) c.PJRT_FFI_UserData {
return .{
.type_id = self.type_id,
.data = self.user_data,
};
}
};
pub const RegisterFfiOptions = struct {
traits: RegisterHandlerTraits = @enumFromInt(0),
};
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub fn register(
self: *const FFI,
api: *const Api,
target_name: []const u8,
platform_name: []const u8,
func: *const ffi.Handler,
options: RegisterFfiOptions,
) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
.api_version = 1,
@ -1183,12 +1201,51 @@ pub const CustomCallRegistry = extern struct {
.handler = @ptrCast(@constCast(func)),
.platform_name = platform_name.ptr,
.platform_name_size = platform_name.len,
.traits = @intFromEnum(options.traits),
});
const result = self.inner(&ret);
const result = self.inner.register_handler.?(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)});
log.err("registerFfi error: {s}", .{pjrt_error.getMessage(api)});
return pjrt_error.getCode(api).toApiError();
}
}
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
const type_name = @typeName(T);
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
.type_name = type_name.ptr,
.type_name_size = type_name.len,
.type_id = 0, // let the plugin assign a unique type ID
});
const result = self.inner.type_id_register.?(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
return pjrt_error.getCode(api).toApiError();
}
T.type_id = ret.type_id;
}
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
.context = @ptrCast(context),
.user_data = user_data.toCStruct(),
});
const result = self.inner.user_data_add.?(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("addUserData error: {s}", .{pjrt_error.getMessage(api)});
return pjrt_error.getCode(api).toApiError();
}
}
};
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
_,
};
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_FFI_Register_Handler,
};

View File

@ -224,7 +224,7 @@ pub const Buffer = struct {
/// Creates a Buffer from a pointer into device memory.
/// This allows to interface with other libraries producing buffers.
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer {
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
var res: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| {

View File

@ -37,6 +37,7 @@ pub const Context = struct {
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
if (runtimes.load(t)) |api| {
Context.apis.set(t, api);
if (t == .cuda) cuda.init();
} else |_| {}
}
}
@ -218,10 +219,10 @@ pub const Context = struct {
const CustomCall = struct {
pub fn registerZmlCustomCalls(platform: Platform) !void {
const registry = platform.pjrt_api.customCallRegistry();
const maybe_ffi = platform.pjrt_api.ffi();
if (registry) |reg| {
try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback);
if (maybe_ffi) |ffi| {
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
} else {
stdx.debug.panic("Registering custom calls failed", .{});
}
@ -240,12 +241,12 @@ const CustomCall = struct {
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
for (input_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i));
b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
}
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
for (output_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i));
b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
}
callback(user_ctx, input_buffers, output_buffers);
@ -258,10 +259,10 @@ fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
const dt: DataType = switch (buffer_desc.dtype) {
.invalid => @panic("invalid ffi"),
.pred => .bool,
.s8 => .i8,
.s16 => .i16,
.s32 => .i32,
.s64 => .i64,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
inline else => |t| @field(DataType, @tagName(t)),
};
@ -278,3 +279,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
buffer_desc.data[0..buffer_shape.byteSize()],
);
}
pub const cuda = struct {
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
inferred = 4,
};
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int;
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int;
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int;
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
pub fn init() void {
var cudart = std.DynLib.open("libcudart.so.12") catch {
log.err("cudart not found, callback will segfault", .{});
return;
};
defer cudart.close();
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
@panic("cudaMemcpyAsync not found");
};
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
@panic("cudaMemcpy not found");
};
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
@panic("cudaStreamSynchronize not found");
};
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
@panic("cudaLaunchHostFunc not found");
};
}
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
check(err);
}
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
check(err);
}
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
check(err);
}
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
check(err);
}
pub fn check(err: c_int) void {
if (err == 0) return;
stdx.debug.panic("CUDA error: {d}", .{err});
}
};

View File

@ -12,7 +12,7 @@ const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const log = std.log.scoped(.zml);
const log = std.log.scoped(.@"zml/exe");
test {
std.testing.refAllDecls(@This());
@ -135,6 +135,9 @@ pub const BaseExe = struct {
/// The PJRT executable representing the compiled module.
exe: *pjrt.LoadedExecutable,
/// The execution context for this executable.
context: ?*pjrt.ExecuteContext = null,
/// Pre-allocated slice of buffers to use as inputs when the module is called.
input_per_device: []const [*]*pjrt.Buffer,
@ -199,6 +202,9 @@ pub const BaseExe = struct {
}
pub fn deinit(self: BaseExe) void {
if (self.context) |ctx| {
ctx.deinit(self.platform.pjrt_api);
}
self._arena.deinit();
}
@ -220,6 +226,7 @@ pub const BaseExe = struct {
// even if it has been marked as "can be donated" during compilation.
// TODO: expose it ?
.non_donatable_input_indices = &.{},
.context = self.context,
}) catch |err| {
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
};
@ -288,11 +295,13 @@ pub const BaseExe = struct {
}
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
return .init(parent_allocator, self.platform, self.exe, .{
.input_shapes = self.input_shapes,
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
.n_in = self.input_buffer_count,
.result_shapes = self.result_shapes,
.n_devices = self.num_devices,
});
exe.context = self.context;
return exe;
}
};

View File

@ -8,6 +8,7 @@ pub const ffi = pjrt.ffi;
pub const Profiler = pjrt.Profiler;
pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode;
pub const ExecuteContext = pjrt.ExecuteContext;
pub const BufferType = pjrt.BufferType;
pub const Device = pjrt.Device;
pub const DeviceDescription = pjrt.DeviceDescription;
@ -20,6 +21,7 @@ pub const SerializeResult = pjrt.SerializeResult;
pub const Executable = pjrt.Executable;
pub const ExecuteError = ApiError;
pub const Memory = pjrt.Memory;
pub const Stream = pjrt.Stream;
const log = std.log.scoped(.zml);
@ -120,7 +122,7 @@ pub const Client = opaque {
return self.inner().addressableMemories(api);
}
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory {
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*const Memory {
for (self.addressableMemories(api)) |mem| {
if (mem.kind(api) == kind) {
return mem;
@ -182,7 +184,7 @@ pub const Buffer = opaque {
}
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
return @ptrCast(self.inner().copyToMemory(api, memory_));
return @ptrCast(try self.inner().copyToMemory(api, memory_));
}
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
@ -262,6 +264,7 @@ pub const LoadedExecutable = opaque {
results: []const [*]*Buffer,
events: []?*Event,
non_donatable_input_indices: []const i64 = &.{},
context: ?*ExecuteContext,
};
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
@ -271,6 +274,7 @@ pub const LoadedExecutable = opaque {
.results = @ptrCast(args.results),
.events = @ptrCast(args.events),
.non_donatable_input_indices = args.non_donatable_input_indices,
.context = args.context,
} });
}

View File

@ -75,6 +75,17 @@ pub const Platform = struct {
return res;
}
pub fn registerFFIType(self: Platform, comptime T: type) !void {
if (self.pjrt_api.ffi()) |ffi| {
if (!@hasDecl(T, "type_id")) {
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
}
try ffi.registerTypeId(self.pjrt_api, T);
} else {
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
}
}
pub fn deinit(self: *Platform) void {
self.pjrt_client.deinit(self.pjrt_api);
}