Add example and Bazel build for zml.callback demonstrating a manual CUDA kernel invocation.
This commit is contained in:
parent
01da2184fe
commit
1fa056a790
11
examples/callback/BUILD.bazel
Normal file
11
examples/callback/BUILD.bazel
Normal file
@ -0,0 +1,11 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_binary")
|
||||
|
||||
zig_binary(
|
||||
name = "callback",
|
||||
main = "main.zig",
|
||||
deps = [
|
||||
"@zml//async",
|
||||
"@zml//zml",
|
||||
"@zml//runtimes",
|
||||
],
|
||||
)
|
||||
218
examples/callback/main.zig
Normal file
218
examples/callback/main.zig
Normal file
@ -0,0 +1,218 @@
|
||||
const std = @import("std");
|
||||
|
||||
const asynk = @import("async");
|
||||
const runtimes = @import("runtimes");
|
||||
const zml = @import("zml");
|
||||
const cu = zml.platform_specific;
|
||||
|
||||
pub const std_options: std.Options = .{
|
||||
.log_level = .info,
|
||||
.logFn = asynk.logFn(std.log.defaultLog),
|
||||
};
|
||||
|
||||
const log = std.log.scoped(.@"examples/custom_call");
|
||||
|
||||
/// Demonstration of the custom_call mechanism.
|
||||
///
|
||||
/// * defines a function to compute the grayscale version of an image.
|
||||
/// * using simple for loop on CPU
|
||||
/// * using a custom PTX kernel on Cuda, with manual cuLaunchKernel
|
||||
/// * Use `zml.customCall` to create an executable calling our callback and our kernel.
|
||||
pub const GrayScale = struct {
|
||||
// Mandatory fields to work with ZML custom call api
|
||||
pub var type_id: zml.pjrt.ffi.TypeId = undefined;
|
||||
pub const callback_config: zml.callback.Config = .{};
|
||||
platform: zml.Platform,
|
||||
|
||||
// Custom field to store our cuda module and function.
|
||||
cu_data: [2]*anyopaque,
|
||||
|
||||
// Set by ZML before `call` is entered
|
||||
results: [1]zml.Buffer = undefined,
|
||||
stream: *zml.pjrt.Stream = undefined,
|
||||
|
||||
pub fn init(platform: zml.Platform) !GrayScale {
|
||||
var cu_data: [2]*anyopaque = undefined;
|
||||
if (comptime runtimes.isEnabled(.cuda)) {
|
||||
var module: cu.CUmodule = undefined;
|
||||
try cuda.check(cuda.moduleLoadData.?(&module, grayscale_ptx));
|
||||
log.info("Loaded Grayscale cuda module", .{});
|
||||
var function: cu.CUfunction = undefined;
|
||||
try cuda.check(cuda.moduleGetFunction.?(&function, module, "rgba_to_grayscale"));
|
||||
log.info("Found Grayscale cuda function", .{});
|
||||
cu_data = .{ module.?, function.? };
|
||||
}
|
||||
return .{
|
||||
.platform = platform,
|
||||
.cu_data = cu_data,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *GrayScale) void {
|
||||
if (comptime runtimes.isEnabled(.cuda)) {
|
||||
const module: cu.CUmodule = @ptrCast(self.cu_data[0]);
|
||||
_ = cuda.moduleUnload.?(module);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call(self: *GrayScale, rgb_d: zml.Buffer) !void {
|
||||
switch (self.platform.target) {
|
||||
.cpu => grayScaleCpu(rgb_d, self.results[0]),
|
||||
// Only try to compile `grayScaleCuda` if we have cuda symbols.
|
||||
.cuda => if (comptime runtimes.isEnabled(.cuda))
|
||||
try self.grayScaleCuda(rgb_d, self.results[0])
|
||||
else
|
||||
unreachable,
|
||||
else => @panic("Platform not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn grayScaleCpu(rgb_d: zml.Buffer, gray_d: zml.Buffer) void {
|
||||
const rgb_h = rgb_d.asHostBuffer().items(u8);
|
||||
const gray_h = gray_d.asHostBuffer().mutItems(u8);
|
||||
|
||||
for (gray_h, 0..) |*gray, i| {
|
||||
const px = rgb_h[i * 3 .. i * 3 + 3];
|
||||
const R: u32 = @intCast(px[0]);
|
||||
const G: u32 = @intCast(px[1]);
|
||||
const B: u32 = @intCast(px[2]);
|
||||
const gray_u32: u32 = @divFloor(299 * R + 587 * G + 114 * B, 1000);
|
||||
gray.* = @intCast(gray_u32);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn grayScaleCuda(self: GrayScale, rgb_d: zml.Buffer, gray_d: zml.Buffer) !void {
|
||||
var args: [2][]u8 = .{
|
||||
rgb_d.opaqueDeviceMemoryDataPointer()[0..rgb_d.shape().byteSize()],
|
||||
gray_d.opaqueDeviceMemoryDataPointer()[0..gray_d.shape().byteSize()],
|
||||
};
|
||||
var args_ptr: [2:null]?*anyopaque = .{ @ptrCast(&args[0]), @ptrCast(&args[1]) };
|
||||
// This is a naive kernel with one block per pixel.
|
||||
try cuda.check(cuda.launchKernel.?(
|
||||
@ptrCast(self.cu_data[1]), // function
|
||||
@intCast(rgb_d.shape().count() / 3), // num blocks x
|
||||
1, // num blocks y
|
||||
1, // num blocks z
|
||||
1, // num grids x
|
||||
1, // num grids y
|
||||
1, // num grids z
|
||||
0, // shared mem
|
||||
@ptrCast(self.stream),
|
||||
&args_ptr,
|
||||
null,
|
||||
));
|
||||
// Note: no explicit synchronization, we just enqueue work in the stream.
|
||||
}
|
||||
|
||||
const cuda = struct {
|
||||
// Here we leverage ZML sandboxing to access cuda symbols and their definitions.
|
||||
const moduleLoadData = @extern(*const @TypeOf(cu.cuModuleLoadData), .{ .name = "cuModuleLoadData", .linkage = .weak });
|
||||
const moduleUnload = @extern(*const @TypeOf(cu.cuModuleUnload), .{ .name = "cuModuleUnload", .linkage = .weak });
|
||||
const moduleGetFunction = @extern(*const @TypeOf(cu.cuModuleGetFunction), .{ .name = "cuModuleGetFunction", .linkage = .weak });
|
||||
const launchKernel = @extern(*const @TypeOf(cu.cuLaunchKernel), .{ .name = "cuLaunchKernel", .linkage = .weak });
|
||||
|
||||
pub fn check(result: cu.CUresult) error{CudaError}!void {
|
||||
if (result == cu.CUDA_SUCCESS) return;
|
||||
std.log.err("cuda error: {}", .{result});
|
||||
return error.CudaError;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
pub fn grayscale(rgb: zml.Tensor) zml.Tensor {
|
||||
const gray_shape = rgb.shape().setDim(0, @divExact(rgb.dim(0), 3));
|
||||
const result = zml.callback.call(GrayScale, .{rgb.print()}, &.{gray_shape});
|
||||
return result[0];
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.smp_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform(.{});
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Register our custom call
|
||||
try zml.callback.register(GrayScale, platform);
|
||||
|
||||
// Compile MLIR code containing our custom call.
|
||||
const rgb_shape = zml.Shape.init(.{12 * 3}, .u8);
|
||||
const exe = try zml.compileFn(allocator, grayscale, .{rgb_shape}, platform);
|
||||
defer exe.deinit();
|
||||
|
||||
// Provide runtime information needed by our custom call.
|
||||
var gray_op: GrayScale = try .init(platform);
|
||||
defer gray_op.deinit();
|
||||
try exe.bind(GrayScale, &gray_op);
|
||||
|
||||
// Load data and run the executable.
|
||||
const rgb_h: [12][3]u8 = @splat(.{ 0xFF, 0xAA, 0x00 });
|
||||
const rgb_d = try zml.Buffer.fromBytes(platform, rgb_shape, @ptrCast(&rgb_h));
|
||||
defer rgb_d.deinit();
|
||||
|
||||
var gray_d: zml.Buffer = exe.call(.{rgb_d});
|
||||
defer gray_d.deinit();
|
||||
|
||||
// Inspect results
|
||||
const gray_h = try gray_d.toHostAlloc(allocator);
|
||||
defer gray_h.deinit(allocator);
|
||||
|
||||
std.debug.print("Grayscale conversion of {any} -> {d}\n", .{ rgb_h, gray_h });
|
||||
try std.testing.expectEqualSlices(u8, &@as([12]u8, @splat(0xB0)), gray_h.items(u8));
|
||||
}
|
||||
|
||||
// Compiled with Zig and a little help from `https://github.com/gwenzek/cudaz`
|
||||
const grayscale_ptx =
|
||||
\\.version 4.0
|
||||
\\.target sm_32
|
||||
\\.address_size 64
|
||||
\\
|
||||
\\.visible .entry rgba_to_grayscale(
|
||||
\\ .param .align 8 .b8 rgba_to_grayscale_param_0[16],
|
||||
\\ .param .align 8 .b8 rgba_to_grayscale_param_1[16]
|
||||
\\)
|
||||
\\{
|
||||
\\ .reg .pred %p<2>;
|
||||
\\ .reg .b16 %rs<4>;
|
||||
\\ .reg .b32 %r<13>;
|
||||
\\ .reg .b64 %rd<9>;
|
||||
\\
|
||||
\\ ld.param.u64 %rd5, [rgba_to_grayscale_param_1+8];
|
||||
\\ mov.u32 %r1, %tid.x;
|
||||
\\ mov.u32 %r3, %ntid.x;
|
||||
\\ mov.u32 %r2, %ctaid.x;
|
||||
\\ mad.lo.s32 %r4, %r2, %r3, %r1;
|
||||
\\ cvt.u64.u32 %rd1, %r4;
|
||||
\\ setp.gt.u64 %p1, %rd5, %rd1;
|
||||
\\ @%p1 bra $L__BB0_2;
|
||||
\\ bra.uni $L__BB0_1;
|
||||
\\$L__BB0_2:
|
||||
\\ ld.param.u64 %rd4, [rgba_to_grayscale_param_1];
|
||||
\\ ld.param.u64 %rd2, [rgba_to_grayscale_param_0];
|
||||
\\ cvt.u32.u64 %r5, %rd1;
|
||||
\\ mul.lo.s32 %r6, %r5, 3;
|
||||
\\ cvt.u64.u32 %rd6, %r6;
|
||||
\\ add.s64 %rd7, %rd2, %rd6;
|
||||
\\ ld.u8 %rs1, [%rd7];
|
||||
\\ ld.u8 %rs2, [%rd7+1];
|
||||
\\ ld.u8 %rs3, [%rd7+2];
|
||||
\\ mul.wide.u16 %r7, %rs1, 299;
|
||||
\\ mul.wide.u16 %r8, %rs2, 587;
|
||||
\\ add.s32 %r9, %r8, %r7;
|
||||
\\ mul.wide.u16 %r10, %rs3, 114;
|
||||
\\ add.s32 %r11, %r9, %r10;
|
||||
\\ mul.hi.u32 %r12, %r11, 4294968;
|
||||
\\ add.s64 %rd8, %rd4, %rd1;
|
||||
\\ st.u8 [%rd8], %r12;
|
||||
\\$L__BB0_1:
|
||||
\\ ret;
|
||||
\\}
|
||||
;
|
||||
Loading…
Reference in New Issue
Block a user