diff --git a/runtimes/common/packages.bzl b/runtimes/common/packages.bzl index b5e20cf..bf30258 100644 --- a/runtimes/common/packages.bzl +++ b/runtimes/common/packages.bzl @@ -10,6 +10,22 @@ def _kwargs(**kwargs): def _cc_import(**kwargs): return """cc_import({})""".format(_kwargs(**kwargs)) +def _cc_import_glob_hdrs(name, hdrs_glob, shared_library, deps = []): + return """\ +filegroup( + name = "{name}_files", + srcs = glob(["{hdrs_glob}"]), + visibility = ["//visibility:public"], +) +cc_import( + name = "{name}", + shared_library = {shared_library}, + hdrs = [":{name}_files"], + deps = {deps}, + visibility = ["//visibility:public"], +) +""".format(name = name, hdrs_glob = hdrs_glob, shared_library = repr(shared_library), deps = repr(deps)) + def _filegroup(**kwargs): return """filegroup({})""".format(_kwargs(**kwargs)) @@ -59,6 +75,7 @@ common_apt_packages = module_extension( packages = struct( read = _read, cc_import = _cc_import, + cc_import_glob_hdrs = _cc_import_glob_hdrs, filegroup = _filegroup, load_ = _load, ) diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 58fec19..82d1e44 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -25,6 +25,11 @@ CUDA_PACKAGES = { name = "cupti", shared_library = "lib/libcupti.so.12", ), + "cuda_nvtx": packages.cc_import_glob_hdrs( + name = "nvtx", + hdrs_glob = "include/nvtx3/**/*.h", + shared_library = "lib/libnvToolsExt.so.1", + ), "libcufft": packages.cc_import( name = "cufft", shared_library = "lib/libcufft.so.11", diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index 4412cd5..8881a01 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -35,6 +35,7 @@ cc_import( ":zmlxcuda", "@cuda_cudart//:cudart", "@cuda_cupti//:cupti", + "@cuda_nvtx//:nvtx", "@cuda_nvcc//:nvptxcompiler", "@cuda_nvcc//:nvvm", "@cuda_nvrtc//:nvrtc", diff --git a/zml/tools/tracer.zig b/zml/tools/tracer.zig index ed7d7fe..0576956 100644 --- a/zml/tools/tracer.zig +++ b/zml/tools/tracer.zig @@ -1,13 +1,50 @@ const builtin = @import("builtin"); +const c = @import("c"); pub const Tracer = switch (builtin.os.tag) { .macos => MacOsTracer, + .linux => if (@hasDecl(c, "ZML_RUNTIME_CUDA")) CudaTracer else FakeTracer, else => FakeTracer, }; -const MacOsTracer = struct { - const c = @import("c"); +const CudaTracer = struct { + extern fn cudaProfilerStart() c_int; + extern fn cudaProfilerStop() c_int; + extern fn nvtxMarkA(message: [*:0]const u8) void; + extern fn nvtxRangeStartA(message: [*:0]const u8) c_int; + extern fn nvtxRangeEnd(id: c_int) void; + + pub fn init(name: [:0]const u8) CudaTracer { + _ = name; + _ = cudaProfilerStart(); + return .{}; + } + + pub fn deinit(self: *const CudaTracer) void { + _ = self; + _ = cudaProfilerStop(); + } + + pub fn event(self: *const CudaTracer, message: [:0]const u8) void { + _ = self; + nvtxMarkA(message.ptr); + } + + pub fn frameStart(self: *const CudaTracer, message: [:0]const u8) u64 { + _ = self; + return @intCast(nvtxRangeStartA(message.ptr)); + } + + pub fn frameEnd(self: *const CudaTracer, interval_id: u64, message: [:0]const u8) void { + _ = self; + _ = message; + nvtxRangeEnd(@intCast(interval_id)); + return; + } +}; + +const MacOsTracer = struct { logger: c.os_log_t, pub fn init(name: [:0]const u8) MacOsTracer { @@ -40,7 +77,7 @@ const FakeTracer = struct { return .{}; } - pub fn event(self: *const MacOsTracer, message: [:0]const u8) void { + pub fn event(self: *const FakeTracer, message: [:0]const u8) void { _ = self; _ = message; return;