Remove gguf and torch loader implementations and related BUILD and test assets.
This commit is contained in:
parent
01db09c24b
commit
e3b7705e3d
4
third_party/mnist/repo.bzl
vendored
4
third_party/mnist/repo.bzl
vendored
@ -3,7 +3,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
|||||||
def repo():
|
def repo():
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "mnist",
|
name = "mnist",
|
||||||
sha256 = "075905e433ea0cce13c3fc08832448ab86225d089b5d412be67f59c29388fb19",
|
sha256 = "9c7d9a5ef9c245084996f6d2ec66ef176e51186e6a5b22efdcc3828d644941ca",
|
||||||
url = "https://mirror.zml.ai/data/mnist.tar.zst",
|
url = "https://mirror.zml.ai/data/mnist_safetensors.tar.zst",
|
||||||
build_file_content = """exports_files(glob(["**"]), visibility = ["//visibility:public"])""",
|
build_file_content = """exports_files(glob(["**"]), visibility = ["//visibility:public"])""",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,15 +19,35 @@ cc_library(
|
|||||||
|
|
||||||
zig_library(
|
zig_library(
|
||||||
name = "zml",
|
name = "zml",
|
||||||
srcs = glob([
|
srcs = [
|
||||||
"*.zig",
|
"aio.zig",
|
||||||
"aio/**/*.zig",
|
"aio/json.zig",
|
||||||
"nn/**/*.zig",
|
"aio/safetensors.zig",
|
||||||
# TODO: test_runner.zig should not be here.
|
"aio/tinyllama.zig",
|
||||||
# It's here for now because it seems that test_runner property in zig_test is misbehaving.
|
"buffer.zig",
|
||||||
# See https://github.com/zml/rules_zig/issues/2
|
"context.zig",
|
||||||
|
"dtype.zig",
|
||||||
|
"exe.zig",
|
||||||
|
"floats.zig",
|
||||||
|
"helpers.zig",
|
||||||
|
"hostbuffer.zig",
|
||||||
|
"meta.zig",
|
||||||
|
"mlirx.zig",
|
||||||
|
"module.zig",
|
||||||
|
"nn.zig",
|
||||||
|
"nn/cuda.zig",
|
||||||
|
"ops.zig",
|
||||||
|
"pjrtx.zig",
|
||||||
|
"platform.zig",
|
||||||
|
"posix.zig",
|
||||||
|
"quantization.zig",
|
||||||
|
"shape.zig",
|
||||||
|
"tensor.zig",
|
||||||
"test_runner.zig",
|
"test_runner.zig",
|
||||||
]),
|
"testing.zig",
|
||||||
|
"torch.zig",
|
||||||
|
"zml.zig",
|
||||||
|
],
|
||||||
copts = ["-lc"],
|
copts = ["-lc"],
|
||||||
main = "zml.zig",
|
main = "zml.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -52,10 +72,6 @@ zig_library(
|
|||||||
|
|
||||||
zig_test(
|
zig_test(
|
||||||
name = "test",
|
name = "test",
|
||||||
data = [
|
|
||||||
"aio/torch/simple.pt",
|
|
||||||
"aio/torch/simple_test_4.pickle",
|
|
||||||
],
|
|
||||||
test_runner = ":test_runner",
|
test_runner = ":test_runner",
|
||||||
deps = [":zml"],
|
deps = [":zml"],
|
||||||
)
|
)
|
||||||
|
|||||||
15
zml/aio.zig
15
zml/aio.zig
@ -4,24 +4,17 @@ const asynk = @import("async");
|
|||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub const gguf = @import("aio/gguf.zig");
|
|
||||||
// pub const nemo = @import("aio/nemo.zig");
|
|
||||||
pub const safetensors = @import("aio/safetensors.zig");
|
pub const safetensors = @import("aio/safetensors.zig");
|
||||||
pub const tinyllama = @import("aio/tinyllama.zig");
|
pub const tinyllama = @import("aio/tinyllama.zig");
|
||||||
pub const torch = @import("aio/torch.zig");
|
|
||||||
// pub const yaml = @import("aio/yaml.zig");
|
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
const posix = @import("posix.zig");
|
const posix = @import("posix.zig");
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
pub const log = std.log.scoped(.@"zml/aio");
|
pub const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(gguf);
|
|
||||||
// std.testing.refAllDecls(nemo);
|
|
||||||
std.testing.refAllDecls(safetensors);
|
std.testing.refAllDecls(safetensors);
|
||||||
std.testing.refAllDecls(torch);
|
|
||||||
// std.testing.refAllDecls(yaml);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO error set for weight loading
|
// TODO error set for weight loading
|
||||||
@ -32,12 +25,6 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
|
|||||||
try safetensors.open(allocator, model_path)
|
try safetensors.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
|
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
|
||||||
try safetensors.open(allocator, model_path)
|
try safetensors.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".gguf"))
|
|
||||||
try gguf.open(allocator, model_path)
|
|
||||||
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
|
||||||
try torch.open(allocator, model_path)
|
|
||||||
else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
|
|
||||||
try tinyllama.open(allocator, model_path)
|
|
||||||
else {
|
else {
|
||||||
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user