Add PJRT custom call integration with generic zmlHostBufferCallback to copy tensors to host and invoke user callbacks. Introduce Tensor.print() method to output runtime tensor values (CUDA‑specific, uses a pre‑allocated host buffer).

This commit is contained in:
Tarry Singh 2023-06-05 13:42:45 +00:00
parent bf23eef0d9
commit 6d720126ac
4 changed files with 185 additions and 6 deletions

View File

@ -139,6 +139,14 @@ pub const Api = struct {
.minor = @intCast(self.inner.pjrt_api_version.minor_version), .minor = @intCast(self.inner.pjrt_api_version.minor_version),
}; };
} }
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
return .{ .inner = ext.custom_call.? };
}
// log.warn("No Custom Call registry found for platform: {}", .{self});
return null;
}
}; };
pub const ErrorCode = enum(c.PJRT_Error_Code) { pub const ErrorCode = enum(c.PJRT_Error_Code) {
@ -854,3 +862,28 @@ pub const NamedValue = extern struct {
try writer.writeAll("}"); try writer.writeAll("}");
} }
}; };
/// Custom call signature arguments are:
/// * a pointer to a platform specific stream handle
/// * a pointer to an unspecified list of platform specific buffer handle
/// * a context struct passed as a slice of bytes
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_Gpu_Register_Custom_Call,
pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void {
var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
.function_name = name.ptr,
.function_name_size = name.len,
.api_version = @intCast(api_version),
.custom_call_function = @ptrCast(@constCast(func)),
});
const result = self.inner(&ret);
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)});
return pjrt_error.getCode(api).toApiError();
}
}
};

View File

@ -9,6 +9,7 @@ const platform = @import("platform.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const available_targets = @import("platform.zig").available_targets; const available_targets = @import("platform.zig").available_targets;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Target = @import("platform.zig").Target; const Target = @import("platform.zig").Target;
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
@ -90,6 +91,9 @@ pub const Context = struct {
log.err("No device found for platform {} !", .{target}); log.err("No device found for platform {} !", .{target});
continue; continue;
} }
if (target == .cuda) {
try cuda.registerZmlCustomCalls(p);
}
platforms.set(target, p); platforms.set(target, p);
num_platforms += 1; num_platforms += 1;
} }
@ -140,4 +144,78 @@ pub const Context = struct {
} }
return platform_ orelse @panic("No platform found !"); return platform_ orelse @panic("No platform found !");
} }
pub const HostCallbackCtx = struct {
host: HostBuffer,
mutex: std.Thread.Mutex = std.Thread.Mutex{},
};
pub const HostCallback = fn (HostBuffer) void;
};
const cuda = struct {
var runtime: Runtime = undefined;
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void {
std.debug.assert(cuda_platform.target == .cuda);
cuda.runtime = try Runtime.init();
const registry = cuda_platform.pjrt_api.customCallRegistry().?;
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
}
pub const Stream = opaque {};
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
default = 4,
};
pub const Runtime = struct {
memcpyAsync: MemcpyAsync,
streamSynchronize: StreamSynchronize,
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int;
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int;
pub fn init() !Runtime {
var cudart = try std.DynLib.open("libcudart.so");
defer cudart.close();
return .{
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
};
}
};
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
return .{ fn_ptr, ctx_ptr };
}
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void {
const stream: *Stream = @ptrCast(opaque_stream);
const src: *anyopaque = buffers[0];
const callback, const ctx = getContext(args, args_len);
// Add synchronization because this is called from the device driver.
ctx.mutex.lock();
defer ctx.mutex.unlock();
const host_dst: []u8 = @constCast(ctx.host.data);
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
_ = memcpy_result;
const synchronize_result = cuda.runtime.streamSynchronize(stream);
_ = synchronize_result;
callback(ctx.host);
}
}; };

View File

@ -1,25 +1,26 @@
const std = @import("std"); const std = @import("std");
const mlir = @import("mlir.zig"); const mlir = @import("mlir.zig");
const buffer = @import("buffer.zig");
const helpers = @import("helpers.zig"); const helpers = @import("helpers.zig");
const module = @import("module.zig"); const module = @import("module.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = module.CompilationContext; const CompilationContext = module.CompilationContext;
const Tensor = @import("tensor.zig").Tensor; const Context = @import("context.zig").Context;
const Shape = @import("shape.zig").Shape;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const Buffer = buffer.Buffer;
const EnumLiteral = @TypeOf(.enum_literal); const EnumLiteral = @TypeOf(.enum_literal);
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const dialect = struct { const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo; const stablehlo = @import("mlir/dialects").stablehlo;
}; };
const assert = std.debug.assert; const assert = std.debug.assert;
const log = std.log.scoped(.zml_tensor); const log = std.log.scoped(.zml);
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
@ -621,7 +622,7 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
@memcpy(backend_config[0..8], address[0..8]); @memcpy(backend_config[0..8], address[0..8]);
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "name={s}", .{name}); const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
const op = dialect.stablehlo.custom_call( const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(), ctx.mlirCtx(),
@ -638,3 +639,49 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque
); );
return Tensor._result(input.shape(), op.result(0)); return Tensor._result(input.shape(), op.result(0));
} }
/// At runtime the given tensor will be materialized and copied to host,
/// and the callback will be called on it.
pub fn addHostCallback(
callback: *const fn (HostBuffer) void,
input: Tensor,
) Tensor {
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy.
if (input.getContext().target() != .cuda) return input;
const len = input.byteSize();
// Reserve memory to be able to log the runtime Buffer later during the computation.
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
const HostCallbackCtx = Context.HostCallbackCtx;
const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch {
log.err("Failed to pre-allocate buffer to print {}.", .{input});
return input;
};
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
// Use an fba to have the stable buffer at an aligned offset.
var fba = std.heap.FixedBufferAllocator.init(full_data[len..]);
const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable;
stable_ctx_ptr.* = .{
.host = HostBuffer.fromBytes(input.shape(), full_data[0..len]),
};
const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr };
const ctx = CompilationContext.current();
const loc = ctx.mlirCtx().location(@src());
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
&.{input.value()},
.{
.api_version = 1,
.has_side_effect = false,
.call_target_name = "zmlHostBufferCallback",
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)),
.output_operand_aliases = &.{0},
},
&.{input.value().getType()},
loc,
);
return Tensor._result(input.shape(), op.result(0));
}

View File

@ -3629,6 +3629,27 @@ pub const Tensor = struct {
} }
}.binaryOpHelper; }.binaryOpHelper;
} }
/// Insert code that will print the content of the given buffer at runtime.
/// Only for debug purpose, it has the following limitations:
/// * only supported on Cuda atm
/// * only prints the first 1024 values
/// * pre allocates a buffer on the host to copy the content of the device buffer,
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
/// * does device to host synchronization so it will slow down the program execution.
pub fn print(input: Tensor) Tensor {
return ops.addHostCallback(&printCallback, input);
}
fn printCallback(host_buffer: HostBuffer) void {
switch (host_buffer.dtype()) {
inline else => |dt| {
const items = host_buffer.items(dt.toZigType());
const n = @min(items.len, 1024);
std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] });
},
}
}
}; };
fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 { fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {