zml: add support for NVTX tracing

This commit is contained in:
Tarry Singh 2024-08-21 14:41:40 +00:00
parent a5e588f53b
commit 63ef78efcc
4 changed files with 63 additions and 3 deletions

View File

@ -10,6 +10,22 @@ def _kwargs(**kwargs):
def _cc_import(**kwargs): def _cc_import(**kwargs):
return """cc_import({})""".format(_kwargs(**kwargs)) return """cc_import({})""".format(_kwargs(**kwargs))
def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []):
return """\
filegroup(
name = "{name}_files",
srcs = glob(["{hdrs_glob}"]),
visibility = ["//visibility:public"],
)
cc_import(
name = "{name}",
shared_library = {shared_library},
hdrs = [":{name}_files"],
deps = {deps},
visibility = ["//visibility:public"],
)
""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps))
def _filegroup(**kwargs): def _filegroup(**kwargs):
return """filegroup({})""".format(_kwargs(**kwargs)) return """filegroup({})""".format(_kwargs(**kwargs))
@ -59,6 +75,7 @@ common_apt_packages = module_extension(
packages = struct( packages = struct(
read = _read, read = _read,
cc_import = _cc_import, cc_import = _cc_import,
cc_import_glob_hdrs = _cc_import_glob_hdrs,
filegroup = _filegroup, filegroup = _filegroup,
load_ = _load, load_ = _load,
) )

View File

@ -25,6 +25,11 @@ CUDA_PACKAGES = {
name = "cupti", name = "cupti",
shared_library = "lib/libcupti.so.12", shared_library = "lib/libcupti.so.12",
), ),
"cuda_nvtx": packages.cc_import_glob_hdrs(
name = "nvtx",
hdrs_glob = "include/nvtx3/**/*.h",
shared_library = "lib/libnvToolsExt.so.1",
),
"libcufft": packages.cc_import( "libcufft": packages.cc_import(
name = "cufft", name = "cufft",
shared_library = "lib/libcufft.so.11", shared_library = "lib/libcufft.so.11",

View File

@ -35,6 +35,7 @@ cc_import(
":zmlxcuda", ":zmlxcuda",
"@cuda_cudart//:cudart", "@cuda_cudart//:cudart",
"@cuda_cupti//:cupti", "@cuda_cupti//:cupti",
"@cuda_nvtx//:nvtx",
"@cuda_nvcc//:nvptxcompiler", "@cuda_nvcc//:nvptxcompiler",
"@cuda_nvcc//:nvvm", "@cuda_nvcc//:nvvm",
"@cuda_nvrtc//:nvrtc", "@cuda_nvrtc//:nvrtc",

View File

@ -1,13 +1,50 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const c = @import("c");
pub const Tracer = switch (builtin.os.tag) { pub const Tracer = switch (builtin.os.tag) {
.macos => MacOsTracer, .macos => MacOsTracer,
.linux => if (@hasDecl(c, "ZML_RUNTIME_CUDA")) CudaTracer else FakeTracer,
else => FakeTracer, else => FakeTracer,
}; };
const MacOsTracer = struct { const CudaTracer = struct {
const c = @import("c"); extern fn cudaProfilerStart() c_int;
extern fn cudaProfilerStop() c_int;
extern fn nvtxMarkA(message: [*:0]const u8) void;
extern fn nvtxRangeStartA(message: [*:0]const u8) c_int;
extern fn nvtxRangeEnd(id: c_int) void;
pub fn init(name: [:0]const u8) CudaTracer {
_ = name;
_ = cudaProfilerStart();
return .{};
}
pub fn deinit(self: *const CudaTracer) void {
_ = self;
_ = cudaProfilerStop();
}
pub fn event(self: *const CudaTracer, message: [:0]const u8) void {
_ = self;
nvtxMarkA(message.ptr);
}
pub fn frameStart(self: *const CudaTracer, message: [:0]const u8) u64 {
_ = self;
return @intCast(nvtxRangeStartA(message.ptr));
}
pub fn frameEnd(self: *const CudaTracer, interval_id: u64, message: [:0]const u8) void {
_ = self;
_ = message;
nvtxRangeEnd(@intCast(interval_id));
return;
}
};
const MacOsTracer = struct {
logger: c.os_log_t, logger: c.os_log_t,
pub fn init(name: [:0]const u8) MacOsTracer { pub fn init(name: [:0]const u8) MacOsTracer {
@ -40,7 +77,7 @@ const FakeTracer = struct {
return .{}; return .{};
} }
pub fn event(self: *const MacOsTracer, message: [:0]const u8) void { pub fn event(self: *const FakeTracer, message: [:0]const u8) void {
_ = self; _ = self;
_ = message; _ = message;
return; return;