Fix CUDA and ROCm sandbox discovery, update epoll libxev patch to prevent high CPU usage, enable XLA GPU latency‑hiding scheduler, and upgrade cuDNN to 9.6.0.

This commit is contained in:
Tarry Singh 2024-01-15 09:41:42 +00:00
parent 5b8e42f9a9
commit 434cee3a6c
21 changed files with 4157 additions and 8602 deletions

View File

@ -71,7 +71,7 @@ cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
use_repo(cuda, "libpjrt_cuda") use_repo(cuda, "libpjrt_cuda")
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages") rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
use_repo(rocm, "libpjrt_rocm") use_repo(rocm, "libpjrt_rocm", "hipblaslt", "rocblas")
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages") tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
use_repo(tpu, "libpjrt_tpu") use_repo(tpu, "libpjrt_tpu")
@ -84,7 +84,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
register_toolchains("//third_party/zls:all") register_toolchains("//third_party/zls:all")
bazel_dep(name = "libxev", version = "20241119.0-6afcde9") bazel_dep(name = "libxev", version = "20241208.0-db6a52b")
bazel_dep(name = "llvm-raw", version = "20250102.0-f739aa4") bazel_dep(name = "llvm-raw", version = "20250102.0-f739aa4")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")

View File

@ -27,6 +27,8 @@ zig_library(
"//runtimes:cuda.enabled": [ "//runtimes:cuda.enabled": [
":libpjrt_cuda", ":libpjrt_cuda",
"//async", "//async",
"//stdx",
"@rules_zig//zig/runfiles",
], ],
"//conditions:default": [":empty"], "//conditions:default": [":empty"],
}), }),

View File

@ -5,7 +5,7 @@ load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ARCH = "linux-x86_64" ARCH = "linux-x86_64"
CUDA_VERSION = "12.6.3" CUDA_VERSION = "12.6.3"
CUDNN_VERSION = "9.5.1" CUDNN_VERSION = "9.6.0"
def _filegroup(name, srcs): def _filegroup(name, srcs):
return """\ return """\

View File

@ -1,7 +1,12 @@
const builtin = @import("builtin"); const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const pjrt = @import("pjrt"); const bazel_builtin = @import("bazel_builtin");
const builtin = @import("builtin");
const c = @import("c"); const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
pub fn isEnabled() bool { pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_CUDA"); return @hasDecl(c, "ZML_RUNTIME_CUDA");
@ -12,6 +17,23 @@ fn hasNvidiaDevice() bool {
return true; return true;
} }
fn setupXlaGpuCudaDirFlag() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find CUDA directory", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "--xla_gpu_cuda_data_dir={s} {s}", .{ cuda_data_dir, xla_flags });
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
}
pub fn load() !*const pjrt.Api { pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) { if (comptime !isEnabled()) {
return error.Unavailable; return error.Unavailable;
@ -23,5 +45,9 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
// CUDA path has to be set _before_ loading the PJRT plugin.
// See https://github.com/openxla/xla/issues/21428
try setupXlaGpuCudaDirFlag();
return try pjrt.Api.loadFrom("libpjrt_cuda.so"); return try pjrt.Api.loadFrom("libpjrt_cuda.so");
} }

View File

