pjrt,zml: add memory bindings
This preliminary PR binds PJRT memory endpoints and adds them to `zml.Buffer`. A follow up PR will properly integrate it inside `zml.Buffer`
This commit is contained in:
parent
190c6978d2
commit
13eff4e661
@ -780,6 +780,28 @@ pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: ml
|
||||
});
|
||||
}
|
||||
|
||||
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
const frontend_attributes = mlir.DictionaryAttribute.init(
|
||||
ctx,
|
||||
&.{
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
|
||||
},
|
||||
).asAttr();
|
||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||
.operands = inputs,
|
||||
.results = res_types,
|
||||
.attributes = &.{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() },
|
||||
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "mhlo.frontend_attributes", frontend_attributes },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
pub const DotDimensionNumbersAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
|
||||
309
pjrt/pjrt.zig
309
pjrt/pjrt.zig
@ -142,15 +142,21 @@ pub const Api = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn stablehloCurrentVersion(self: *const Api, buf: []u8) ?[]u8 {
|
||||
pub fn stablehloCurrentVersion(self: *const Api) ?[]const u8 {
|
||||
const state = struct {
|
||||
var buf: [32]u8 = undefined;
|
||||
var str: ?[:0]const u8 = null;
|
||||
};
|
||||
if (state.str) |str| {
|
||||
return str;
|
||||
}
|
||||
if (self.getPluginAttribute("stablehlo_current_version")) |v| {
|
||||
stdx.debug.assert(v.kind() == .int64list, "fetched attribute \"stablehlo_current_version\" from the plugin with type `{}`, expected `.int64list`", .{v.kind()});
|
||||
stdx.debug.assert(v.inner.value_size == 3, "expect version format to have 3 elements representing `major.minor.patch` format, got {} elements", .{v.inner.value_size});
|
||||
const value = v.inner.unnamed_0.int64_array_value[0..v.inner.value_size];
|
||||
return std.fmt.bufPrint(buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable;
|
||||
state.str = std.fmt.bufPrintZ(&state.buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable;
|
||||
}
|
||||
|
||||
return null;
|
||||
return state.str;
|
||||
}
|
||||
|
||||
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
||||
@ -203,22 +209,22 @@ pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
||||
|
||||
pub fn toApiError(code: ErrorCode) ApiError {
|
||||
return switch (code) {
|
||||
.cancelled => error.Cancelled,
|
||||
.unknown => error.Unknown,
|
||||
.invalid_argument => error.InvalidArgument,
|
||||
.deadline_exceeded => error.DeadlineExceeded,
|
||||
.not_found => error.NotFound,
|
||||
.already_exists => error.AlreadyExists,
|
||||
.permission_denied => error.PermissionDenied,
|
||||
.resource_exhausted => error.ResourceExhausted,
|
||||
.failed_precondition => error.FailedPrecondition,
|
||||
.aborted => error.Aborted,
|
||||
.out_of_range => error.OutOfRange,
|
||||
.unimplemented => error.Unimplemented,
|
||||
.internal => error.Internal,
|
||||
.unavailable => error.Unavailable,
|
||||
.data_loss => error.DataLoss,
|
||||
.unauthenticated => error.Unauthenticated,
|
||||
.cancelled => ApiError.Cancelled,
|
||||
.unknown => ApiError.Unknown,
|
||||
.invalid_argument => ApiError.InvalidArgument,
|
||||
.deadline_exceeded => ApiError.DeadlineExceeded,
|
||||
.not_found => ApiError.NotFound,
|
||||
.already_exists => ApiError.AlreadyExists,
|
||||
.permission_denied => ApiError.PermissionDenied,
|
||||
.resource_exhausted => ApiError.ResourceExhausted,
|
||||
.failed_precondition => ApiError.FailedPrecondition,
|
||||
.aborted => ApiError.Aborted,
|
||||
.out_of_range => ApiError.OutOfRange,
|
||||
.unimplemented => ApiError.Unimplemented,
|
||||
.internal => ApiError.Internal,
|
||||
.unavailable => ApiError.Unavailable,
|
||||
.data_loss => ApiError.DataLoss,
|
||||
.unauthenticated => ApiError.Unauthenticated,
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -247,6 +253,32 @@ pub const Error = opaque {
|
||||
|
||||
pub const ClientInitError = error{LoadingFailed} || ApiError;
|
||||
|
||||
pub const ShapeSpec = extern struct {
|
||||
comptime {
|
||||
std.debug.assert(@sizeOf(ShapeSpec) == @sizeOf(c.PJRT_ShapeSpec));
|
||||
}
|
||||
|
||||
inner: c.PJRT_ShapeSpec,
|
||||
|
||||
pub fn init(dims_: []const usize, bt: BufferType) ShapeSpec {
|
||||
return .{
|
||||
.inner = pjrtStruct(c.PJRT_ShapeSpec{
|
||||
.dims = @ptrCast(@constCast(dims_.ptr)),
|
||||
.num_dims = dims.len,
|
||||
.buffer_type = @intFromEnum(bt),
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn dims(self: ShapeSpec) []usize {
|
||||
return self.inner.dims[0..self.inner.num_dims];
|
||||
}
|
||||
|
||||
pub fn bufferType(self: ShapeSpec) BufferType {
|
||||
return @enumFromInt(self.inner.buffer_type);
|
||||
}
|
||||
};
|
||||
|
||||
pub const Client = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Client).inner;
|
||||
|
||||
@ -322,8 +354,9 @@ pub const Client = opaque {
|
||||
buffer_type: BufferType,
|
||||
dims: []const i64,
|
||||
byte_strides: ?[]const i64,
|
||||
device: *const Device,
|
||||
device: ?*const Device = null,
|
||||
host_buffer_semantics: HostBufferSemantics,
|
||||
memory: ?*const Memory = null,
|
||||
};
|
||||
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||
@ -337,7 +370,7 @@ pub const Client = opaque {
|
||||
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
|
||||
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
|
||||
.device = @ptrCast(@constCast(args.device)),
|
||||
.memory = null, // TODO
|
||||
.memory = @ptrCast(@constCast(args.memory)),
|
||||
.device_layout = null, // TODO
|
||||
.done_with_host_buffer = null,
|
||||
.buffer = null,
|
||||
@ -377,7 +410,9 @@ pub const Client = opaque {
|
||||
element_type: BufferType,
|
||||
layout: MemoryLayout,
|
||||
device: *const Device,
|
||||
on_delete_callback: ?*const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = null,
|
||||
on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = &struct {
|
||||
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
||||
}.call,
|
||||
on_delete_callback_arg: ?*anyopaque = null,
|
||||
stream: ?isize = null,
|
||||
};
|
||||
@ -398,6 +433,50 @@ pub const Client = opaque {
|
||||
});
|
||||
return @ptrCast(ret.buffer.?);
|
||||
}
|
||||
|
||||
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
||||
const ret = api.call(.PJRT_Client_AddressableMemories, .{
|
||||
.client = self.inner(),
|
||||
}) catch unreachable;
|
||||
if (ret.addressable_memories) |memories| {
|
||||
return @constCast(@ptrCast(memories[0..ret.num_addressable_memories]));
|
||||
}
|
||||
return &.{};
|
||||
}
|
||||
|
||||
pub fn dmaMap(self: *const Client, api: *const Api, data: []const u8) ApiError!*Buffer {
|
||||
const ret = try api.call(.PJRT_Client_DMA_Map, .{
|
||||
.client = self.inner(),
|
||||
.data = @ptrCast(@constCast(data.ptr)),
|
||||
.size = @intCast(data.len),
|
||||
});
|
||||
return @ptrCast(ret.buffer.?);
|
||||
}
|
||||
|
||||
pub fn dmaUnmap(self: *const Client, api: *const Api, data: []const u8) void {
|
||||
_ = api.call(.PJRT_Client_DMA_Unmap, .{
|
||||
.client = self.inner(),
|
||||
.data = @ptrCast(@constCast(data.ptr)),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub const CreateBuffersForAsyncHostToDeviceArgs = struct {
|
||||
shape_specs: []const ShapeSpec,
|
||||
device_layouts: ?[]*const MemoryLayout = null,
|
||||
memory: *const Memory,
|
||||
};
|
||||
|
||||
pub fn createBuffersForAsyncHostToDevice(self: *const Client, api: *const Api, args: CreateBuffersForAsyncHostToDeviceArgs) ApiError!*AsyncHostToDeviceTransferManager {
|
||||
const ret = try api.call(.PJRT_Client_CreateBuffersForAsyncHostToDevice, .{
|
||||
.client = self.inner(),
|
||||
.shape_specs = @ptrCast(args.shape_specs.ptr),
|
||||
.num_shape_specs = args.shape_specs.len,
|
||||
.device_layouts = if (args.device_layouts) |layouts| @ptrCast(@constCast(layouts.ptr)) else null,
|
||||
.num_device_layouts = if (args.device_layouts) |layouts| @intCast(layouts.len) else 0,
|
||||
.memory = @ptrCast(@constCast(args.memory)),
|
||||
});
|
||||
return @ptrCast(ret.transfer_manager.?);
|
||||
}
|
||||
};
|
||||
|
||||
pub const Device = opaque {
|
||||
@ -753,6 +832,43 @@ pub const Buffer = opaque {
|
||||
});
|
||||
return ret.device_memory_ptr.?;
|
||||
}
|
||||
|
||||
pub fn copyRawToHost(self: *const Buffer, api: *const Api, dst: []u8, offset: i64) ApiError!?*Event {
|
||||
const ret = try api.call(.PJRT_Buffer_CopyRawToHost, .{
|
||||
.buffer = self.inner(),
|
||||
.dst = @ptrCast(dst.ptr),
|
||||
.offset = offset,
|
||||
.transfer_size = @intCast(dst.len),
|
||||
});
|
||||
return @ptrCast(ret.event);
|
||||
}
|
||||
|
||||
pub fn copyToMemory(self: *const Buffer, api: *const Api, dst_memory: *const Memory) ApiError!?*Buffer {
|
||||
const ret = try api.call(.PJRT_Buffer_CopyToMemory, .{
|
||||
.buffer = self.inner(),
|
||||
.dst_memory = @ptrCast(@constCast(dst_memory)),
|
||||
});
|
||||
return @ptrCast(ret.dst_buffer);
|
||||
}
|
||||
|
||||
pub fn memory(self: *const Buffer, api: *const Api) *const Memory {
|
||||
const ret = api.call(.PJRT_Buffer_Memory, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.memory);
|
||||
}
|
||||
|
||||
pub fn increaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void {
|
||||
_ = try api.call(.PJRT_Buffer_IncreaseExternalReferenceCount, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn decreaseExternalReferenceCount(self: *const Buffer, api: *const Api) ApiError!void {
|
||||
_ = try api.call(.PJRT_Buffer_DecreaseExternalReferenceCount, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
pub const Event = opaque {
|
||||
@ -793,6 +909,153 @@ pub const Event = opaque {
|
||||
}
|
||||
};
|
||||
|
||||
pub const Memory = opaque {
|
||||
pub const Kind = enum {
|
||||
device,
|
||||
pinned_host,
|
||||
unpinned_host,
|
||||
};
|
||||
|
||||
const inner = InnerMixin(c.PJRT_Memory).inner;
|
||||
|
||||
pub fn id(self: *const Memory, api: *const Api) usize {
|
||||
const ret = api.call(.PJRT_Memory_Id, .{
|
||||
.memory = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @intCast(ret.id);
|
||||
}
|
||||
|
||||
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
||||
const ret = api.call(.PJRT_Memory_Kind, .{
|
||||
.memory = self.inner(),
|
||||
}) catch unreachable;
|
||||
const kind_ = ret.kind orelse unreachable;
|
||||
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
|
||||
}
|
||||
|
||||
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
||||
const ret = api.call(.PJRT_Memory_Kind_Id, .{
|
||||
.memory = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @bitCast(ret.kind_id);
|
||||
}
|
||||
|
||||
pub fn debugString(self: *const Memory, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_Memory_DebugString, .{
|
||||
.memory = self.inner(),
|
||||
}) catch unreachable;
|
||||
if (ret.debug_string) |debug_string| {
|
||||
return debug_string[0..ret.debug_string_size];
|
||||
}
|
||||
return &.{};
|
||||
}
|
||||
|
||||
pub fn toString(self: *const Memory, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_Memory_ToString, .{
|
||||
.memory = self.inner(),
|
||||
}) catch unreachable;
|
||||
if (ret.to_string) |to_string| {
|
||||
return to_string[0..ret.to_string_size];
|
||||
}
|
||||
return &.{};
|
||||
}
|
||||
|
||||
pub fn addressableByDevices(self: *const Memory, api: *const Api) []*Device {
|
||||
const ret = api.call(.PJRT_Memory_AddressableByDevices, .{
|
||||
.event = self.inner(),
|
||||
}) catch unreachable;
|
||||
if (ret.devices) |devices| {
|
||||
return devices[0..ret.num_devices];
|
||||
}
|
||||
return &.{};
|
||||
}
|
||||
};
|
||||
|
||||
pub const AsyncHostToDeviceTransferManager = opaque {
|
||||
const inner = InnerMixin(c.PJRT_AsyncHostToDeviceTransferManager).inner;
|
||||
|
||||
pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void {
|
||||
_ = api.call(.PJRT_AsyncHostToDeviceTransferManager_Destroy, .{
|
||||
.transfer_manager = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event {
|
||||
const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_TransferData, .{
|
||||
.transfer_manager = self.inner(),
|
||||
.buffer_index = @intCast(buffer_index),
|
||||
.data = data.ptr,
|
||||
.offset = offset,
|
||||
.transfer_size = @intCast(data.len),
|
||||
.is_last_transfer = is_last_transfer,
|
||||
});
|
||||
return @ptrCast(ret.done_with_h2d_transfer.?);
|
||||
}
|
||||
|
||||
pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer {
|
||||
const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer, .{
|
||||
.transfer_manager = self.inner(),
|
||||
.buffer_index = @intCast(buffer_index),
|
||||
});
|
||||
return @ptrCast(ret.buffer_out.?);
|
||||
}
|
||||
|
||||
pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!*Device {
|
||||
const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_Device, .{
|
||||
.transfer_manager = self.inner(),
|
||||
});
|
||||
return @ptrCast(ret.device_out.?);
|
||||
}
|
||||
|
||||
pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) ApiError!usize {
|
||||
const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferCount, .{
|
||||
.transfer_manager = self.inner(),
|
||||
});
|
||||
return ret.buffer_count;
|
||||
}
|
||||
|
||||
pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!usize {
|
||||
const ret = try api.call(.PJRT_AsyncHostToDeviceTransferManager_BufferSize, .{
|
||||
.transfer_manager = self.inner(),
|
||||
.buffer_index = @intCast(buffer_index),
|
||||
});
|
||||
return ret.buffer_size;
|
||||
}
|
||||
|
||||
pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: c.PJRT_Error_Code, error_message: []const u8) ApiError!void {
|
||||
_ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_SetBufferError, .{
|
||||
.transfer_manager = self.inner(),
|
||||
.buffer_index = @intCast(buffer_index),
|
||||
.error_code = error_code,
|
||||
.error_message = error_message.ptr,
|
||||
.error_message_size = error_message.len,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) ApiError!void {
|
||||
_ = try api.call(.PJRT_AsyncHostToDeviceTransferManager_AddMetadata, .{
|
||||
.transfer_manager = self.inner(),
|
||||
.transfer_metadata = transfer_metadata.ptr,
|
||||
.num_metadata = transfer_metadata.len,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
pub const ExecutionContext = opaque {
|
||||
const inner = InnerMixin(c.PJRT_ExecutionContext).inner;
|
||||
|
||||
pub fn init(api: *const Api) ApiError!*ExecutionContext {
|
||||
const ret = try api.call(.PJRT_ExecutionContext_Create, .{});
|
||||
return @ptrCast(ret.context.?);
|
||||
}
|
||||
|
||||
pub fn deinit(self: *ExecutionContext, api: *const Api) void {
|
||||
_ = api.call(.PJRT_ExecutionContext_Destroy, .{
|
||||
.context = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
};
|
||||
|
||||
pub const NamedValue = extern struct {
|
||||
comptime {
|
||||
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
|
||||
|
||||
@ -27,6 +27,29 @@ const log = std.log.scoped(.zml);
|
||||
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||
pub const Buffer = struct {
|
||||
pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).Enum.tag_type) {
|
||||
host = @intFromEnum(pjrt.Memory.Kind.unpinned_host),
|
||||
host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host),
|
||||
device = @intFromEnum(pjrt.Memory.Kind.device),
|
||||
};
|
||||
|
||||
pub const Shard = struct {
|
||||
api: *const pjrt.Api,
|
||||
buffer: *pjrt.Buffer,
|
||||
ready_event: ?*pjrt.Event = null,
|
||||
ready: bool = false,
|
||||
|
||||
pub fn awaitt(self: *Shard) !void {
|
||||
if (self.ready) {
|
||||
return;
|
||||
}
|
||||
if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| {
|
||||
try ev.awaitt(self.api);
|
||||
}
|
||||
self.ready = true;
|
||||
}
|
||||
};
|
||||
|
||||
_shape: Shape,
|
||||
_api: *const pjrt.Api,
|
||||
_shards: Shards,
|
||||
|
||||
132
zml/pjrtx.zig
132
zml/pjrtx.zig
@ -17,7 +17,6 @@ const log = std.log.scoped(.zml);
|
||||
pub const Profiler = pjrt.Profiler;
|
||||
pub const ApiError = pjrt.ApiError;
|
||||
pub const ErrorCode = pjrt.ErrorCode;
|
||||
pub const Buffer = pjrt.Buffer;
|
||||
pub const BufferType = pjrt.BufferType;
|
||||
pub const Device = pjrt.Device;
|
||||
pub const DeviceDescription = pjrt.DeviceDescription;
|
||||
@ -30,6 +29,7 @@ pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
||||
pub const SerializeResult = pjrt.SerializeResult;
|
||||
pub const Executable = pjrt.Executable;
|
||||
pub const ExecuteError = ApiError;
|
||||
pub const Memory = pjrt.Memory;
|
||||
|
||||
fn InnerMixin(comptime innerT: type) type {
|
||||
return struct {
|
||||
@ -69,7 +69,7 @@ pub const Client = opaque {
|
||||
const event: *Event = @ptrCast(event__);
|
||||
try event.await_(api);
|
||||
}
|
||||
return buffer;
|
||||
return @ptrCast(buffer);
|
||||
}
|
||||
|
||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||
@ -78,12 +78,7 @@ pub const Client = opaque {
|
||||
|
||||
pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs;
|
||||
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
||||
var args_ = args;
|
||||
args_.on_delete_callback = args_.on_delete_callback orelse &(struct {
|
||||
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
||||
}.call);
|
||||
const buf = try self.inner().createViewOfDeviceBuffer(api, args_);
|
||||
return @ptrCast(buf);
|
||||
return @ptrCast(try self.inner().createViewOfDeviceBuffer(api, args));
|
||||
}
|
||||
|
||||
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
||||
@ -97,13 +92,11 @@ pub const Client = opaque {
|
||||
var serialized_buffer = std.ArrayList(u8).init(allocator);
|
||||
defer serialized_buffer.deinit();
|
||||
|
||||
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
|
||||
var requested_stablehlo_version_buf: [32]u8 = undefined;
|
||||
const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf);
|
||||
const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: {
|
||||
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
|
||||
} else blk: {
|
||||
break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
||||
const stablehlo_version = blk: {
|
||||
if (api.stablehloCurrentVersion()) |requested_version| {
|
||||
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
|
||||
}
|
||||
break :blk dialects.stablehlo.getMinimumVersion();
|
||||
};
|
||||
|
||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
||||
@ -128,6 +121,79 @@ pub const Client = opaque {
|
||||
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
|
||||
return self.inner().getProfiler(api, options);
|
||||
}
|
||||
|
||||
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
||||
return self.inner().addressableMemories(api);
|
||||
}
|
||||
|
||||
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*Memory {
|
||||
for (self.addressableMemories(api)) |mem| {
|
||||
if (mem.kind(api) == kind) {
|
||||
return mem;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub const Buffer = opaque {
|
||||
pub const inner = InnerMixin(pjrt.Buffer).inner;
|
||||
|
||||
pub fn deinit(self: *Buffer, api: *const Api) void {
|
||||
self.inner().deinit(api);
|
||||
}
|
||||
|
||||
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
|
||||
return try self.inner().getDevice(api);
|
||||
}
|
||||
|
||||
pub fn delete(self: *Buffer, api: *const Api) void {
|
||||
self.inner().delete(api);
|
||||
}
|
||||
|
||||
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
|
||||
return self.inner().isDeleted(api);
|
||||
}
|
||||
|
||||
pub fn isOnCpu(self: *const Buffer, api: *const Api) bool {
|
||||
return self.inner().isOnCpu(api);
|
||||
}
|
||||
|
||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
|
||||
return @ptrCast(try self.inner().toHostBuffer(api, dst));
|
||||
}
|
||||
|
||||
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
|
||||
return self.inner().getElementType(api);
|
||||
}
|
||||
|
||||
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
|
||||
return self.inner().getDimensions(api);
|
||||
}
|
||||
|
||||
pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 {
|
||||
return try self.inner().getUnpaddedDimensions(api);
|
||||
}
|
||||
|
||||
pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize {
|
||||
return try self.inner().getOnDeviceSizeInBytes(api);
|
||||
}
|
||||
|
||||
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer {
|
||||
return @ptrCast(self.inner().copyToDevice(api, device));
|
||||
}
|
||||
|
||||
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory: *const Memory) ApiError!*Buffer {
|
||||
return @ptrCast(self.inner().copyToMemory(api, memory));
|
||||
}
|
||||
|
||||
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
||||
return @ptrCast(self.inner().getReadyEvent(api));
|
||||
}
|
||||
|
||||
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
|
||||
return try self.inner().getOpaqueDeviceMemoryDataPointer(api);
|
||||
}
|
||||
};
|
||||
|
||||
pub const Event = opaque {
|
||||
@ -214,3 +280,39 @@ pub const LoadedExecutable = opaque {
|
||||
return try self.inner().getExecutable(api);
|
||||
}
|
||||
};
|
||||
|
||||
pub const AsyncHostToDeviceTransferManager = opaque {
|
||||
const inner = InnerMixin(pjrt.AsyncHostToDeviceTransferManager).inner;
|
||||
|
||||
pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void {
|
||||
self.inner().deinit(api);
|
||||
}
|
||||
|
||||
pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event {
|
||||
return @ptrCast(try self.inner().transferData(api, buffer_index, data, offset, is_last_transfer));
|
||||
}
|
||||
|
||||
pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer {
|
||||
return @ptrCast(try self.inner().retrieveBuffer(api, buffer_index));
|
||||
}
|
||||
|
||||
pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) *Device {
|
||||
return @ptrCast(self.inner().device(api));
|
||||
}
|
||||
|
||||
pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) usize {
|
||||
return self.inner().bufferCount(api);
|
||||
}
|
||||
|
||||
pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) usize {
|
||||
return self.inner().bufferSize(api, buffer_index);
|
||||
}
|
||||
|
||||
pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: ErrorCode, error_message: []const u8) void {
|
||||
self.inner().setBufferError(api, buffer_index, error_code, error_message);
|
||||
}
|
||||
|
||||
pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) void {
|
||||
return self.inner().addMetadata(api, transfer_metadata);
|
||||
}
|
||||
};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user