diff --git a/third_party/mnist/repo.bzl b/third_party/mnist/repo.bzl index 025845e..fdabe71 100644 --- a/third_party/mnist/repo.bzl +++ b/third_party/mnist/repo.bzl @@ -3,7 +3,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def repo(): http_archive( name = "mnist", - sha256 = "075905e433ea0cce13c3fc08832448ab86225d089b5d412be67f59c29388fb19", - url = "https://mirror.zml.ai/data/mnist.tar.zst", + sha256 = "9c7d9a5ef9c245084996f6d2ec66ef176e51186e6a5b22efdcc3828d644941ca", + url = "https://mirror.zml.ai/data/mnist_safetensors.tar.zst", build_file_content = """exports_files(glob(["**"]), visibility = ["//visibility:public"])""", ) diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index 7901965..e0722f7 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -19,15 +19,35 @@ cc_library( zig_library( name = "zml", - srcs = glob([ - "*.zig", - "aio/**/*.zig", - "nn/**/*.zig", - # TODO: test_runner.zig should not be here. - # It's here for now because it seems that test_runner property in zig_test is misbehaving. - # See https://github.com/zml/rules_zig/issues/2 + srcs = [ + "aio.zig", + "aio/json.zig", + "aio/safetensors.zig", + "aio/tinyllama.zig", + "buffer.zig", + "context.zig", + "dtype.zig", + "exe.zig", + "floats.zig", + "helpers.zig", + "hostbuffer.zig", + "meta.zig", + "mlirx.zig", + "module.zig", + "nn.zig", + "nn/cuda.zig", + "ops.zig", + "pjrtx.zig", + "platform.zig", + "posix.zig", + "quantization.zig", + "shape.zig", + "tensor.zig", "test_runner.zig", - ]), + "testing.zig", + "torch.zig", + "zml.zig", + ], copts = ["-lc"], main = "zml.zig", visibility = ["//visibility:public"], @@ -52,10 +72,6 @@ zig_library( zig_test( name = "test", - data = [ - "aio/torch/simple.pt", - "aio/torch/simple_test_4.pickle", - ], test_runner = ":test_runner", deps = [":zml"], ) diff --git a/zml/aio.zig b/zml/aio.zig index 467def7..08f352e 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -4,24 +4,17 @@ const asynk = @import("async"); const c = @import("c"); const stdx = @import("stdx"); -pub const gguf = @import("aio/gguf.zig"); -// pub const nemo = @import("aio/nemo.zig"); pub const safetensors = @import("aio/safetensors.zig"); pub const tinyllama = @import("aio/tinyllama.zig"); -pub const torch = @import("aio/torch.zig"); -// pub const yaml = @import("aio/yaml.zig"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; const posix = @import("posix.zig"); const zml = @import("zml.zig"); pub const log = std.log.scoped(.@"zml/aio"); + test { std.testing.refAllDecls(@This()); - std.testing.refAllDecls(gguf); - // std.testing.refAllDecls(nemo); std.testing.refAllDecls(safetensors); - std.testing.refAllDecls(torch); - // std.testing.refAllDecls(yaml); } // TODO error set for weight loading @@ -32,12 +25,6 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) try safetensors.open(allocator, model_path) else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json")) try safetensors.open(allocator, model_path) - else if (std.mem.endsWith(u8, model_path, ".gguf")) - try gguf.open(allocator, model_path) - else if (std.mem.endsWith(u8, model_path, ".pt")) - try torch.open(allocator, model_path) - else if (std.mem.endsWith(u8, model_path, ".tinyllama")) - try tinyllama.open(allocator, model_path) else { std.debug.panic("File extension not recognized: {s}", .{model_path}); };