@ -1,77 +0,0 @@
{
"release_date": "2024-10-25",
"release_label": "9.5.1",
"release_product": "cudnn",
"cudnn": {
"name": "NVIDIA CUDA Deep Neural Network library",
"license": "cudnn",
"license_path": "cudnn/LICENSE.txt",
"version": "9.5.1.17",
"linux-x86_64": {
"cuda11": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.1.17_cuda11-archive.tar.xz",
"sha256": "b1f5050cd2bfd7fa9d3d0dd00d417cc2124692d8421295e12f841be6c8e3a426",
"md5": "5da3b0533fcd3d6a9020d08f3b78ddba",
"size": "736935276"
},
"cuda12": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.1.17_cuda12-archive.tar.xz",
"sha256": "35dd20b9c68324ae1288ac36f66ab1f318d2bfecfafb703a82617aa283272be4",
"md5": "a8604f6b80f42ec60e98ba9c8f681572",
"size": "744697316"
}
},
"cuda_variant": [
"11",
"12"
],
"linux-sbsa": {
"cuda11": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.5.1.17_cuda11-archive.tar.xz",
"sha256": "ad68d12ee351b5f3478078fc8188eefb8712721c3e501c9345ec5ffb0b85fae8",
"md5": "a9438457a47b2bca7951a19736e8d4e8",
"size": "735387008"
},
"cuda12": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.5.1.17_cuda12-archive.tar.xz",
"sha256": "340c49b32c133b0321c5c5b00d14fb64887dcac83ee8fd24195d9191061f1ad7",
"md5": "83c9f3f9eddadd0c1941d7f3e763174c",
"size": "743147752"
}
},
"windows-x86_64": {
"cuda11": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.5.1.17_cuda11-archive.zip",
"sha256": "8318e93ab017af2356d3b6cf35aab2238e2a51c426450842eb4ade12e4619bbb",
"md5": "b7c456ddab820ec335a724be7a969091",
"size": "554195447"
},
"cuda12": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.5.1.17_cuda12-archive.zip",
"sha256": "3a4cecc8b6d6aa7f6777620e6f2c129b76be635357c4506f2c4ccdbe0e2a1641",
"md5": "fda9196a60fb8e2b4c78e8a19ff056a3",
"size": "557597538"
}
},
"linux-aarch64": {
"cuda12": {
"relative_path": "cudnn/linux-aarch64/cudnn-linux-aarch64-9.5.1.17_cuda12-archive.tar.xz",
"sha256": "0099b8e4081ac146f802e769cdd30d9e01a289ea0fd056e64e44297a13e1aa0c",
"md5": "9d20deeb313a05c442fbff036ca29581",
"size": "780854928"
}
}
},
"cudnn_samples": {
"name": "NVIDIA cuDNN samples",
"license": "cudnn",
"license_path": "cudnn_samples/LICENSE.txt",
"version": "9.5.1.17",
"source": {
"relative_path": "cudnn_samples/source/cudnn_samples-source-9.5.1.17-archive.tar.xz",
"sha256": "bb79dc528c6a3b2a019a60d4af13cb4cb3d56146b692b3f3badec3fd8bfc98e7",
"md5": "76fe86423261f1ae984b00b1de2e40f3",
"size": "1664836"
}
}
}

View File

@ -0,0 +1,77 @@
{
"release_date": "2024-12-02",
"release_label": "9.6.0",
"release_product": "cudnn",
"cudnn": {
"name": "NVIDIA CUDA Deep Neural Network library",
"license": "cudnn",
"license_path": "cudnn/LICENSE.txt",
"version": "9.6.0.74",
"linux-x86_64": {
"cuda11": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda11-archive.tar.xz",
"sha256": "9717b0022d4f5ea88ccd9796bef7ad1cc5d04b3bd53f690041767aabfb98d14d",
"md5": "aa68c3eda5ad616c0eead9f646d4885b",
"size": "655044868"
},
"cuda12": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.6.0.74_cuda12-archive.tar.xz",
"sha256": "72595f0d17d952cf568c1d76e370a9c303bb08c2f80888a8cf33e316a65d46a8",
"md5": "16afe7a88f576525d45d9e00c4ac4277",
"size": "662367552"
}
},
"cuda_variant": [
"11",
"12"
],
"linux-sbsa": {
"cuda11": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.6.0.74_cuda11-archive.tar.xz",
"sha256": "5f7440fd8269f7a7986bff89dd6924e4145644a94958ab49f146b6f8c0230d46",
"md5": "b125b9988b0c71592561163e6c64d8c6",
"size": "653894704"
},
"cuda12": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.6.0.74_cuda12-archive.tar.xz",
"sha256": "f71fb008833fa92f9eac02c0b786a21f5e383470235ddeb1eee98fe370148ace",
"md5": "9a35d77abda9f279b3fe32023d3b2e47",
"size": "661271556"
}
},
"windows-x86_64": {
"cuda11": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.6.0.74_cuda11-archive.zip",
"sha256": "388cce5d31919ef41a4231d21ed3fb7b60609d0bd6baaf6173cd9583a00b25c9",
"md5": "6a8c3bea2bad7ca83f356a7832f3671f",
"size": "492545214"
},
"cuda12": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.6.0.74_cuda12-archive.zip",
"sha256": "65ca0f2d77a46de1def35e289780b8d8729ef2fa39cf8dd0c8448e381dd2978c",
"md5": "b969339363b43cc80f6184929a6633fa",
"size": "495911494"
}
},
"linux-aarch64": {
"cuda12": {
"relative_path": "cudnn/linux-aarch64/cudnn-linux-aarch64-9.6.0.74_cuda12-archive.tar.xz",
"sha256": "6f907bf97731d30ffd55dcc53fe8aa666b583b2c0c6b20e88c7341f98bb0b594",
"md5": "751f516ce47fb0e504b878bfc97176e4",
"size": "766134984"
}
}
},
"cudnn_samples": {
"name": "NVIDIA cuDNN samples",
"license": "cudnn",
"license_path": "cudnn_samples/LICENSE.txt",
"version": "9.6.0.74",
"source": {
"relative_path": "cudnn_samples/source/cudnn_samples-source-9.6.0.74-archive.tar.xz",
"sha256": "2cad2fb38ef359a1956daf73b6b4c0faf826a865fcc3dc791437ca83863a6cb9",
"md5": "8e46375025bdd14ba55a695e12ee3694",
"size": "1667016"
}
}
}

View File

@ -1,8 +1,23 @@
load("@rules_zig//zig:defs.bzl", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_library")
filegroup( cc_library(
name = "zmlxrocm_srcs", name = "zmlxrocm_lib",
srcs = ["zmlxrocm.cc"], srcs = ["zmlxrocm.c"],
linkopts = [
"-lc",
"-ldl",
],
)
cc_shared_library(
name = "zmlxrocm_so",
shared_lib_name = "libzmlxrocm.so.0",
deps = [":zmlxrocm_lib"],
)
cc_import(
name = "zmlxrocm",
shared_library = ":zmlxrocm_so",
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )
@ -37,6 +52,8 @@ zig_library(
"//runtimes:rocm.enabled": [ "//runtimes:rocm.enabled": [
":libpjrt_rocm", ":libpjrt_rocm",
"//async", "//async",
"//stdx",
"@rules_zig//zig/runfiles",
], ],
"//conditions:default": [":empty"], "//conditions:default": [":empty"],
}), }),

View File

@ -40,10 +40,8 @@ bytecode_select = rule(
}, },
) )
def if_gfx(gfx, value): def if_gfx(gfx, value):
return select({ return select({
"@zml//runtimes/rocm:_{}".format(gfx): value, "@zml//runtimes/rocm:_{}".format(gfx): value,
"//conditions:default": [], "//conditions:default": [],
}) })

View File

@ -7,7 +7,7 @@ string_list_flag(
build_setting_default = ["all"], build_setting_default = ["all"],
visibility = [ visibility = [
"@rocblas//:__subpackages__", "@rocblas//:__subpackages__",
"@hipblaslt-dev//:__subpackages__", "@hipblaslt//:__subpackages__",
], ],
) )
@ -38,7 +38,6 @@ cc_library(
"-lc", "-lc",
"-ldl", "-ldl",
], ],
deps = ["@bazel_tools//tools/cpp/runfiles"],
) )
cc_shared_library( cc_shared_library(
@ -58,18 +57,23 @@ cc_import(
":sandbox", ":sandbox",
"@rocblas//:runfiles", "@rocblas//:runfiles",
] + select({ ] + select({
":_hipblaslt": ["@hipblaslt-dev//:runfiles"], ":_hipblaslt": ["@hipblaslt//:runfiles"],
"//conditions:default": [], "//conditions:default": [],
}), }),
shared_library = "libpjrt_rocm.so", shared_library = "libpjrt_rocm.so",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":zmlxrocm",
"@comgr//:amd_comgr", "@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip", "@hip-runtime-amd//:amdhip",
"@hipblaslt", "@hipblaslt",
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile", "@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
"@hsa-rocr//:hsa-runtime", "@hsa-rocr//:hsa-runtime",
"@libdrm-amdgpu",
"@libdrm",
"@libelf",
"@libnuma",
"@libtinfo",
"@libzstd",
"@miopen-hip//:MIOpen", "@miopen-hip//:MIOpen",
"@rccl", "@rccl",
"@rocblas", "@rocblas",
@ -77,12 +81,7 @@ cc_import(
"@rocm-smi-lib//:rocm_smi", "@rocm-smi-lib//:rocm_smi",
"@rocprofiler-register", "@rocprofiler-register",
"@roctracer", "@roctracer",
"@libelf",
"@libdrm",
"@libnuma",
"@libzstd",
"@libdrm-amdgpu",
"@libtinfo",
"@zlib1g", "@zlib1g",
"@zml//runtimes/rocm:zmlxrocm",
], ],
) )

File diff suppressed because it is too large Load Diff

