470 lines
19 KiB
Zig
470 lines
19 KiB
Zig
const std = @import("std");
|
|
|
|
const stdx = @import("stdx");
|
|
|
|
const aio = @import("aio.zig");
|
|
const Buffer = @import("buffer.zig").Buffer;
|
|
const Bufferized = @import("tensor.zig").Bufferized;
|
|
const callback = @import("callback.zig");
|
|
const CompilationContext = @import("module.zig").CompilationContext;
|
|
const meta = @import("meta.zig");
|
|
const pjrt = @import("pjrtx.zig");
|
|
const Platform = @import("platform.zig").Platform;
|
|
const Shape = @import("shape.zig").Shape;
|
|
const ShapeOf = @import("tensor.zig").ShapeOf;
|
|
|
|
const log = std.log.scoped(.@"zml/exe");
|
|
|
|
test {
|
|
std.testing.refAllDecls(@This());
|
|
}
|
|
|
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
|
/// The steps are:
|
|
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
|
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
|
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
|
pub fn compile(
|
|
allocator: std.mem.Allocator,
|
|
comptime func: anytype,
|
|
init_args: anytype,
|
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
|
buffer_store: aio.BufferStore,
|
|
platform: Platform,
|
|
) !FnExe(func) {
|
|
return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, "");
|
|
}
|
|
|
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
|
/// Uses a prefix for looking up model weights in the buffer store.
|
|
/// The steps are:
|
|
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
|
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
|
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
|
pub fn compileWithPrefix(
|
|
allocator: std.mem.Allocator,
|
|
comptime func: anytype,
|
|
init_args: anytype,
|
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
|
buffer_store: aio.BufferStore,
|
|
platform: Platform,
|
|
prefix: []const u8,
|
|
) !FnExe(func) {
|
|
const ModelT = ModuleSignature(func).ModelT;
|
|
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|
defer arena_state.deinit();
|
|
const arena = arena_state.allocator();
|
|
var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix);
|
|
|
|
// If the Model has a "init" function, call it with the given parameters.
|
|
if (@hasDecl(ModelT, "init")) {
|
|
// TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void.
|
|
@call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args);
|
|
}
|
|
|
|
return compileModel(allocator, func, model, args_shapes, platform);
|
|
}
|
|
|
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
|
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
|
pub fn compileModel(
|
|
allocator: std.mem.Allocator,
|
|
comptime func: anytype,
|
|
model: ModuleSignature(func).ModelT,
|
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
|
platform: Platform,
|
|
) !FnExe(func) {
|
|
const ModelT = ModuleSignature(func).ModelT;
|
|
const name = @typeName(ModelT) ++ ".forward";
|
|
log.info("Compiling {s} with {}", .{ name, args_shapes });
|
|
|
|
var context = try CompilationContext.init(allocator, name, platform);
|
|
defer context.deinit();
|
|
|
|
return .{ .inner = try context.compileInternal(allocator, func, .{model} ++ args_shapes) };
|
|
}
|
|
|
|
/// Compiles a function with the given configuration and shapes, for the given platform.
|
|
/// Generate MLIR by calling the given function with tensor of the given shapes.
|
|
pub fn compileFn(
|
|
allocator: std.mem.Allocator,
|
|
comptime func: anytype,
|
|
args: ShapeOf(stdx.meta.FnArgs(func)),
|
|
platform: Platform,
|
|
) !FnExe(func) {
|
|
var pretty_name = try prettyFnName(func, allocator);
|
|
defer pretty_name.deinit(allocator);
|
|
var context = try CompilationContext.init(allocator, pretty_name.items, platform);
|
|
defer context.deinit();
|
|
|
|
return .{ .inner = try context.compileInternal(allocator, func, args) };
|
|
}
|
|
|
|
pub fn FnExe(comptime func: anytype) type {
|
|
return Exe(stdx.meta.FnArgs(func), stdx.meta.FnResult(func));
|
|
}
|
|
|
|
/// Represents a ZML model, compiled into a PJRT executable, and ready to call.
|
|
/// The buffers for the model weights are saved inside the struct and will be used in `call`.
|
|
/// You only need to pass the remaining arguments.
|
|
/// Creating a `ModuleExe` is a two steps proccess:
|
|
///
|
|
/// ```
|
|
/// const exe: zml.FnExe(MyModel.forward) = try zml.compile(allocator, MyModel.forward, init_args, model_shapes, buffer_store, platform);`
|
|
/// const module: zml.ModuleExe(MyModel.forward) = exe.prepare(model_buffers);
|
|
/// ```
|
|
pub fn ModuleExe(comptime func: anytype) type {
|
|
const AllArgs = stdx.meta.FnArgs(func);
|
|
const len = @typeInfo(AllArgs).@"struct".fields.len;
|
|
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
|
return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func));
|
|
}
|
|
|
|
// making this a struct force all fields to be evaluted on creation,
|
|
// which gives a better error stacktrace
|
|
// than delaying the error to when the object fields are read.
|
|
const Sign = struct {
|
|
ModelT: type,
|
|
ArgsT: type,
|
|
ReturnT: type,
|
|
};
|
|
|
|
pub fn ModuleSignature(comptime func: anytype) Sign {
|
|
const AllArgsT = stdx.meta.FnArgs(func);
|
|
const len = @typeInfo(AllArgsT).@"struct".fields.len;
|
|
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
|
|
|
return .{
|
|
.ModelT = stdx.meta.Head(AllArgsT),
|
|
.ArgsT = stdx.meta.Tail(AllArgsT),
|
|
.ReturnT = stdx.meta.FnResult(func),
|
|
};
|
|
}
|
|
|
|
/// Represents an MLIR module compiled into a PJRT executable.
|
|
/// The BaseExe is a plain old struct and doesn't have information about Zig types.
|
|
///
|
|
/// It also contains pre-allocated buffers so that we can pass them to PJRT_LoadedExecutable_Execute
|
|
/// without allocations.
|
|
pub const BaseExe = struct {
|
|
/// The platform for which this module was compiled.
|
|
platform: Platform,
|
|
|
|
/// The PJRT executable representing the compiled module.
|
|
exe: *pjrt.LoadedExecutable,
|
|
|
|
/// The execution context for this executable.
|
|
execute_context: ?*pjrt.ExecuteContext,
|
|
|
|
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
|
input_per_device: []const [*]*pjrt.Buffer,
|
|
|
|
/// Pre-allocated slice of buffers to use as outputs when the module is called.
|
|
output_per_device: []const [*]*pjrt.Buffer,
|
|
|
|
/// Number of buffers already fed to the executable.
|
|
ready_buffer_count: u32,
|
|
|
|
/// Total number of buffers needed by this executable.
|
|
input_buffer_count: u32,
|
|
|
|
input_shapes: []Shape,
|
|
result_shapes: []Shape,
|
|
|
|
/// Num devices used (>1 for sharded executable)
|
|
num_devices: u8,
|
|
|
|
/// Allocator backing memory
|
|
_arena: std.heap.ArenaAllocator,
|
|
|
|
pub fn init(
|
|
parent_allocator: std.mem.Allocator,
|
|
platform: Platform,
|
|
exe: *pjrt.LoadedExecutable,
|
|
args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 },
|
|
) !BaseExe {
|
|
var arena = std.heap.ArenaAllocator.init(parent_allocator);
|
|
errdefer arena.deinit();
|
|
const allocator = arena.allocator();
|
|
const n_in = args.input_shapes.len;
|
|
const n_out = args.result_shapes.len;
|
|
const n_devices = args.n_devices;
|
|
// Allocate once for all the *pjrt.Buffer we need to store ...
|
|
const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices);
|
|
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices });
|
|
|
|
// ... and once for all the [*]*pjrt.Buffer.
|
|
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
|
|
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
|
|
|
|
for (0..n_devices) |i| {
|
|
input_per_device[i] = all_input_buffers[i * n_in ..].ptr;
|
|
output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
|
|
}
|
|
|
|
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
|
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
|
@memcpy(all_shapes[n_in..], args.result_shapes);
|
|
|
|
var execute_context: ?*pjrt.ExecuteContext = null;
|
|
if (platform.pjrt_api.ffi()) |ffi| {
|
|
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
|
|
execute_context = try platform.pjrt_api.createExecuteContext();
|
|
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
|
|
}
|
|
|
|
return .{
|
|
.platform = platform,
|
|
.exe = exe,
|
|
.execute_context = execute_context,
|
|
.ready_buffer_count = 0,
|
|
.input_buffer_count = @intCast(n_in),
|
|
.num_devices = args.n_devices,
|
|
.input_per_device = input_per_device,
|
|
.output_per_device = output_per_device,
|
|
.input_shapes = all_shapes[0..n_in],
|
|
.result_shapes = all_shapes[n_in..],
|
|
._arena = arena,
|
|
};
|
|
}
|
|
|
|
pub fn deinit(self: BaseExe) void {
|
|
if (self.execute_context) |ctx| {
|
|
ctx.deinit(self.platform.pjrt_api);
|
|
}
|
|
self._arena.deinit();
|
|
}
|
|
|
|
pub fn call(self: BaseExe) void {
|
|
stdx.debug.assert(self.input_buffer_count == self.ready_buffer_count, "BaseExe isn't ready to be called, expected {} buffer inputs got {}", .{ self.input_buffer_count, self.ready_buffer_count });
|
|
return self._unsafeCall();
|
|
}
|
|
|
|
pub fn _unsafeCall(self: BaseExe) void {
|
|
var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES;
|
|
const sharding = self.platform.sharding();
|
|
|
|
self.exe.execute(self.platform.pjrt_api, .{
|
|
.arguments = self.input_per_device,
|
|
.num_args = self.input_buffer_count,
|
|
.results = self.output_per_device,
|
|
.events = events[0..sharding.num_partitions],
|
|
// this allows to tell a specific buffer shouldn't be donated,
|
|
// even if it has been marked as "can be donated" during compilation.
|
|
// TODO: expose it ?
|
|
.non_donatable_input_indices = &.{},
|
|
.context = self.execute_context,
|
|
}) catch |err| {
|
|
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
|
};
|
|
|
|
// for (events[0..sharding.num_partitions]) |e| {
|
|
// if (e) |ev| {
|
|
// ev.await_(self.platform.pjrt_api) catch unreachable;
|
|
// }
|
|
// }
|
|
}
|
|
|
|
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
|
const LocalContext = struct {
|
|
index: u32,
|
|
platform: Platform,
|
|
outputs: []const [*]*pjrt.Buffer,
|
|
output_shapes: []Shape,
|
|
};
|
|
var local_ctx: LocalContext = .{
|
|
.index = 0,
|
|
.platform = self.platform,
|
|
.outputs = self.output_per_device,
|
|
.output_shapes = self.result_shapes,
|
|
};
|
|
meta.visit((struct {
|
|
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
|
const i = ctx.index;
|
|
ctx.index += 1;
|
|
if (i >= ctx.output_shapes.len) return;
|
|
|
|
var shards: Buffer.Shards = .{};
|
|
for (ctx.outputs) |buff| {
|
|
shards.appendAssumeCapacity(buff[i]);
|
|
}
|
|
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice());
|
|
}
|
|
}).cb, &local_ctx, result);
|
|
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
|
}
|
|
|
|
pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void {
|
|
stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{});
|
|
const pjrt_api = exe.platform.pjrt_api;
|
|
|
|
if (pjrt_api.ffi()) |ffi| {
|
|
try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op);
|
|
} else {
|
|
stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)});
|
|
}
|
|
}
|
|
|
|
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
|
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
|
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
|
defer serialize_result.deinit();
|
|
try writer.writeAll(serialize_result.bytes);
|
|
}
|
|
|
|
// pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self {
|
|
// const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size);
|
|
// defer allocator.free(bytes);
|
|
// return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes);
|
|
// }
|
|
|
|
pub fn prepare(self: *BaseExe, x: anytype) void {
|
|
const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count);
|
|
self.ready_buffer_count += n;
|
|
}
|
|
|
|
pub fn getOutputBuffer(self: BaseExe, i: usize) Buffer {
|
|
var shards: Buffer.Shards = .{};
|
|
for (self.output_per_device) |dev_out| {
|
|
shards.appendAssumeCapacity(dev_out[i]);
|
|
}
|
|
|
|
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
|
|
}
|
|
|
|
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
|
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
|
|
.input_shapes = self.input_shapes,
|
|
.result_shapes = self.result_shapes,
|
|
.n_devices = self.num_devices,
|
|
});
|
|
exe.execute_context = self.execute_context;
|
|
return exe;
|
|
}
|
|
};
|
|
|
|
/// Represents a ZML function, compiled into a PJRT executable.
|
|
/// The signature of the Exe reflects the arguments that are needed for `call`.
|
|
pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|
return struct {
|
|
const Self = @This();
|
|
|
|
/// The raw untyped compiled module.
|
|
inner: BaseExe,
|
|
|
|
pub fn deinit(self: Self) void {
|
|
self.inner.deinit();
|
|
}
|
|
|
|
/// Hardcode the first argument of the function to the given buffers.
|
|
/// Returns an Exe with one less argument in `call`.
|
|
/// In functional languages this is known as partial application.
|
|
///
|
|
/// **Warning:** the new Exe reuses the underlying memory of the previous one.
|
|
/// The caller is responsible to come up with a strategy to call `deinit` exactly once.
|
|
pub fn prepare(self: Self, first_arg: Bufferized(stdx.meta.Head(ArgsT))) Exe(stdx.meta.Tail(ArgsT), ReturnT) {
|
|
var new: Exe(stdx.meta.Tail(ArgsT), ReturnT) = .{ .inner = self.inner };
|
|
new.inner.prepare(first_arg);
|
|
return new;
|
|
}
|
|
|
|
/// For a given customCall inside this executable,
|
|
/// provide a pointer to runtime data.
|
|
/// The caller keeps memory ownership and need to ensure that the value
|
|
/// stays alive as long as the executable.
|
|
pub fn bind(self: Self, comptime T: type, value: *T) !void {
|
|
try self.inner.bind(T, value);
|
|
}
|
|
|
|
pub fn serialize(self: Self, writer: anytype) !void {
|
|
return try self.inner.serialize(writer);
|
|
}
|
|
|
|
pub fn platform(self: Self) Platform {
|
|
return self.inner.platform;
|
|
}
|
|
|
|
pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
|
|
const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count);
|
|
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
|
self.inner._unsafeCall();
|
|
var result: Bufferized(ReturnT) = undefined;
|
|
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
|
return result;
|
|
}
|
|
};
|
|
}
|
|
|
|
fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
|
|
var res: [lengths.len][]T = undefined;
|
|
var i: usize = 0;
|
|
inline for (&res, lengths) |*r, len| {
|
|
r.* = buffer[i .. i + len];
|
|
i += len;
|
|
}
|
|
std.debug.assert(i == buffer.len);
|
|
return res;
|
|
}
|
|
|
|
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
|
fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
|
|
const LocalContext = struct {
|
|
index: u32,
|
|
buffers: []const [*]*pjrt.Buffer,
|
|
shapes: []const Shape,
|
|
};
|
|
var context: LocalContext = .{
|
|
.index = start,
|
|
.buffers = buffers,
|
|
.shapes = shapes,
|
|
};
|
|
meta.visit((struct {
|
|
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
|
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
|
const model_sharding = ctx.buffers.len;
|
|
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
|
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
|
|
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
|
ctx.buffers[d][ctx.index] = shard;
|
|
}
|
|
ctx.index += 1;
|
|
}
|
|
}).cb, &context, v);
|
|
return context.index;
|
|
}
|
|
|
|
fn prettyFnName(
|
|
comptime func: anytype,
|
|
allocator: std.mem.Allocator,
|
|
) !std.ArrayListUnmanaged(u8) {
|
|
const full_noisy_name = @typeName(@TypeOf(func));
|
|
const og_len = full_noisy_name.len;
|
|
const buffer = try allocator.alloc(u8, og_len);
|
|
errdefer comptime unreachable; // No errors below this point.
|
|
var out: []u8 = buffer;
|
|
|
|
{
|
|
const verbose = "tensor.Tensor";
|
|
const compact = "Tensor";
|
|
const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer);
|
|
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
|
}
|
|
|
|
{
|
|
const verbose = "tensor.Tensor.";
|
|
const compact = "";
|
|
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
|
|
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
|
}
|
|
|
|
{
|
|
const verbose = "shape.Shape";
|
|
const compact = "Shape";
|
|
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
|
|
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
|
}
|
|
|
|
return .{ .items = out, .capacity = og_len };
|
|
}
|