diff --git a/bazel/http_deb_archive.bzl b/bazel/http_deb_archive.bzl index f847e6f..91b52f5 100644 --- a/bazel/http_deb_archive.bzl +++ b/bazel/http_deb_archive.bzl @@ -39,6 +39,7 @@ http_deb_archive = repository_rule( "strip_prefix": attr.string(), "build_file": attr.label(allow_single_file = True), "build_file_content": attr.string(), + "patches": attr.label_list(allow_files = True), "workspace_file": attr.label(allow_single_file = True), "workspace_file_content": attr.string(), }, diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 4698349..3ec8b8b 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -24,3 +24,13 @@ pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T { } pub const noalloc: std.mem.Allocator = if (builtin.mode == .ReleaseFast) undefined else std.testing.failing_allocator; + +pub fn pinToCore(core_id: usize) void { + if (builtin.os.tag == .linux) { + const CPUSet = std.bit_set.ArrayBitSet(usize, std.os.linux.CPU_SETSIZE * @sizeOf(usize)); + + var set: CPUSet = .initEmpty(); + set.set(core_id); + std.os.linux.sched_setaffinity(0, @ptrCast(&set.masks)) catch {}; + } +} diff --git a/zml/callback.zig b/zml/callback.zig index 5af96d1..ea7a4b2 100644 --- a/zml/callback.zig +++ b/zml/callback.zig @@ -120,7 +120,13 @@ pub fn call( pub const Config = struct { output_operand_aliases: []const i64 = &.{}, copy_inputs_to_host_pinned: bool = false, - // TODO: document precisely what `command_buffer_compatible` is doing and its limitations. + + /// Indicates that the handler is compatible with command buffers (ie Cuda graphs). + /// The Cuda backend traces the execution of the callback and records it as a Cuda graph. + /// To be compatible with CUDA graphs, the callback need: + /// 1. to not use any synchronous operations + /// 2. to not allocate any new host/device memory + /// See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#prohibited-and-unhandled-operations. traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false }, // TODO: handle sharded inputs has_side_effect: bool = true,