Add preliminary implementation for custom call support.
This commit is contained in:
parent
1d5b79111a
commit
6aa9aa5a7b
148
pjrt/ffi.zig
148
pjrt/ffi.zig
@ -91,27 +91,25 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
|||||||
pub const Api = opaque {
|
pub const Api = opaque {
|
||||||
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
||||||
|
|
||||||
pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque {
|
pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
||||||
.ctx = if (context) |ctx| ctx.inner() else null,
|
.ctx = @constCast(context.inner()),
|
||||||
});
|
});
|
||||||
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
||||||
|
|
||||||
if (result) |ffi_error| {
|
if (result) |ffi_error| {
|
||||||
const err = Error.fromInner(ffi_error);
|
const err = Error.fromInner(ffi_error);
|
||||||
defer err.destroy(self);
|
defer err.destroy(self);
|
||||||
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
|
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
|
||||||
|
|
||||||
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
@panic("failed to get stream");
|
||||||
return error.Unknown;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret.stream.?;
|
return @ptrCast(ret.stream.?);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
|
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
||||||
.ctx = if (context) |ctx| ctx.inner() else null,
|
.ctx = @constCast(context.inner()),
|
||||||
.size = size,
|
.size = size,
|
||||||
.alignment = alignment,
|
.alignment = alignment,
|
||||||
});
|
});
|
||||||
@ -129,9 +127,9 @@ pub const Api = opaque {
|
|||||||
return ret.data.?;
|
return ret.data.?;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
|
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
||||||
.ctx = if (context) |ctx| ctx.inner() else null,
|
.ctx = @constCast(context.inner()),
|
||||||
.size = size,
|
.size = size,
|
||||||
.data = data,
|
.data = data,
|
||||||
});
|
});
|
||||||
@ -165,36 +163,13 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
|
|||||||
pub const ExecutionContext = opaque {
|
pub const ExecutionContext = opaque {
|
||||||
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
||||||
|
|
||||||
// pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void {
|
pub fn Context(comptime T: type) type {
|
||||||
// // register type id ==> typeid
|
return struct {
|
||||||
// const typename_ = "zml." ++ @typeName(@TypeOf(value));
|
pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
|
||||||
|
const type_id: TypeId = .{ .type_id = T.type_id };
|
||||||
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
|
|
||||||
// .ctx = self.inner(),
|
|
||||||
// .handler = @ptrCast(@alignCast(handler)),
|
|
||||||
// });
|
|
||||||
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
|
|
||||||
|
|
||||||
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
|
|
||||||
// .ctx = self.inner(),
|
|
||||||
// .handler = @ptrCast(@alignCast(handler)),
|
|
||||||
// });
|
|
||||||
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
|
|
||||||
|
|
||||||
// if (result) |ffi_error| {
|
|
||||||
// const err = Error.fromInner(ffi_error);
|
|
||||||
// defer err.destroy(api);
|
|
||||||
// log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)});
|
|
||||||
|
|
||||||
// // TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
|
||||||
// return error.Unknown;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque {
|
|
||||||
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
|
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
|
||||||
.ctx = self.inner(),
|
.ctx = @constCast(self.inner()),
|
||||||
.type_id = @ptrCast(@alignCast(type_id)),
|
.type_id = @constCast(&type_id.toCStruct()),
|
||||||
});
|
});
|
||||||
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
||||||
|
|
||||||
@ -207,12 +182,65 @@ pub const ExecutionContext = opaque {
|
|||||||
return error.Unknown;
|
return error.Unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret.data.?;
|
if (ret.data == null) return error.NotFound;
|
||||||
|
return @ptrCast(@alignCast(ret.data.?));
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO getDeviceOrdinal()
|
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
|
||||||
|
.ctx = @constCast(self.inner()),
|
||||||
|
});
|
||||||
|
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
log.err("[ExecutionContext.getDeviceOrdinal] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.device_ordinal;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
|
||||||
|
.ctx = @constCast(self.inner()),
|
||||||
|
.task = @ptrCast(@alignCast(task)),
|
||||||
|
.data = @ptrCast(@alignCast(data)),
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = api.inner().XLA_FFI_ThreadPool_Schedule.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
std.debug.print("error: {any} \n", .{err});
|
||||||
|
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn getTypeId(type_name: []const u8) TypeId {
|
||||||
|
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.type_id = id,
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const TypeId = c.XLA_FFI_TypeId;
|
||||||
|
|
||||||
|
const Task = fn (*anyopaque) void;
|
||||||
|
|
||||||
|
const Stream = @import("pjrt.zig").Stream;
|
||||||
|
|
||||||
const ByteSpan = extern struct {
|
const ByteSpan = extern struct {
|
||||||
ptr: [*]const u8,
|
ptr: [*]const u8,
|
||||||
len: usize,
|
len: usize,
|
||||||
@ -222,17 +250,13 @@ const ByteSpan = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const TypeId = extern struct {
|
|
||||||
type_id: i64,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const DataType = enum(c.XLA_FFI_DataType) {
|
pub const DataType = enum(c.XLA_FFI_DataType) {
|
||||||
invalid = c.XLA_FFI_DataType_INVALID,
|
invalid = c.XLA_FFI_DataType_INVALID,
|
||||||
pred = c.XLA_FFI_DataType_PRED,
|
pred = c.XLA_FFI_DataType_PRED,
|
||||||
s8 = c.XLA_FFI_DataType_S8,
|
i8 = c.XLA_FFI_DataType_S8,
|
||||||
s16 = c.XLA_FFI_DataType_S16,
|
i16 = c.XLA_FFI_DataType_S16,
|
||||||
s32 = c.XLA_FFI_DataType_S32,
|
i32 = c.XLA_FFI_DataType_S32,
|
||||||
s64 = c.XLA_FFI_DataType_S64,
|
i64 = c.XLA_FFI_DataType_S64,
|
||||||
u8 = c.XLA_FFI_DataType_U8,
|
u8 = c.XLA_FFI_DataType_U8,
|
||||||
u16 = c.XLA_FFI_DataType_U16,
|
u16 = c.XLA_FFI_DataType_U16,
|
||||||
u32 = c.XLA_FFI_DataType_U32,
|
u32 = c.XLA_FFI_DataType_U32,
|
||||||
@ -289,9 +313,8 @@ pub const Args = extern struct {
|
|||||||
buffer = c.XLA_FFI_ArgType_BUFFER,
|
buffer = c.XLA_FFI_ArgType_BUFFER,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn get(self: Args, i: usize) *const Buffer {
|
pub fn buffers(self: Args) []*const Buffer {
|
||||||
std.debug.assert(self.types[0..self.len][i] == .buffer);
|
return self.ptr[0..self.len];
|
||||||
return self.ptr[0..self.len][i];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -306,9 +329,8 @@ pub const Rets = extern struct {
|
|||||||
buffer = c.XLA_FFI_RetType_BUFFER,
|
buffer = c.XLA_FFI_RetType_BUFFER,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn get(self: Rets, i: usize) *const Buffer {
|
pub fn buffers(self: Rets) []*const Buffer {
|
||||||
std.debug.assert(self.types[0..self.len][i] == .buffer);
|
return self.ptr[0..self.len];
|
||||||
return self.ptr[0..self.len][i];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -346,8 +368,18 @@ pub const Attrs = extern struct {
|
|||||||
dtype: DataType,
|
dtype: DataType,
|
||||||
len: usize,
|
len: usize,
|
||||||
data: [*]const u8,
|
data: [*]const u8,
|
||||||
|
|
||||||
|
pub fn slice(self: Array, T: type) []const T {
|
||||||
|
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
|
||||||
|
return ptr[0..self.len];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub fn slice(self: Array, T: type) []const T {
|
||||||
|
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
|
||||||
|
return ptr[0..self.len];
|
||||||
|
}
|
||||||
|
|
||||||
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
|
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
|
||||||
const attr = self.ptr[0..self.len][index];
|
const attr = self.ptr[0..self.len][index];
|
||||||
const actual_type = self.types[index];
|
const actual_type = self.types[index];
|
||||||
@ -370,8 +402,8 @@ pub const Attrs = extern struct {
|
|||||||
pub const CallFrame = extern struct {
|
pub const CallFrame = extern struct {
|
||||||
struct_size: usize,
|
struct_size: usize,
|
||||||
extension_start: ?*ExtensionBase,
|
extension_start: ?*ExtensionBase,
|
||||||
api: ?*const Api,
|
api: *const Api,
|
||||||
ctx: ?*const ExecutionContext,
|
ctx: *const ExecutionContext,
|
||||||
stage: ExecutionStage,
|
stage: ExecutionStage,
|
||||||
args: Args,
|
args: Args,
|
||||||
results: Rets,
|
results: Rets,
|
||||||
@ -438,7 +470,7 @@ pub const Error = opaque {
|
|||||||
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
||||||
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
||||||
|
|
||||||
pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error {
|
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
||||||
.message = message.ptr,
|
.message = message.ptr,
|
||||||
.errc = @intFromEnum(error_code),
|
.errc = @intFromEnum(error_code),
|
||||||
|
|||||||
135
pjrt/pjrt.zig
135
pjrt/pjrt.zig
@ -31,7 +31,7 @@ fn pjrtStructSize(comptime T: type) usize {
|
|||||||
return @field(c, typedef_name ++ "_STRUCT_SIZE");
|
return @field(c, typedef_name ++ "_STRUCT_SIZE");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn pjrtStruct(v: anytype) @TypeOf(v) {
|
pub inline fn pjrtStruct(v: anytype) @TypeOf(v) {
|
||||||
var ret = v;
|
var ret = v;
|
||||||
ret.struct_size = pjrtStructSize(@TypeOf(v));
|
ret.struct_size = pjrtStructSize(@TypeOf(v));
|
||||||
return ret;
|
return ret;
|
||||||
@ -160,9 +160,14 @@ pub const Api = struct {
|
|||||||
return state.str;
|
return state.str;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
pub fn createExecuteContext(api: *const Api) ApiError!*ExecuteContext {
|
||||||
|
const ret = try api.call(.PJRT_ExecuteContext_Create, .{});
|
||||||
|
return @ptrCast(ret.context.?);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ffi(api: *const Api) ?FFI {
|
||||||
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
||||||
return .{ .inner = ext.register_handler.? };
|
return .{ .inner = ext };
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -279,6 +284,8 @@ pub const ShapeSpec = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const Stream = opaque {};
|
||||||
|
|
||||||
pub const Client = opaque {
|
pub const Client = opaque {
|
||||||
const inner = InnerMixin(c.PJRT_Client).inner;
|
const inner = InnerMixin(c.PJRT_Client).inner;
|
||||||
|
|
||||||
@ -414,7 +421,7 @@ pub const Client = opaque {
|
|||||||
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
||||||
}.call,
|
}.call,
|
||||||
on_delete_callback_arg: ?*anyopaque = null,
|
on_delete_callback_arg: ?*anyopaque = null,
|
||||||
stream: ?isize = null,
|
stream: ?*const Stream = null,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
||||||
@ -429,7 +436,7 @@ pub const Client = opaque {
|
|||||||
.device = @ptrCast(@constCast(args.device)),
|
.device = @ptrCast(@constCast(args.device)),
|
||||||
.on_delete_callback = args.on_delete_callback,
|
.on_delete_callback = args.on_delete_callback,
|
||||||
.on_delete_callback_arg = args.on_delete_callback_arg,
|
.on_delete_callback_arg = args.on_delete_callback_arg,
|
||||||
.stream = if (args.stream) |stream| stream else 0,
|
.stream = @bitCast(@intFromPtr(args.stream)),
|
||||||
});
|
});
|
||||||
return @ptrCast(ret.buffer.?);
|
return @ptrCast(ret.buffer.?);
|
||||||
}
|
}
|
||||||
@ -444,20 +451,19 @@ pub const Client = opaque {
|
|||||||
return &.{};
|
return &.{};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer {
|
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
|
||||||
const ret = try api.call(.PJRT_Client_DMA_Map, .{
|
try api.call(.PJRT_Client_DmaMap, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
.data = @ptrCast(@constCast(data.ptr)),
|
.data = @ptrCast(@constCast(data.ptr)),
|
||||||
.size = @intCast(data.len),
|
.size = @intCast(data.len),
|
||||||
});
|
});
|
||||||
return @ptrCast(ret.buffer.?);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void {
|
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) ApiError!void {
|
||||||
_ = api.call(.PJRT_Client_DMA_Unmap, .{
|
try api.call(.PJRT_Client_DmaUnmap, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
.data = @ptrCast(@constCast(data.ptr)),
|
.data = @ptrCast(@constCast(data.ptr)),
|
||||||
}) catch unreachable;
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const CreateBuffersForAsyncHostToDeviceArgs = struct {
|
pub const CreateBuffersForAsyncHostToDeviceArgs = struct {
|
||||||
@ -564,6 +570,14 @@ pub const SerializeResult = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const ExecuteContext = opaque {
|
||||||
|
pub fn deinit(self: *ExecuteContext, api: *const Api) void {
|
||||||
|
_ = api.call(.PJRT_ExecuteContext_Destroy, .{
|
||||||
|
.context = @ptrCast(self),
|
||||||
|
}) catch {};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub const Executable = opaque {
|
pub const Executable = opaque {
|
||||||
const inner = InnerMixin(c.PJRT_Executable).inner;
|
const inner = InnerMixin(c.PJRT_Executable).inner;
|
||||||
|
|
||||||
@ -630,6 +644,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
results: []const [*]*Buffer,
|
results: []const [*]*Buffer,
|
||||||
events: []?*Event,
|
events: []?*Event,
|
||||||
non_donatable_input_indices: []const i64 = &.{},
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
|
context: ?*ExecuteContext,
|
||||||
};
|
};
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
||||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||||
@ -640,6 +655,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
.launch_id = 0,
|
.launch_id = 0,
|
||||||
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
|
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
|
||||||
.num_non_donatable_input_indices = args.non_donatable_input_indices.len,
|
.num_non_donatable_input_indices = args.non_donatable_input_indices.len,
|
||||||
|
.context = @ptrCast(args.context),
|
||||||
});
|
});
|
||||||
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{
|
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{
|
||||||
.executable = self.inner(),
|
.executable = self.inner(),
|
||||||
@ -653,7 +669,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
pub fn getExecutable(self: *const LoadedExecutable, api: *const Api) ApiError!*Executable {
|
||||||
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
|
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
|
||||||
.loaded_executable = self.inner(),
|
.loaded_executable = self.inner(),
|
||||||
});
|
});
|
||||||
@ -818,7 +834,7 @@ pub const Buffer = opaque {
|
|||||||
return ret.on_device_size_in_bytes;
|
return ret.on_device_size_in_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer {
|
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer {
|
||||||
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
|
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
|
||||||
.buffer = self.inner(),
|
.buffer = self.inner(),
|
||||||
.dst_device = device.inner,
|
.dst_device = device.inner,
|
||||||
@ -850,7 +866,7 @@ pub const Buffer = opaque {
|
|||||||
return @ptrCast(ret.event);
|
return @ptrCast(ret.event);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer {
|
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!*Buffer {
|
||||||
const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{
|
const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{
|
||||||
.buffer = self.inner(),
|
.buffer = self.inner(),
|
||||||
.dst_memory = @ptrCast(@constCast(dst_memory)),
|
.dst_memory = @ptrCast(@constCast(dst_memory)),
|
||||||
@ -932,8 +948,8 @@ pub const Memory = opaque {
|
|||||||
|
|
||||||
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
||||||
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
||||||
const kind_ = ret.kind orelse unreachable[0..ret.kind_size];
|
const kind_ = ret.kind orelse unreachable;
|
||||||
return std.meta.stringToEnum(Kind, kind_) orelse unreachable;
|
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
||||||
@ -1044,21 +1060,6 @@ pub const AsyncHostToDeviceTransferManager = opaque {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const ExecutionContext = opaque {
|
|
||||||
const inner = InnerMixin(c.PJRT_ExecutionContext).inner;
|
|
||||||
|
|
||||||
pub fn init(api: *const Api) ApiError!*ExecutionContext {
|
|
||||||
const ret = try api.call(.PJRT_ExecutionContext_Create, .{});
|
|
||||||
return @ptrCast(ret.context.?);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *ExecutionContext, api: *const Api) void {
|
|
||||||
_ = api.call(.PJRT_ExecutionContext_Destroy, .{
|
|
||||||
.context = self.inner(),
|
|
||||||
}) catch unreachable;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const NamedValue = extern struct {
|
pub const NamedValue = extern struct {
|
||||||
comptime {
|
comptime {
|
||||||
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
|
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
|
||||||
@ -1164,17 +1165,34 @@ pub const NamedValue = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
pub const FFI = extern struct {
|
||||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
inner: *const c.PJRT_FFI,
|
||||||
pub const CustomCallRegistry = extern struct {
|
|
||||||
inner: *const c.PJRT_FFI_Register_Handler,
|
|
||||||
|
|
||||||
pub fn registerFfi(
|
pub const UserData = extern struct {
|
||||||
self: *const CustomCallRegistry,
|
type_id: i64,
|
||||||
|
user_data: *anyopaque,
|
||||||
|
|
||||||
|
fn toCStruct(self: UserData) c.PJRT_FFI_UserData {
|
||||||
|
return .{
|
||||||
|
.type_id = self.type_id,
|
||||||
|
.data = self.user_data,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const RegisterFfiOptions = struct {
|
||||||
|
traits: RegisterHandlerTraits = @enumFromInt(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||||
|
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||||
|
pub fn register(
|
||||||
|
self: *const FFI,
|
||||||
api: *const Api,
|
api: *const Api,
|
||||||
target_name: []const u8,
|
target_name: []const u8,
|
||||||
platform_name: []const u8,
|
platform_name: []const u8,
|
||||||
func: *const ffi.Handler,
|
func: *const ffi.Handler,
|
||||||
|
options: RegisterFfiOptions,
|
||||||
) ApiError!void {
|
) ApiError!void {
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
||||||
.api_version = 1,
|
.api_version = 1,
|
||||||
@ -1183,12 +1201,51 @@ pub const CustomCallRegistry = extern struct {
|
|||||||
.handler = @ptrCast(@constCast(func)),
|
.handler = @ptrCast(@constCast(func)),
|
||||||
.platform_name = platform_name.ptr,
|
.platform_name = platform_name.ptr,
|
||||||
.platform_name_size = platform_name.len,
|
.platform_name_size = platform_name.len,
|
||||||
|
.traits = @intFromEnum(options.traits),
|
||||||
});
|
});
|
||||||
const result = self.inner(&ret);
|
const result = self.inner.register_handler.?(&ret);
|
||||||
if (result) |pjrt_c_error| {
|
if (result) |pjrt_c_error| {
|
||||||
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||||
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)});
|
log.err("registerFfi error: {s}", .{pjrt_error.getMessage(api)});
|
||||||
|
return pjrt_error.getCode(api).toApiError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
|
||||||
|
const type_name = @typeName(T);
|
||||||
|
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
|
||||||
|
.type_name = type_name.ptr,
|
||||||
|
.type_name_size = type_name.len,
|
||||||
|
.type_id = 0, // let the plugin assign a unique type ID
|
||||||
|
});
|
||||||
|
const result = self.inner.type_id_register.?(&ret);
|
||||||
|
if (result) |pjrt_c_error| {
|
||||||
|
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||||
|
return pjrt_error.getCode(api).toApiError();
|
||||||
|
}
|
||||||
|
|
||||||
|
T.type_id = ret.type_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
|
||||||
|
.context = @ptrCast(context),
|
||||||
|
.user_data = user_data.toCStruct(),
|
||||||
|
});
|
||||||
|
const result = self.inner.user_data_add.?(&ret);
|
||||||
|
if (result) |pjrt_c_error| {
|
||||||
|
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||||
|
log.err("addUserData error: {s}", .{pjrt_error.getMessage(api)});
|
||||||
return pjrt_error.getCode(api).toApiError();
|
return pjrt_error.getCode(api).toApiError();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
|
||||||
|
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
|
||||||
|
_,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const CustomCallRegistry = extern struct {
|
||||||
|
inner: *const c.PJRT_FFI_Register_Handler,
|
||||||
|
};
|
||||||
|
|||||||
@ -224,7 +224,7 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
/// Creates a Buffer from a pointer into device memory.
|
/// Creates a Buffer from a pointer into device memory.
|
||||||
/// This allows to interface with other libraries producing buffers.
|
/// This allows to interface with other libraries producing buffers.
|
||||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer {
|
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
|
||||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
|
|||||||
@ -37,6 +37,7 @@ pub const Context = struct {
|
|||||||
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
|
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
|
||||||
if (runtimes.load(t)) |api| {
|
if (runtimes.load(t)) |api| {
|
||||||
Context.apis.set(t, api);
|
Context.apis.set(t, api);
|
||||||
|
if (t == .cuda) cuda.init();
|
||||||
} else |_| {}
|
} else |_| {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -218,10 +219,10 @@ pub const Context = struct {
|
|||||||
|
|
||||||
const CustomCall = struct {
|
const CustomCall = struct {
|
||||||
pub fn registerZmlCustomCalls(platform: Platform) !void {
|
pub fn registerZmlCustomCalls(platform: Platform) !void {
|
||||||
const registry = platform.pjrt_api.customCallRegistry();
|
const maybe_ffi = platform.pjrt_api.ffi();
|
||||||
|
|
||||||
if (registry) |reg| {
|
if (maybe_ffi) |ffi| {
|
||||||
try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback);
|
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
|
||||||
} else {
|
} else {
|
||||||
stdx.debug.panic("Registering custom calls failed", .{});
|
stdx.debug.panic("Registering custom calls failed", .{});
|
||||||
}
|
}
|
||||||
@ -240,12 +241,12 @@ const CustomCall = struct {
|
|||||||
|
|
||||||
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
|
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
|
||||||
for (input_buffers, 0..) |*b, i| {
|
for (input_buffers, 0..) |*b, i| {
|
||||||
b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i));
|
b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
|
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
|
||||||
for (output_buffers, 0..) |*b, i| {
|
for (output_buffers, 0..) |*b, i| {
|
||||||
b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i));
|
b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
callback(user_ctx, input_buffers, output_buffers);
|
callback(user_ctx, input_buffers, output_buffers);
|
||||||
@ -258,10 +259,10 @@ fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
|
|||||||
const dt: DataType = switch (buffer_desc.dtype) {
|
const dt: DataType = switch (buffer_desc.dtype) {
|
||||||
.invalid => @panic("invalid ffi"),
|
.invalid => @panic("invalid ffi"),
|
||||||
.pred => .bool,
|
.pred => .bool,
|
||||||
.s8 => .i8,
|
.i8 => .i8,
|
||||||
.s16 => .i16,
|
.i16 => .i16,
|
||||||
.s32 => .i32,
|
.i32 => .i32,
|
||||||
.s64 => .i64,
|
.i64 => .i64,
|
||||||
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
|
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
|
||||||
inline else => |t| @field(DataType, @tagName(t)),
|
inline else => |t| @field(DataType, @tagName(t)),
|
||||||
};
|
};
|
||||||
@ -278,3 +279,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
|||||||
buffer_desc.data[0..buffer_shape.byteSize()],
|
buffer_desc.data[0..buffer_shape.byteSize()],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const cuda = struct {
|
||||||
|
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
|
||||||
|
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
|
||||||
|
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
|
||||||
|
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
|
||||||
|
|
||||||
|
pub const MemcpyKind = enum(c_int) {
|
||||||
|
host_to_host = 0,
|
||||||
|
host_to_device = 1,
|
||||||
|
device_to_host = 2,
|
||||||
|
device_to_device = 3,
|
||||||
|
inferred = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.C) c_int;
|
||||||
|
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.C) c_int;
|
||||||
|
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.C) c_int;
|
||||||
|
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
|
||||||
|
|
||||||
|
pub fn init() void {
|
||||||
|
var cudart = std.DynLib.open("libcudart.so.12") catch {
|
||||||
|
log.err("cudart not found, callback will segfault", .{});
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
defer cudart.close();
|
||||||
|
|
||||||
|
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
|
||||||
|
@panic("cudaMemcpyAsync not found");
|
||||||
|
};
|
||||||
|
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
|
||||||
|
@panic("cudaMemcpy not found");
|
||||||
|
};
|
||||||
|
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
|
||||||
|
@panic("cudaStreamSynchronize not found");
|
||||||
|
};
|
||||||
|
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
|
||||||
|
@panic("cudaLaunchHostFunc not found");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
|
||||||
|
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
|
||||||
|
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
|
||||||
|
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
|
||||||
|
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(err: c_int) void {
|
||||||
|
if (err == 0) return;
|
||||||
|
stdx.debug.panic("CUDA error: {d}", .{err});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
15
zml/exe.zig
15
zml/exe.zig
@ -12,7 +12,7 @@ const Platform = @import("platform.zig").Platform;
|
|||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.@"zml/exe");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -135,6 +135,9 @@ pub const BaseExe = struct {
|
|||||||
/// The PJRT executable representing the compiled module.
|
/// The PJRT executable representing the compiled module.
|
||||||
exe: *pjrt.LoadedExecutable,
|
exe: *pjrt.LoadedExecutable,
|
||||||
|
|
||||||
|
/// The execution context for this executable.
|
||||||
|
context: ?*pjrt.ExecuteContext = null,
|
||||||
|
|
||||||
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
||||||
input_per_device: []const [*]*pjrt.Buffer,
|
input_per_device: []const [*]*pjrt.Buffer,
|
||||||
|
|
||||||
@ -199,6 +202,9 @@ pub const BaseExe = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: BaseExe) void {
|
pub fn deinit(self: BaseExe) void {
|
||||||
|
if (self.context) |ctx| {
|
||||||
|
ctx.deinit(self.platform.pjrt_api);
|
||||||
|
}
|
||||||
self._arena.deinit();
|
self._arena.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,6 +226,7 @@ pub const BaseExe = struct {
|
|||||||
// even if it has been marked as "can be donated" during compilation.
|
// even if it has been marked as "can be donated" during compilation.
|
||||||
// TODO: expose it ?
|
// TODO: expose it ?
|
||||||
.non_donatable_input_indices = &.{},
|
.non_donatable_input_indices = &.{},
|
||||||
|
.context = self.context,
|
||||||
}) catch |err| {
|
}) catch |err| {
|
||||||
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
||||||
};
|
};
|
||||||
@ -288,11 +295,13 @@ pub const BaseExe = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
||||||
return .init(parent_allocator, self.platform, self.exe, .{
|
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
|
||||||
.input_shapes = self.input_shapes,
|
.n_in = self.input_buffer_count,
|
||||||
.result_shapes = self.result_shapes,
|
.result_shapes = self.result_shapes,
|
||||||
.n_devices = self.num_devices,
|
.n_devices = self.num_devices,
|
||||||
});
|
});
|
||||||
|
exe.context = self.context;
|
||||||
|
return exe;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ pub const ffi = pjrt.ffi;
|
|||||||
pub const Profiler = pjrt.Profiler;
|
pub const Profiler = pjrt.Profiler;
|
||||||
pub const ApiError = pjrt.ApiError;
|
pub const ApiError = pjrt.ApiError;
|
||||||
pub const ErrorCode = pjrt.ErrorCode;
|
pub const ErrorCode = pjrt.ErrorCode;
|
||||||
|
pub const ExecuteContext = pjrt.ExecuteContext;
|
||||||
pub const BufferType = pjrt.BufferType;
|
pub const BufferType = pjrt.BufferType;
|
||||||
pub const Device = pjrt.Device;
|
pub const Device = pjrt.Device;
|
||||||
pub const DeviceDescription = pjrt.DeviceDescription;
|
pub const DeviceDescription = pjrt.DeviceDescription;
|
||||||
@ -20,6 +21,7 @@ pub const SerializeResult = pjrt.SerializeResult;
|
|||||||
pub const Executable = pjrt.Executable;
|
pub const Executable = pjrt.Executable;
|
||||||
pub const ExecuteError = ApiError;
|
pub const ExecuteError = ApiError;
|
||||||
pub const Memory = pjrt.Memory;
|
pub const Memory = pjrt.Memory;
|
||||||
|
pub const Stream = pjrt.Stream;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
@ -120,7 +122,7 @@ pub const Client = opaque {
|
|||||||
return self.inner().addressableMemories(api);
|
return self.inner().addressableMemories(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory {
|
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*const Memory {
|
||||||
for (self.addressableMemories(api)) |mem| {
|
for (self.addressableMemories(api)) |mem| {
|
||||||
if (mem.kind(api) == kind) {
|
if (mem.kind(api) == kind) {
|
||||||
return mem;
|
return mem;
|
||||||
@ -182,7 +184,7 @@ pub const Buffer = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
|
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
|
||||||
return @ptrCast(self.inner().copyToMemory(api, memory_));
|
return @ptrCast(try self.inner().copyToMemory(api, memory_));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
||||||
@ -262,6 +264,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
results: []const [*]*Buffer,
|
results: []const [*]*Buffer,
|
||||||
events: []?*Event,
|
events: []?*Event,
|
||||||
non_donatable_input_indices: []const i64 = &.{},
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
|
context: ?*ExecuteContext,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
||||||
@ -271,6 +274,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
.results = @ptrCast(args.results),
|
.results = @ptrCast(args.results),
|
||||||
.events = @ptrCast(args.events),
|
.events = @ptrCast(args.events),
|
||||||
.non_donatable_input_indices = args.non_donatable_input_indices,
|
.non_donatable_input_indices = args.non_donatable_input_indices,
|
||||||
|
.context = args.context,
|
||||||
} });
|
} });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -75,6 +75,17 @@ pub const Platform = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn registerFFIType(self: Platform, comptime T: type) !void {
|
||||||
|
if (self.pjrt_api.ffi()) |ffi| {
|
||||||
|
if (!@hasDecl(T, "type_id")) {
|
||||||
|
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
|
||||||
|
}
|
||||||
|
try ffi.registerTypeId(self.pjrt_api, T);
|
||||||
|
} else {
|
||||||
|
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn deinit(self: *Platform) void {
|
pub fn deinit(self: *Platform) void {
|
||||||
self.pjrt_client.deinit(self.pjrt_api);
|
self.pjrt_client.deinit(self.pjrt_api);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user