2024-09-10 09:14:28 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
|
2023-05-26 15:54:15 +00:00
|
|
|
const asynk = @import("async");
|
2023-06-21 14:45:14 +00:00
|
|
|
const dialects = @import("mlir/dialects");
|
2023-05-26 15:54:15 +00:00
|
|
|
const mlir = @import("mlir");
|
|
|
|
|
const pjrt = @import("pjrt");
|
2024-09-10 09:14:28 +00:00
|
|
|
pub const ffi = pjrt.ffi;
|
2023-06-21 14:45:14 +00:00
|
|
|
pub const ApiError = pjrt.ApiError;
|
|
|
|
|
pub const ErrorCode = pjrt.ErrorCode;
|
2024-12-10 09:36:37 +00:00
|
|
|
pub const ExecuteContext = pjrt.ExecuteContext;
|
2023-05-26 15:54:15 +00:00
|
|
|
pub const BufferType = pjrt.BufferType;
|
|
|
|
|
pub const Device = pjrt.Device;
|
2025-02-19 12:14:05 +00:00
|
|
|
pub const MemoryStats = pjrt.MemoryStats;
|
2023-05-26 15:54:15 +00:00
|
|
|
pub const DeviceDescription = pjrt.DeviceDescription;
|
|
|
|
|
pub const Api = pjrt.Api;
|
|
|
|
|
pub const NamedValue = pjrt.NamedValue;
|
|
|
|
|
pub const ClientInitError = pjrt.ClientInitError;
|
|
|
|
|
pub const Error = pjrt.Error;
|
|
|
|
|
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
|
|
|
|
pub const SerializeResult = pjrt.SerializeResult;
|
|
|
|
|
pub const Executable = pjrt.Executable;
|
2025-01-02 16:36:13 +00:00
|
|
|
pub const CompiledMemoryStats = pjrt.CompiledMemoryStats;
|
2023-05-26 15:54:15 +00:00
|
|
|
pub const ExecuteError = ApiError;
|
2024-04-11 15:43:24 +00:00
|
|
|
pub const Memory = pjrt.Memory;
|
2024-12-10 09:36:37 +00:00
|
|
|
pub const Stream = pjrt.Stream;
|
2023-05-26 15:54:15 +00:00
|
|
|
|
2024-09-10 09:14:28 +00:00
|
|
|
const log = std.log.scoped(.zml);
|
|
|
|
|
|
|
|
|
|
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
|
|
|
|
|
|
2023-05-26 15:54:15 +00:00
|
|
|
fn InnerMixin(comptime innerT: type) type {
|
|
|
|
|
return struct {
|
2024-07-02 14:19:04 +00:00
|
|
|
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT {
|
2023-05-26 15:54:15 +00:00
|
|
|
return @ptrCast(self);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub const Client = opaque {
|
|
|
|
|
const inner = InnerMixin(pjrt.Client).inner;
|
|
|
|
|
|
2023-11-13 12:45:17 +00:00
|
|
|
pub fn init(api: *const Api, options: []const NamedValue) ClientInitError!*Client {
|
|
|
|
|
return @ptrCast(try pjrt.Client.init(api, options));
|
2023-05-26 15:54:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn deinit(self: *Client, api: *const Api) void {
|
|
|
|
|
self.inner().deinit(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 {
|
|
|
|
|
return self.inner().getPlatformName(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getDevices(self: *const Client, api: *const Api) []const *const Device {
|
|
|
|
|
return self.inner().getDevices(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *const Device {
|
|
|
|
|
return self.inner().getAddressableDevices(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
2024-12-25 17:14:44 +00:00
|
|
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
|
|
|
|
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
|
|
|
|
return .{ @ptrCast(buffer), @ptrCast(event_) };
|
2023-05-26 15:54:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
|
|
|
|
return @ptrCast(try asynk.callBlocking(pjrt.Client.deserializeAndLoad, .{ self.inner(), api, bytes }));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs;
|
|
|
|
|
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
2024-04-11 15:43:24 +00:00
|
|
|
return @ptrCast(try self.inner().createViewOfDeviceBuffer(api, args));
|
2023-05-26 15:54:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
|
|
|
var bytecode = std.ArrayList(u8).init(allocator);
|
|
|
|
|
defer bytecode.deinit();
|
2023-05-29 17:18:19 +00:00
|
|
|
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| {
|
|
|
|
|
log.err("failed to write module bytecode: {}", .{err});
|
|
|
|
|
return err;
|
2023-05-26 15:54:15 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
var serialized_buffer = std.ArrayList(u8).init(allocator);
|
|
|
|
|
defer serialized_buffer.deinit();
|
2023-08-07 12:28:36 +00:00
|
|
|
|
2024-04-11 15:43:24 +00:00
|
|
|
const stablehlo_version = blk: {
|
|
|
|
|
if (api.stablehloCurrentVersion()) |requested_version| {
|
|
|
|
|
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
|
|
|
|
|
}
|
|
|
|
|
break :blk dialects.stablehlo.getMinimumVersion();
|
2023-09-07 17:06:19 +00:00
|
|
|
};
|
2023-08-07 12:28:36 +00:00
|
|
|
|
|
|
|
|
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
2023-05-29 17:18:19 +00:00
|
|
|
log.err("failed to serialize to portable artifact: {}", .{err});
|
|
|
|
|
return err;
|
2023-05-26 15:54:15 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return @ptrCast(try self.inner().compile(api, .{
|
|
|
|
|
.bytecode = serialized_buffer.items,
|
|
|
|
|
.bytecode_format = .mlir,
|
|
|
|
|
.compile_options_pb = compile_options_pb,
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn compile(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
|
|
|
return try asynk.callBlocking(compileSync, .{ self, api, allocator, module, compile_options_pb });
|
|
|
|
|
}
|
|
|
|
|
|
2024-04-11 15:43:24 +00:00
|
|
|
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
|
|
|
|
return self.inner().addressableMemories(api);
|
|
|
|
|
}
|
|
|
|
|
|
2024-12-10 09:36:37 +00:00
|
|
|
pub fn memoryByKind(self: *const Client, api: *const Api, kind: Memory.Kind) ?*const Memory {
|
2024-04-11 15:43:24 +00:00
|
|
|
for (self.addressableMemories(api)) |mem| {
|
|
|
|
|
if (mem.kind(api) == kind) {
|
|
|
|
|
return mem;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return null;
|
|
|
|
|
}
|
2025-02-25 10:37:45 +00:00
|
|
|
|
|
|
|
|
pub const CreateUninitializedBufferArgs = pjrt.Client.CreateUninitializedBufferArgs;
|
|
|
|
|
|
|
|
|
|
pub fn createUnitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
|
|
|
|
return @ptrCast(try self.inner().createUninitializedBuffer(api, args));
|
|
|
|
|
}
|
2024-04-11 15:43:24 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub const Buffer = opaque {
|
|
|
|
|
pub const inner = InnerMixin(pjrt.Buffer).inner;
|
|
|
|
|
|
|
|
|
|
pub fn deinit(self: *Buffer, api: *const Api) void {
|
|
|
|
|
self.inner().deinit(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
|
|
|
|
|
return try self.inner().getDevice(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn delete(self: *Buffer, api: *const Api) void {
|
|
|
|
|
self.inner().delete(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
|
|
|
|
|
return self.inner().isDeleted(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn isOnCpu(self: *const Buffer, api: *const Api) bool {
|
|
|
|
|
return self.inner().isOnCpu(api);
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-10 09:14:28 +00:00
|
|
|
pub fn memory(self: *const Buffer, api: *const Api) *const Memory {
|
|
|
|
|
return self.inner().memory(api);
|
|
|
|
|
}
|
|
|
|
|
|
2024-04-11 15:43:24 +00:00
|
|
|
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
|
|
|
|
|
return @ptrCast(try self.inner().toHostBuffer(api, dst));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
|
|
|
|
|
return self.inner().getElementType(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
|
|
|
|
|
return self.inner().getDimensions(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 {
|
|
|
|
|
return try self.inner().getUnpaddedDimensions(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize {
|
|
|
|
|
return try self.inner().getOnDeviceSizeInBytes(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!*Buffer {
|
|
|
|
|
return @ptrCast(self.inner().copyToDevice(api, device));
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-10 09:14:28 +00:00
|
|
|
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
|
2024-12-10 09:36:37 +00:00
|
|
|
return @ptrCast(try self.inner().copyToMemory(api, memory_));
|
2024-04-11 15:43:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
|
|
|
|
return @ptrCast(self.inner().getReadyEvent(api));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
|
|
|
|
|
return try self.inner().getOpaqueDeviceMemoryDataPointer(api);
|
|
|
|
|
}
|
2023-05-26 15:54:15 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub const Event = opaque {
|
|
|
|
|
pub const inner = InnerMixin(pjrt.Event).inner;
|
|
|
|
|
|
|
|
|
|
pub fn deinit(self: *Event, api: *const Api) void {
|
|
|
|
|
self.inner().deinit(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn isReady(self: *const Event, api: *const Api) bool {
|
|
|
|
|
return self.inner().isReady(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
|
|
|
|
|
return self.inner().getEventError(api);
|
|
|
|
|
}
|
|
|
|
|
|
2024-05-02 17:10:11 +00:00
|
|
|
pub fn await_(self: *Event, api: *const Api) ApiError!void {
|
2023-05-26 15:54:15 +00:00
|
|
|
defer self.deinit(api);
|
|
|
|
|
|
2023-08-01 11:35:04 +00:00
|
|
|
if (self.isReady(api)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-26 15:54:15 +00:00
|
|
|
var ctx = struct {
|
|
|
|
|
err: ?*pjrt.Error = null,
|
2023-08-01 11:35:04 +00:00
|
|
|
event: asynk.threading.ResetEventSingle = .{},
|
|
|
|
|
}{};
|
2023-05-26 15:54:15 +00:00
|
|
|
|
|
|
|
|
try self.inner().onReady(api, &(struct {
|
|
|
|
|
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
|
|
|
|
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
|
|
|
|
ctx_.err = err;
|
2023-08-01 11:35:04 +00:00
|
|
|
ctx_.event.set();
|
2023-05-26 15:54:15 +00:00
|
|
|
}
|
|
|
|
|
}.call), &ctx);
|
2023-08-01 11:35:04 +00:00
|
|
|
ctx.event.wait();
|
2023-05-26 15:54:15 +00:00
|
|
|
|
|
|
|
|
if (ctx.err) |e| {
|
|
|
|
|
defer e.deinit(api);
|
|
|
|
|
return e.getCode(api).toApiError();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub const LoadedExecutable = opaque {
|
|
|
|
|
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
|
|
|
|
|
|
|
|
|
|
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
|
|
|
|
self.inner().deinit(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
|
|
|
|
self.inner().delete(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool {
|
|
|
|
|
return self.inner().isDeleted(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
|
|
|
|
|
return self.inner().getAddressableDevices(api);
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-21 14:45:14 +00:00
|
|
|
pub const ExecuteArgs = struct {
|
2023-05-26 15:54:15 +00:00
|
|
|
arguments: []const [*]const *const Buffer,
|
|
|
|
|
num_args: usize,
|
|
|
|
|
results: []const [*]*Buffer,
|
|
|
|
|
events: []?*Event,
|
|
|
|
|
non_donatable_input_indices: []const i64 = &.{},
|
2024-12-10 09:36:37 +00:00
|
|
|
context: ?*ExecuteContext,
|
2023-06-21 14:45:14 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
|
|
|
|
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
|
2023-05-26 15:54:15 +00:00
|
|
|
.num_args = args.num_args,
|
|
|
|
|
.arguments = @ptrCast(args.arguments),
|
|
|
|
|
.results = @ptrCast(args.results),
|
|
|
|
|
.events = @ptrCast(args.events),
|
|
|
|
|
.non_donatable_input_indices = args.non_donatable_input_indices,
|
2024-12-10 09:36:37 +00:00
|
|
|
.context = args.context,
|
2023-05-26 15:54:15 +00:00
|
|
|
} });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
|
|
|
|
return try self.inner().getExecutable(api);
|
|
|
|
|
}
|
|
|
|
|
};
|
2024-04-11 15:43:24 +00:00
|
|
|
|
|
|
|
|
pub const AsyncHostToDeviceTransferManager = opaque {
|
|
|
|
|
const inner = InnerMixin(pjrt.AsyncHostToDeviceTransferManager).inner;
|
|
|
|
|
|
|
|
|
|
pub fn deinit(self: *AsyncHostToDeviceTransferManager, api: *const Api) void {
|
|
|
|
|
self.inner().deinit(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn transferData(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, data: []const u8, offset: i64, is_last_transfer: bool) ApiError!*Event {
|
|
|
|
|
return @ptrCast(try self.inner().transferData(api, buffer_index, data, offset, is_last_transfer));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn retrieveBuffer(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) ApiError!*Buffer {
|
|
|
|
|
return @ptrCast(try self.inner().retrieveBuffer(api, buffer_index));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn device(self: *AsyncHostToDeviceTransferManager, api: *const Api) *Device {
|
|
|
|
|
return @ptrCast(self.inner().device(api));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn bufferCount(self: *AsyncHostToDeviceTransferManager, api: *const Api) usize {
|
|
|
|
|
return self.inner().bufferCount(api);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn bufferSize(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize) usize {
|
|
|
|
|
return self.inner().bufferSize(api, buffer_index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn setBufferError(self: *AsyncHostToDeviceTransferManager, api: *const Api, buffer_index: usize, error_code: ErrorCode, error_message: []const u8) void {
|
|
|
|
|
self.inner().setBufferError(api, buffer_index, error_code, error_message);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn addMetadata(self: *AsyncHostToDeviceTransferManager, api: *const Api, transfer_metadata: []const NamedValue) void {
|
|
|
|
|
return self.inner().addMetadata(api, transfer_metadata);
|
|
|
|
|
}
|
|
|
|
|
};
|