From 6d720126ace395a73151887d1b82fc2d7a66d591 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 5 Jun 2023 13:42:45 +0000 Subject: [PATCH] =?UTF-8?q?Add=20PJRT=20custom=20call=20integration=20with?= =?UTF-8?q?=20generic=20zmlHostBufferCallback=20to=20copy=20tensors=20to?= =?UTF-8?q?=20host=20and=20invoke=20user=20callbacks.=20Introduce=20Tensor?= =?UTF-8?q?.print()=20method=20to=20output=20runtime=20tensor=20values=20(?= =?UTF-8?q?CUDA=E2=80=91specific,=20uses=20a=20pre=E2=80=91allocated=20hos?= =?UTF-8?q?t=20buffer).?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pjrt/pjrt.zig | 33 +++++++++++++++++++++ zml/context.zig | 78 +++++++++++++++++++++++++++++++++++++++++++++++++ zml/ops.zig | 59 +++++++++++++++++++++++++++++++++---- zml/tensor.zig | 21 +++++++++++++ 4 files changed, 185 insertions(+), 6 deletions(-) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 30562f3..9ecf289 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -139,6 +139,14 @@ pub const Api = struct { .minor = @intCast(self.inner.pjrt_api_version.minor_version), }; } + + pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { + if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { + return .{ .inner = ext.custom_call.? }; + } + // log.warn("No Custom Call registry found for platform: {}", .{self}); + return null; + } }; pub const ErrorCode = enum(c.PJRT_Error_Code) { @@ -854,3 +862,28 @@ pub const NamedValue = extern struct { try writer.writeAll("}"); } }; + +/// Custom call signature arguments are: +/// * a pointer to a platform specific stream handle +/// * a pointer to an unspecified list of platform specific buffer handle +/// * a context struct passed as a slice of bytes +pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void; + +pub const CustomCallRegistry = extern struct { + inner: *const c.PJRT_Gpu_Register_Custom_Call, + + pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void { + var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{ + .function_name = name.ptr, + .function_name_size = name.len, + .api_version = @intCast(api_version), + .custom_call_function = @ptrCast(@constCast(func)), + }); + const result = self.inner(&ret); + if (result) |pjrt_c_error| { + const pjrt_error: *Error = @ptrCast(pjrt_c_error); + log.err("[GpuRegisterCustomCall] {s}", .{pjrt_error.getMessage(api)}); + return pjrt_error.getCode(api).toApiError(); + } + } +}; diff --git a/zml/context.zig b/zml/context.zig index 30c48f3..5997cf8 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -9,6 +9,7 @@ const platform = @import("platform.zig"); const pjrt = @import("pjrtx.zig"); const available_targets = @import("platform.zig").available_targets; +const HostBuffer = @import("hostbuffer.zig").HostBuffer; const Target = @import("platform.zig").Target; const Platform = @import("platform.zig").Platform; @@ -90,6 +91,9 @@ pub const Context = struct { log.err("No device found for platform {} !", .{target}); continue; } + if (target == .cuda) { + try cuda.registerZmlCustomCalls(p); + } platforms.set(target, p); num_platforms += 1; } @@ -140,4 +144,78 @@ pub const Context = struct { } return platform_ orelse @panic("No platform found !"); } + + pub const HostCallbackCtx = struct { + host: HostBuffer, + mutex: std.Thread.Mutex = std.Thread.Mutex{}, + }; + pub const HostCallback = fn (HostBuffer) void; +}; + +const cuda = struct { + var runtime: Runtime = undefined; + + pub fn registerZmlCustomCalls(cuda_platform: Platform) !void { + std.debug.assert(cuda_platform.target == .cuda); + + cuda.runtime = try Runtime.init(); + const registry = cuda_platform.pjrt_api.customCallRegistry().?; + try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback); + } + + pub const Stream = opaque {}; + pub const MemcpyKind = enum(c_int) { + host_to_host = 0, + host_to_device = 1, + device_to_host = 2, + device_to_device = 3, + default = 4, + }; + + pub const Runtime = struct { + memcpyAsync: MemcpyAsync, + streamSynchronize: StreamSynchronize, + + const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int; + const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int; + + pub fn init() !Runtime { + var cudart = try std.DynLib.open("libcudart.so"); + defer cudart.close(); + + return .{ + .memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound, + .streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound, + }; + } + }; + + fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } { + std.debug.assert(args_len == @sizeOf(*anyopaque) * 2); + + const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*); + const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr); + + const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*); + const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr); + return .{ fn_ptr, ctx_ptr }; + } + + fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void { + const stream: *Stream = @ptrCast(opaque_stream); + const src: *anyopaque = buffers[0]; + const callback, const ctx = getContext(args, args_len); + + // Add synchronization because this is called from the device driver. + ctx.mutex.lock(); + defer ctx.mutex.unlock(); + + const host_dst: []u8 = @constCast(ctx.host.data); + const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream); + _ = memcpy_result; + const synchronize_result = cuda.runtime.streamSynchronize(stream); + _ = synchronize_result; + + callback(ctx.host); + } }; diff --git a/zml/ops.zig b/zml/ops.zig index 74c6060..8d65b18 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -1,25 +1,26 @@ const std = @import("std"); const mlir = @import("mlir.zig"); -const buffer = @import("buffer.zig"); const helpers = @import("helpers.zig"); const module = @import("module.zig"); const meta = @import("meta.zig"); +const Buffer = @import("buffer.zig").Buffer; const CompilationContext = module.CompilationContext; -const Tensor = @import("tensor.zig").Tensor; -const Shape = @import("shape.zig").Shape; +const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; -const Buffer = buffer.Buffer; const EnumLiteral = @TypeOf(.enum_literal); +const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const Shape = @import("shape.zig").Shape; +const Tensor = @import("tensor.zig").Tensor; const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; const assert = std.debug.assert; -const log = std.log.scoped(.zml_tensor); +const log = std.log.scoped(.zml); test { std.testing.refAllDecls(@This()); @@ -621,7 +622,7 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque @memcpy(backend_config[0..8], address[0..8]); const ctx = CompilationContext.current(); - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "name={s}", .{name}); + const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name}); const op = dialect.stablehlo.custom_call( ctx.mlirCtx(), @@ -638,3 +639,49 @@ pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque ); return Tensor._result(input.shape(), op.result(0)); } + +/// At runtime the given tensor will be materialized and copied to host, +/// and the callback will be called on it. +pub fn addHostCallback( + callback: *const fn (HostBuffer) void, + input: Tensor, +) Tensor { + // TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy. + if (input.getContext().target() != .cuda) return input; + + const len = input.byteSize(); + // Reserve memory to be able to log the runtime Buffer later during the computation. + // This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled. + const HostCallbackCtx = Context.HostCallbackCtx; + const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch { + log.err("Failed to pre-allocate buffer to print {}.", .{input}); + return input; + }; + + // Save the HostBuffer inside the same memory slice, so that it's still present at runtime. + // Use an fba to have the stable buffer at an aligned offset. + var fba = std.heap.FixedBufferAllocator.init(full_data[len..]); + const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable; + stable_ctx_ptr.* = .{ + .host = HostBuffer.fromBytes(input.shape(), full_data[0..len]), + }; + + const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr }; + const ctx = CompilationContext.current(); + + const loc = ctx.mlirCtx().location(@src()); + const op = dialect.stablehlo.custom_call( + ctx.mlirCtx(), + &.{input.value()}, + .{ + .api_version = 1, + .has_side_effect = false, + .call_target_name = "zmlHostBufferCallback", + .backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)), + .output_operand_aliases = &.{0}, + }, + &.{input.value().getType()}, + loc, + ); + return Tensor._result(input.shape(), op.result(0)); +} diff --git a/zml/tensor.zig b/zml/tensor.zig index fdb6875..82e5b6b 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -3629,6 +3629,27 @@ pub const Tensor = struct { } }.binaryOpHelper; } + + /// Insert code that will print the content of the given buffer at runtime. + /// Only for debug purpose, it has the following limitations: + /// * only supported on Cuda atm + /// * only prints the first 1024 values + /// * pre allocates a buffer on the host to copy the content of the device buffer, + /// this buffer won't be freed. You will have one buffer per "print" call in the IR. + /// * does device to host synchronization so it will slow down the program execution. + pub fn print(input: Tensor) Tensor { + return ops.addHostCallback(&printCallback, input); + } + + fn printCallback(host_buffer: HostBuffer) void { + switch (host_buffer.dtype()) { + inline else => |dt| { + const items = host_buffer.items(dt.toZigType()); + const n = @min(items.len, 1024); + std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] }); + }, + } + } }; fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {