Radix/pjrt/ffi.zig

524 lines
18 KiB
Zig

/// Bindings for PJRT custom call declaration / execution.
const std = @import("std");
const c = @import("c");
pub const TypeId = c.XLA_FFI_TypeId;
const stdx = @import("stdx");
const pjrt = @import("pjrt.zig");
const Stream = @import("pjrt.zig").Stream;
const log = std.log.scoped(.pjrt);
comptime {
if (@typeInfo(TypeId).@"struct".fields.len != 1) @compileError("TypeId has changed");
}
/// The signature of a generic custom call.
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
pub const ApiVersion = extern struct {
pub const major = c.XLA_FFI_API_MAJOR;
pub const minor = c.XLA_FFI_API_MINOR;
struct_size: usize,
extension_start: ?*ExtensionBase,
major_version: i32,
minor_version: i32,
};
pub const ExtensionType = enum(c.XLA_FFI_Extension_Type) {
metadata = c.XLA_FFI_Extension_Metadata,
};
pub const ExtensionBase = extern struct {
struct_size: usize,
type: ExtensionType,
next: ?*ExtensionBase,
};
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
pub const HandlerTraits = packed struct(c_uint) {
/// Calls to FFI handler are safe to trace into the command buffer.
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
/// that can be captured and then replayed.
command_buffer_compatible: bool,
__unassigned__: u31 = 0,
};
pub const Metadata = extern struct {
struct_size: usize,
api_version: ApiVersion,
traits: HandlerTraits,
};
pub const MetadataExtension = extern struct {
extension_base: ExtensionBase,
metadata: ?*Metadata,
};
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
return struct {
pub fn to(self: anytype) switch (@TypeOf(self)) {
*T => *InnerT,
*const T => *const InnerT,
else => unreachable,
} {
return @ptrCast(@alignCast(self));
}
pub fn from(self: anytype) switch (@TypeOf(self)) {
*InnerT => *T,
*const InnerT => *const T,
else => unreachable,
} {
return @ptrCast(@alignCast(self));
}
};
}
pub const Api = opaque {
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
pub fn stream(self: *const Api, context: *const ExecutionContext) *pjrt.Stream {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Stream_Get_Args{
.ctx = @constCast(context.inner()),
});
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
@panic("failed to get stream");
}
return @ptrCast(ret.stream.?);
}
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) pjrt.ApiError!*anyopaque {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
.ctx = @constCast(context.inner()),
.size = size,
.alignment = alignment,
});
const result = self.inner().XLA_FFI_DeviceMemory_Allocate.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.allocateDeviceMemory] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.data.?;
}
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
.ctx = @constCast(context.inner()),
.size = size,
.data = data,
});
const result = self.inner().XLA_FFI_DeviceMemory_Free.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.freeDeviceMemory] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
// TODO(Corentin): Implement remaining methods if needed:
// * XLA_FFI_ThreadPool_Schedule
// * XLA_FFI_Handler_Register
// * XLA_FFI_TypeId_Register
// * XLA_FFI_State_Set
// * XLA_FFI_State_Get
};
pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
instantiate = c.XLA_FFI_ExecutionStage_INSTANTIATE,
prepare = c.XLA_FFI_ExecutionStage_PREPARE,
initialize = c.XLA_FFI_ExecutionStage_INITIALIZE,
execute = c.XLA_FFI_ExecutionStage_EXECUTE,
};
pub const ExecutionContext = opaque {
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
pub fn getContext(self: *const ExecutionContext, type_id: TypeId, api: *const Api) pjrt.ApiError!*anyopaque {
var ret: c.XLA_FFI_ExecutionContext_Get_Args = .{
.struct_size = pjrt.pjrtStructSize(c.XLA_FFI_ExecutionContext_Get_Args),
.extension_start = api.inner().extension_start,
.ctx = @ptrCast(@constCast(self)),
.type_id = @constCast(&type_id),
.data = undefined, // set by XLA_FFI_ExecutionContext_Get.
};
const maybe_err = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
if (maybe_err) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
if (ret.data == null) return error.NotFound;
return ret.data.?;
}
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) pjrt.ApiError!i32 {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
.ctx = @constCast(self.inner()),
});
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.getDeviceOrdinal] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.device_ordinal;
}
const Task = fn (*anyopaque) void;
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
.ctx = @constCast(self.inner()),
.task = @ptrCast(@alignCast(task)),
.data = @ptrCast(@alignCast(data)),
});
const result = api.inner().XLA_FFI_ThreadPool_Schedule.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
std.debug.print("error: {any} \n", .{err});
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
};
pub const ByteSpan = extern struct {
ptr: [*]const u8,
len: usize,
pub fn slice(self: ByteSpan) []const u8 {
return self.ptr[0..self.len];
}
};
pub const DataType = enum(c.XLA_FFI_DataType) {
invalid = c.XLA_FFI_DataType_INVALID,
bool = c.XLA_FFI_DataType_PRED,
i8 = c.XLA_FFI_DataType_S8,
i16 = c.XLA_FFI_DataType_S16,
i32 = c.XLA_FFI_DataType_S32,
i64 = c.XLA_FFI_DataType_S64,
u8 = c.XLA_FFI_DataType_U8,
u16 = c.XLA_FFI_DataType_U16,
u32 = c.XLA_FFI_DataType_U32,
u64 = c.XLA_FFI_DataType_U64,
f16 = c.XLA_FFI_DataType_F16,
f32 = c.XLA_FFI_DataType_F32,
f64 = c.XLA_FFI_DataType_F64,
bf16 = c.XLA_FFI_DataType_BF16,
c64 = c.XLA_FFI_DataType_C64,
c128 = c.XLA_FFI_DataType_C128,
token = c.XLA_FFI_DataType_TOKEN,
f8e5m2 = c.XLA_FFI_DataType_F8E5M2,
f8e3m4 = c.XLA_FFI_DataType_F8E3M4,
f8e4m3 = c.XLA_FFI_DataType_F8E4M3,
f8e4m3fn = c.XLA_FFI_DataType_F8E4M3FN,
f8e4m3b11fnuz = c.XLA_FFI_DataType_F8E4M3B11FNUZ,
f8e5m2fnuz = c.XLA_FFI_DataType_F8E5M2FNUZ,
f8e4m3fnuz = c.XLA_FFI_DataType_F8E4M3FNUZ,
};
pub const Buffer = extern struct {
struct_size: usize,
extension_start: ?*c.XLA_FFI_Extension_Base,
dtype: DataType,
data: [*]u8,
rank: u64,
_dims: [*]const i64,
pub fn dims(self: Buffer) []const i64 {
return self._dims[0..self.rank];
}
pub fn format(buffer: Buffer, writer: *std.Io.Writer) !void {
try writer.print("FfiBuffer({any}, .{t})@0x{x}", .{ buffer.dims(), buffer.dtype, @intFromPtr(buffer.data) });
}
};
pub const Args = extern struct {
struct_size: usize,
extension_start: ?*const c.XLA_FFI_Extension_Base,
len: u64,
types: [*]const Type,
ptr: [*]*const Buffer,
pub const Type = enum(c.XLA_FFI_ArgType) {
buffer = c.XLA_FFI_ArgType_BUFFER,
};
pub fn buffers(self: Args) []*const Buffer {
return self.ptr[0..self.len];
}
};
pub const Rets = extern struct {
struct_size: usize,
extension_start: ?*const c.XLA_FFI_Extension_Base,
len: u64,
types: [*]const Type,
ptr: [*]*const Buffer,
pub const Type = enum(c.XLA_FFI_RetType) {
buffer = c.XLA_FFI_RetType_BUFFER,
};
pub fn buffers(self: Rets) []*const Buffer {
return self.ptr[0..self.len];
}
};
pub const AttrType = enum(c.XLA_FFI_AttrType) {
array = c.XLA_FFI_AttrType_ARRAY,
dictionary = c.XLA_FFI_AttrType_DICTIONARY,
scalar = c.XLA_FFI_AttrType_SCALAR,
string = c.XLA_FFI_AttrType_STRING,
};
pub const Attrs = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
len: u64,
types: [*]const AttrType,
names: [*]const *const ByteSpan,
ptr: [*]const *const Attr,
const Attr = extern union {
scalar: Scalar,
array: Array,
};
pub const Scalar = extern struct {
dtype: DataType,
value: *const anyopaque,
pub fn get(self: Scalar, T: type) T {
const ptr: *const T = @ptrCast(@alignCast(self.value));
return ptr.*;
}
};
pub const Array = extern struct {
dtype: DataType,
len: usize,
data: [*]const u8,
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len];
}
};
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len];
}
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
const attr = self.ptr[0..self.len][index];
const actual_type = self.types[index];
if (actual_type != attr_type) return null;
return @ptrCast(attr);
}
pub fn getByName(self: Attrs, comptime attr_type: AttrType, name: []const u8) ?*const @FieldType(Attr, @tagName(attr_type)) {
const names = self.names[0..self.len];
for (0.., names) |i, attr_name| {
if (std.mem.eql(u8, attr_name.slice(), name)) {
return self.getByIndex(attr_type, i);
}
}
return null;
}
};
/// All informations needed by the user callback,
/// including the list of input/ouput buffers to work on.
pub const CallFrame = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
api: *const Api,
ctx: *const ExecutionContext,
stage: ExecutionStage,
args: Args,
results: Rets,
attrs: Attrs,
future: ?*Future,
/// The registery mechanism will first call the custom call in registration mode,
/// and expects us to indicate which version of XLA we have been compiled against.
/// Returns true if we registered ourselves and if the caller custom call should return early.
pub fn registeringHook(call_frame: *CallFrame) bool {
if (call_frame.extension_start != null and call_frame.extension_start.?.type == .metadata) {
const metadata_extension: *MetadataExtension = @fieldParentPtr("extension_base", call_frame.extension_start.?);
metadata_extension.metadata.?.api_version.major_version = ApiVersion.major;
metadata_extension.metadata.?.api_version.minor_version = ApiVersion.minor;
return true;
}
return false;
}
pub fn stream(call_frame: CallFrame) ?*const pjrt.Stream {
return call_frame.api.stream(call_frame.ctx);
}
};
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
unknown = c.XLA_FFI_Error_Code_UNKNOWN,
invalid_argument = c.XLA_FFI_Error_Code_INVALID_ARGUMENT,
deadline_exceeded = c.XLA_FFI_Error_Code_DEADLINE_EXCEEDED,
not_found = c.XLA_FFI_Error_Code_NOT_FOUND,
already_exists = c.XLA_FFI_Error_Code_ALREADY_EXISTS,
permission_denied = c.XLA_FFI_Error_Code_PERMISSION_DENIED,
resource_exhausted = c.XLA_FFI_Error_Code_RESOURCE_EXHAUSTED,
failed_precondition = c.XLA_FFI_Error_Code_FAILED_PRECONDITION,
aborted = c.XLA_FFI_Error_Code_ABORTED,
out_of_range = c.XLA_FFI_Error_Code_OUT_OF_RANGE,
unimplemented = c.XLA_FFI_Error_Code_UNIMPLEMENTED,
internal = c.XLA_FFI_Error_Code_INTERNAL,
unavailable = c.XLA_FFI_Error_Code_UNAVAILABLE,
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
pub fn toApiError(code: ErrorCode) pjrt.ApiError {
return switch (code) {
.cancelled => error.Cancelled,
.unknown => error.Unknown,
.invalid_argument => error.InvalidArgument,
.deadline_exceeded => error.DeadlineExceeded,
.not_found => error.FfiNotFound,
.already_exists => error.AlreadyExists,
.permission_denied => error.PermissionDenied,
.resource_exhausted => error.ResourceExhausted,
.failed_precondition => error.FailedPrecondition,
.aborted => error.Aborted,
.out_of_range => error.OutOfRange,
.unimplemented => error.Unimplemented,
.internal => error.Internal,
.unavailable => error.Unavailable,
.data_loss => error.DataLoss,
.unauthenticated => error.Unauthenticated,
};
}
};
pub const Error = opaque {
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
pub const ok: ?*Error = null;
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Create_Args{
.message = message.ptr,
.errc = @intFromEnum(error_code),
});
return fromInner(api.inner().XLA_FFI_Error_Create.?(&ret).?);
}
pub fn destroy(err: *Error, api: *const Api) void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
api.inner().XLA_FFI_Error_Destroy.?(&ret);
}
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
.@"error" = err.inner(),
});
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
return std.mem.span(ret.message);
}
};
pub const Future = opaque {
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
pub fn create(api: *const Api) pjrt.ApiError!*Future {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_Create_Args{});
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[Future.create] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return fromInner(ret.future.?);
}
pub fn setAvailable(self: *Future, api: *const Api) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
.future = self.inner(),
});
const result = api.inner().XLA_FFI_Future_SetAvailable.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[Future.setAvailable] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
pub fn setError(self: *Future, api: *const Api, err: *Error) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetError_Args{
.future = self.inner(),
.@"error" = err.inner(),
});
const result = api.inner().XLA_FFI_Future_SetError.?(&ret);
if (result) |ffi_error| {
const err2 = Error.fromInner(ffi_error);
defer err2.destroy(api);
log.err("[Future.setError] {s}", .{err2.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
};