const std = @import("std"); const async = @import("async"); const runtimes = @import("runtimes"); const zml = @import("zml"); const cu = zml.platform_specific; pub const std_options: std.Options = .{ .log_level = .info, .logFn = async.logFn(std.log.defaultLog), }; const log = std.log.scoped(.@"examples/custom_call"); /// Demonstration of the custom_call mechanism. /// /// * defines a function to compute the grayscale version of an image. /// * using simple for loop on CPU /// * using a custom PTX kernel on Cuda, with manual cuLaunchKernel /// * Use `zml.customCall` to create an executable calling our callback and our kernel. pub const GrayScale = struct { // Mandatory fields to work with ZML custom call api pub var type_id: zml.pjrt.ffi.TypeId = undefined; pub const callback_config: zml.callback.Config = .{}; platform: zml.Platform, // Custom field to store our cuda module and function. cu_data: [2]*anyopaque, // Set by ZML before `call` is entered results: [1]zml.Buffer = undefined, stream: *zml.pjrt.Stream = undefined, pub fn init(platform: zml.Platform) !GrayScale { var cu_data: [2]*anyopaque = undefined; if (comptime runtimes.isEnabled(.cuda)) { var module: cu.CUmodule = undefined; try cuda.check(cuda.moduleLoadData.?(&module, grayscale_ptx)); log.info("Loaded Grayscale cuda module", .{}); var function: cu.CUfunction = undefined; try cuda.check(cuda.moduleGetFunction.?(&function, module, "rgba_to_grayscale")); log.info("Found Grayscale cuda function", .{}); cu_data = .{ module.?, function.? }; } return .{ .platform = platform, .cu_data = cu_data, }; } pub fn deinit(self: *GrayScale) void { if (comptime runtimes.isEnabled(.cuda)) { const module: cu.CUmodule = @ptrCast(self.cu_data[0]); _ = cuda.moduleUnload.?(module); } } pub fn call(self: *GrayScale, rgb_d: zml.Buffer) !void { switch (self.platform.target) { .cpu => grayScaleCpu(rgb_d, self.results[0]), // Only try to compile `grayScaleCuda` if we have cuda symbols. .cuda => if (comptime runtimes.isEnabled(.cuda)) try self.grayScaleCuda(rgb_d, self.results[0]) else unreachable, else => @panic("Platform not supported"), } } pub fn grayScaleCpu(rgb_d: zml.Buffer, gray_d: zml.Buffer) void { const rgb_h = rgb_d.asHostBuffer().items(u8); const gray_h = gray_d.asHostBuffer().mutItems(u8); for (gray_h, 0..) |*gray, i| { const px = rgb_h[i * 3 .. i * 3 + 3]; const R: u32 = @intCast(px[0]); const G: u32 = @intCast(px[1]); const B: u32 = @intCast(px[2]); const gray_u32: u32 = @divFloor(299 * R + 587 * G + 114 * B, 1000); gray.* = @intCast(gray_u32); } } pub fn grayScaleCuda(self: GrayScale, rgb_d: zml.Buffer, gray_d: zml.Buffer) !void { var args: [2][]u8 = .{ rgb_d.opaqueDeviceMemoryDataPointer()[0..rgb_d.shape().byteSize()], gray_d.opaqueDeviceMemoryDataPointer()[0..gray_d.shape().byteSize()], }; var args_ptr: [2:null]?*anyopaque = .{ @ptrCast(&args[0]), @ptrCast(&args[1]) }; // This is a naive kernel with one block per pixel. try cuda.check(cuda.launchKernel.?( @ptrCast(self.cu_data[1]), // function @intCast(rgb_d.shape().count() / 3), // num blocks x 1, // num blocks y 1, // num blocks z 1, // num grids x 1, // num grids y 1, // num grids z 0, // shared mem @ptrCast(self.stream), &args_ptr, null, )); // Note: no explicit synchronization, we just enqueue work in the stream. } const cuda = struct { // Here we leverage ZML sandboxing to access cuda symbols and their definitions. const moduleLoadData = @extern(*const @TypeOf(cu.cuModuleLoadData), .{ .name = "cuModuleLoadData", .linkage = .weak }); const moduleUnload = @extern(*const @TypeOf(cu.cuModuleUnload), .{ .name = "cuModuleUnload", .linkage = .weak }); const moduleGetFunction = @extern(*const @TypeOf(cu.cuModuleGetFunction), .{ .name = "cuModuleGetFunction", .linkage = .weak }); const launchKernel = @extern(*const @TypeOf(cu.cuLaunchKernel), .{ .name = "cuLaunchKernel", .linkage = .weak }); pub fn check(result: cu.CUresult) error{CudaError}!void { if (result == cu.CUDA_SUCCESS) return; std.log.err("cuda error: {}", .{result}); return error.CudaError; } }; }; pub fn grayscale(rgb: zml.Tensor) zml.Tensor { const gray_shape = rgb.shape().setDim(0, @divExact(rgb.dim(0), 3)); const result = zml.callback.call(GrayScale, .{rgb.print()}, &.{gray_shape}); return result[0]; } pub fn main() !void { try async.AsyncThread.main(std.heap.smp_allocator, asyncMain); } pub fn asyncMain() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); var context = try zml.Context.init(); defer context.deinit(); const platform = context.autoPlatform(.{}); context.printAvailablePlatforms(platform); // Register our custom call try zml.callback.register(GrayScale, platform); // Compile MLIR code containing our custom call. const rgb_shape = zml.Shape.init(.{12 * 3}, .u8); const exe = try zml.compileFn(allocator, grayscale, .{rgb_shape}, platform); defer exe.deinit(); // Provide runtime information needed by our custom call. var gray_op: GrayScale = try .init(platform); defer gray_op.deinit(); try exe.bind(GrayScale, &gray_op); // Load data and run the executable. const rgb_h: [12][3]u8 = @splat(.{ 0xFF, 0xAA, 0x00 }); const rgb_d = try zml.Buffer.fromBytes(platform, rgb_shape, @ptrCast(&rgb_h)); defer rgb_d.deinit(); var gray_d: zml.Buffer = exe.call(.{rgb_d}); defer gray_d.deinit(); // Inspect results const gray_h = try gray_d.toHostAlloc(allocator); defer gray_h.deinit(allocator); std.debug.print("Grayscale conversion of {any} -> {d}\n", .{ rgb_h, gray_h }); try std.testing.expectEqualSlices(u8, &@as([12]u8, @splat(0xB0)), gray_h.items(u8)); } // Compiled with Zig and a little help from `https://github.com/gwenzek/cudaz` const grayscale_ptx = \\.version 4.0 \\.target sm_32 \\.address_size 64 \\ \\.visible .entry rgba_to_grayscale( \\ .param .align 8 .b8 rgba_to_grayscale_param_0[16], \\ .param .align 8 .b8 rgba_to_grayscale_param_1[16] \\) \\{ \\ .reg .pred %p<2>; \\ .reg .b16 %rs<4>; \\ .reg .b32 %r<13>; \\ .reg .b64 %rd<9>; \\ \\ ld.param.u64 %rd5, [rgba_to_grayscale_param_1+8]; \\ mov.u32 %r1, %tid.x; \\ mov.u32 %r3, %ntid.x; \\ mov.u32 %r2, %ctaid.x; \\ mad.lo.s32 %r4, %r2, %r3, %r1; \\ cvt.u64.u32 %rd1, %r4; \\ setp.gt.u64 %p1, %rd5, %rd1; \\ @%p1 bra $L__BB0_2; \\ bra.uni $L__BB0_1; \\$L__BB0_2: \\ ld.param.u64 %rd4, [rgba_to_grayscale_param_1]; \\ ld.param.u64 %rd2, [rgba_to_grayscale_param_0]; \\ cvt.u32.u64 %r5, %rd1; \\ mul.lo.s32 %r6, %r5, 3; \\ cvt.u64.u32 %rd6, %r6; \\ add.s64 %rd7, %rd2, %rd6; \\ ld.u8 %rs1, [%rd7]; \\ ld.u8 %rs2, [%rd7+1]; \\ ld.u8 %rs3, [%rd7+2]; \\ mul.wide.u16 %r7, %rs1, 299; \\ mul.wide.u16 %r8, %rs2, 587; \\ add.s32 %r9, %r8, %r7; \\ mul.wide.u16 %r10, %rs3, 114; \\ add.s32 %r11, %r9, %r10; \\ mul.hi.u32 %r12, %r11, 4294968; \\ add.s64 %rd8, %rd4, %rd1; \\ st.u8 [%rd8], %r12; \\$L__BB0_1: \\ ret; \\} ;