Add PJRT custom call integration with generic zmlHostBufferCallback to copy tensors to host and invoke user callbacks. Introduce Tensor.print() method to output runtime tensor values (CUDA‑specific, uses a pre‑allocated host buffer).
This commit is contained in:
parent
bf23eef0d9
commit
6d720126ac
@ -139,6 +139,14 @@ pub const Api = struct {
|
|||||||
.minor = @intCast(self.inner.pjrt_api_version.minor_version),
|
.minor = @intCast(self.inner.pjrt_api_version.minor_version),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
||||||
|
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
||||||
|
return .{ .inner = ext.custom_call.? };
|
||||||
|
}
|
||||||
|
// log.warn("No Custom Call registry found for platform: {}", .{self});
|
||||||
|
return null;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
||||||
@ -854,3 +862,28 @@ pub const NamedValue = extern struct {
|
|||||||
try writer.writeAll("}");
|
try writer.writeAll("}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Custom call signature arguments are:
|
||||||
|
/// * a pointer to a platform specific stream handle
|
||||||
|
/// * a pointer to an unspecified list of platform specific buffer handle
|
||||||
|
/// * a context struct passed as a slice of bytes
|
||||||
|
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
|
||||||
|
|
||||||
|
pub const CustomCallRegistry = extern struct {
|
||||||
|
inner: *const c.PJRT_Gpu_Register_Custom_Call,
|
||||||
|
|
||||||
|
pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
|
||||||
|
.function_name = name.ptr,
|
||||||
|
.function_name_size = name.len,
|
||||||
|
.api_version = @intCast(api_version),
|
||||||
|
.custom_call_function = @ptrCast(@constCast(func)),
|
||||||
|
});
|
||||||
|
const result = self.inner(&ret);
|
||||||
|
if (result) |pjrt_c_error| {
|
||||||
|
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||||
|
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)});
|
||||||
|
return pjrt_error.getCode(api).toApiError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -9,6 +9,7 @@ const platform = @import("platform.zig");
|
|||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
|
|
||||||
const available_targets = @import("platform.zig").available_targets;
|
const available_targets = @import("platform.zig").available_targets;
|
||||||
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
|
|
||||||
@ -90,6 +91,9 @@ pub const Context = struct {
|
|||||||
log.err("No device found for platform {} !", .{target});
|
log.err("No device found for platform {} !", .{target});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (target == .cuda) {
|
||||||
|
try cuda.registerZmlCustomCalls(p);
|
||||||
|
}
|
||||||
platforms.set(target, p);
|
platforms.set(target, p);
|
||||||
num_platforms += 1;
|
num_platforms += 1;
|
||||||
}
|
}
|
||||||
@ -140,4 +144,78 @@ pub const Context = struct {
|
|||||||
}
|
}
|
||||||
return platform_ orelse @panic("No platform found !");
|
return platform_ orelse @panic("No platform found !");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const HostCallbackCtx = struct {
|
||||||
|
host: HostBuffer,
|
||||||
|
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
||||||
|
};
|
||||||
|
pub const HostCallback = fn (HostBuffer) void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const cuda = struct {
|
||||||
|
var runtime: Runtime = undefined;
|
||||||
|
|
||||||
|
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void {
|
||||||
|
std.debug.assert(cuda_platform.target == .cuda);
|
||||||
|
|
||||||
|
cuda.runtime = try Runtime.init();
|
||||||
|
const registry = cuda_platform.pjrt_api.customCallRegistry().?;
|
||||||
|
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Stream = opaque {};
|
||||||
|
pub const MemcpyKind = enum(c_int) {
|
||||||
|
host_to_host = 0,
|
||||||
|
host_to_device = 1,
|
||||||
|
device_to_host = 2,
|
||||||
|
device_to_device = 3,
|
||||||
|
default = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Runtime = struct {
|
||||||
|
memcpyAsync: MemcpyAsync,
|
||||||
|
streamSynchronize: StreamSynchronize,
|
||||||
|
|
||||||
|
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int;
|
||||||
|
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int;
|
||||||
|
|
||||||
|
pub fn init() !Runtime {
|
||||||
|
var cudart = try std.DynLib.open("libcudart.so");
|
||||||
|
defer cudart.close();
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
|
||||||
|
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
|
||||||
|
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
|
||||||
|
|
||||||
|
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
|
||||||
|
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
|
||||||
|
|
||||||
|
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
|
||||||
|
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
|
||||||
|
return .{ fn_ptr, ctx_ptr };
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void {
|
||||||
|
const stream: *Stream = @ptrCast(opaque_stream);
|
||||||
|
const src: *anyopaque = buffers[0];
|
||||||
|
const callback, const ctx = getContext(args, args_len);
|
||||||
|
|
||||||
|
// Add synchronization because this is called from the device driver.
|
||||||
|
ctx.mutex.lock();
|
||||||
|
defer ctx.mutex.unlock();
|
||||||
|
|
||||||
|
const host_dst: []u8 = @constCast(ctx.host.data);
|
||||||
|
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
|
||||||
|
_ = memcpy_result;
|
||||||
|
const synchronize_result = cuda.runtime.streamSynchronize(stream);
|
||||||
|
_ = synchronize_result;
|
||||||
|
|
||||||
|
callback(ctx.host);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
59
zml/ops.zig
59
zml/ops.zig
@ -1,25 +1,26 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const buffer = @import("buffer.zig");
|
|
||||||
|
|
||||||
const helpers = @import("helpers.zig");
|
const helpers = @import("helpers.zig");
|
||||||
const module = @import("module.zig");
|
const module = @import("module.zig");
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
|
|
||||||
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const CompilationContext = module.CompilationContext;
|
const CompilationContext = module.CompilationContext;
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Context = @import("context.zig").Context;
|
||||||
const Shape = @import("shape.zig").Shape;
|
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
const Buffer = buffer.Buffer;
|
|
||||||
const EnumLiteral = @TypeOf(.enum_literal);
|
const EnumLiteral = @TypeOf(.enum_literal);
|
||||||
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
|
||||||
const dialect = struct {
|
const dialect = struct {
|
||||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||||
};
|
};
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const log = std.log.scoped(.zml_tensor);
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -621,7 +622,7 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
|
|||||||
@memcpy(backend_config[0..8], address[0..8]);
|
@memcpy(backend_config[0..8], address[0..8]);
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "name={s}", .{name});
|
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
|
||||||
|
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
@ -638,3 +639,49 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
|
|||||||
);
|
);
|
||||||
return Tensor._result(input.shape(), op.result(0));
|
return Tensor._result(input.shape(), op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// At runtime the given tensor will be materialized and copied to host,
|
||||||
|
/// and the callback will be called on it.
|
||||||
|
pub fn addHostCallback(
|
||||||
|
callback: *const fn (HostBuffer) void,
|
||||||
|
input: Tensor,
|
||||||
|
) Tensor {
|
||||||
|
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy.
|
||||||
|
if (input.getContext().target() != .cuda) return input;
|
||||||
|
|
||||||
|
const len = input.byteSize();
|
||||||
|
// Reserve memory to be able to log the runtime Buffer later during the computation.
|
||||||
|
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
|
||||||
|
const HostCallbackCtx = Context.HostCallbackCtx;
|
||||||
|
const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch {
|
||||||
|
log.err("Failed to pre-allocate buffer to print {}.", .{input});
|
||||||
|
return input;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
|
||||||
|
// Use an fba to have the stable buffer at an aligned offset.
|
||||||
|
var fba = std.heap.FixedBufferAllocator.init(full_data[len..]);
|
||||||
|
const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable;
|
||||||
|
stable_ctx_ptr.* = .{
|
||||||
|
.host = HostBuffer.fromBytes(input.shape(), full_data[0..len]),
|
||||||
|
};
|
||||||
|
|
||||||
|
const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr };
|
||||||
|
const ctx = CompilationContext.current();
|
||||||
|
|
||||||
|
const loc = ctx.mlirCtx().location(@src());
|
||||||
|
const op = dialect.stablehlo.custom_call(
|
||||||
|
ctx.mlirCtx(),
|
||||||
|
&.{input.value()},
|
||||||
|
.{
|
||||||
|
.api_version = 1,
|
||||||
|
.has_side_effect = false,
|
||||||
|
.call_target_name = "zmlHostBufferCallback",
|
||||||
|
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)),
|
||||||
|
.output_operand_aliases = &.{0},
|
||||||
|
},
|
||||||
|
&.{input.value().getType()},
|
||||||
|
loc,
|
||||||
|
);
|
||||||
|
return Tensor._result(input.shape(), op.result(0));
|
||||||
|
}
|
||||||
|
|||||||
@ -3629,6 +3629,27 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
}.binaryOpHelper;
|
}.binaryOpHelper;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Insert code that will print the content of the given buffer at runtime.
|
||||||
|
/// Only for debug purpose, it has the following limitations:
|
||||||
|
/// * only supported on Cuda atm
|
||||||
|
/// * only prints the first 1024 values
|
||||||
|
/// * pre allocates a buffer on the host to copy the content of the device buffer,
|
||||||
|
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
|
||||||
|
/// * does device to host synchronization so it will slow down the program execution.
|
||||||
|
pub fn print(input: Tensor) Tensor {
|
||||||
|
return ops.addHostCallback(&printCallback, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn printCallback(host_buffer: HostBuffer) void {
|
||||||
|
switch (host_buffer.dtype()) {
|
||||||
|
inline else => |dt| {
|
||||||
|
const items = host_buffer.items(dt.toZigType());
|
||||||
|
const n = @min(items.len, 1024);
|
||||||
|
std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] });
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {
|
fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user