zml: Remove pjrtx wrapper, migrate remaining helpers to their native modules, and fix blocking issue in Event.await.
This commit is contained in:
parent
0c126c2e12
commit
dfa71018a5
@ -398,3 +398,5 @@ pub const Mutex = struct {
|
|||||||
_ = self.inner.recv();
|
_ = self.inner.recv();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const inCoro = libcoro.inCoro;
|
||||||
|
|||||||
@ -2,7 +2,6 @@ const builtin = @import("builtin");
|
|||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const posix = @import("posix.zig");
|
const posix = @import("posix.zig");
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
|
|
||||||
|
const meta = @import("meta.zig");
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
@ -53,6 +53,7 @@ pub const Buffer = struct {
|
|||||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||||
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
||||||
|
|
||||||
|
var events: std.BoundedArray(*pjrt.Event, MAX_NUM_SHARDS) = .{};
|
||||||
const devices = platform.getDevices();
|
const devices = platform.getDevices();
|
||||||
for (0..n_partitions) |i| {
|
for (0..n_partitions) |i| {
|
||||||
// If no sharding if found, the given buffer is replicated on all devices.
|
// If no sharding if found, the given buffer is replicated on all devices.
|
||||||
@ -61,7 +62,7 @@ pub const Buffer = struct {
|
|||||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||||
} else host_buffer;
|
} else host_buffer;
|
||||||
|
|
||||||
const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
||||||
.data = buf.data,
|
.data = buf.data,
|
||||||
.buffer_type = buffer_type,
|
.buffer_type = buffer_type,
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
@ -70,8 +71,13 @@ pub const Buffer = struct {
|
|||||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
events.appendAssumeCapacity(event);
|
||||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (events.constSlice()) |event| {
|
||||||
|
try platform.awaitEvent(event);
|
||||||
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,18 +1,17 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const mlir = @import("mlir");
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const mlir = @import("mlir");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
|
||||||
const platform = @import("platform.zig");
|
const platform = @import("platform.zig");
|
||||||
const pjrtx = @import("pjrtx.zig");
|
|
||||||
|
|
||||||
const available_targets = @import("platform.zig").available_targets;
|
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
const PjrtApiMap = std.EnumArray(Target, ?*const pjrtx.Api);
|
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
||||||
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
||||||
|
|
||||||
/// Every program using ZML must start with a `zml.Context.init(.{});`
|
/// Every program using ZML must start with a `zml.Context.init(.{});`
|
||||||
@ -27,7 +26,7 @@ pub const Context = struct {
|
|||||||
fn call() void {
|
fn call() void {
|
||||||
inline for (platform.available_targets) |t| {
|
inline for (platform.available_targets) |t| {
|
||||||
if (canLoad(t)) {
|
if (canLoad(t)) {
|
||||||
if (pjrtx.Api.loadFrom(platformToLibrary(t))) |api| {
|
if (pjrt.Api.loadFrom(platformToLibrary(t))) |api| {
|
||||||
Context.apis.set(t, api);
|
Context.apis.set(t, api);
|
||||||
} else |_| {}
|
} else |_| {}
|
||||||
}
|
}
|
||||||
@ -107,7 +106,7 @@ pub const Context = struct {
|
|||||||
return std.mem.eql(u8, &buf, GoogleComputeEngine);
|
return std.mem.eql(u8, &buf, GoogleComputeEngine);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pjrtApi(target: Target) *const pjrtx.Api {
|
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
||||||
return Context.apis.get(target).?;
|
return Context.apis.get(target).?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
|
|
||||||
@ -7,7 +8,6 @@ const xla_pb = @import("//xla:xla_proto");
|
|||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const ops = @import("ops.zig");
|
const ops = @import("ops.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
const protobuf = @import("io/protobuf");
|
const protobuf = @import("io/protobuf");
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const aio = @import("aio.zig");
|
const aio = @import("aio.zig");
|
||||||
@ -1118,8 +1118,21 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
const options_bytes = try options.encode(arena);
|
const options_bytes = try options.encode(arena);
|
||||||
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
|
|
||||||
errdefer unreachable; // errdefer loaded_executable.deinit();
|
var mlir_bytecode = std.ArrayList(u8).init(arena);
|
||||||
|
defer mlir_bytecode.deinit();
|
||||||
|
// Note: we may need to restore IR downgrade if we need to support old pjrt plugins.
|
||||||
|
module.op().writeBytecode(mlir_bytecode.writer());
|
||||||
|
|
||||||
|
const loaded_executable = try asynk.call(pjrt.Client.compile, .{
|
||||||
|
platform.pjrt_client, platform.pjrt_api, .{
|
||||||
|
.bytecode = mlir_bytecode.items,
|
||||||
|
.bytecode_format = .mlir,
|
||||||
|
.compile_options_pb = options_bytes,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
errdefer loaded_executable.deinit();
|
||||||
|
|
||||||
if (platform.compilation_options.cache_location) |compilation_cache_location| {
|
if (platform.compilation_options.cache_location) |compilation_cache_location| {
|
||||||
log.debug("Storing module to {s}", .{compilation_cache_location});
|
log.debug("Storing module to {s}", .{compilation_cache_location});
|
||||||
|
|||||||
270
zml/pjrtx.zig
270
zml/pjrtx.zig
@ -1,270 +0,0 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
|
||||||
|
|
||||||
const mlir = @import("mlir");
|
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
const dtype = @import("dtype.zig");
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const asynk = @import("async");
|
|
||||||
|
|
||||||
pub const Profiler = pjrt.Profiler;
|
|
||||||
pub const ApiError = pjrt.ApiError;
|
|
||||||
pub const ErrorCode = pjrt.ErrorCode;
|
|
||||||
|
|
||||||
const Target = @import("platform.zig").Target;
|
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
|
||||||
|
|
||||||
pub const Buffer = pjrt.Buffer;
|
|
||||||
pub const Device = pjrt.Device;
|
|
||||||
pub const DeviceDescription = pjrt.DeviceDescription;
|
|
||||||
pub const Api = pjrt.Api;
|
|
||||||
pub const NamedValue = pjrt.NamedValue;
|
|
||||||
pub const ClientInitError = pjrt.ClientInitError;
|
|
||||||
pub const CompileError = std.mem.Allocator.Error || ApiError;
|
|
||||||
pub const Error = pjrt.Error;
|
|
||||||
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
|
||||||
pub const SerializeResult = pjrt.SerializeResult;
|
|
||||||
pub const Executable = pjrt.Executable;
|
|
||||||
pub const ExecuteError = ApiError;
|
|
||||||
|
|
||||||
test {
|
|
||||||
std.testing.refAllDecls(Client);
|
|
||||||
std.testing.refAllDecls(Event);
|
|
||||||
std.testing.refAllDecls(LoadedExecutable);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn InnerMixin(comptime innerT: type) type {
|
|
||||||
return struct {
|
|
||||||
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
|
|
||||||
return @ptrCast(self);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const Client = opaque {
|
|
||||||
const inner = InnerMixin(pjrt.Client).inner;
|
|
||||||
|
|
||||||
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
|
|
||||||
return @ptrCast(try pjrt.Client.init(api, create_options));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *Client, api: *const Api) void {
|
|
||||||
self.inner().deinit(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 {
|
|
||||||
return self.inner().getPlatformName(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getDevices(self: *const Client, api: *const Api) []const *const Device {
|
|
||||||
return self.inner().getDevices(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *const Device {
|
|
||||||
return self.inner().getAddressableDevices(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
|
||||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer {
|
|
||||||
const buffer, const event_ = try self.inner().bufferFromHostBuffer(api, args);
|
|
||||||
const event: *Event = @ptrCast(event_);
|
|
||||||
try event.await_(api);
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
|
||||||
return @ptrCast(try asynk.call(pjrt.Client.deserializeAndLoad, .{ self.inner(), api, bytes }));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const CreateViewOfDeviceBufferArgs = pjrt.Client.CreateViewOfDeviceBufferArgs;
|
|
||||||
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
|
||||||
var args_ = args;
|
|
||||||
args_.on_delete_callback = args_.on_delete_callback orelse &(struct {
|
|
||||||
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
|
||||||
}.call);
|
|
||||||
const buf = try self.inner().createViewOfDeviceBuffer(api, args_);
|
|
||||||
return @ptrCast(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn downgradeStableHLO(self: Client, operation: mlir.Operation) mlir.Operation {
|
|
||||||
var cloned = operation.clone() catch unreachable;
|
|
||||||
cloned.walk(.pre_order, .{ .api_version = self.getApiVersion() }, struct {
|
|
||||||
const OpsStaticMap = std.StaticStringMap([]const [:0]const u8);
|
|
||||||
const convertPre40Ops = OpsStaticMap.initComptime(.{
|
|
||||||
.{ "stablehlo.broadcast", &.{"broadcast_sizes"} },
|
|
||||||
.{ "stablehlo.dynamic_slice", &.{"slice_sizes"} },
|
|
||||||
.{ "stablehlo.fft", &.{"fft_length"} },
|
|
||||||
.{ "stablehlo.pad", &.{ "edge_padding_low", "edge_padding_high", "interior_padding" } },
|
|
||||||
.{ "stablehlo.reverse", &.{"dimensions"} },
|
|
||||||
.{ "stablehlo.slice", &.{ "start_indices", "limit_indices", "strides" } },
|
|
||||||
.{ "stablehlo.transpose", &.{"permutation"} },
|
|
||||||
});
|
|
||||||
const convertOps = OpsStaticMap.initComptime(.{
|
|
||||||
.{ "stablehlo.broadcast_in_dim", &.{"broadcast_dimensions"} },
|
|
||||||
.{ "stablehlo.convolution", &.{ "window_strides", "rhs_dilation", "lhs_dilation", "window_reversal" } },
|
|
||||||
.{ "stablehlo.dynamic_broadcast_in_dim", &.{ "broadcast_dimensions", "known_expanding_dimensions", "known_nonexpanding_dimensions" } },
|
|
||||||
.{ "stablehlo.dynamic_convolution", &.{ "window_strides", "rhs_dilation", "lhs_dilation", "window_reversal" } },
|
|
||||||
.{ "stablehlo.gather", &.{"slice_sizes"} },
|
|
||||||
.{ "stablehlo.map", &.{"dimensions"} },
|
|
||||||
.{ "stablehlo.reduce", &.{"dimensions"} },
|
|
||||||
.{ "stablehlo.reduce_window", &.{ "window_dimensions", "window_strides", "base_dilations", "window_dilations" } },
|
|
||||||
.{ "stablehlo.select_and_scatter", &.{ "window_dimensions", "window_strides" } },
|
|
||||||
});
|
|
||||||
|
|
||||||
fn convert(map: OpsStaticMap, op: mlir.Operation) void {
|
|
||||||
if (map.get(op.name().str())) |attrs| {
|
|
||||||
for (attrs) |attr_name| {
|
|
||||||
if (op.getAttributeByName(attr_name)) |attr| {
|
|
||||||
if (attr.as(mlir.DenseArrayAttribute(.bool))) |attr_| {
|
|
||||||
op.setAttributeByName(attr_name, attr_.toElements().as(mlir.Attribute).?);
|
|
||||||
} else if (attr.as(mlir.DenseArrayAttribute(.i64))) |attr_| {
|
|
||||||
op.setAttributeByName(attr_name, attr_.toElements().as(mlir.Attribute).?);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn walk(wctx: anytype, op: mlir.Operation) mlir.Operation.WalkResult {
|
|
||||||
// Keep in sync with https://github.com/openxla/xla/blob/a05ff095226aa2301903c2b475017b248d2c5ef3/xla/pjrt/mlir_to_hlo.cc#L101
|
|
||||||
if (wctx.api_version.minor < 40) {
|
|
||||||
convert(convertPre40Ops, op);
|
|
||||||
}
|
|
||||||
convert(convertOps, op);
|
|
||||||
|
|
||||||
return .advance;
|
|
||||||
}
|
|
||||||
}.walk);
|
|
||||||
return cloned;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
||||||
var buffer = std.ArrayList(u8).init(allocator);
|
|
||||||
defer buffer.deinit();
|
|
||||||
// Note: we may need to restore IR downgrade if we need to support old pjrt plugins.
|
|
||||||
module.op().writeBytecode(buffer.writer());
|
|
||||||
|
|
||||||
return @ptrCast(try self.inner().compile(api, .{
|
|
||||||
.bytecode = buffer.items,
|
|
||||||
.bytecode_format = .mlir,
|
|
||||||
.compile_options_pb = compile_options_pb,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compileSync2(self: *const Client, api: *const Api, module: []const u8, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
||||||
return @ptrCast(try self.inner().compile(api, .{
|
|
||||||
.bytecode = module,
|
|
||||||
.bytecode_format = .mlir,
|
|
||||||
.compile_options_pb = compile_options_pb,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compile(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
||||||
return try asynk.call(compileSync, .{ self, api, allocator, module, compile_options_pb });
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compile2(self: *const Client, api: *const Api, module: []const u8, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
|
||||||
return try asynk.call(compileSync2, .{ self, api, module, compile_options_pb });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the Profiler for this API.
|
|
||||||
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
|
||||||
/// Platforms with known profiler extensions: cuda, xpu
|
|
||||||
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
|
|
||||||
return self.inner().getProfiler(api, options);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const Event = opaque {
|
|
||||||
pub const inner = InnerMixin(pjrt.Event).inner;
|
|
||||||
|
|
||||||
pub fn deinit(self: *Event, api: *const Api) void {
|
|
||||||
self.inner().deinit(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isReady(self: *const Event, api: *const Api) bool {
|
|
||||||
return self.inner().isReady(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getEventError(self: *const Event, api: *const Api) ?*Error {
|
|
||||||
return self.inner().getEventError(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn await_(self: *Event, api: *const Api) !void {
|
|
||||||
defer self.deinit(api);
|
|
||||||
try self.inner().await_(api);
|
|
||||||
|
|
||||||
var ctx = struct {
|
|
||||||
err: ?*pjrt.Error = null,
|
|
||||||
notif: asynk.Notification,
|
|
||||||
ready: bool = false,
|
|
||||||
}{
|
|
||||||
.notif = try asynk.Notification.init(),
|
|
||||||
};
|
|
||||||
defer ctx.notif.deinit();
|
|
||||||
|
|
||||||
try self.inner().onReady(api, &(struct {
|
|
||||||
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
|
||||||
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
|
||||||
ctx_.err = err;
|
|
||||||
@atomicStore(bool, &ctx_.ready, true, .seq_cst);
|
|
||||||
ctx_.notif.notify() catch @panic("Unable to notify");
|
|
||||||
}
|
|
||||||
}.call), &ctx);
|
|
||||||
|
|
||||||
while (!ctx.ready) {
|
|
||||||
try ctx.notif.wait();
|
|
||||||
}
|
|
||||||
if (ctx.err) |e| {
|
|
||||||
defer e.deinit(api);
|
|
||||||
return e.getCode(api).toApiError();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const LoadedExecutable = opaque {
|
|
||||||
const inner = InnerMixin(pjrt.LoadedExecutable).inner;
|
|
||||||
|
|
||||||
// pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
|
||||||
// self.inner().deinit(api);
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
|
||||||
self.inner().delete(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool {
|
|
||||||
return self.inner().isDeleted(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO fix me
|
|
||||||
// pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []*const Device {
|
|
||||||
// return self.inner().getAddressableDevices(api);
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
|
||||||
arguments: []const [*]const *const Buffer,
|
|
||||||
num_args: usize,
|
|
||||||
results: []const [*]*Buffer,
|
|
||||||
events: []*Event,
|
|
||||||
non_donatable_input_indices: []const i64 = &.{},
|
|
||||||
}) ExecuteError!void {
|
|
||||||
try self.inner().execute(api, .{
|
|
||||||
.num_args = args.num_args,
|
|
||||||
.arguments = @ptrCast(args.arguments),
|
|
||||||
.results = @ptrCast(args.results),
|
|
||||||
.events = @ptrCast(args.events),
|
|
||||||
.non_donatable_input_indices = args.non_donatable_input_indices,
|
|
||||||
});
|
|
||||||
|
|
||||||
for (args.events) |event| {
|
|
||||||
// TODO(Corentin): Maybe better handle the error here.
|
|
||||||
event.await_(api) catch return error.Unknown;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
|
||||||
return try self.inner().getExecutable(api);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@ -1,11 +1,11 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const aio = @import("aio.zig");
|
const pjrt = @import("pjrt");
|
||||||
|
const asynk = @import("async");
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
const pjrt_core = @import("pjrt");
|
|
||||||
const log = std.log.scoped(.zml);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
pub const Target = enum {
|
pub const Target = enum {
|
||||||
@ -58,7 +58,7 @@ pub const Platform = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getDevices(self: Platform) []const *const pjrt_core.Device {
|
pub fn getDevices(self: Platform) []const *const pjrt.Device {
|
||||||
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
const all_devices = self.pjrt_client.getAddressableDevices(self.pjrt_api);
|
||||||
if (all_devices.len > MAX_NUM_DEVICES) {
|
if (all_devices.len > MAX_NUM_DEVICES) {
|
||||||
return all_devices[0..MAX_NUM_DEVICES];
|
return all_devices[0..MAX_NUM_DEVICES];
|
||||||
@ -94,7 +94,40 @@ pub const Platform = struct {
|
|||||||
/// Returns the Profiler for this API.
|
/// Returns the Profiler for this API.
|
||||||
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
||||||
/// Platforms with known profiler extensions: cuda, xpu
|
/// Platforms with known profiler extensions: cuda, xpu
|
||||||
pub fn getProfiler(self: Platform, options: pjrt_core.Profiler.Options) pjrt_core.Profiler {
|
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
|
||||||
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Suspend the current co-routine while awaiting for a pjrt event to be over.
|
||||||
|
pub fn awaitEvent(self: Platform, event: *pjrt.Event) !void {
|
||||||
|
defer event.deinit(self.pjrt_api);
|
||||||
|
// If we aren't in a coroutine just use the normal blocking api.
|
||||||
|
if (!asynk.inCoro()) {
|
||||||
|
return try event.await_(self.pjrt_api);
|
||||||
|
}
|
||||||
|
|
||||||
|
var ctx = struct {
|
||||||
|
err: ?*pjrt.Error = null,
|
||||||
|
notif: asynk.Notification,
|
||||||
|
ready: bool = false,
|
||||||
|
}{
|
||||||
|
.notif = try asynk.Notification.init(),
|
||||||
|
};
|
||||||
|
defer ctx.notif.deinit();
|
||||||
|
|
||||||
|
try event.onReady(self.pjrt_api, &(struct {
|
||||||
|
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
||||||
|
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
||||||
|
ctx_.err = err;
|
||||||
|
@atomicStore(bool, &ctx_.ready, true, .seq_cst);
|
||||||
|
ctx_.notif.notify() catch @panic("Unable to notify");
|
||||||
|
}
|
||||||
|
}.call), &ctx);
|
||||||
|
// Suspend
|
||||||
|
try ctx.notif.wait();
|
||||||
|
if (ctx.err) |e| {
|
||||||
|
defer e.deinit(self.pjrt_api);
|
||||||
|
return e.getCode(self.pjrt_api).toApiError();
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -3,7 +3,6 @@ const std = @import("std");
|
|||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const ops = @import("ops.zig");
|
const ops = @import("ops.zig");
|
||||||
@ -3487,46 +3486,6 @@ test "argMax" {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dynamicSlice1d() void {
|
|
||||||
const zml = @import("zml.zig");
|
|
||||||
const platform = zml.testing.env();
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
|
||||||
defer arena_state.deinit();
|
|
||||||
const allocator = arena_state.allocator();
|
|
||||||
const T = f32;
|
|
||||||
|
|
||||||
{
|
|
||||||
defer _ = arena_state.reset(.retain_capacity);
|
|
||||||
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
|
||||||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
|
||||||
var comp = zml.module.CompilationContext.init(allocator, "test", platform, .{});
|
|
||||||
defer comp.deinit();
|
|
||||||
var x_tensor = x.shape();
|
|
||||||
var args: struct { i8, u63, zml.Shape } = .{ 0, 2, z.shape() };
|
|
||||||
var dynamicSlice = try zml.compileRaw(allocator, &comp, Tensor.dynamicSlice1d, &x_tensor, &args);
|
|
||||||
|
|
||||||
var res: [1]*pjrt.Buffer = undefined;
|
|
||||||
dynamicSlice.call(&.{ x._data, z._data }, &res);
|
|
||||||
try testing.expectEqual([2]T{ 4, 5 }, try zml.Buffer.fromPjrtBuffer(platform, res[0]).getValue([2]T));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// Strided
|
|
||||||
var x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
|
||||||
var z = try zml.Buffer.scalar(platform, 3, .i32);
|
|
||||||
|
|
||||||
var comp = zml.module.CompilationContext.init(allocator, "test", platform, .{});
|
|
||||||
defer comp.deinit();
|
|
||||||
var x_tensor = x.shape();
|
|
||||||
var args: struct { i8, u63, zml.Tensor } = .{ 1, 2, z.shape() };
|
|
||||||
var dynamicSlice = try zml.compileRaw(allocator, &comp, Tensor.dynamicSlice1d, &x_tensor, &args);
|
|
||||||
|
|
||||||
var res: [1]*pjrt.Buffer = undefined;
|
|
||||||
dynamicSlice.call(&.{ x._data, z._data }, &res);
|
|
||||||
try testing.expectEqualSlices(T, &.{ 3, 4, 8, 9 }, &(try zml.Buffer.fromPjrtBuffer(platform, res[0]).getValue([4]T)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test "dynamicUpdateSlice1d" {
|
test "dynamicUpdateSlice1d" {
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user