From 13eff4e661fb26d028b67f473c5f22d80c1fb93d Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 11 Apr 2024 15:43:24 +0000 Subject: [PATCH] pjrt,zml: add memory bindings This preliminary PR binds PJRT memory endpoints and adds them to `zml.Buffer`. A follow up PR will properly integrate it inside `zml.Buffer` --- mlir/dialects/stablehlo.zig | 22 +++ pjrt/pjrt.zig | 309 +++++++++++++++++++++++++++++++++--- zml/buffer.zig | 23 +++ zml/pjrtx.zig | 132 +++++++++++++-- 4 files changed, 448 insertions(+), 38 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index a5c59cb..d7d8673 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -780,6 +780,28 @@ pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: ml }); } +pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { + const frontend_attributes = mlir.DictionaryAttribute.init( + ctx, + &.{ + mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()), + }, + ).asAttr(); + return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ + .operands = inputs, + .results = res_types, + .attributes = &.{ + .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() }, + .{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() }, + .{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() }, + .{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() }, + .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() }, + .{ "mhlo.frontend_attributes", frontend_attributes }, + }, + .location = location, + }); +} + pub const DotDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 9cfa5fb..f48b636 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -142,15 +142,21 @@ pub const Api = struct { }; } - pub fn stablehloCurrentVersion(self: *const Api, buf: []u8) ?[]u8 { + pub fn stablehloCurrentVersion(self: *const Api) ?[]const u8 { + const state = struct { + var buf: [32]u8 = undefined; + var str: ?[:0]const u8 = null; + }; + if (state.str) |str| { + return str; + } if (self.getPluginAttribute("stablehlo_current_version")) |v| { stdx.debug.assert(v.kind() == .int64list, "fetched attribute \"stablehlo_current_version\" from the plugin with type `{}`, expected `.int64list`", .{v.kind()}); stdx.debug.assert(v.inner.value_size == 3, "expect version format to have 3 elements representing `major.minor.patch` format, got {} elements", .{v.inner.value_size}); const value = v.inner.unnamed_0.int64_array_value[0..v.inner.value_size]; - return std.fmt.bufPrint(buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable; + state.str = std.fmt.bufPrintZ(&state.buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable; } - - return null; + return state.str; } pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { @@ -203,22 +209,22 @@ pub const ErrorCode = enum(c.PJRT_Error_Code) { pub fn toApiError(code: ErrorCode) ApiError { return switch (code) { - .cancelled => error.Cancelled, - .unknown => error.Unknown, - .invalid_argument => error.InvalidArgument, - .deadline_exceeded => error.DeadlineExceeded, - .not_found => error.NotFound, - .already_exists => error.AlreadyExists, - .permission_denied => error.PermissionDenied, - .resource_exhausted => error.ResourceExhausted, - .failed_precondition => error.FailedPrecondition, - .aborted => error.Aborted, - .out_of_range => error.OutOfRange, - .unimplemented => error.Unimplemented, - .internal => error.Internal, - .unavailable => error.Unavailable, - .data_loss => error.DataLoss, - .unauthenticated => error.Unauthenticated, + .cancelled => ApiError.Cancelled, + .unknown => ApiError.Unknown, + .invalid_argument => ApiError.InvalidArgument, + .deadline_exceeded => ApiError.DeadlineExceeded, + .not_found => ApiError.NotFound, + .already_exists => ApiError.AlreadyExists, + .permission_denied => ApiError.PermissionDenied, + .resource_exhausted => ApiError.ResourceExhausted, + .failed_precondition => ApiError.FailedPrecondition, + .aborted => ApiError.Aborted, + .out_of_range => ApiError.OutOfRange, + .unimplemented => ApiError.Unimplemented, + .internal => ApiError.Internal, + .unavailable => ApiError.Unavailable, + .data_loss => ApiError.DataLoss, + .unauthenticated => ApiError.Unauthenticated, }; } }; @@ -247,6 +253,32 @@ pub const Error = opaque { pub const ClientInitError = error{LoadingFailed} || ApiError; +pub const ShapeSpec = extern struct { + comptime { + std.debug.assert(@sizeOf(ShapeSpec) == @sizeOf(c.PJRT_ShapeSpec)); + } + + inner: c.PJRT_ShapeSpec, + + pub fn init(dims_: []const usize, bt: BufferType) ShapeSpec { + return .{ + .inner = pjrtStruct(c.PJRT_ShapeSpec{ + .dims = @ptrCast(@constCast(dims_.ptr)), + .num_dims = dims.len, + .buffer_type = @intFromEnum(bt), + }), + }; + } + + pub fn dims(self: ShapeSpec) []usize { + return self.inner.dims[0..self.inner.num_dims]; + } + + pub fn bufferType(self: ShapeSpec) BufferType { + return @enumFromInt(self.inner.buffer_type); + } +}; + pub const Client = opaque { const inner = InnerMixin(c.PJRT_Client).inner; @@ -322,8 +354,9 @@ pub const Client = opaque { buffer_type: BufferType, dims: []const i64, byte_strides: ?[]const i64, - device: *const Device, + device: ?*const Device = null, host_buffer_semantics: HostBufferSemantics, + memory: ?*const Memory = null, }; pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { @@ -337,7 +370,7 @@ pub const Client = opaque { .num_byte_strides = if (args.byte_strides) |bs| bs.len else 0, .host_buffer_semantics = @intFromEnum(args.host_buffer_semantics), .device = @ptrCast(@constCast(args.device)), - .memory = null, // TODO + .memory = @ptrCast(@constCast(args.memory)), .device_layout = null, // TODO .done_with_host_buffer = null, .buffer = null, @@ -377,7 +410,9 @@ pub const Client = opaque { element_type: BufferType, layout: MemoryLayout, device: *const Device, - on_delete_callback: ?*const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = null, + on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = &struct { + fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} + }.call, on_delete_callback_arg: ?*anyopaque = null, stream: ?isize = null, }; @@ -398,6 +433,50 @@ pub const Client = opaque { }); return @ptrCast(ret.buffer.?); } + + pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory { + const ret = api.call(.PJRT_Client_AddressableMemories, .{ + .client = self.inner(), + }) catch unreachable; + if (ret.addressable_memories) |memories| { + return @constCast(@ptrCast(memories[0..ret.num_addressable_memories])); + } + return &.{}; + } + + pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer { + const ret = try api.call(.PJRT_Client_DMA_Map, .{ + .client = self.inner(), + .data = @ptrCast(@constCast(data.ptr)), + .size = @intCast(data.len), + }); + return @ptrCast(ret.buffer.?); + } + + pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void { + _ = api.call(.PJRT_Client_DMA_Unmap, .{ + .client = self.inner(), + .data = @ptrCast(@constCast(data.ptr)), + }) catch unreachable; + } + + pub const CreateBuffersForAsyncHostToDeviceArgs = struct { + shape_specs: []const ShapeSpec, + device_layouts: ?[]*const MemoryLayout = null, + memory: *const Memory, + }; + + pub fn createBuffersForAsyncHostToDevice(self: *const Client, api: *const Api, args: CreateBuffersForAsyncHostToDeviceArgs) ApiError!*AsyncHostToDeviceTransferManager { + const ret = try api.call(.PJRT_Client_CreateBuffersForAsyncHostToDevice, .{ + .client = self.inner(), + .shape_specs = @ptrCast(args.shape_specs.ptr), + .num_shape_specs = args.shape_specs.len, + .device_layouts = if (args.device_layouts) |layouts| @ptrCast(@constCast(layouts.ptr)) else null, + .num_device_layouts = if (args.device_layouts) |layouts| @intCast(layouts.len) else 0, + .memory = @ptrCast(@constCast(args.memory)), + }); + return @ptrCast(ret.transfer_manager.?); + } }; pub const Device = opaque { @@ -753,6 +832,43 @@ pub const Buffer = opaque { }); return ret.device_memory_ptr.?; } + + pub fn copyRawToHost(self: *const Buffer, api: *const Api, dst: []u8, offset: i64) ApiError!?*Event { + const ret = try api.call(.PJRT_Buffer_CopyRawToHost, .{ + .buffer = self.inner(), + .dst = @ptrCast(dst.ptr), + .offset = offset, + .transfer_size = @intCast(dst.len), + }); + return @ptrCast(ret.event); + } + + pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer { + const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{ + .buffer = self.inner(), + .dst_memory = @ptrCast(@constCast(dst_memory)), + }); + return @ptrCast(ret.dst_buffer); + } + + pub fn memory(self: *const Buffer, api: *const Api) *const Memory { + const ret = api.call(.PJRT_Buffer_Memory, .{ + .buffer = self.inner(), + }) catch unreachable; + return @ptrCast(ret.memory); + } + + pub fn increaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void { + _ = try api.call(.PJRT_Buffer_IncreaseExternalReferenceCount, .{ + .buffer = self.inner(), + }); + } + + pub fn decreaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void { + _ = try api.call(.PJRT_Buffer_DecreaseExternalReferenceCount, .{ + .buffer = self.inner(), + }); + } }; pub const Event = opaque { @@ -793,6 +909,153 @@ pub const Event = opaque { } }; +pub const Memory = opaque { + pub const Kind = enum { + device, + pinned_host, + unpinned_host, + }; + + const inner = InnerMixin(c.PJRT_Memory).inner; + + pub fn id(self: *const Memory, api: *const Api) usize { + const ret = api.call(.PJRT_Memory_Id, .{ + .memory = self.inner(), + }) catch unreachable; + return @intCast(ret.id); + } + + pub fn kind(self: *const Memory, api: *const Api) Kind { + const ret = api.call(.PJRT_Memory_Kind, .{ + .memory = self.inner(), + }) catch unreachable; + const kind_ = ret.kind orelse unreachable; + return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable; + } + + pub fn kindId(self: *const Memory, api: *const Api) u32 { + const ret = api.call(.PJRT_Memory_Kind_Id, .{ + .memory = self.inner(), + }) catch unreachable; + return @bitCast(ret.kind_id); + } + + pub fn debugString(self: *const Memory, api: *const Api) []const u8 { + const ret = api.call(.PJRT_Memory_DebugString, .{ + .memory = self.inner(), + }) catch unreachable; + if (ret.debug_string) |debug_string| { + return debug_string[0..ret.debug_string_size]; + } + return &.{}; + } + + pub fn toString(self: *const Memory, api: *const Api) []const u8 { + const ret = api.call(.PJRT_Memory_ToString, .{ + .memory = self.inner(), + }) catch unreachable; + if (ret.to_string) |to_string| { + return to_string[0..ret.to_string_size]; + } + return &.{}; + } + + pub fn addressableByDevices(self: *const Memory, api: *const Api) []*Device { + const ret = api.call(.PJRT_Memory_AddressableByDevices, .{ + .event = self.inner(), + }) catch unreachable; + if (ret.devices) |devices| { + return devices[0..ret.num_devices]; + } + return &.{}; + } +}; + +pub const AsyncHostToDeviceTransferManager = opaque { + const inner = InnerMixin(c.PJRT_AsyncHostToDeviceTransferManager).inner; + + pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void { + _ = api.call(.PJRT_AsyncHostToDeviceTransferManager_Destroy, .{ + .transfer_manager = self.inner(), + }) catch unreachable; + } + + pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event { + const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_TransferData, .{ + .transfer_manager = self.inner(), + .buffer_index = @intCast(buffer_index), + .data = data.ptr, + .offset = offset, + .transfer_size = @intCast(data.len), + .is_last_transfer = is_last_transfer, + }); + return @ptrCast(ret.done_with_h2d_transfer.?); + } + + pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer { + const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer, .{ + .transfer_manager = self.inner(), + .buffer_index = @intCast(buffer_index), + }); + return @ptrCast(ret.buffer_out.?); + } + + pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!*Device { + const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_Device, .{ + .transfer_manager = self.inner(), + }); + return @ptrCast(ret.device_out.?); + } + + pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!usize { + const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferCount, .{ + .transfer_manager = self.inner(), + }); + return ret.buffer_count; + } + + pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!usize { + const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferSize, .{ + .transfer_manager = self.inner(), + .buffer_index = @intCast(buffer_index), + }); + return ret.buffer_size; + } + + pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: c.PJRT_Error_Code, error_message: []const u8) ApiError!void { + _ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_SetBufferError, .{ + .transfer_manager = self.inner(), + .buffer_index = @intCast(buffer_index), + .error_code = error_code, + .error_message = error_message.ptr, + .error_message_size = error_message.len, + }); + } + + pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) ApiError!void { + _ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_AddMetadata, .{ + .transfer_manager = self.inner(), + .transfer_metadata = transfer_metadata.ptr, + .num_metadata = transfer_metadata.len, + }); + } +}; + +pub const ExecutionContext = opaque { + const inner = InnerMixin(c.PJRT_ExecutionContext).inner; + + pub fn init(api: *const Api) ApiError!*ExecutionContext { + const ret = try api.call(.PJRT_ExecutionContext_Create, .{}); + return @ptrCast(ret.context.?); + } + + pub fn deinit(self: *ExecutionContext, api: *const Api) void { + _ = api.call(.PJRT_ExecutionContext_Destroy, .{ + .context = self.inner(), + }) catch unreachable; + } +}; + pub const NamedValue = extern struct { comptime { std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue)); diff --git a/zml/buffer.zig b/zml/buffer.zig index a9fba67..2369138 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -27,6 +27,29 @@ const log = std.log.scoped(.zml); /// * loading weights from disk directly to the `device zml.aio.loadBuffers` /// * can be created by calling `HostBuffer.toDevice(platform)`. pub const Buffer = struct { + pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).Enum.tag_type) { + host = @intFromEnum(pjrt.Memory.Kind.unpinned_host), + host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host), + device = @intFromEnum(pjrt.Memory.Kind.device), + }; + + pub const Shard = struct { + api: *const pjrt.Api, + buffer: *pjrt.Buffer, + ready_event: ?*pjrt.Event = null, + ready: bool = false, + + pub fn awaitt(self: *Shard) !void { + if (self.ready) { + return; + } + if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| { + try ev.awaitt(self.api); + } + self.ready = true; + } + }; + _shape: Shape, _api: *const pjrt.Api, _shards: Shards, diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 7408e07..ae17f51 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -17,7 +17,6 @@ const log = std.log.scoped(.zml); pub const Profiler = pjrt.Profiler; pub const ApiError = pjrt.ApiError; pub const ErrorCode = pjrt.ErrorCode; -pub const Buffer = pjrt.Buffer; pub const BufferType = pjrt.BufferType; pub const Device = pjrt.Device; pub const DeviceDescription = pjrt.DeviceDescription; @@ -30,6 +29,7 @@ pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; pub const SerializeResult = pjrt.SerializeResult; pub const Executable = pjrt.Executable; pub const ExecuteError = ApiError; +pub const Memory = pjrt.Memory; fn InnerMixin(comptime innerT: type) type { return struct { @@ -69,7 +69,7 @@ pub const Client = opaque { const event: *Event = @ptrCast(event__); try event.await_(api); } - return buffer; + return @ptrCast(buffer); } pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { @@ -78,12 +78,7 @@ pub const Client = opaque { pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs; pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer { - var args_ = args; - args_.on_delete_callback = args_.on_delete_callback orelse &(struct { - fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} - }.call); - const buf = try self.inner().createViewOfDeviceBuffer(api, args_); - return @ptrCast(buf); + return @ptrCast(try self.inner().createViewOfDeviceBuffer(api, args)); } fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { @@ -97,13 +92,11 @@ pub const Client = opaque { var serialized_buffer = std.ArrayList(u8).init(allocator); defer serialized_buffer.deinit(); - // spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399 - var requested_stablehlo_version_buf: [32]u8 = undefined; - const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf); - const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: { - break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion()); - } else blk: { - break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12); + const stablehlo_version = blk: { + if (api.stablehloCurrentVersion()) |requested_version| { + break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion()); + } + break :blk dialects.stablehlo.getMinimumVersion(); }; dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| { @@ -128,6 +121,79 @@ pub const Client = opaque { pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler { return self.inner().getProfiler(api, options); } + + pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory { + return self.inner().addressableMemories(api); + } + + pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory { + for (self.addressableMemories(api)) |mem| { + if (mem.kind(api) == kind) { + return mem; + } + } + return null; + } +}; + +pub const Buffer = opaque { + pub const inner = InnerMixin(pjrt.Buffer).inner; + + pub fn deinit(self: *Buffer, api: *const Api) void { + self.inner().deinit(api); + } + + pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device { + return try self.inner().getDevice(api); + } + + pub fn delete(self: *Buffer, api: *const Api) void { + self.inner().delete(api); + } + + pub fn isDeleted(self: *const Buffer, api: *const Api) bool { + return self.inner().isDeleted(api); + } + + pub fn isOnCpu(self: *const Buffer, api: *const Api) bool { + return self.inner().isOnCpu(api); + } + + pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event { + return @ptrCast(try self.inner().toHostBuffer(api, dst)); + } + + pub fn getElementType(self: *const Buffer, api: *const Api) BufferType { + return self.inner().getElementType(api); + } + + pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 { + return self.inner().getDimensions(api); + } + + pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 { + return try self.inner().getUnpaddedDimensions(api); + } + + pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize { + return try self.inner().getOnDeviceSizeInBytes(api); + } + + pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer { + return @ptrCast(self.inner().copyToDevice(api, device)); + } + + pub fn copyToMemory(self: *const Buffer, api: *const Api, memory: *const Memory) ApiError!*Buffer { + return @ptrCast(self.inner().copyToMemory(api, memory)); + } + + pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event { + return @ptrCast(self.inner().getReadyEvent(api)); + } + + pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque { + return try self.inner().getOpaqueDeviceMemoryDataPointer(api); + } }; pub const Event = opaque { @@ -214,3 +280,39 @@ pub const LoadedExecutable = opaque { return try self.inner().getExecutable(api); } }; + +pub const AsyncHostToDeviceTransferManager = opaque { + const inner = InnerMixin(pjrt.AsyncHostToDeviceTransferManager).inner; + + pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void { + self.inner().deinit(api); + } + + pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event { + return @ptrCast(try self.inner().transferData(api, buffer_index, data, offset, is_last_transfer)); + } + + pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer { + return @ptrCast(try self.inner().retrieveBuffer(api, buffer_index)); + } + + pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) *Device { + return @ptrCast(self.inner().device(api)); + } + + pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) usize { + return self.inner().bufferCount(api); + } + + pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) usize { + return self.inner().bufferSize(api, buffer_index); + } + + pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: ErrorCode, error_message: []const u8) void { + self.inner().setBufferError(api, buffer_index, error_code, error_message); + } + + pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) void { + return self.inner().addMetadata(api, transfer_metadata); + } +};