diff --git a/MODULE.bazel b/MODULE.bazel index 6df826e..45fb05d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -17,7 +17,7 @@ bazel_dep(name = "rules_proto", version = "7.1.0") bazel_dep(name = "rules_python", version = "0.40.0") bazel_dep(name = "rules_rust", version = "0.60.0") bazel_dep(name = "rules_uv", version = "0.65.0") -bazel_dep(name = "rules_zig", version = "20250530.0-5084f1f") +bazel_dep(name = "rules_zig", version = "20250613.0-567662a") bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.3") bazel_dep(name = "toolchains_protoc", version = "0.4.1") diff --git a/async/BUILD.bazel b/async/BUILD.bazel index 9d6efff..65e44d0 100644 --- a/async/BUILD.bazel +++ b/async/BUILD.bazel @@ -1,4 +1,6 @@ -load("@rules_zig//zig:defs.bzl", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") +load("@zml//bazel:zig_srcs.bzl", "zig_srcs") + zig_library( name = "async", @@ -18,3 +20,14 @@ zig_library( "@libxev//:xev", ], ) + +zig_test( + name = "test", + deps = [":async"], + testonly = False, +) + +zig_srcs( + name = "sources", + zig_bin = ":test", +) diff --git a/async/async.zig b/async/async.zig index 78e225b..f14608b 100644 --- a/async/async.zig +++ b/async/async.zig @@ -1,13 +1,19 @@ const std = @import("std"); + const stdx = @import("stdx"); const xev = @import("xev").Dynamic; +const XevThreadPool = @import("xev").ThreadPool; + +const aio = @import("asyncio.zig"); +const channel_mod = @import("channel.zig"); const coro = @import("coro.zig"); const executor = @import("executor.zig"); -const channel_mod = @import("channel.zig"); -const aio = @import("asyncio.zig"); const stack = @import("stack.zig"); -const XevThreadPool = @import("xev").ThreadPool; +test { + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(coro); +} pub const Condition = struct { inner: executor.Condition, diff --git a/async/coro.zig b/async/coro.zig index cfad9c8..be26b95 100644 --- a/async/coro.zig +++ b/async/coro.zig @@ -265,29 +265,6 @@ const CoroT = struct { } }; -/// Estimates the remaining stack size in the currently running coroutine -pub noinline fn remainingStackSize() usize { - var dummy: usize = 0; - dummy += 1; - const addr = @intFromPtr(&dummy); - - // Check if the stack was already overflowed - const current = xframe(); - StackOverflow.check(current) catch return 0; - - // Check if the stack is currently overflowed - const bottom = @intFromPtr(current.stack.ptr); - if (addr < bottom) { - return 0; - } - - // Debug check that we're actually in the stack - const top = @intFromPtr(current.stack.ptr + current.stack.len); - std.debug.assert(addr < top); // should never have popped beyond the top - - return addr - bottom; -} - // ============================================================================ /// Thread-local coroutine runtime @@ -450,7 +427,9 @@ fn testSetIdx(val: usize) void { } fn testFn() void { - std.debug.assert(remainingStackSize() > 2048); + // Check if the stack was already overflowed + const current = xframe(); + std.debug.assert(current.stack.remaining().len > 2048); testSetIdx(2); xsuspend(); testSetIdx(4); diff --git a/bazel/zig_srcs.bzl b/bazel/zig_srcs.bzl new file mode 100644 index 0000000..2aa7956 --- /dev/null +++ b/bazel/zig_srcs.bzl @@ -0,0 +1,34 @@ +load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") +load("@rules_zig//zig:defs.bzl", "zig_binary", "BINARY_KIND") + +def zig_srcs(name, zig_bin="", zig_lib=""): + """For a given zig_library, recursively extract all zig sources into a tarball. + + This also includes the files translated from C headers. + It's also possible to pass zig_lib instead of zig_bin in which case, + The rule takes care of creating an intermediary binary from the lib. + """ + if zig_bin == "": + zig_bin = "{}_bin".format(name) + zig_binary( + name = zig_bin, + kind = BINARY_KIND.bc, + tags = ["manual", "@rules_zig//zig/lib:libc"], + deps = [zig_lib], + ) + + native.filegroup( + name = "{}_files".format(name), + srcs = [zig_bin], + output_group = "srcs", + ) + mtree_spec( + name = "{}_mtree".format(name), + srcs = [":{}_files".format(name)], + ) + tar( + name = name, + srcs = ["{}_files".format(name)], + args = [], + mtree = "{}_mtree".format(name), + ) diff --git a/mlir/BUILD.bazel b/mlir/BUILD.bazel index 755d6be..63b9ede 100644 --- a/mlir/BUILD.bazel +++ b/mlir/BUILD.bazel @@ -1,5 +1,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") + load("@rules_zig//zig:defs.bzl", "zig_library") +load("//bazel:zig_srcs.bzl", "zig_srcs") load("//bazel:zig.bzl", "zig_cc_test") cc_library( @@ -30,3 +32,13 @@ zig_cc_test( name = "test", deps = [":mlir"], ) + +cc_static_library( + name="mlir_static", + deps = ["c"] +) + +zig_srcs( + name = "sources", + zig_bin = ":test_test_lib", +) diff --git a/mlir/dialects/BUILD.bazel b/mlir/dialects/BUILD.bazel index c4c8220..bcb3a1b 100644 --- a/mlir/dialects/BUILD.bazel +++ b/mlir/dialects/BUILD.bazel @@ -1,5 +1,6 @@ load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:zig.bzl", "zig_cc_test") +load("//bazel:zig_srcs.bzl", "zig_srcs") zig_library( name = "dialects", @@ -24,6 +25,19 @@ zig_cc_test( deps = [":dialects"], ) +zig_srcs( + name = "sources", + zig_bin = ":test_test_lib", +) + +cc_static_library( + name="mlir_static", + deps = [ + "//mlir:c", + "@stablehlo//:stablehlo_dialect_capi", + ] +) + zig_library( name = "stablehlo", import_name = "mlir/dialects/stablehlo", diff --git a/mlir/mlir.zig b/mlir/mlir.zig old mode 100644 new mode 100755 index 918c304..10d26c9 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -8,6 +8,8 @@ const log = std.log.scoped(.mlir); test { std.testing.refAllDecls(@This()); + + _ = try Context.init(); } const Error = error{ diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 3ed3483..682b569 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -1,15 +1,7 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_zig//zig:defs.bzl", "zig_library") load("@zml//bazel:zig.bzl", "zig_cc_binary") -load("//bazel:zig_proto_library.bzl", "zig_proto_library") - -cc_library( - name = "dlfcn", - hdrs = ["dlfcn.h"], - target_compatible_with = [ - "@platforms//os:linux", - ], -) +load("@zml//bazel:zig_srcs.bzl", "zig_srcs") +load("@zml//bazel:zig_proto_library.bzl", "zig_proto_library") zig_library( name = "pjrt", @@ -32,10 +24,12 @@ zig_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - ] + select({ - "@platforms//os:linux": [":dlfcn"], - "//conditions:default": [], - }), + ], +) + +zig_srcs( + name = "sources", + zig_lib = ":pjrt", ) zig_proto_library( diff --git a/pjrt/dlfcn.h b/pjrt/dlfcn.h deleted file mode 100644 index 9e446a2..0000000 --- a/pjrt/dlfcn.h +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 18260f9..a0be0b2 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -74,20 +74,19 @@ pub const Api = struct { inner: c.PJRT_Api, - pub fn loadFrom(library: []const u8) !*const Api { + pub fn loadFrom(library: [:0]const u8) !*const Api { var lib: std.DynLib = switch (builtin.os.tag) { .linux => blk: { - const library_c = try std.posix.toPosixPath(library); - break :blk .{ - .inner = .{ - .handle = c.dlopen(&library_c, c.RTLD_LAZY | c.RTLD_LOCAL | c.RTLD_NODELETE) orelse { - log.err("Unable to dlopen plugin: {s}", .{library}); - return error.FileNotFound; - }, - }, + const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse { + log.err("Unable to dlopen plugin: {s}", .{library}); + return error.FileNotFound; }; + break :blk .{ .inner = .{ .handle = handle } }; + }, + else => std.DynLib.open(library) catch |err| { + log.err("Unable to dlopen plugin: {s}", .{library}); + return err; }, - else => try std.DynLib.open(library), }; const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse { std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library}); diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig index bc74969..46eb69d 100644 --- a/runtimes/cuda/cuda.zig +++ b/runtimes/cuda/cuda.zig @@ -36,7 +36,8 @@ fn setupXlaGpuCudaDirFlag() !void { defer arena.deinit(); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { - stdx.debug.panic("Unable to find CUDA directory", .{}); + log.warn("Unable to find CUDA directory. Using system defaults.", .{}); + return; }; const source_repo = bazel_builtin.current_repository; diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index f2fb4ea..59abb65 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -1,4 +1,5 @@ -load("@rules_zig//zig:defs.bzl", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") +load("@zml//bazel:zig_srcs.bzl", "zig_srcs") zig_library( name = "stdx", @@ -17,3 +18,14 @@ zig_library( main = "stdx.zig", visibility = ["//visibility:public"], ) + +zig_test( + name = "test", + deps = [":stdx"], + testonly = False, +) + +zig_srcs( + name = "sources", + zig_bin = ":test", +) diff --git a/stdx/queue.zig b/stdx/queue.zig index e4722af..ebbf058 100644 --- a/stdx/queue.zig +++ b/stdx/queue.zig @@ -91,7 +91,7 @@ test SPSC { try testing.expect(q.empty()); // Elems - var elems: [10]Elem = .{.{}} ** 10; + var elems: [10]Elem = @splat(.{}); // One try testing.expect(q.pop() == null); @@ -207,7 +207,7 @@ test MPSC { q.init(); // Elems - var elems: [10]Elem = .{.{}} ** 10; + var elems: [10]Elem = @splat(.{}); // One try testing.expect(q.pop() == null); diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 226c122..cda9635 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -8,6 +8,11 @@ pub const meta = @import("meta.zig"); pub const queue = @import("queue.zig"); pub const time = @import("time.zig"); +test { + const std = @import("std"); + std.testing.refAllDecls(@This()); +} + pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T { debug.assert(len <= max_len, "stackSlice can only create a slice of up to {} elements, got: {}", .{ max_len, len }); var storage: [max_len]T = undefined; diff --git a/third_party/modules/rules_zig/20250613.0-567662a/MODULE.bazel b/third_party/modules/rules_zig/20250613.0-567662a/MODULE.bazel new file mode 100644 index 0000000..4189634 --- /dev/null +++ b/third_party/modules/rules_zig/20250613.0-567662a/MODULE.bazel @@ -0,0 +1,68 @@ +module( + name = "rules_zig", + version = "20250613.0-567662a", + compatibility_level = 1, +) + +bazel_dep(name = "aspect_bazel_lib", version = "2.8.1") +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") + +zig = use_extension("//zig:extensions.bzl", "zig") +zig.index(file = "//zig/private:versions.json") +use_repo(zig, "zig_toolchains") + +register_toolchains("@rules_zig//zig/target:all") + +register_toolchains("@zig_toolchains//:all") + +zig_dev = use_extension( + "//zig:extensions.bzl", + "zig", + dev_dependency = True, +) +zig_dev.toolchain(zig_version = "0.13.0") +zig_dev.toolchain(zig_version = "0.12.1") +zig_dev.toolchain(zig_version = "0.12.0") +zig_dev.toolchain(zig_version = "0.11.0") + +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc") +bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle") +bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True) +bazel_dep( + name = "buildifier_prebuilt", + version = "7.3.1", + dev_dependency = True, +) +bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True) +bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True) +bazel_dep( + name = "rules_bazel_integration_test", + version = "0.25.0", + dev_dependency = True, +) + +bazel_binaries = use_extension( + "@rules_bazel_integration_test//:extensions.bzl", + "bazel_binaries", + dev_dependency = True, +) + +# NOTE: Keep in sync with WORKSPACE. +bazel_binaries.download(version_file = "//:.bazelversion") +bazel_binaries.download(version = "7.0.0") +use_repo( + bazel_binaries, + "bazel_binaries", + "bazel_binaries_bazelisk", + "build_bazel_bazel_.bazelversion", + "build_bazel_bazel_7_0_0", +) + +# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test. +# However, if we do not include it explicitly, then the runfiles resolution for +# cgrindel_bazel_starlib/shlib/lib/message.sh fails in +# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked +# through the rules_multirun target //util:update. +bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True) diff --git a/third_party/modules/rules_zig/20250613.0-567662a/source.json b/third_party/modules/rules_zig/20250613.0-567662a/source.json new file mode 100644 index 0000000..4dbdafe --- /dev/null +++ b/third_party/modules/rules_zig/20250613.0-567662a/source.json @@ -0,0 +1,5 @@ +{ + "strip_prefix": "rules_zig-567662a3ce5e87894950d56c69d8de5ad6e0b5f0", + "url": "https://github.com/zml/rules_zig/archive/567662a3ce5e87894950d56c69d8de5ad6e0b5f0.tar.gz", + "integrity": "sha256-PjFpDJXO0BGx2CAZ54Ppv6d4TwTDOWS+mVangKQFvZc=" +} diff --git a/third_party/modules/rules_zig/metadata.json b/third_party/modules/rules_zig/metadata.json index 51e22e7..90794cf 100644 --- a/third_party/modules/rules_zig/metadata.json +++ b/third_party/modules/rules_zig/metadata.json @@ -16,7 +16,8 @@ "20240912.0-41bfe84", "20240913.0-1957d05", "20250314.0-b9739c6", - "20250519.0-233b207" + "20250519.0-233b207", + "20250613.0-567662a" ], "yanked_versions": {} } diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index b346b04..218c484 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -1,8 +1,8 @@ -load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:zig.bzl", "zig_cc_test") load("//bazel:zig_proto_library.bzl", "zig_proto_library") +load("//bazel:zig_srcs.bzl", "zig_srcs") cc_library( name = "posix", @@ -25,7 +25,6 @@ zig_library( visibility = ["//visibility:public"], deps = [ ":posix", - ":sentencepiece_model_proto", ":xla_proto", "//async", "//mlir", @@ -35,9 +34,7 @@ zig_library( "//stdx", "//zml/tokenizer", "//zml/tools", - "@rules_zig//zig/lib:libc", "@rules_zig//zig/runfiles", - "@zig-yaml//:zig-yaml", ], ) @@ -47,11 +44,6 @@ zig_proto_library( deps = ["@xla//xla/pjrt/proto:compile_options_proto"], ) -zig_proto_library( - name = "sentencepiece_model_proto", - import_name = "//sentencepiece:model_proto", - deps = ["@sentencepiece//:sentencepiece_model_proto"], -) # All ZML Tests @@ -71,21 +63,7 @@ filegroup( visibility = ["//visibility:public"], ) -filegroup( - name = "srcs", - srcs = [":test_test_lib"], - output_group = "srcs", -) - -mtree_spec( - name = "mtree", - srcs = [":srcs"], -) - -tar( +zig_srcs( name = "sources", - srcs = [":srcs"], - args = [ - ], - mtree = ":mtree", + zig_bin = ":test_test_lib", ) diff --git a/zml/aio.zig b/zml/aio.zig index 56fc14e..467def7 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -5,11 +5,11 @@ const c = @import("c"); const stdx = @import("stdx"); pub const gguf = @import("aio/gguf.zig"); -pub const nemo = @import("aio/nemo.zig"); +// pub const nemo = @import("aio/nemo.zig"); pub const safetensors = @import("aio/safetensors.zig"); pub const tinyllama = @import("aio/tinyllama.zig"); pub const torch = @import("aio/torch.zig"); -pub const yaml = @import("aio/yaml.zig"); +// pub const yaml = @import("aio/yaml.zig"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; const posix = @import("posix.zig"); const zml = @import("zml.zig"); @@ -18,10 +18,10 @@ pub const log = std.log.scoped(.@"zml/aio"); test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(gguf); - std.testing.refAllDecls(nemo); + // std.testing.refAllDecls(nemo); std.testing.refAllDecls(safetensors); std.testing.refAllDecls(torch); - std.testing.refAllDecls(yaml); + // std.testing.refAllDecls(yaml); } // TODO error set for weight loading diff --git a/zml/tokenizer/hftokenizers/BUILD.bazel b/zml/tokenizer/hftokenizers/BUILD.bazel index d2ac5ac..68188e3 100644 --- a/zml/tokenizer/hftokenizers/BUILD.bazel +++ b/zml/tokenizer/hftokenizers/BUILD.bazel @@ -28,3 +28,11 @@ zig_library( "//ffi:zig", ], ) + +cc_static_library( + name="hftokenizer_static", + deps = [ + ":hftokenizers_rs", + "//ffi:cc", + ] +) diff --git a/zml/tokenizer/sentencepiece/BUILD.bazel b/zml/tokenizer/sentencepiece/BUILD.bazel index 354b289..457d811 100644 --- a/zml/tokenizer/sentencepiece/BUILD.bazel +++ b/zml/tokenizer/sentencepiece/BUILD.bazel @@ -1,5 +1,6 @@ load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:swig.bzl", "swig_cc_library") +load("//bazel:zig_srcs.bzl", "zig_srcs") swig_cc_library( name = "sentencepiece_swig", @@ -21,3 +22,8 @@ zig_library( "//ffi:zig", ], ) + +zig_srcs( + name = "sources", + zig_lib = ":sentencepiece", +) diff --git a/zml/tokenizer/tokenizer.zig b/zml/tokenizer/tokenizer.zig index a28b3a0..443e2fb 100644 --- a/zml/tokenizer/tokenizer.zig +++ b/zml/tokenizer/tokenizer.zig @@ -1,10 +1,18 @@ const std = @import("std"); + +const asynk = @import("async"); const hftokenizers = @import("hftokenizers"); const sentencepiece = @import("sentencepiece"); -const asynk = @import("async"); const homemade = @import("homemade.zig"); +test { + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(hftokenizers); + std.testing.refAllDecls(sentencepiece); + std.testing.refAllDecls(homemade); +} + const Tokenizers = enum { hftokenizers, sentencepiece, diff --git a/zml/tools/BUILD.bazel b/zml/tools/BUILD.bazel index 7ffe3a5..7579fb9 100644 --- a/zml/tools/BUILD.bazel +++ b/zml/tools/BUILD.bazel @@ -22,3 +22,8 @@ zig_library( "//conditions:default": [], }), ) + +cc_static_library( + name = "macos_static_tools", + deps = ["macos_c"] +)