Radix/zml/callback.zig

323 lines
14 KiB
Zig

const std = @import("std");
const async = @import("async");
const mlir = @import("mlir");
const pjrt = @import("pjrt");
const stablehlo = @import("mlir/dialects").stablehlo;
const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = @import("module.zig").CompilationContext;
const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const mlirx = @import("mlirx.zig");
const pjrtx = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.@"zml/callback");
/// Inserts a user-defined callback into the computation graph.
/// The callback is defined with a struct, that store runtime information needed by the callback.
///
/// ❗Experimental API❗
///
/// ```zig
/// pub const MyCallback = struct {
/// // a unique type_id will be set by the PJRT plugin during registration.
/// pub var type_id: pjrt.ffi.TypeId = undefined;
///
/// pub const callback_config: zml.callback.Config = .{
/// // assumption this custom call makes about the input / output buffers
/// };
///
/// // Required, this will tell the callback in which env it runs.
/// platform: zml.Platform,
/// // data needed by the callback
/// my_data: []const u8,
///
/// // storage modified by the runtime to tell the callback where it should write its results.
/// // Normally the callback doesn't need to allocate as the input and output buffers are given.
/// results: [1]Buffer = undefined,
///
/// pub fn init(my_data: []const u8) !MyCallback {
/// return .{ .my_data = my_data };
/// }
///
/// pub fn call(callback: *MyCallback, input: Buffer) !void {
/// // Do something with `input` and `callback.my_data`, write the results inside `callback.results[0]`
/// }
/// };
/// ```
///
/// See eg the implementation of the `zml.callback.Print` callback, for a practical example.
///
/// Note calling this during the compilation of a module, isn't enough:
///
/// * backend need to be made aware of the callback, see `zml.Platform.registerCallback`
/// * executable need to know the specific data needed by `MyCallback`, see `zml.Exe.bind`
pub fn call(
comptime Callback: type,
inputs: TensorArgs(Callback),
output_shapes: []const Shape,
) []Tensor {
checkIsValidCallback(Callback);
const ctx = CompilationContext.current();
const allocator = ctx.allocator();
const mlir_ctx = ctx.mlirCtx();
const platform = ctx._platform;
const pjrt_api = platform.pjrt_api;
if (pjrt_api.ffi() == null) {
stdx.debug.panic("Custom calls are not supported for target {s}", .{@tagName(platform.target)});
}
const output_tensors = allocator.alloc(Tensor, output_shapes.len) catch @panic("OOM");
// Note: we don't always free output_tensor, because it's returned to the caller.
// It's also why we allocate it first so that it doesn't fragment the arena.
errdefer allocator.free(output_tensors);
const output_types = allocator.alloc(mlir.Type, output_shapes.len) catch @panic("OOM");
defer allocator.free(output_types);
for (output_types, output_shapes) |*output_type, output_shape| {
output_type.* = mlirx.tensorType(mlir_ctx, output_shape);
}
const input_values = allocator.alloc(mlir.Value, inputs.len) catch @panic("OOM");
defer allocator.free(input_values);
for (input_values, inputs) |*input_value, input_tensor| {
input_value.* = input_tensor.value();
}
const target_name = "zml$" ++ @typeName(Callback);
const op = stablehlo.custom_call(
mlir_ctx,
input_values,
.{
.call_target_name = target_name,
.api_version = .typed_ffi,
.backend_config = .dict(mlir_ctx, &.{}),
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
.has_side_effect = Callback.callback_config.has_side_effect,
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
},
output_types,
mlir_ctx.location(@src()),
);
for (output_tensors, output_shapes, 0..) |*output_tensor, output_shape, i| {
output_tensor.* = Tensor._result(output_shape, op.result(i));
}
return output_tensors;
}
/// Describe properties of a callback
///
/// * output_operand_aliases: the callback reuse input buffer to write the output
/// * copy_inputs_to_host_pinned: the callback need to work on host visible buffers
/// * traits: PJRT specified properties of the callback
pub const Config = struct {
output_operand_aliases: []const i64 = &.{},
copy_inputs_to_host_pinned: bool = false,
/// Indicates that the handler is compatible with command buffers (ie Cuda graphs).
/// The Cuda backend traces the execution of the callback and records it as a Cuda graph.
/// To be compatible with CUDA graphs, the callback need:
/// 1. to not use any synchronous operations
/// 2. to not allocate any new host/device memory
/// See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#prohibited-and-unhandled-operations.
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
// TODO: handle sharded inputs
has_side_effect: bool = true,
};
/// Compile-time check that a callback has all informations we require.
pub fn checkIsValidCallback(Callback: type) void {
stdx.debug.assertComptime(@hasDecl(Callback, "call"), "Expected callback {} to have a call method", .{Callback});
const ArgsT = stdx.meta.FnArgs(Callback.call);
inline for (@typeInfo(ArgsT).@"struct".fields[1..]) |field| {
stdx.debug.assertComptime(field.type == Buffer, "Expected callback {}.call arguments to be of type zml.Buffer, got {}", .{ Callback, field.type });
}
stdx.debug.assertComptime(@hasDecl(Callback, "type_id") and @TypeOf(Callback.type_id) == pjrt.ffi.TypeId, "Expected callback {} to have a field `pub var type_id: pjrt.ffi.TypeId`", .{Callback});
stdx.debug.assertComptime(@hasDecl(Callback, "callback_config") and @TypeOf(Callback.callback_config) == Config, "Expected callback {} to have a field `pub const callback_config: zml.CustomCallOptions`", .{Callback});
}
pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
checkIsValidCallback(Callback);
const ffi = platform.pjrt_api.ffi() orelse return error.Unavailable;
const target_name = "zml$" ++ @typeName(Callback);
const proxy_cb = proxy(Callback);
const type_info: pjrt.Ffi.TypeInfo = .{};
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback), &type_info.toCStruct());
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
}
fn proxy(Callback: type) pjrt.ffi.Handler {
return struct {
pub fn cb(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
return CallbackImpl(Callback, call_frame);
}
}.cb;
}
fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt.ffi.Error {
if (call_frame.registeringHook()) return null;
const opts = Callback.callback_config;
const execution_context = call_frame.ctx;
log.debug("Custom call {s} called !", .{@typeName(Callback)});
const user_ctx_opaque = execution_context.getContext(Callback.type_id, call_frame.api) catch {
log.err("{} user data was never given for current executable", .{Callback});
return .create(call_frame.api, .failed_precondition, "failed to fetch user context" ++ @typeName(Callback));
};
const user_ctx: *Callback = @ptrCast(@alignCast(user_ctx_opaque));
// We actually have one more constraint here, we force the Callback to have a platform field,
// and to correctly set it.
// Is this good ? We could also simplify this by registering ourselves the `Platform` type id.
const platform: Platform = user_ctx.platform;
// Hook to get a cuda stream in the callback.
if (@hasField(Callback, "stream") and platform.target != .cpu) {
const stream = call_frame.api.stream(execution_context);
user_ctx.stream = stream;
}
var callback_args: std.meta.ArgsTuple(@TypeOf(Callback.call)) = undefined;
callback_args[0] = user_ctx;
inline for (1..callback_args.len, call_frame.args.buffers()) |i, ffi_buffer| {
const shape = shapeFromFfi(ffi_buffer);
var zml_buffer: Buffer = if (platform.target == .cpu)
.asViewOfHostBuffer(platform, .fromBytes(shape, ffi_buffer.data[0..shape.byteSize()]))
else
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
// log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()) });
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()), err });
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
};
// log.debug("--> {f} {x}", .{ zml_buffer, @intFromPtr(zml_buffer.devicePtr()) });
}
callback_args[i] = zml_buffer;
}
defer {
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
inline for (1..callback_args.len) |i| callback_args[i].deinit();
}
}
for (0..call_frame.results.len) |i| {
const ffi_buffer = call_frame.results.buffers()[i];
const ffi_buffer_shape = shapeFromFfi(ffi_buffer);
if (platform.target == .cpu) {
user_ctx.results[i] = Buffer.asViewOfHostBuffer(platform, HostBuffer.fromBytes(ffi_buffer_shape, ffi_buffer.data[0..ffi_buffer_shape.byteSize()]));
} else {
user_ctx.results[i] = Buffer.asViewOfDeviceBuffer(platform, shapeFromFfi(ffi_buffer), null, ffi_buffer.data);
}
}
@call(.auto, Callback.call, callback_args) catch |err| {
log.err("Callback {} failed with {}", .{ Callback, err });
return .create(call_frame.api, .internal, "internal callback error");
};
return .ok;
}
/// Internal custom calls.
/// These are not meant to be used by users, but rather by the library itself.
pub const internal_callbacks = [_]type{
Print,
};
pub fn registerInternalCallbacks(platform: Platform) !void {
inline for (internal_callbacks) |Callback| {
try register(Callback, platform);
// log.debug("Registered internal custom call {s} with type_id {d}", .{ @typeName(Callback), Callback.type_id.type_id });
}
}
/// Allocate user data data needed by the ZML provided custom calls.
pub fn bindInternalCallbacks(
arena: std.mem.Allocator,
platform: Platform,
ffi: pjrt.Ffi,
execute_context: *pjrt.ExecuteContext,
) (std.mem.Allocator.Error || pjrt.ApiError)!void {
// Atm we don't have a mechanism to detect which ZML callbacks the executable needs,
// so we always allocate.
{
// Print
const print_ptr = try arena.create(Print);
print_ptr.* = try .init(platform);
try addUserData(Print, platform.pjrt_api, ffi, execute_context, print_ptr);
}
}
pub fn addUserData(
Callback: type,
api: *const pjrt.Api,
ffi: pjrt.Ffi,
execute_context: *pjrt.ExecuteContext,
user_data: *Callback,
) pjrt.ApiError!void {
try ffi.addUserData(
api,
execute_context,
.{ .type_id = Callback.type_id.type_id, .user_data = @ptrCast(user_data) },
);
log.debug("Bound {s}@{x} with type id {d} on {any}", .{ @typeName(Callback), @intFromPtr(user_data), Callback.type_id.type_id, execute_context });
}
/// The print callback
pub const Print = struct {
// a unique type_id will be set by the PJRT plugin during registration.
pub var type_id: pjrt.ffi.TypeId = undefined;
pub const callback_config: Config = .{
// Print callback pretends to modify the given input buffer, but just returns it unmodified.
.output_operand_aliases = &.{0},
// It also needs PJRT to copy the data on the host first so it can print it.
.copy_inputs_to_host_pinned = true,
// Print is fairly predictable and can be captured in an execution graph.
.traits = .{ .command_buffer_compatible = false },
.has_side_effect = false,
};
platform: Platform,
results: [1]Buffer = undefined,
pub fn init(platform: Platform) !Print {
return .{ .platform = platform };
}
pub fn call(_: *Print, input: Buffer) !void {
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {d:20.3}", .{ input, input.asHostBuffer() });
}
};
fn shapeFromFfi(ffi_buffer: *const pjrt.ffi.Buffer) Shape {
const dt: DataType = switch (ffi_buffer.dtype) {
.invalid => stdx.debug.panic("Invalid FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
.token, .f8e4m3, .f8e3m4 => stdx.debug.panic("Unsupported FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
inline else => |t| @field(DataType, @tagName(t)),
};
return Shape.init(ffi_buffer.dims(), dt);
}
fn TensorArgs(Callback: type) type {
const ArgsT = stdx.meta.FnArgs(Callback.call);
const args = @typeInfo(ArgsT).@"struct".fields;
return [args.len - 1]Tensor;
}