256 lines
9.1 KiB
Zig
256 lines
9.1 KiB
Zig
const builtin = @import("builtin");
|
|
const c = @import("c");
|
|
const mlir = @import("mlir");
|
|
const runfiles = @import("runfiles");
|
|
const runtimes = @import("runtimes");
|
|
const std = @import("std");
|
|
const stdx = @import("stdx");
|
|
|
|
const platform = @import("platform.zig");
|
|
const pjrt = @import("pjrtx.zig");
|
|
|
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
|
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
|
const Platform = @import("platform.zig").Platform;
|
|
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
|
const Target = @import("platform.zig").Target;
|
|
|
|
const available_targets = @import("platform.zig").available_targets;
|
|
const log = std.log.scoped(.@"zml/context");
|
|
|
|
/// Every program using ZML must start with a `zml.Context.init(.{});`
|
|
/// The ZML context contains global state to interact with the different
|
|
/// devices available on your system.
|
|
/// Note that the runtimes available depends on how the program was compiled.
|
|
/// For example you need to compile your program with `--//runtimes:cuda=true`
|
|
/// to have the CUDA runtime available.
|
|
pub const Context = struct {
|
|
var apis = PjrtApiMap.initFill(null);
|
|
var apis_once = std.once(struct {
|
|
fn call() void {
|
|
inline for (comptime std.enums.values(runtimes.Platform)) |t| {
|
|
if (runtimes.load(t)) |api| {
|
|
Context.apis.set(t, api);
|
|
} else |_| {}
|
|
}
|
|
}
|
|
}.call);
|
|
|
|
var mlir_once = std.once(struct {
|
|
fn call() void {
|
|
mlir.registerPasses("Transforms");
|
|
}
|
|
}.call);
|
|
|
|
var runfiles_once = std.once(struct {
|
|
fn call_() !void {
|
|
if (std.process.hasEnvVarConstant("RUNFILES_MANIFEST_FILE") or std.process.hasEnvVarConstant("RUNFILES_DIR")) {
|
|
return;
|
|
}
|
|
|
|
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
|
const allocator = arena.allocator();
|
|
defer arena.deinit();
|
|
|
|
var envMap = std.process.EnvMap.init(allocator);
|
|
var r = (try runfiles.Runfiles.create(.{ .allocator = allocator })) orelse return;
|
|
try r.environment(&envMap);
|
|
|
|
var it = envMap.iterator();
|
|
while (it.next()) |entry| {
|
|
const keyZ = try allocator.dupeZ(u8, entry.key_ptr.*);
|
|
const valueZ = try allocator.dupeZ(u8, entry.value_ptr.*);
|
|
_ = c.setenv(keyZ.ptr, valueZ.ptr, 1);
|
|
}
|
|
}
|
|
|
|
fn call() void {
|
|
call_() catch @panic("Unable to init runfiles env");
|
|
}
|
|
}.call);
|
|
|
|
platforms: PlatformsMap,
|
|
|
|
/// Creates a ZML Context and returns it.
|
|
pub fn init() !Context {
|
|
Context.runfiles_once.call();
|
|
Context.apis_once.call();
|
|
Context.mlir_once.call();
|
|
|
|
var platforms = PlatformsMap.initFill(null);
|
|
var num_platforms: u8 = 0;
|
|
var it = Context.apis.iterator();
|
|
while (it.next()) |entry| {
|
|
if (entry.value.*) |api| {
|
|
const target = entry.key;
|
|
const p = Platform.init(target, api) catch |err| {
|
|
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
|
continue;
|
|
};
|
|
if (p.getDevices().len == 0) {
|
|
log.err("No device found for platform {} !", .{target});
|
|
continue;
|
|
}
|
|
if (target == .cuda) {
|
|
try cuda.registerZmlCustomCalls(p);
|
|
}
|
|
platforms.set(target, p);
|
|
num_platforms += 1;
|
|
}
|
|
}
|
|
if (num_platforms == 0) {
|
|
log.err("No platform available", .{});
|
|
return error.NoPlatformAvailable;
|
|
}
|
|
|
|
return .{
|
|
.platforms = platforms,
|
|
};
|
|
}
|
|
|
|
fn platformToLibrary(comptime target: Target) []const u8 {
|
|
const ext = switch (builtin.os.tag) {
|
|
.windows => ".dll",
|
|
.macos, .ios, .watchos => ".dylib",
|
|
else => ".so",
|
|
};
|
|
return switch (target) {
|
|
inline else => "libpjrt_" ++ @tagName(target) ++ ext,
|
|
};
|
|
}
|
|
|
|
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
|
return Context.apis.get(target).?;
|
|
}
|
|
|
|
pub fn deinit(self: *Context) void {
|
|
var iterator = self.platforms.iterator();
|
|
while (iterator.next()) |entry| {
|
|
if (entry.value.*) |*p| {
|
|
p.deinit();
|
|
}
|
|
}
|
|
self.* = undefined;
|
|
}
|
|
|
|
/// Automatically selects the best Platform loaded in the current Context.
|
|
///
|
|
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
|
pub fn autoPlatform(self: *Context) Platform {
|
|
// the last platform is the one that with the high enum number, so considered
|
|
// to be the "best" one
|
|
var platform_: ?Platform = null;
|
|
var iterator = self.platforms.iterator();
|
|
while (iterator.next()) |entry| {
|
|
if (entry.value.*) |p| {
|
|
platform_ = p;
|
|
}
|
|
}
|
|
return platform_ orelse @panic("No platform found !");
|
|
}
|
|
|
|
pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void {
|
|
// List available targets
|
|
log.info("Available Platforms:", .{});
|
|
const selected_prefix = "✅";
|
|
const not_selected_prefix = "• ";
|
|
const selected_postfix = "(AUTO-SELECTED)";
|
|
const not_selected_postfix = "";
|
|
|
|
for (platform.available_targets) |target| {
|
|
log.info(" {s} {s} {s}", .{
|
|
if (target == selected.target) selected_prefix else not_selected_prefix,
|
|
@tagName(target),
|
|
if (target == selected.target) selected_postfix else not_selected_postfix,
|
|
});
|
|
|
|
// now the platform's devices
|
|
if (self.platforms.get(target)) |pfm| {
|
|
for (pfm.getDevices(), 0..) |device, index| {
|
|
const deviceKind = device.getDescription(pfm.pjrt_api).getKind(pfm.pjrt_api);
|
|
log.info(" ◦ #{d}: {s}", .{
|
|
index,
|
|
deviceKind,
|
|
});
|
|
// we only list 1 CPU device
|
|
if (target == .cpu) break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub const HostCallbackCtx = struct {
|
|
host: HostBuffer,
|
|
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
|
};
|
|
pub const HostCallback = fn (HostBuffer) void;
|
|
};
|
|
|
|
const cuda = struct {
|
|
var runtime: Runtime = undefined;
|
|
|
|
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void {
|
|
std.debug.assert(cuda_platform.target == .cuda);
|
|
|
|
cuda.runtime = try Runtime.init();
|
|
const registry = cuda_platform.pjrt_api.customCallRegistry().?;
|
|
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
|
|
}
|
|
|
|
pub const Stream = opaque {};
|
|
pub const MemcpyKind = enum(c_int) {
|
|
host_to_host = 0,
|
|
host_to_device = 1,
|
|
device_to_host = 2,
|
|
device_to_device = 3,
|
|
default = 4,
|
|
};
|
|
|
|
pub const Runtime = struct {
|
|
memcpyAsync: MemcpyAsync,
|
|
streamSynchronize: StreamSynchronize,
|
|
|
|
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int;
|
|
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int;
|
|
|
|
pub fn init() !Runtime {
|
|
var cudart = try std.DynLib.open("libcudart.so");
|
|
defer cudart.close();
|
|
|
|
return .{
|
|
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
|
|
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
|
|
};
|
|
}
|
|
};
|
|
|
|
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
|
|
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
|
|
|
|
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
|
|
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
|
|
|
|
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
|
|
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
|
|
return .{ fn_ptr, ctx_ptr };
|
|
}
|
|
|
|
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void {
|
|
const stream: *Stream = @ptrCast(opaque_stream);
|
|
const src: *anyopaque = buffers[0];
|
|
const callback, const ctx = getContext(args, args_len);
|
|
|
|
// Add synchronization because this is called from the device driver.
|
|
ctx.mutex.lock();
|
|
defer ctx.mutex.unlock();
|
|
|
|
const host_dst: []u8 = @constCast(ctx.host.data);
|
|
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
|
|
_ = memcpy_result;
|
|
const synchronize_result = cuda.runtime.streamSynchronize(stream);
|
|
_ = synchronize_result;
|
|
|
|
callback(ctx.host);
|
|
}
|
|
};
|