Add PJRT custom call integration with generic zmlHostBufferCallback to copy tensors to host and invoke user callbacks. Introduce Tensor.print() method to output runtime tensor values (CUDA‑specific, uses a pre‑allocated host buffer).
This commit is contained in:
parent
bf23eef0d9
commit
6d720126ac
@ -139,6 +139,14 @@ pub const Api = struct {
|
||||
.minor = @intCast(self.inner.pjrt_api_version.minor_version),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
||||
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
||||
return .{ .inner = ext.custom_call.? };
|
||||
}
|
||||
// log.warn("No Custom Call registry found for platform: {}", .{self});
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
||||
@ -854,3 +862,28 @@ pub const NamedValue = extern struct {
|
||||
try writer.writeAll("}");
|
||||
}
|
||||
};
|
||||
|
||||
/// Custom call signature arguments are:
|
||||
/// * a pointer to a platform specific stream handle
|
||||
/// * a pointer to an unspecified list of platform specific buffer handle
|
||||
/// * a context struct passed as a slice of bytes
|
||||
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
|
||||
|
||||
pub const CustomCallRegistry = extern struct {
|
||||
inner: *const c.PJRT_Gpu_Register_Custom_Call,
|
||||
|
||||
pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void {
|
||||
var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
|
||||
.function_name = name.ptr,
|
||||
.function_name_size = name.len,
|
||||
.api_version = @intCast(api_version),
|
||||
.custom_call_function = @ptrCast(@constCast(func)),
|
||||
});
|
||||
const result = self.inner(&ret);
|
||||
if (result) |pjrt_c_error| {
|
||||
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)});
|
||||
return pjrt_error.getCode(api).toApiError();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -9,6 +9,7 @@ const platform = @import("platform.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
|
||||
const available_targets = @import("platform.zig").available_targets;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Target = @import("platform.zig").Target;
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
|
||||
@ -90,6 +91,9 @@ pub const Context = struct {
|
||||
log.err("No device found for platform {} !", .{target});
|
||||
continue;
|
||||
}
|
||||
if (target == .cuda) {
|
||||
try cuda.registerZmlCustomCalls(p);
|
||||
}
|
||||
platforms.set(target, p);
|
||||
num_platforms += 1;
|
||||
}
|
||||
@ -140,4 +144,78 @@ pub const Context = struct {
|
||||
}
|
||||
return platform_ orelse @panic("No platform found !");
|
||||
}
|
||||
|
||||
pub const HostCallbackCtx = struct {
|
||||
host: HostBuffer,
|
||||
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
||||
};
|
||||
pub const HostCallback = fn (HostBuffer) void;
|
||||
};
|
||||
|
||||
const cuda = struct {
|
||||
var runtime: Runtime = undefined;
|
||||
|
||||
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void {
|
||||
std.debug.assert(cuda_platform.target == .cuda);
|
||||
|
||||
cuda.runtime = try Runtime.init();
|
||||
const registry = cuda_platform.pjrt_api.customCallRegistry().?;
|
||||
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
|
||||
}
|
||||
|
||||
pub const Stream = opaque {};
|
||||
pub const MemcpyKind = enum(c_int) {
|
||||
host_to_host = 0,
|
||||
host_to_device = 1,
|
||||
device_to_host = 2,
|
||||
device_to_device = 3,
|
||||
default = 4,
|
||||
};
|
||||
|
||||
pub const Runtime = struct {
|
||||
memcpyAsync: MemcpyAsync,
|
||||
streamSynchronize: StreamSynchronize,
|
||||
|
||||
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int;
|
||||
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int;
|
||||
|
||||
pub fn init() !Runtime {
|
||||
var cudart = try std.DynLib.open("libcudart.so");
|
||||
defer cudart.close();
|
||||
|
||||
return .{
|
||||
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
|
||||
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
|
||||
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
|
||||
|
||||
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
|
||||
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
|
||||
|
||||
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
|
||||
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
|
||||
return .{ fn_ptr, ctx_ptr };
|
||||
}
|
||||
|
||||
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void {
|
||||
const stream: *Stream = @ptrCast(opaque_stream);
|
||||
const src: *anyopaque = buffers[0];
|
||||
const callback, const ctx = getContext(args, args_len);
|
||||
|
||||
// Add synchronization because this is called from the device driver.
|
||||
ctx.mutex.lock();
|
||||
defer ctx.mutex.unlock();
|
||||
|
||||
const host_dst: []u8 = @constCast(ctx.host.data);
|
||||
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
|
||||
_ = memcpy_result;
|
||||
const synchronize_result = cuda.runtime.streamSynchronize(stream);
|
||||
_ = synchronize_result;
|
||||
|
||||
callback(ctx.host);
|
||||
}
|
||||
};
|
||||
|
||||
59
zml/ops.zig
59
zml/ops.zig
@ -1,25 +1,26 @@
|
||||
const std = @import("std");
|
||||
const mlir = @import("mlir.zig");
|
||||
const buffer = @import("buffer.zig");
|
||||
|
||||
const helpers = @import("helpers.zig");
|
||||
const module = @import("module.zig");
|
||||
const meta = @import("meta.zig");
|
||||
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Context = @import("context.zig").Context;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const Buffer = buffer.Buffer;
|
||||
const EnumLiteral = @TypeOf(.enum_literal);
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
const dialect = struct {
|
||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||
};
|
||||
|
||||
const assert = std.debug.assert;
|
||||
const log = std.log.scoped(.zml_tensor);
|
||||
const log = std.log.scoped(.zml);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
@ -621,7 +622,7 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
|
||||
@memcpy(backend_config[0..8], address[0..8]);
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "name={s}", .{name});
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
|
||||
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
@ -638,3 +639,49 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
|
||||
);
|
||||
return Tensor._result(input.shape(), op.result(0));
|
||||
}
|
||||
|
||||
/// At runtime the given tensor will be materialized and copied to host,
|
||||
/// and the callback will be called on it.
|
||||
pub fn addHostCallback(
|
||||
callback: *const fn (HostBuffer) void,
|
||||
input: Tensor,
|
||||
) Tensor {
|
||||
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy.
|
||||
if (input.getContext().target() != .cuda) return input;
|
||||
|
||||
const len = input.byteSize();
|
||||
// Reserve memory to be able to log the runtime Buffer later during the computation.
|
||||
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
|
||||
const HostCallbackCtx = Context.HostCallbackCtx;
|
||||
const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch {
|
||||
log.err("Failed to pre-allocate buffer to print {}.", .{input});
|
||||
return input;
|
||||
};
|
||||
|
||||
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
|
||||
// Use an fba to have the stable buffer at an aligned offset.
|
||||
var fba = std.heap.FixedBufferAllocator.init(full_data[len..]);
|
||||
const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable;
|
||||
stable_ctx_ptr.* = .{
|
||||
.host = HostBuffer.fromBytes(input.shape(), full_data[0..len]),
|
||||
};
|
||||
|
||||
const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr };
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
&.{input.value()},
|
||||
.{
|
||||
.api_version = 1,
|
||||
.has_side_effect = false,
|
||||
.call_target_name = "zmlHostBufferCallback",
|
||||
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)),
|
||||
.output_operand_aliases = &.{0},
|
||||
},
|
||||
&.{input.value().getType()},
|
||||
loc,
|
||||
);
|
||||
return Tensor._result(input.shape(), op.result(0));
|
||||
}
|
||||
|
||||
@ -3629,6 +3629,27 @@ pub const Tensor = struct {
|
||||
}
|
||||
}.binaryOpHelper;
|
||||
}
|
||||
|
||||
/// Insert code that will print the content of the given buffer at runtime.
|
||||
/// Only for debug purpose, it has the following limitations:
|
||||
/// * only supported on Cuda atm
|
||||
/// * only prints the first 1024 values
|
||||
/// * pre allocates a buffer on the host to copy the content of the device buffer,
|
||||
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
|
||||
/// * does device to host synchronization so it will slow down the program execution.
|
||||
pub fn print(input: Tensor) Tensor {
|
||||
return ops.addHostCallback(&printCallback, input);
|
||||
}
|
||||
|
||||
fn printCallback(host_buffer: HostBuffer) void {
|
||||
switch (host_buffer.dtype()) {
|
||||
inline else => |dt| {
|
||||
const items = host_buffer.items(dt.toZigType());
|
||||
const n = @min(items.len, 1024);
|
||||
std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] });
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user