2024-01-16 14:13:45 +00:00
|
|
|
const builtin = @import("builtin");
|
2024-01-15 09:41:42 +00:00
|
|
|
const std = @import("std");
|
2023-05-15 09:36:41 +00:00
|
|
|
const asynk = @import("async");
|
2024-01-15 09:41:42 +00:00
|
|
|
const pjrt = @import("pjrt");
|
2024-02-06 09:31:48 +00:00
|
|
|
const c = @import("c");
|
|
|
|
|
|
|
|
|
|
const nvidiaLibsPath = "/usr/local/cuda/lib64";
|
2023-05-15 09:36:41 +00:00
|
|
|
|
|
|
|
|
pub fn isEnabled() bool {
|
|
|
|
|
return @hasDecl(c, "ZML_RUNTIME_CUDA");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn hasNvidiaDevice() bool {
|
2023-09-14 13:28:25 +00:00
|
|
|
asynk.File.access("/dev/nvidiactl", .{ .mode = .read_only }) catch return false;
|
2023-05-15 09:36:41 +00:00
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-06 09:31:48 +00:00
|
|
|
fn hasCudaPathInLDPath() bool {
|
|
|
|
|
const ldLibraryPath = c.getenv("LD_LIBRARY_PATH");
|
2024-01-15 09:41:42 +00:00
|
|
|
|
2024-02-06 09:31:48 +00:00
|
|
|
if (ldLibraryPath == null) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
2024-01-15 09:41:42 +00:00
|
|
|
|
2024-02-06 09:31:48 +00:00
|
|
|
return std.mem.indexOf(u8, std.mem.span(ldLibraryPath), nvidiaLibsPath) != null;
|
2024-01-15 09:41:42 +00:00
|
|
|
}
|
|
|
|
|
|
2023-05-15 09:36:41 +00:00
|
|
|
pub fn load() !*const pjrt.Api {
|
|
|
|
|
if (comptime !isEnabled()) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
|
|
|
|
if (comptime builtin.os.tag != .linux) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
|
|
|
|
if (!hasNvidiaDevice()) {
|
|
|
|
|
return error.Unavailable;
|
|
|
|
|
}
|
2024-02-06 09:31:48 +00:00
|
|
|
if (hasCudaPathInLDPath()) {
|
|
|
|
|
std.log.warn("Detected {s} in LD_LIBRARY_PATH. This can lead to undefined behaviors and crashes", .{nvidiaLibsPath});
|
|
|
|
|
}
|
2024-01-15 09:41:42 +00:00
|
|
|
|
2024-01-16 14:13:45 +00:00
|
|
|
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_cuda.so"});
|
2023-05-15 09:36:41 +00:00
|
|
|
}
|