View File

@ -2,9 +2,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:dpkg.bzl", "dpkg") load("//bazel:dpkg.bzl", "dpkg")
load("//bazel:http_deb_archive.bzl", "http_deb_archive") load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ROCM_VERSION = "6.2.2" ROCM_VERSION = "6.3.1"
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION) BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
STRIP_PREFIX = "opt/rocm-6.2.2" STRIP_PREFIX = "opt/rocm-6.3.1"
def pkg_kwargs(pkg): def pkg_kwargs(pkg):
return dict( return dict(
@ -127,9 +127,17 @@ cc_import(
) )
bytecode_select( bytecode_select(
name = "runfiles", name = "bytecodes",
bytecodes = glob(["lib/rocblas/library/*"]), bytecodes = glob(["lib/rocblas/library/*"]),
enabled_gfx = "@libpjrt_rocm//:gfx", enabled_gfx = "@libpjrt_rocm//:gfx",
)
filegroup(
name = "runfiles",
srcs = [
":bytecodes",
"lib/rocblas/library/TensileManifest.txt",
],
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )
""", """,
@ -148,6 +156,7 @@ cc_import(
""", """,
"hipblaslt": """\ "hipblaslt": """\
load("@zml//bazel:cc_import.bzl", "cc_import") load("@zml//bazel:cc_import.bzl", "cc_import")
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
cc_import( cc_import(
name = "hipblaslt", name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0", shared_library = "lib/libhipblaslt.so.0",
@ -157,9 +166,6 @@ cc_import(
}, },
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )
""",
"hipblaslt-dev": """\
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
bytecode_select( bytecode_select(
name = "bytecodes", name = "bytecodes",
@ -173,8 +179,9 @@ bytecode_select(
filegroup( filegroup(
name = "runfiles", name = "runfiles",
srcs = [ srcs = [
"lib/hipblaslt/library/hipblasltExtOpLibrary.dat",
":bytecodes", ":bytecodes",
"lib/hipblaslt/library/hipblasltExtOpLibrary.dat",
"lib/hipblaslt/library/TensileManifest.txt",
], ],
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )

View File

@ -1,7 +1,12 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const pjrt = @import("pjrt"); const bazel_builtin = @import("bazel_builtin");
const c = @import("c"); const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
pub fn isEnabled() bool { pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_ROCM"); return @hasDecl(c, "ZML_RUNTIME_ROCM");
@ -14,6 +19,44 @@ fn hasRocmDevices() bool {
return true; return true;
} }
fn setupRocmEnv() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
const paths = .{
.{ "HIPBLASLT_EXT_OP_LIBRARY_PATH", "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", false },
.{ "HIPBLASLT_TENSILE_LIBPATH", "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", true },
.{ "ROCBLAS_TENSILE_LIBPATH", "rocblas/lib/rocblas/library/TensileManifest.txt", true },
.{ "ROCM_PATH", "libpjrt_rocm/sandbox", false },
};
const r = blk: {
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find Runfiles directory", .{});
};
const source_repo = bazel_builtin.current_repository;
break :blk r_.withSourceRepo(source_repo);
};
inline for (paths) |path| {
const name = path[0];
const rpath = path[1];
const dirname = path[2];
var real_path = r.rlocationAlloc(arena.allocator(), rpath) catch null orelse {
stdx.debug.panic("Unable to find " ++ name ++ " in " ++ bazel_builtin.current_repository, .{});
};
if (dirname) {
real_path = std.fs.path.dirname(real_path) orelse {
stdx.debug.panic("Unable to dirname on {s}", .{real_path});
};
}
_ = c.setenv(name, try arena.allocator().dupeZ(u8, real_path), 1);
}
}
pub fn load() !*const pjrt.Api { pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) { if (comptime !isEnabled()) {
return error.Unavailable; return error.Unavailable;
@ -25,5 +68,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
try setupRocmEnv();
return try pjrt.Api.loadFrom("libpjrt_rocm.so"); return try pjrt.Api.loadFrom("libpjrt_rocm.so");
} }

View File

@ -1,39 +1,9 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <errno.h> #include <errno.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h>
#include <fstream> void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default")))
#include <iostream>
#include <string>
#include "tools/cpp/runfiles/runfiles.h"
static void setup_runfiles(int argc, char **argv) __attribute__((constructor))
{
using bazel::tools::cpp::runfiles::Runfiles;
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
auto HIPBLASLT_EXT_OP_LIBRARY_PATH =
runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
auto HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
auto ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
auto ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default")))
{ {
if (filename != NULL) if (filename != NULL)
{ {

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20241208.0-db6a52b",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,13 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "xev",
srcs = glob([
"src/*.zig",
"src/backend/*.zig",
"src/linux/*.zig",
"src/watcher/*.zig",
]),
main = "main2.zig",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20241208.0-db6a52b",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,22 @@
const builtin = @import("builtin");
const root = @import("root");
const main = @import("src/main.zig");
pub const ThreadPool = main.ThreadPool;
pub const stream = main.stream;
pub const Options = struct {
linux_backend: main.Backend = .epoll,
};
pub const options: Options = if (@hasDecl(root, "xev_options")) root.xev_options else .{};
const default: main.Backend = switch (builtin.os.tag) {
.ios, .macos => .kqueue,
.linux => options.linux_backend,
.wasi => .wasi_poll,
.windows => .iocp,
else => @compileError("Unsupported OS"),
};
pub usingnamespace default.Api();

View File

@ -0,0 +1,159 @@
From 0d1c2f8258072148459d3114b9ccaf43c02e0958 Mon Sep 17 00:00:00 2001
From: Steeve Morin <steeve@zml.ai>
Date: Tue, 19 Nov 2024 16:14:14 +0100
Subject: [PATCH 1/2] backend/epoll: implement eventfd wakeup notification
Tries to mimic what happens in backend/kqueue.
Closes #4
---
src/backend/epoll.zig | 42 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 42 insertions(+)
diff --git a/src/backend/epoll.zig b/src/backend/epoll.zig
index ae4ec7d..f44d326 100644
--- a/src/backend/epoll.zig
+++ b/src/backend/epoll.zig
@@ -21,6 +21,12 @@ pub const Loop = struct {
fd: posix.fd_t,
+ /// The eventfd that this epoll queue always has a filter for. Writing
+ /// an empty message to this eventfd can be used to wake up the loop
+ /// at any time. Waking up the loop via this eventfd won't trigger any
+ /// particular completion, it just forces tick to cycle.
+ eventfd: xev.Async,
+
/// The number of active completions. This DOES NOT include completions that
/// are queued in the submissions queue.
active: usize = 0,
@@ -56,8 +62,12 @@ pub const Loop = struct {
} = .{},
pub fn init(options: xev.Options) !Loop {
+ var eventfd = try xev.Async.init();
+ errdefer eventfd.deinit();
+
var res: Loop = .{
.fd = try posix.epoll_create1(std.os.linux.EPOLL.CLOEXEC),
+ .eventfd = eventfd,
.thread_pool = options.thread_pool,
.thread_pool_completions = undefined,
.cached_now = undefined,
@@ -68,6 +78,7 @@ pub const Loop = struct {
pub fn deinit(self: *Loop) void {
posix.close(self.fd);
+ self.eventfd.deinit();
}
/// Run the event loop. See RunMode documentation for details on modes.
@@ -262,9 +273,26 @@ pub const Loop = struct {
// Initialize
if (!self.flags.init) {
self.flags.init = true;
+
if (self.thread_pool != null) {
self.thread_pool_completions.init();
}
+
+ var ev: linux.epoll_event = .{
+ .events = linux.EPOLL.IN | linux.EPOLL.RDHUP,
+ .data = .{ .ptr = 0 },
+ };
+ posix.epoll_ctl(
+ self.fd,
+ linux.EPOLL.CTL_ADD,
+ self.eventfd.fd,
+ &ev,
+ ) catch |err| {
+ // We reset initialization because we can't do anything
+ // safely unless we get this mach port registered!
+ self.flags.init = false;
+ return err;
+ };
}
// Submit all the submissions. We copy the submission queue so that
@@ -369,6 +397,10 @@ pub const Loop = struct {
// Process all our events and invoke their completion handlers
for (events[0..n]) |ev| {
+ // Zero data values are internal events that we do nothing
+ // on such as the eventfd wakeup.
+ if (ev.data.ptr == 0) continue;
+
const c: *Completion = @ptrFromInt(@as(usize, @intCast(ev.data.ptr)));
// We get the fd and mark this as in progress we can properly
@@ -415,6 +447,7 @@ pub const Loop = struct {
const pool = self.thread_pool orelse return error.ThreadPoolRequired;
// Setup our completion state so that thread_perform can do stuff
+ c.task_loop = self;
c.task_completions = &self.thread_pool_completions;
c.task = .{ .callback = Loop.thread_perform };
@@ -436,6 +469,14 @@ pub const Loop = struct {
// Add to our completion queue
c.task_completions.push(c);
+
+ // Wake up our main loop
+ c.task_loop.wakeup() catch {};
+ }
+
+ /// Sends an empty message to this loop's eventfd so that it wakes up.
+ fn wakeup(self: *Loop) !void {
+ try self.eventfd.notify();
}
fn start(self: *Loop, completion: *Completion) void {
@@ -800,6 +841,7 @@ pub const Completion = struct {
/// reliable way to get access to the loop and shouldn't be used
/// except internally.
task: ThreadPool.Task = undefined,
+ task_loop: *Loop = undefined,
task_completions: *Loop.TaskCompletionQueue = undefined,
task_result: Result = undefined,
From 38d4dbed71a732b0fc30c1181354ad9d53919402 Mon Sep 17 00:00:00 2001
From: Corentin Godeau <corentin@zml.ai>
Date: Tue, 14 Jan 2025 14:43:54 +0000
Subject: [PATCH 2/2] backend/epoll: read the wakeup eventfd to avoid being
awaken again
---
src/backend/epoll.zig | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/backend/epoll.zig b/src/backend/epoll.zig
index f44d326..f84c687 100644
--- a/src/backend/epoll.zig
+++ b/src/backend/epoll.zig
@@ -280,7 +280,7 @@ pub const Loop = struct {
var ev: linux.epoll_event = .{
.events = linux.EPOLL.IN | linux.EPOLL.RDHUP,
- .data = .{ .ptr = 0 },
+ .data = .{ .fd = self.eventfd.fd },
};
posix.epoll_ctl(
self.fd,
@@ -397,9 +397,12 @@ pub const Loop = struct {
// Process all our events and invoke their completion handlers
for (events[0..n]) |ev| {
- // Zero data values are internal events that we do nothing
- // on such as the eventfd wakeup.
- if (ev.data.ptr == 0) continue;
+ // Handle wakeup eventfd
+ if (ev.data.fd == self.eventfd.fd) {
+ var buffer: u64 = undefined;
+ _ = posix.read(self.eventfd.fd, std.mem.asBytes(&buffer)) catch {};
+ continue;
+ }
const c: *Completion = @ptrFromInt(@as(usize, @intCast(ev.data.ptr)));

View File

@ -0,0 +1,14 @@
{
"strip_prefix": "libxev-db6a52bafadf00360e675fefa7926e8e6c0e9931",
"url": "https://github.com/zml/libxev/archive/db6a52bafadf00360e675fefa7926e8e6c0e9931.tar.gz",
"integrity": "sha256-4GT5wkfkZnIjNv20yDiWEzHAhbIiwHHJfS7A4u/LoNQ=",
"overlay": {
"MODULE.bazel": "",
"BUILD.bazel": "",
"main2.zig": ""
},
"patches": {
"128.patch": ""
},
"patch_strip": 1
}

View File

@ -12,7 +12,9 @@
], ],
"versions": [ "versions": [
"20240825.0-dbe2291", "20240825.0-dbe2291",
"20240910.0-a2d9b31" "20240910.0-a2d9b31",
"20241119.0-6afcde9",
"20241208.0-db6a52b"
], ],
"yanked_versions": {} "yanked_versions": {}
} }

View File

@ -2,7 +2,6 @@ const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const dialect = @import("mlir/dialects"); const dialect = @import("mlir/dialects");
const runfiles = @import("runfiles");
const stdx = @import("stdx"); const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto"); const xla_pb = @import("//xla:xla_proto");
@ -902,10 +901,11 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
} }
} }
switch (platform.target) { switch (platform.target) {
.cuda => cuda_dir: { .cuda => {
// NVIDIA recommends to disable Triton GEMM on JAX: // NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables // https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
setFlag(&options, "xla_gpu_enable_triton_gemm", false); setFlag(&options, "xla_gpu_enable_triton_gemm", false);
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true); // setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true); // setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
@ -913,17 +913,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true);
// setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
log.warn("Bazel runfile not found !", .{});
break :cuda_dir;
};
defer r_.deinit(arena);
const source_repo = @import("bazel_builtin").current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
}, },
.rocm => { .rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when // Disable Triton GEMM on ROCM. For some reason it's much, much slower when