Update rules_zig: add zig_srcs target, fix source handling bug, clean up BUILD files, adjust async/coro.zig tests, and disable nemo and yaml model loaders.

This commit is contained in:
Tarry Singh 2025-03-13 12:27:21 +00:00
parent 6fc1148206
commit f27a524f31
24 changed files with 238 additions and 89 deletions

View File

@ -17,7 +17,7 @@ bazel_dep(name = "rules_proto", version = "7.1.0")
bazel_dep(name = "rules_python", version = "0.40.0")
bazel_dep(name = "rules_rust", version = "0.60.0")
bazel_dep(name = "rules_uv", version = "0.65.0")
bazel_dep(name = "rules_zig", version = "20250530.0-5084f1f")
bazel_dep(name = "rules_zig", version = "20250613.0-567662a")
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.3")
bazel_dep(name = "toolchains_protoc", version = "0.4.1")

View File

@ -1,4 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test")
load("@zml//bazel:zig_srcs.bzl", "zig_srcs")
zig_library(
name = "async",
@ -18,3 +20,14 @@ zig_library(
"@libxev//:xev",
],
)
zig_test(
name = "test",
deps = [":async"],
testonly = False,
)
zig_srcs(
name = "sources",
zig_bin = ":test",
)

View File

@ -1,13 +1,19 @@
const std = @import("std");
const stdx = @import("stdx");
const xev = @import("xev").Dynamic;
const XevThreadPool = @import("xev").ThreadPool;
const aio = @import("asyncio.zig");
const channel_mod = @import("channel.zig");
const coro = @import("coro.zig");
const executor = @import("executor.zig");
const channel_mod = @import("channel.zig");
const aio = @import("asyncio.zig");
const stack = @import("stack.zig");
const XevThreadPool = @import("xev").ThreadPool;
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(coro);
}
pub const Condition = struct {
inner: executor.Condition,

View File

@ -265,29 +265,6 @@ const CoroT = struct {
}
};
/// Estimates the remaining stack size in the currently running coroutine
pub noinline fn remainingStackSize() usize {
var dummy: usize = 0;
dummy += 1;
const addr = @intFromPtr(&dummy);
// Check if the stack was already overflowed
const current = xframe();
StackOverflow.check(current) catch return 0;
// Check if the stack is currently overflowed
const bottom = @intFromPtr(current.stack.ptr);
if (addr < bottom) {
return 0;
}
// Debug check that we're actually in the stack
const top = @intFromPtr(current.stack.ptr + current.stack.len);
std.debug.assert(addr < top); // should never have popped beyond the top
return addr - bottom;
}
// ============================================================================
/// Thread-local coroutine runtime
@ -450,7 +427,9 @@ fn testSetIdx(val: usize) void {
}
fn testFn() void {
std.debug.assert(remainingStackSize() > 2048);
// Check if the stack was already overflowed
const current = xframe();
std.debug.assert(current.stack.remaining().len > 2048);
testSetIdx(2);
xsuspend();
testSetIdx(4);

34
bazel/zig_srcs.bzl Normal file
View File

@ -0,0 +1,34 @@
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@rules_zig//zig:defs.bzl", "zig_binary", "BINARY_KIND")
def zig_srcs(name, zig_bin="", zig_lib=""):
"""For a given zig_library, recursively extract all zig sources into a tarball.
This also includes the files translated from C headers.
It's also possible to pass zig_lib instead of zig_bin in which case,
The rule takes care of creating an intermediary binary from the lib.
"""
if zig_bin == "":
zig_bin = "{}_bin".format(name)
zig_binary(
name = zig_bin,
kind = BINARY_KIND.bc,
tags = ["manual", "@rules_zig//zig/lib:libc"],
deps = [zig_lib],
)
native.filegroup(
name = "{}_files".format(name),
srcs = [zig_bin],
output_group = "srcs",
)
mtree_spec(
name = "{}_mtree".format(name),
srcs = [":{}_files".format(name)],
)
tar(
name = name,
srcs = ["{}_files".format(name)],
args = [],
mtree = "{}_mtree".format(name),
)

View File

@ -1,5 +1,7 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig_srcs.bzl", "zig_srcs")
load("//bazel:zig.bzl", "zig_cc_test")
cc_library(
@ -30,3 +32,13 @@ zig_cc_test(
name = "test",
deps = [":mlir"],
)
cc_static_library(
name="mlir_static",
deps = ["c"]
)
zig_srcs(
name = "sources",
zig_bin = ":test_test_lib",
)

View File

@ -1,5 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
load("//bazel:zig_srcs.bzl", "zig_srcs")
zig_library(
name = "dialects",
@ -24,6 +25,19 @@ zig_cc_test(
deps = [":dialects"],
)
zig_srcs(
name = "sources",
zig_bin = ":test_test_lib",
)
cc_static_library(
name="mlir_static",
deps = [
"//mlir:c",
"@stablehlo//:stablehlo_dialect_capi",
]
)
zig_library(
name = "stablehlo",
import_name = "mlir/dialects/stablehlo",

2
mlir/mlir.zig Normal file → Executable file
View File

@ -8,6 +8,8 @@ const log = std.log.scoped(.mlir);
test {
std.testing.refAllDecls(@This());
_ = try Context.init();
}
const Error = error{

View File

@ -1,15 +1,7 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
cc_library(
name = "dlfcn",
hdrs = ["dlfcn.h"],
target_compatible_with = [
"@platforms//os:linux",
],
)
load("@zml//bazel:zig_srcs.bzl", "zig_srcs")
load("@zml//bazel:zig_proto_library.bzl", "zig_proto_library")
zig_library(
name = "pjrt",
@ -32,10 +24,12 @@ zig_library(
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
] + select({
"@platforms//os:linux": [":dlfcn"],
"//conditions:default": [],
}),
],
)
zig_srcs(
name = "sources",
zig_lib = ":pjrt",
)
zig_proto_library(

View File

@ -1 +0,0 @@
#include <dlfcn.h>

View File

@ -74,20 +74,19 @@ pub const Api = struct {
inner: c.PJRT_Api,
pub fn loadFrom(library: []const u8) !*const Api {
pub fn loadFrom(library: [:0]const u8) !*const Api {
var lib: std.DynLib = switch (builtin.os.tag) {
.linux => blk: {
const library_c = try std.posix.toPosixPath(library);
break :blk .{
.inner = .{
.handle = c.dlopen(&library_c, c.RTLD_LAZY | c.RTLD_LOCAL | c.RTLD_NODELETE) orelse {
const handle = std.c.dlopen(library, .{ .LAZY = true, .GLOBAL = false, .NODELETE = true }) orelse {
log.err("Unable to dlopen plugin: {s}", .{library});
return error.FileNotFound;
},
},
};
break :blk .{ .inner = .{ .handle = handle } };
},
else => std.DynLib.open(library) catch |err| {
log.err("Unable to dlopen plugin: {s}", .{library});
return err;
},
else => try std.DynLib.open(library),
};
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse {
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});

View File

@ -36,7 +36,8 @@ fn setupXlaGpuCudaDirFlag() !void {
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find CUDA directory", .{});
log.warn("Unable to find CUDA directory. Using system defaults.", .{});
return;
};
const source_repo = bazel_builtin.current_repository;

View File

@ -1,4 +1,5 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test")
load("@zml//bazel:zig_srcs.bzl", "zig_srcs")
zig_library(
name = "stdx",
@ -17,3 +18,14 @@ zig_library(
main = "stdx.zig",
visibility = ["//visibility:public"],
)
zig_test(
name = "test",
deps = [":stdx"],
testonly = False,
)
zig_srcs(
name = "sources",
zig_bin = ":test",
)

View File

@ -91,7 +91,7 @@ test SPSC {
try testing.expect(q.empty());
// Elems
var elems: [10]Elem = .{.{}} ** 10;
var elems: [10]Elem = @splat(.{});
// One
try testing.expect(q.pop() == null);
@ -207,7 +207,7 @@ test MPSC {
q.init();
// Elems
var elems: [10]Elem = .{.{}} ** 10;
var elems: [10]Elem = @splat(.{});
// One
try testing.expect(q.pop() == null);

View File

@ -8,6 +8,11 @@ pub const meta = @import("meta.zig");
pub const queue = @import("queue.zig");
pub const time = @import("time.zig");
test {
const std = @import("std");
std.testing.refAllDecls(@This());
}
pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T {
debug.assert(len <= max_len, "stackSlice can only create a slice of up to {} elements, got: {}", .{ max_len, len });
var storage: [max_len]T = undefined;

View File

@ -0,0 +1,68 @@
module(
name = "rules_zig",
version = "20250613.0-567662a",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)

View File

@ -0,0 +1,5 @@
{
"strip_prefix": "rules_zig-567662a3ce5e87894950d56c69d8de5ad6e0b5f0",
"url": "https://github.com/zml/rules_zig/archive/567662a3ce5e87894950d56c69d8de5ad6e0b5f0.tar.gz",
"integrity": "sha256-PjFpDJXO0BGx2CAZ54Ppv6d4TwTDOWS+mVangKQFvZc="
}

View File

@ -16,7 +16,8 @@
"20240912.0-41bfe84",
"20240913.0-1957d05",
"20250314.0-b9739c6",
"20250519.0-233b207"
"20250519.0-233b207",
"20250613.0-567662a"
],
"yanked_versions": {}
}

View File

@ -1,8 +1,8 @@
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
load("//bazel:zig_srcs.bzl", "zig_srcs")
cc_library(
name = "posix",
@ -25,7 +25,6 @@ zig_library(
visibility = ["//visibility:public"],
deps = [
":posix",
":sentencepiece_model_proto",
":xla_proto",
"//async",
"//mlir",
@ -35,9 +34,7 @@ zig_library(
"//stdx",
"//zml/tokenizer",
"//zml/tools",
"@rules_zig//zig/lib:libc",
"@rules_zig//zig/runfiles",
"@zig-yaml//:zig-yaml",
],
)
@ -47,11 +44,6 @@ zig_proto_library(
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
)
zig_proto_library(
name = "sentencepiece_model_proto",
import_name = "//sentencepiece:model_proto",
deps = ["@sentencepiece//:sentencepiece_model_proto"],
)
# All ZML Tests
@ -71,21 +63,7 @@ filegroup(
visibility = ["//visibility:public"],
)
filegroup(
name = "srcs",
srcs = [":test_test_lib"],
output_group = "srcs",
)
mtree_spec(
name = "mtree",
srcs = [":srcs"],
)
tar(
zig_srcs(
name = "sources",
srcs = [":srcs"],
args = [
],
mtree = ":mtree",
zig_bin = ":test_test_lib",
)

View File

@ -5,11 +5,11 @@ const c = @import("c");
const stdx = @import("stdx");
pub const gguf = @import("aio/gguf.zig");
pub const nemo = @import("aio/nemo.zig");
// pub const nemo = @import("aio/nemo.zig");
pub const safetensors = @import("aio/safetensors.zig");
pub const tinyllama = @import("aio/tinyllama.zig");
pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig");
// pub const yaml = @import("aio/yaml.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const posix = @import("posix.zig");
const zml = @import("zml.zig");
@ -18,10 +18,10 @@ pub const log = std.log.scoped(.@"zml/aio");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(gguf);
std.testing.refAllDecls(nemo);
// std.testing.refAllDecls(nemo);
std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(torch);
std.testing.refAllDecls(yaml);
// std.testing.refAllDecls(yaml);
}
// TODO error set for weight loading

View File

@ -28,3 +28,11 @@ zig_library(
"//ffi:zig",
],
)
cc_static_library(
name="hftokenizer_static",
deps = [
":hftokenizers_rs",
"//ffi:cc",
]
)

View File

@ -1,5 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:swig.bzl", "swig_cc_library")
load("//bazel:zig_srcs.bzl", "zig_srcs")
swig_cc_library(
name = "sentencepiece_swig",
@ -21,3 +22,8 @@ zig_library(
"//ffi:zig",
],
)
zig_srcs(
name = "sources",
zig_lib = ":sentencepiece",
)

View File

@ -1,10 +1,18 @@
const std = @import("std");
const asynk = @import("async");
const hftokenizers = @import("hftokenizers");
const sentencepiece = @import("sentencepiece");
const asynk = @import("async");
const homemade = @import("homemade.zig");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(hftokenizers);
std.testing.refAllDecls(sentencepiece);
std.testing.refAllDecls(homemade);
}
const Tokenizers = enum {
hftokenizers,
sentencepiece,

View File

@ -22,3 +22,8 @@ zig_library(
"//conditions:default": [],
}),
)
cc_static_library(
name = "macos_static_tools",
deps = ["macos_c"]
)