const std = @import("std"); const builtin = @import("builtin"); const c = @import("c"); const stdx = @import("stdx"); pub const ffi = @import("ffi.zig"); const log = std.log.scoped(.pjrt); test { std.testing.refAllDecls(@This()); } // We could calculate it like PJRT does, but it turns out that some of those // were wrong in PJRT itself [1], which gets propagated to binary plugins. In // order to mirror that, we just the value as computed by PJRT itself, through // comptime reflection. We could make the argument to remove that one day since // [1] has been fixed. The problem is that this problem could happen again in // as the way PJRT does it is not very robust. // // 1. https://github.com/openxla/xla/issues/10032 pub fn pjrtStructSize(comptime T: type) usize { // unsafe on purpose, we want this to fail if that ever changes const typedef_name = comptime blk: { const needle = ".struct_"; const idx = std.mem.indexOf(u8, @typeName(T), needle).?; break :blk @typeName(T)[idx + needle.len ..]; }; return @field(c, typedef_name ++ "_STRUCT_SIZE"); } pub inline fn pjrtStruct(v: anytype) @TypeOf(v) { var ret = v; ret.struct_size = pjrtStructSize(@TypeOf(v)); return ret; } pub const ApiError = error{ Cancelled, Unknown, InvalidArgument, DeadlineExceeded, NotFound, AlreadyExists, PermissionDenied, ResourceExhausted, FailedPrecondition, Aborted, OutOfRange, Unimplemented, Internal, Unavailable, DataLoss, Unauthenticated, }; fn InnerMixin(comptime innerT: type) type { return struct { fn inner(self: anytype) *innerT { return @ptrCast(@alignCast(@constCast(self))); } }; } pub const Api = struct { pub const Version = struct { major: i64, minor: i64, }; const Funcs = std.meta.FieldEnum(c.PJRT_Api); inner: c.PJRT_Api, pub fn loadFrom(library: [:0]const u8) !*const Api { var lib: std.DynLib = switch (builtin.os.tag) { .linux => blk: { // We use RTLD_GLOBAL so that symbols from NEEDED libraries are available in the global namespace. const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = true, .NODELETE = true }) orelse { log.err("Unable to dlopen plugin: {s}", .{library}); return error.FileNotFound; }; break :blk .{ .inner = .{ .handle = handle } }; }, else => std.DynLib.open(library) catch |err| { log.err("Unable to dlopen plugin: {s}", .{library}); return err; }, }; const DynGetPjrtApi = lib.lookup(*const fn () callconv(.c) *const Api, "GetPjrtApi") orelse { std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library}); }; const api = DynGetPjrtApi(); log.info("Loaded library: {s}", .{library}); _ = api.call(.PJRT_Plugin_Initialize, .{}) catch unreachable; return api; } fn CallFnArgType(comptime func: Funcs) type { const fti = @typeInfo(@FieldType(c.PJRT_Api, @tagName(func))); const fn_ptr = @typeInfo(fti.optional.child); const fn_type_info = @typeInfo(fn_ptr.pointer.child); const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?); return arg_array_type_info.pointer.child; } inline fn call(self: *const Api, comptime method: Funcs, arg: CallFnArgType(method)) ApiError!@TypeOf(arg) { var ret = pjrtStruct(arg); const fn_ptr = @field(&self.inner, @tagName(method)).?; const result = fn_ptr(&ret); if (@TypeOf(result) == void) { return ret; } if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); log.err("[{s}] {s}", .{ @tagName(method), pjrt_error.getMessage(self) }); return pjrt_error.getCode(self).toApiError(); } return ret; } pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT { var cur: [*c]const c.PJRT_Extension_Base = @ptrCast(@alignCast(self.inner.extension_start)); while (cur != null) : (cur = cur.*.next) { if (cur.*.type == ext_id) { return @ptrCast(@alignCast(cur)); } } return null; } pub inline fn version(self: *const Api) Version { return .{ .major = @intCast(self.inner.pjrt_api_version.major_version), .minor = @intCast(self.inner.pjrt_api_version.minor_version), }; } pub fn stablehloCurrentVersion(self: *const Api) ?[]const u8 { const state = struct { var buf: [32]u8 = undefined; var str: ?[:0]const u8 = null; }; if (state.str) |str| { return str; } if (self.getPluginAttribute("stablehlo_current_version")) |v| { stdx.debug.assert(v.kind() == .int64list, "fetched attribute \"stablehlo_current_version\" from the plugin with type `{}`, expected `.int64list`", .{v.kind()}); stdx.debug.assert(v.inner.value_size == 3, "expect version format to have 3 elements representing `major.minor.patch` format, got {} elements", .{v.inner.value_size}); const value = v.inner.unnamed_0.int64_array_value[0..v.inner.value_size]; state.str = std.fmt.bufPrintZ(&state.buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable; } return state.str; } pub fn createExecuteContext(api: *const Api) ApiError!*ExecuteContext { const ret = try api.call(.PJRT_ExecuteContext_Create, .{}); return @ptrCast(ret.context.?); } pub fn ffi(api: *const Api) ?Ffi { if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| { return .{ .inner = ext }; } return null; } fn getPluginAttribute(api: *const Api, key: []const u8) ?NamedValue { const attributes = api.getPluginAttributes(); for (attributes) |attr| { if (std.mem.eql(u8, attr.name(), key)) { return attr; } } return null; } fn getPluginAttributes(api: *const Api) []const NamedValue { const ret = api.call(.PJRT_Plugin_Attributes, .{ .extension_start = null, }) catch unreachable; if (ret.attributes == null) return &.{}; return @ptrCast(ret.attributes[0..ret.num_attributes]); } }; pub const ErrorCode = enum(c.PJRT_Error_Code) { cancelled = c.PJRT_Error_Code_CANCELLED, unknown = c.PJRT_Error_Code_UNKNOWN, invalid_argument = c.PJRT_Error_Code_INVALID_ARGUMENT, deadline_exceeded = c.PJRT_Error_Code_DEADLINE_EXCEEDED, not_found = c.PJRT_Error_Code_NOT_FOUND, already_exists = c.PJRT_Error_Code_ALREADY_EXISTS, permission_denied = c.PJRT_Error_Code_PERMISSION_DENIED, resource_exhausted = c.PJRT_Error_Code_RESOURCE_EXHAUSTED, failed_precondition = c.PJRT_Error_Code_FAILED_PRECONDITION, aborted = c.PJRT_Error_Code_ABORTED, out_of_range = c.PJRT_Error_Code_OUT_OF_RANGE, unimplemented = c.PJRT_Error_Code_UNIMPLEMENTED, internal = c.PJRT_Error_Code_INTERNAL, unavailable = c.PJRT_Error_Code_UNAVAILABLE, data_loss = c.PJRT_Error_Code_DATA_LOSS, unauthenticated = c.PJRT_Error_Code_UNAUTHENTICATED, pub fn toApiError(code: ErrorCode) ApiError { return switch (code) { .cancelled => ApiError.Cancelled, .unknown => ApiError.Unknown, .invalid_argument => ApiError.InvalidArgument, .deadline_exceeded => ApiError.DeadlineExceeded, .not_found => ApiError.NotFound, .already_exists => ApiError.AlreadyExists, .permission_denied => ApiError.PermissionDenied, .resource_exhausted => ApiError.ResourceExhausted, .failed_precondition => ApiError.FailedPrecondition, .aborted => ApiError.Aborted, .out_of_range => ApiError.OutOfRange, .unimplemented => ApiError.Unimplemented, .internal => ApiError.Internal, .unavailable => ApiError.Unavailable, .data_loss => ApiError.DataLoss, .unauthenticated => ApiError.Unauthenticated, }; } }; pub const Error = opaque { pub fn deinit(self: *Error, api: *const Api) void { _ = api.call(.PJRT_Error_Destroy, .{ .@"error" = @ptrCast(self), }) catch unreachable; } pub fn getCode(self: *Error, api: *const Api) ErrorCode { const ret = api.call(.PJRT_Error_GetCode, .{ .@"error" = @ptrCast(self), }) catch unreachable; return @enumFromInt(ret.code); } pub fn getMessage(self: *Error, api: *const Api) []const u8 { const ret = api.call(.PJRT_Error_Message, .{ .@"error" = @ptrCast(self), }) catch unreachable; return ret.message[0..ret.message_size]; } }; pub const ClientInitError = error{LoadingFailed} || ApiError; pub const ShapeSpec = extern struct { comptime { std.debug.assert(@sizeOf(ShapeSpec) == @sizeOf(c.PJRT_ShapeSpec)); } inner: c.PJRT_ShapeSpec, pub fn init(dims_: []const usize, bt: BufferType) ShapeSpec { return .{ .inner = pjrtStruct(c.PJRT_ShapeSpec{ .dims = @ptrCast(@constCast(dims_.ptr)), .num_dims = dims.len, .buffer_type = @intFromEnum(bt), }), }; } pub fn dims(self: ShapeSpec) []usize { return self.inner.dims[0..self.inner.num_dims]; } pub fn bufferType(self: ShapeSpec) BufferType { return @enumFromInt(self.inner.buffer_type); } }; pub const Stream = opaque {}; pub const Client = opaque { const inner = InnerMixin(c.PJRT_Client).inner; pub const ProgramFormat = enum { hlo, mlir, }; pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client { // log.info("Loaded PJRT runtime plugin: {s}", .{api.Platform}); const ret = try api.call(.PJRT_Client_Create, .{ .create_options = @ptrCast(create_options.ptr), .num_options = create_options.len, .kv_get_callback = null, .kv_put_callback = null, .kv_put_user_arg = null, .kv_get_user_arg = null, }); return @ptrCast(ret.client.?); } pub fn deinit(self: *Client, api: *const Api) void { _ = api.call(.PJRT_Client_Destroy, .{ .client = self.inner(), }) catch {}; } pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 { const ret = api.call(.PJRT_Client_PlatformName, .{ .client = self.inner(), }) catch unreachable; return ret.platform_name[0..ret.platform_name_size]; } pub fn getDevices(self: *const Client, api: *const Api) []const *Device { const ret = api.call(.PJRT_Client_Devices, .{ .client = self.inner(), }) catch unreachable; return @ptrCast(ret.devices[0..ret.num_devices]); } pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *Device { const ret = api.call(.PJRT_Client_AddressableDevices, .{ .client = self.inner(), }) catch unreachable; return @ptrCast(ret.addressable_devices[0..ret.num_addressable_devices]); } pub const CompileArgs = struct { bytecode: []const u8, bytecode_format: ProgramFormat, compile_options_pb: []const u8, }; pub fn compile(self: *const Client, api: *const Api, args: CompileArgs) ApiError!*LoadedExecutable { const bytecode_format_ = @tagName(args.bytecode_format); const ret = try api.call(.PJRT_Client_Compile, .{ .program = &pjrtStruct(c.PJRT_Program{ .code = @ptrCast(@constCast(args.bytecode.ptr)), .code_size = args.bytecode.len, .format = @ptrCast(@constCast(bytecode_format_.ptr)), .format_size = bytecode_format_.len, }), .compile_options = @ptrCast(@constCast(args.compile_options_pb.ptr)), .compile_options_size = args.compile_options_pb.len, .client = self.inner(), }); return @ptrCast(ret.executable.?); } pub const BufferFromHostBufferArgs = struct { data: [*]const u8, buffer_type: BufferType, dims: []const i64, byte_strides: ?[]const i64, device: ?*const Device = null, host_buffer_semantics: HostBufferSemantics, memory: ?*const Memory = null, }; pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{ .client = self.inner(), .data = @constCast(args.data), .type = @intFromEnum(args.buffer_type), .dims = @ptrCast(@constCast(args.dims.ptr)), .num_dims = args.dims.len, .byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null, .num_byte_strides = if (args.byte_strides) |bs| bs.len else 0, .host_buffer_semantics = @intFromEnum(args.host_buffer_semantics), .device = @ptrCast(@constCast(args.device)), .memory = @ptrCast(@constCast(args.memory)), .device_layout = null, // TODO .done_with_host_buffer = null, .buffer = null, }); return .{ @ptrCast(ret.buffer.?), @ptrCast(ret.done_with_host_buffer), }; } pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{ .client = self.inner(), .serialized_executable = bytes.ptr, .serialized_executable_size = bytes.len, }); return @ptrCast(ret.loaded_executable.?); } pub const CreateViewOfDeviceBufferArgs = struct { data: *anyopaque, dims: []const i64, element_type: BufferType, layout: MemoryLayout, device: *const Device, on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.c) void = &struct { fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.c) void {} }.call, on_delete_callback_arg: ?*anyopaque = null, stream: ?*const Stream = null, }; pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { const layout = args.layout.toCStruct(); const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{ .client = self.inner(), .device_buffer_ptr = @ptrCast(@constCast(args.data)), .dims = args.dims.ptr, .num_dims = args.dims.len, .element_type = @intFromEnum(args.element_type), .layout = @ptrCast(@constCast(&layout)), .device = @ptrCast(@constCast(args.device)), .on_delete_callback = args.on_delete_callback, .on_delete_callback_arg = args.on_delete_callback_arg, .stream = @bitCast(@intFromPtr(args.stream)), }); return @ptrCast(ret.buffer.?); } pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory { const ret = api.call(.PJRT_Client_AddressableMemories, .{ .client = self.inner(), }) catch unreachable; if (ret.addressable_memories) |memories| { return @ptrCast(@constCast(memories[0..ret.num_addressable_memories])); } return &.{}; } pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!void { try api.call(.PJRT_Client_DmaMap, .{ .client = self.inner(), .data = @ptrCast(@constCast(data.ptr)), .size = @intCast(data.len), }); } pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) ApiError!void { try api.call(.PJRT_Client_DmaUnmap, .{ .client = self.inner(), .data = @ptrCast(@constCast(data.ptr)), }); } pub const CreateBuffersForAsyncHostToDeviceArgs = struct { shape_specs: []const ShapeSpec, device_layouts: ?[]*const MemoryLayout = null, memory: *const Memory, }; pub fn createBuffersForAsyncHostToDevice(self: *const Client, api: *const Api, args: CreateBuffersForAsyncHostToDeviceArgs) ApiError!*AsyncHostToDeviceTransferManager { const ret = try api.call(.PJRT_Client_CreateBuffersForAsyncHostToDevice, .{ .client = self.inner(), .shape_specs = @ptrCast(args.shape_specs.ptr), .num_shape_specs = args.shape_specs.len, .device_layouts = if (args.device_layouts) |layouts| @ptrCast(@constCast(layouts.ptr)) else null, .num_device_layouts = if (args.device_layouts) |layouts| @intCast(layouts.len) else 0, .memory = @ptrCast(@constCast(args.memory)), }); return @ptrCast(ret.transfer_manager.?); } pub const CreateUninitializedBufferArgs = struct { dims: []const i64, element_type: BufferType, layout: MemoryLayout, device: ?*const Device = null, memory: ?*const Memory = null, }; pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer { var layout = args.layout.toCStruct(); const ret = try api.call(.PJRT_Client_CreateUninitializedBuffer, .{ .client = self.inner(), .shape_dims = args.dims.ptr, .shape_num_dims = @intCast(args.dims.len), .shape_element_type = @intFromEnum(args.element_type), .shape_layout = @ptrCast(&layout), .device = @ptrCast(@constCast(args.device)), .memory = @ptrCast(@constCast(args.memory)), }); return @ptrCast(ret.buffer.?); } }; pub const MemoryStats = struct { // Number of bytes in use. bytes_in_use: u64, // out // The peak bytes in use. peak_bytes_in_use: u64, // out peak_bytes_in_use_is_set: bool, // out // Number of allocations. num_allocs: u64, // out num_allocs_is_set: bool, // out // The largest single allocation seen. largest_alloc_size: u64, // out largest_alloc_size_is_set: bool, // out // The upper limit of user-allocatable device memory in bytes. bytes_limit: u64, // out bytes_limit_is_set: bool, // out // Number of bytes reserved. bytes_reserved: u64, // out bytes_reserved_is_set: bool, // out // The peak number of bytes reserved. peak_bytes_reserved: u64, // out peak_bytes_reserved_is_set: bool, // out // The upper limit on the number bytes of reservable memory. bytes_reservable_limit: u64, // out bytes_reservable_limit_is_set: bool, // out // Largest free block size in bytes. largest_free_block_bytes: u64, // out largest_free_block_bytes_is_set: bool, // out // Number of bytes of memory held by the allocator. This may be higher than // bytes_in_use if the allocator holds a pool of memory (e.g. BFCAllocator). pool_bytes: u64, // out pool_bytes_is_set: bool, // out peak_pool_bytes: u64, // out peak_pool_bytes_is_set: bool, // out }; pub const Device = opaque { const inner = InnerMixin(c.PJRT_Device).inner; pub fn getDescription(self: *const Device, api: *const Api) *const DeviceDescription { const ret = api.call(.PJRT_Device_GetDescription, .{ .device = self.inner(), }) catch unreachable; return @ptrCast(ret.device_description.?); } pub fn isAddressable(self: *const Device, api: *const Api) bool { const ret = api.call(.PJRT_Device_IsAddressable, .{ .device = self.inner(), }) catch unreachable; return ret.is_addressable; } pub fn getLocalHardwareId(self: *const Device, api: *const Api) usize { const ret = api.call(.PJRT_Device_LocalHardwareId, .{ .device = self.inner(), }) catch unreachable; return @intCast(ret.local_hardware_id); } pub fn addressableMemories(self: *const Device, api: *const Api) ApiError![]const *Memory { const ret = try api.call(.PJRT_Device_AddressableMemories, .{ .device = self.inner(), }); return @ptrCast(ret.memories[0..ret.num_memories]); } pub fn memoryStats(self: *const Device, api: *const Api) ApiError!MemoryStats { const ret = try api.call(.PJRT_Device_MemoryStats, .{ .device = self.inner(), }); return .{ .bytes_in_use = @intCast(ret.bytes_in_use), .peak_bytes_in_use = @intCast(ret.peak_bytes_in_use), .peak_bytes_in_use_is_set = ret.peak_bytes_in_use_is_set, .num_allocs = @intCast(ret.num_allocs), .num_allocs_is_set = ret.num_allocs_is_set, .largest_alloc_size = @intCast(ret.largest_alloc_size), .largest_alloc_size_is_set = ret.largest_alloc_size_is_set, .bytes_limit = @intCast(ret.bytes_limit), .bytes_limit_is_set = ret.bytes_limit_is_set, .bytes_reserved = @intCast(ret.bytes_reserved), .bytes_reserved_is_set = ret.bytes_reserved_is_set, .peak_bytes_reserved = @intCast(ret.peak_bytes_reserved), .peak_bytes_reserved_is_set = ret.peak_bytes_reserved_is_set, .bytes_reservable_limit = @intCast(ret.bytes_reservable_limit), .bytes_reservable_limit_is_set = ret.bytes_reservable_limit_is_set, .largest_free_block_bytes = @intCast(ret.largest_free_block_bytes), .largest_free_block_bytes_is_set = ret.largest_free_block_bytes_is_set, .pool_bytes = @intCast(ret.pool_bytes), .pool_bytes_is_set = ret.pool_bytes_is_set, .peak_pool_bytes = @intCast(ret.peak_pool_bytes), .peak_pool_bytes_is_set = ret.peak_pool_bytes_is_set, }; } }; pub const DeviceDescription = opaque { const inner = InnerMixin(c.PJRT_DeviceDescription).inner; pub fn getId(self: *const DeviceDescription, api: *const Api) usize { const ret = api.call(.PJRT_DeviceDescription_Id, .{ .device_description = self.inner(), }) catch unreachable; return @intCast(ret.id); } pub fn getProcessIndex(self: *const DeviceDescription, api: *const Api) usize { const ret = api.call(.PJRT_DeviceDescription_ProcessIndex, .{ .device_description = self.inner(), }) catch unreachable; return @intCast(ret.process_index); } pub fn getKind(self: *const DeviceDescription, api: *const Api) []const u8 { const ret = api.call(.PJRT_DeviceDescription_Kind, .{ .device_description = self.inner(), }) catch unreachable; return ret.device_kind[0..ret.device_kind_size]; } pub fn debugString(self: *const DeviceDescription, api: *const Api) []const u8 { const ret = api.call(.PJRT_DeviceDescription_DebugString, .{ .device_description = self.inner(), }) catch unreachable; return ret.debug_string[0..ret.debug_string_size]; } pub fn toString(self: *const DeviceDescription, api: *const Api) []const u8 { const ret = api.call(.PJRT_DeviceDescription_ToString, .{ .device_description = self.inner(), }) catch unreachable; return ret.to_string[0..ret.to_string_size]; } }; pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError; pub const SerializeResult = struct { bytes: []const u8, handle: *anyopaque, deleter: *const fn (?*anyopaque) callconv(.c) void, pub fn deinit(self: *SerializeResult) void { self.deleter(self.handle); self.bytes = &.{}; self.* = undefined; } }; pub const ExecuteContext = opaque { pub fn deinit(self: *ExecuteContext, api: *const Api) void { _ = api.call(.PJRT_ExecuteContext_Destroy, .{ .context = @ptrCast(self), }) catch {}; } }; pub const Executable = opaque { const inner = InnerMixin(c.PJRT_Executable).inner; pub fn deinit(self: *Executable, api: *const Api) void { _ = api.call(.PJRT_Executable_Destroy, .{ .executable = self.inner(), }) catch unreachable; } pub fn getCostAnalysis(self: *const Executable, api: *const Api) GetCostAnalysisError![]*const NamedValue { const ret = try api.call(.PJRT_Executable_GetCostAnalysis, .{ .executable = self.inner(), }); const values: [*]*const NamedValue = @ptrCast(ret.properties); return values[0..ret.num_properties]; } pub fn serialize(self: *const Executable, api: *const Api) ApiError!SerializeResult { const ret = try api.call(.PJRT_Executable_Serialize, .{ .executable = self.inner(), }); return .{ .bytes = ret.serialized_bytes[0..ret.serialized_bytes_size], .handle = ret.serialized_executable.?, .deleter = @ptrCast(ret.serialized_executable_deleter.?), }; } pub fn getCompiledMemoryStats(self: *const Executable, api: *const Api) ApiError!CompiledMemoryStats { const ret = try api.call(.PJRT_Executable_GetCompiledMemoryStats, .{ .executable = self.inner(), }); return .{ .generated_code_size_in_bytes = @intCast(ret.generated_code_size_in_bytes), .argument_size_in_bytes = @intCast(ret.argument_size_in_bytes), .output_size_in_bytes = @intCast(ret.output_size_in_bytes), .alias_size_in_bytes = @intCast(ret.alias_size_in_bytes), .temp_size_in_bytes = @intCast(ret.temp_size_in_bytes), .host_generated_code_size_in_bytes = @intCast(ret.host_generated_code_size_in_bytes), .host_argument_size_in_bytes = @intCast(ret.host_argument_size_in_bytes), .host_output_size_in_bytes = @intCast(ret.host_output_size_in_bytes), .host_alias_size_in_bytes = @intCast(ret.host_alias_size_in_bytes), .host_temp_size_in_bytes = @intCast(ret.host_temp_size_in_bytes), }; } }; pub const CompiledMemoryStats = struct { // Mirrors xla::CompiledMemoryStats. // Device default memory (e.g., HBM for GPU/TPU) usage stats. generated_code_size_in_bytes: u64, argument_size_in_bytes: u64, output_size_in_bytes: u64, // much: How argument is reused for output. alias_size_in_bytes: u64, temp_size_in_bytes: u64, // memory: Host usage stats. host_generated_code_size_in_bytes: u64, host_argument_size_in_bytes: u64, host_output_size_in_bytes: u64, host_alias_size_in_bytes: u64, host_temp_size_in_bytes: u64, }; pub const LoadedExecutable = opaque { const inner = InnerMixin(c.PJRT_LoadedExecutable).inner; pub fn deinit(self: *LoadedExecutable, api: *const Api) void { _ = api.call(.PJRT_LoadedExecutable_Destroy, .{ .executable = self.inner(), }) catch {}; self.* = undefined; } pub fn delete(self: *LoadedExecutable, api: *const Api) void { _ = api.call(.PJRT_LoadedExecutable_Delete, .{ .executable = self.inner(), }) catch unreachable; } pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool { const ret = api.call(.PJRT_LoadedExecutable_IsDeleted, .{ .executable = self.inner(), }) catch unreachable; return ret.is_deleted; } pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []Device { const ret = api.call(.PJRT_LoadedExecutable_AddressableDevices, .{ .executable = self.inner(), }) catch unreachable; return @ptrCast(ret.addressable_devices); } pub const ExecuteArgs = struct { num_args: usize, arguments: []const [*]const *const Buffer, results: []const [*]*Buffer, events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, context: ?*ExecuteContext, }; pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void { var options = pjrtStruct(c.PJRT_ExecuteOptions{ .send_callbacks = null, .recv_callbacks = null, .num_send_ops = 0, .num_recv_ops = 0, .launch_id = 0, .non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr), .num_non_donatable_input_indices = args.non_donatable_input_indices.len, .context = @ptrCast(args.context), }); _ = try api.call(.PJRT_LoadedExecutable_Execute, .{ .executable = self.inner(), .options = @ptrCast(&options), .argument_lists = @ptrCast(args.arguments.ptr), .num_devices = @intCast(args.arguments.len), .num_args = args.num_args, .output_lists = @ptrCast(args.results.ptr), .device_complete_events = @ptrCast(args.events.ptr), .execute_device = null, }); } pub fn getExecutable(self: *const LoadedExecutable, api: *const Api) ApiError!*Executable { const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{ .loaded_executable = self.inner(), }); return @ptrCast(ret.executable.?); } }; pub const BufferType = enum(c.PJRT_Buffer_Type) { invalid = c.PJRT_Buffer_Type_INVALID, bool = c.PJRT_Buffer_Type_PRED, i4 = c.PJRT_Buffer_Type_S4, i8 = c.PJRT_Buffer_Type_S8, i16 = c.PJRT_Buffer_Type_S16, i32 = c.PJRT_Buffer_Type_S32, i64 = c.PJRT_Buffer_Type_S64, u4 = c.PJRT_Buffer_Type_U4, u8 = c.PJRT_Buffer_Type_U8, u16 = c.PJRT_Buffer_Type_U16, u32 = c.PJRT_Buffer_Type_U32, u64 = c.PJRT_Buffer_Type_U64, f16 = c.PJRT_Buffer_Type_F16, f32 = c.PJRT_Buffer_Type_F32, f64 = c.PJRT_Buffer_Type_F64, bf16 = c.PJRT_Buffer_Type_BF16, c64 = c.PJRT_Buffer_Type_C64, c128 = c.PJRT_Buffer_Type_C128, f8e5m2 = c.PJRT_Buffer_Type_F8E5M2, f8e4m3fn = c.PJRT_Buffer_Type_F8E4M3FN, f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ, f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ, }; pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) { tiled = c.PJRT_Buffer_MemoryLayout_Type_Tiled, strides = c.PJRT_Buffer_MemoryLayout_Type_Strides, }; pub const MemoryLayout = union(MemoryLayoutType) { pub const Type = MemoryLayoutType; pub const Tiled = struct { minor_to_major: []const i64, tile_dims: []const i64, tile_dims_sizes: []const usize, }; pub const Strides = struct { byte_strides: []const i64, }; tiled: Tiled, strides: Strides, fn toCStruct(self: MemoryLayout) c.PJRT_Buffer_MemoryLayout { return pjrtStruct(switch (self) { .tiled => |v| c.PJRT_Buffer_MemoryLayout{ .type = c.PJRT_Buffer_MemoryLayout_Type_Tiled, .unnamed_0 = .{ .tiled = c.PJRT_Buffer_MemoryLayout_Tiled{ .minor_to_major = v.minor_to_major.ptr, .minor_to_major_size = v.minor_to_major.len, .tile_dims = v.tile_dims.ptr, .tile_dim_sizes = v.tile_dims_sizes.ptr, .num_tiles = v.tile_dims_sizes.len, }, }, }, .strides => |v| c.PJRT_Buffer_MemoryLayout{ .type = c.PJRT_Buffer_MemoryLayout_Type_Strides, .unnamed_0 = .{ .strides = c.PJRT_Buffer_MemoryLayout_Strides{ .byte_strides = v.byte_strides.ptr, .num_byte_strides = v.byte_strides.len, }, }, }, }); } }; pub const HostBufferSemantics = enum(c.PJRT_HostBufferSemantics) { ImmutableOnlyDuringCall = c.PJRT_HostBufferSemantics_kImmutableOnlyDuringCall, ImmutableUntilTransferCompletes = c.PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes, ImmutableZeroCopy = c.PJRT_HostBufferSemantics_kImmutableZeroCopy, MutableZeroCopy = c.PJRT_HostBufferSemantics_kMutableZeroCopy, }; pub const Buffer = opaque { const inner = InnerMixin(c.PJRT_Buffer).inner; pub fn deinit(self: *Buffer, api: *const Api) void { _ = api.call(.PJRT_Buffer_Destroy, .{ .buffer = self.inner(), }) catch unreachable; } pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device { const ret = try api.call(.PJRT_Buffer_Device, .{ .buffer = self.inner(), }); return @ptrCast(ret.device.?); } pub fn delete(self: *Buffer, api: *const Api) void { _ = api.call(.PJRT_Buffer_Delete, .{ .buffer = self.inner(), }) catch unreachable; } pub fn isDeleted(self: *const Buffer, api: *const Api) bool { const ret = api.call(.PJRT_Buffer_IsDeleted, .{ .buffer = self.inner(), }) catch unreachable; return ret.is_deleted; } pub fn isOnCpu(self: *const Buffer, api: *const Api) bool { const ret = api.call(.PJRT_Buffer_IsOnCpu, .{ .buffer = self.inner(), }) catch unreachable; return ret.is_on_cpu; } pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event { const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{ .src = self.inner(), .dst = @ptrCast(dst.ptr), .dst_size = dst.len, }); return @ptrCast(ret.event); } pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { const ret = api.call(.PJRT_Buffer_ElementType, .{ .buffer = self.inner(), }) catch unreachable; return @enumFromInt(ret.type); } pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 { const ret = api.call(.PJRT_Buffer_Dimensions, .{ .buffer = self.inner(), }) catch unreachable; if (ret.num_dims == 0) { return &.{}; } return ret.dims[0..ret.num_dims]; } pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 { const ret = try api.call(.PJRT_Buffer_UnpaddedDimensions, .{ .buffer = self.inner(), }); return ret.dims[0..ret.num_dims]; } pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize { const ret = try api.call(.PJRT_Buffer_OnDeviceSizeInBytes, .{ .buffer = self.inner(), }); return ret.on_device_size_in_bytes; } pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer { const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{ .buffer = self.inner(), .dst_device = device.inner, }); return @ptrCast(ret.dst_buffer.?); } pub fn getReadyEvent(self: *const Buffer, api: *const Api) *Event { const ret = api.call(.PJRT_Buffer_ReadyEvent, .{ .buffer = self.inner(), }) catch unreachable; return @ptrCast(ret.event.?); } pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque { const ret = try api.call(.PJRT_Buffer_OpaqueDeviceMemoryDataPointer, .{ .buffer = self.inner(), }); return ret.device_memory_ptr.?; } pub fn copyRawToHost(self: *const Buffer, api: *const Api, dst: []u8, offset: i64) ApiError!?*Event { const ret = try api.call(.PJRT_Buffer_CopyRawToHost, .{ .buffer = self.inner(), .dst = @ptrCast(dst.ptr), .offset = offset, .transfer_size = @intCast(dst.len), }); return @ptrCast(ret.event); } pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!*Buffer { const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{ .buffer = self.inner(), .dst_memory = @ptrCast(@constCast(dst_memory)), }); return @ptrCast(ret.dst_buffer); } pub fn memory(self: *const Buffer, api: *const Api) *const Memory { const ret = api.call(.PJRT_Buffer_Memory, .{ .buffer = self.inner(), }) catch unreachable; return @ptrCast(ret.memory); } pub fn increaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void { _ = try api.call(.PJRT_Buffer_IncreaseExternalReferenceCount, .{ .buffer = self.inner(), }); } pub fn decreaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void { _ = try api.call(.PJRT_Buffer_DecreaseExternalReferenceCount, .{ .buffer = self.inner(), }); } }; pub const Event = opaque { const inner = InnerMixin(c.PJRT_Event).inner; pub fn deinit(self: *Event, api: *const Api) void { _ = api.call(.PJRT_Event_Destroy, .{ .event = self.inner(), }) catch unreachable; } pub fn isReady(self: *const Event, api: *const Api) bool { const ret = api.call(.PJRT_Event_IsReady, .{ .event = self.inner(), }) catch unreachable; return ret.is_ready; } pub fn getEventError(self: *const Event, api: *const Api) ?*Error { var args: Api.CallFnArgType(.PJRT_Event_Error) = .{ .event = self.inner() }; args = pjrtStruct(args); const result: ?*c.PJRT_Error = api.inner.PJRT_Event_Error.?(&args); return @ptrCast(result); } pub fn await(self: *const Event, api: *const Api) ApiError!void { _ = try api.call(.PJRT_Event_Await, .{ .event = self.inner(), }); } pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.c) void, user_arg: ?*anyopaque) ApiError!void { _ = try api.call(.PJRT_Event_OnReady, .{ .event = self.inner(), .callback = @ptrCast(func), .user_arg = user_arg, }); } }; pub const Memory = opaque { pub const Kind = enum { device, pinned_host, unpinned_host, }; const inner = InnerMixin(c.PJRT_Memory).inner; pub fn id(self: *const Memory, api: *const Api) usize { const ret = api.call(.PJRT_Memory_Id, .{ .memory = self.inner() }) catch unreachable; return @intCast(ret.id); } pub fn kind(self: *const Memory, api: *const Api) Kind { const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable; const kind_ = ret.kind orelse unreachable; return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable; } pub fn kindId(self: *const Memory, api: *const Api) u32 { const ret = api.call(.PJRT_Memory_Kind_Id, .{ .memory = self.inner(), }) catch unreachable; return @bitCast(ret.kind_id); } pub fn debugString(self: *const Memory, api: *const Api) []const u8 { const ret = api.call(.PJRT_Memory_DebugString, .{ .memory = self.inner(), }) catch unreachable; if (ret.debug_string) |debug_string| { return debug_string[0..ret.debug_string_size]; } return &.{}; } pub fn toString(self: *const Memory, api: *const Api) []const u8 { const ret = api.call(.PJRT_Memory_ToString, .{ .memory = self.inner(), }) catch unreachable; if (ret.to_string) |to_string| { return to_string[0..ret.to_string_size]; } return &.{}; } pub fn addressableByDevices(self: *const Memory, api: *const Api) []*Device { const ret = api.call(.PJRT_Memory_AddressableByDevices, .{ .event = self.inner(), }) catch unreachable; if (ret.devices) |devices| { return devices[0..ret.num_devices]; } return &.{}; } }; pub const AsyncHostToDeviceTransferManager = opaque { const inner = InnerMixin(c.PJRT_AsyncHostToDeviceTransferManager).inner; pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void { _ = api.call(.PJRT_AsyncHostToDeviceTransferManager_Destroy, .{ .transfer_manager = self.inner(), }) catch unreachable; } pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event { const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_TransferData, .{ .transfer_manager = self.inner(), .buffer_index = @intCast(buffer_index), .data = data.ptr, .offset = offset, .transfer_size = @intCast(data.len), .is_last_transfer = is_last_transfer, }); return @ptrCast(ret.done_with_h2d_transfer.?); } pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer { const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer, .{ .transfer_manager = self.inner(), .buffer_index = @intCast(buffer_index), }); return @ptrCast(ret.buffer_out.?); } pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!*Device { const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_Device, .{ .transfer_manager = self.inner(), }); return @ptrCast(ret.device_out.?); } pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!usize { const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferCount, .{ .transfer_manager = self.inner(), }); return ret.buffer_count; } pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!usize { const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferSize, .{ .transfer_manager = self.inner(), .buffer_index = @intCast(buffer_index), }); return ret.buffer_size; } pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: c.PJRT_Error_Code, error_message: []const u8) ApiError!void { _ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_SetBufferError, .{ .transfer_manager = self.inner(), .buffer_index = @intCast(buffer_index), .error_code = error_code, .error_message = error_message.ptr, .error_message_size = error_message.len, }); } pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) ApiError!void { _ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_AddMetadata, .{ .transfer_manager = self.inner(), .transfer_metadata = transfer_metadata.ptr, .num_metadata = transfer_metadata.len, }); } }; pub const NamedValue = extern struct { comptime { std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue)); } inner: c.PJRT_NamedValue, pub const Kind = enum(c.PJRT_NamedValue_Type) { string = c.PJRT_NamedValue_kString, int64 = c.PJRT_NamedValue_kInt64, int64list = c.PJRT_NamedValue_kInt64List, float = c.PJRT_NamedValue_kFloat, bool = c.PJRT_NamedValue_kBool, }; pub fn kind(self: NamedValue) Kind { return @enumFromInt(self.inner.type); } pub fn name(self: NamedValue) []const u8 { return self.inner.name[0..self.inner.name_size]; } pub fn from(name_: []const u8, value: anytype) NamedValue { return switch (@TypeOf(value)) { []u8, []const u8 => fromString(name_, value), i64 => fromInt64(name_, value), []i64, []const i64 => fromInt64List(name_, value), f32 => fromFloat(name_, value), bool => fromBool(name_, value), else => fromString(name_, @tagName(value)), }; } pub fn fromString(name_: []const u8, value: []const u8) NamedValue { return .{ .inner = pjrtStruct(c.PJRT_NamedValue{ .name = @ptrCast(@constCast(name_.ptr)), .name_size = name_.len, .type = c.PJRT_NamedValue_kString, .unnamed_0 = .{ .string_value = @ptrCast(@constCast(value.ptr)) }, .value_size = value.len, }) }; } pub fn fromInt64(name_: []const u8, value: i64) NamedValue { return .{ .inner = pjrtStruct(c.PJRT_NamedValue{ .name = @ptrCast(@constCast(name_.ptr)), .name_size = name_.len, .type = c.PJRT_NamedValue_kInt64, .unnamed_0 = .{ .int64_value = value }, .value_size = 1, }) }; } pub fn fromInt64List(name_: []const u8, value: []const i64) NamedValue { return .{ .inner = pjrtStruct(c.PJRT_NamedValue{ .name = @ptrCast(@constCast(name_.ptr)), .name_size = name_.len, .type = c.PJRT_NamedValue_kInt64List, .unnamed_0 = .{ .int64_array_value = @ptrCast(@constCast(value.ptr)) }, .value_size = value.len, }) }; } pub fn fromFloat(name_: []const u8, value: f32) NamedValue { return .{ .inner = pjrtStruct(c.PJRT_NamedValue{ .name = @ptrCast(@constCast(name_.ptr)), .name_size = name_.len, .type = c.PJRT_NamedValue_kFloat, .unnamed_0 = .{ .float_value = value }, .value_size = 1, }) }; } pub fn fromBool(name_: []const u8, value: bool) NamedValue { return .{ .inner = pjrtStruct(c.PJRT_NamedValue{ .name = @ptrCast(@constCast(name_.ptr)), .name_size = name_.len, .type = c.PJRT_NamedValue_kBool, .unnamed_0 = .{ .bool_value = value }, .value_size = 1, }) }; } pub fn format(self: NamedValue, writer: *std.Io.Writer) !void { try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] }); const u = self.inner.unnamed_0; switch (self.kind()) { .string => try writer.print(" .string = {s} ", .{u.string_value[0..self.inner.value_size]}), .int64 => try writer.print(" .int64 = {d} ", .{u.int64_value}), .int64list => try writer.print(" .int64list = {d} ", .{u.int64_array_value[0..self.inner.value_size]}), .float => try writer.print(" .float = {d} ", .{u.float_value}), .bool => try writer.print(" .bool = {} ", .{u.bool_value}), } try writer.writeAll("}"); } }; pub const Ffi = extern struct { inner: *const c.PJRT_FFI, pub const UserData = extern struct { type_id: i64, user_data: *anyopaque, fn toCStruct(self: UserData) c.PJRT_FFI_UserData { return .{ .type_id = self.type_id, .data = self.user_data, }; } }; // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 pub fn register( self: *const Ffi, api: *const Api, target_name: []const u8, platform_name: []const u8, func: *const ffi.Handler, traits: ffi.HandlerTraits, ) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ .target_name = target_name.ptr, .target_name_size = target_name.len, .handler = @ptrCast(@constCast(func)), .platform_name = platform_name.ptr, .platform_name_size = platform_name.len, .traits = @bitCast(traits), }); const result = self.inner.register_handler.?(&ret); if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); log.err("registerFfi error: {s}", .{pjrt_error.getMessage(api)}); return pjrt_error.getCode(api).toApiError(); } } pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId { var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{ .type_name = type_name.ptr, .type_name_size = type_name.len, .type_id = 0, // let the plugin assign a unique type ID }); const result = self.inner.type_id_register.?(&ret); if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); return pjrt_error.getCode(api).toApiError(); } return .{ .type_id = ret.type_id }; } pub fn addUserData(self: *const Ffi, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void { var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{ .context = @ptrCast(context), .user_data = user_data.toCStruct(), }); const result = self.inner.user_data_add.?(&ret); if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); log.err("addUserData error: {s}", .{pjrt_error.getMessage(api)}); return pjrt_error.getCode(api).toApiError(); } } };