zml: remove nemo, sentencepiece loaders, remove zig-yaml
Progress towards
This commit is contained in:
parent
fe56f03f5d
commit
78679817df
@ -22,7 +22,6 @@ bazel_dep(name = "rules_uv", version = "0.87.0")
|
||||
bazel_dep(name = "rules_zig", version = "20250714.0-b14a4f1")
|
||||
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
|
||||
bazel_dep(name = "with_cfg.bzl", version = "0.11.0")
|
||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||
|
||||
bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = True)
|
||||
|
||||
@ -155,12 +154,14 @@ apt.install(
|
||||
manifest = "//runtimes/cuda:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_cuda")
|
||||
|
||||
apt.install(
|
||||
name = "apt_rocm",
|
||||
lock = "//runtimes/rocm:packages.lock.json",
|
||||
manifest = "//runtimes/rocm:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_rocm")
|
||||
|
||||
apt.install(
|
||||
name = "apt_neuron",
|
||||
lock = "//runtimes/neuron:packages.lock.json",
|
||||
|
||||
107
zml/aio/json.zig
107
zml/aio/json.zig
@ -1,107 +0,0 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
const file = try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
var res: zml.aio.BufferStore = .{
|
||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||
};
|
||||
errdefer res.arena.deinit();
|
||||
const arena = res.arena.allocator();
|
||||
|
||||
const json_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
|
||||
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
|
||||
|
||||
var it = metadata.object.iterator();
|
||||
while (it.next()) |entry| {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try parseMetadata(allocator, &res, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
|
||||
const metadata = &store._metadata;
|
||||
const key = prefix.items;
|
||||
return switch (val) {
|
||||
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .null),
|
||||
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }),
|
||||
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }),
|
||||
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }),
|
||||
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
|
||||
.array => |v| {
|
||||
if (v.items.len == 0) return;
|
||||
return if (validSlice(v)) |item_type| {
|
||||
const data: zml.aio.Metadata = switch (item_type) {
|
||||
.bool => blk: {
|
||||
const values = try allocator.alloc(bool, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.bool;
|
||||
break :blk .{ .array_bool = values };
|
||||
},
|
||||
.integer => blk: {
|
||||
const values = try allocator.alloc(i64, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.integer;
|
||||
break :blk .{ .array_int = values };
|
||||
},
|
||||
.float => blk: {
|
||||
const values = try allocator.alloc(f64, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.float;
|
||||
break :blk .{ .array_float = values };
|
||||
},
|
||||
inline .string, .number_string => |tag| blk: {
|
||||
const values = try allocator.alloc([]const u8, v.items.len);
|
||||
for (v.items, 0..) |item, i| {
|
||||
values[i] = @field(item, @tagName(tag));
|
||||
}
|
||||
break :blk .{ .array_string = values };
|
||||
},
|
||||
.null, .array, .object => unreachable,
|
||||
};
|
||||
try metadata.put(allocator, try allocator.dupe(u8, key), data);
|
||||
} else {
|
||||
for (v.items, 0..) |item, i| {
|
||||
var new_prefix = prefix;
|
||||
if (prefix.items.len > 0)
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||
try parseMetadata(allocator, store, new_prefix, item);
|
||||
}
|
||||
};
|
||||
},
|
||||
.object => |v| {
|
||||
var obj_iter = v.iterator();
|
||||
while (obj_iter.next()) |entry| {
|
||||
var new_prefix = prefix;
|
||||
if (prefix.items.len > 0)
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*);
|
||||
try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*);
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/// We can only create a Zig slice out of json array, if all values
|
||||
/// in the array have the same type.
|
||||
fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) {
|
||||
if (v.items.len == 0) return null;
|
||||
|
||||
const item_type: std.meta.Tag(std.json.Value) = v.items[0];
|
||||
switch (item_type) {
|
||||
.null, .array, .object => return null,
|
||||
else => {},
|
||||
}
|
||||
|
||||
for (v.items[1..]) |item| {
|
||||
if (item != item_type)
|
||||
return null;
|
||||
}
|
||||
|
||||
return item_type;
|
||||
}
|
||||
@ -1,58 +0,0 @@
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.@"zml/aio");
|
||||
|
||||
const asynk = @import("async");
|
||||
const yaml = @import("zig-yaml");
|
||||
|
||||
const eval = @import("torch/eval.zig");
|
||||
const zml = @import("../zml.zig");
|
||||
const File = @import("torch/file.zig").File;
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
|
||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
var res: zml.aio.BufferStore = .{
|
||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||
};
|
||||
errdefer res.arena.deinit();
|
||||
|
||||
// TODO(cryptodeal): this is incorrect, you should use a temporary arena for all intermediary allocations.
|
||||
const arena = res.arena.allocator();
|
||||
|
||||
// TODO(cryptodeal): mapped_file will never be close in case of success.
|
||||
// You need to store it inside the result.
|
||||
var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{}));
|
||||
errdefer mapped_file.deinit();
|
||||
|
||||
var file_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
|
||||
var link_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
|
||||
var tar_iter = std.tar.iterator(
|
||||
mapped_file.file.reader(),
|
||||
.{
|
||||
.file_name_buffer = &file_name_buffer,
|
||||
.link_name_buffer = &link_name_buffer,
|
||||
},
|
||||
);
|
||||
while (try tar_iter.next()) |file| {
|
||||
if (std.mem.endsWith(u8, file.name, ".yaml")) {
|
||||
const yaml_data = try file.reader().readAllAlloc(arena, file.size);
|
||||
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
||||
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
|
||||
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
|
||||
const start = try mapped_file.file.getPos();
|
||||
var torch_file = try File.fromTarFile(arena, mapped_file, file);
|
||||
const ops = try torch_file.parsePickle(arena);
|
||||
const values = try eval.evaluate(arena, ops, true);
|
||||
|
||||
try torch_file.parseModel(values, &res);
|
||||
// Since we directly manipulate the file handle pointer,
|
||||
// reset to the end of file so iterator does not error
|
||||
// and avoid `skipBytes`.
|
||||
try mapped_file.file.seekTo(start + file.size);
|
||||
file.unread_bytes.* = 0;
|
||||
} else if (std.mem.eql(u8, file.name, "./model_weights/")) @panic(".NeMo sharded weights are not yet supported") else continue;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
@ -1,12 +1,13 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const zml = @import("../zml.zig");
|
||||
const json = @import("json.zig");
|
||||
const HostBuffer = zml.HostBuffer;
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
const asynk = @import("async");
|
||||
|
||||
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
||||
const zml = @import("../zml.zig");
|
||||
const HostBuffer = zml.HostBuffer;
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
const Allocator = std.mem.Allocator;
|
||||
const log = std.log.scoped(.@"zml/io");
|
||||
|
||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
@ -55,11 +56,6 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
|
||||
const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename });
|
||||
try loadFile(allocator, store, files, full_filename);
|
||||
}
|
||||
|
||||
if (index.object.get("__metadata__")) |metadata| {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata);
|
||||
}
|
||||
}
|
||||
|
||||
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
||||
@ -89,11 +85,6 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
|
||||
var it = metadata.object.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const key = entry.key_ptr.*;
|
||||
if (std.mem.eql(u8, key, "__metadata__")) {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*);
|
||||
continue;
|
||||
}
|
||||
const val = entry.value_ptr.*;
|
||||
const shape_field = val.object.get("shape").?.array;
|
||||
if (shape_field.items.len > zml.Shape.MAX_RANK) {
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
const std = @import("std");
|
||||
const asynk = @import("async");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
const sentencepiece_proto = @import("//sentencepiece:model_proto");
|
||||
const Normalizer = zml.tokenizer.Normalizer;
|
||||
|
||||
fn parseTokenId(id: ?i32) u32 {
|
||||
if (id) |idx| {
|
||||
if (idx > 0) return @intCast(idx);
|
||||
}
|
||||
|
||||
return std.math.maxInt(u32);
|
||||
}
|
||||
|
||||
pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
|
||||
std.log.info("NormalizerSpec: {}", .{spec});
|
||||
if (spec.normalization_rule_tsv) |rule_tsv| {
|
||||
if (!rule_tsv.isEmpty()) {
|
||||
std.debug.panic("SentencePiece model with normalization rules not supported: model.normalizer_spec.normalization_rule_tsv: {s}", .{spec.normalization_rule_tsv.?.getSlice()});
|
||||
}
|
||||
}
|
||||
if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()});
|
||||
if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{});
|
||||
if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{});
|
||||
|
||||
return Normalizer.init(
|
||||
.{
|
||||
.remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false,
|
||||
.add_dummy_prefix = spec.add_dummy_prefix orelse false,
|
||||
.add_dummy_suffix = false,
|
||||
.lower_case_ascii = false,
|
||||
.split_on_punct_ascii = false,
|
||||
.use_nfc = false,
|
||||
},
|
||||
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
|
||||
);
|
||||
}
|
||||
@ -1,93 +0,0 @@
|
||||
const std = @import("std");
|
||||
const yaml = @import("zig-yaml");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
|
||||
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
const file = try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
var res: zml.aio.BufferStore = .{
|
||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||
};
|
||||
errdefer res.arena.deinit();
|
||||
const arena = res.arena.allocator();
|
||||
|
||||
const yaml_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
|
||||
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
||||
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
|
||||
switch (val) {
|
||||
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int = v }),
|
||||
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float = v }),
|
||||
.string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
|
||||
.list => |v| switch (validSlice(v)) {
|
||||
true => {
|
||||
if (v.len == 0) return;
|
||||
switch (v[0]) {
|
||||
.int => {
|
||||
const values = try allocator.alloc(i64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.int;
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_int = values });
|
||||
},
|
||||
.float => {
|
||||
const values = try allocator.alloc(f64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.float;
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_float = values });
|
||||
},
|
||||
.string => {
|
||||
const values = try allocator.alloc([]const u8, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| {
|
||||
values[i] = try allocator.dupe(u8, item.string);
|
||||
}
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_string = values });
|
||||
},
|
||||
.list => unreachable,
|
||||
else => {},
|
||||
}
|
||||
},
|
||||
false => for (v, 0..) |item, i| {
|
||||
var new_key = key;
|
||||
if (key.items.len > 0)
|
||||
new_key.appendAssumeCapacity('.');
|
||||
new_key.items.len += std.fmt.formatIntBuf(new_key.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||
try parseMetadata(allocator, store, new_key, item);
|
||||
},
|
||||
},
|
||||
.map => {
|
||||
var map_iter = val.map.iterator();
|
||||
while (map_iter.next()) |entry| {
|
||||
var new_prefix = key;
|
||||
if (key.items.len > 0)
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*);
|
||||
try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*);
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
|
||||
fn validSlice(v: []yaml.Value) bool {
|
||||
if (v.len == 0) return false;
|
||||
const item_type = std.meta.activeTag(v[0]);
|
||||
switch (item_type) {
|
||||
.empty, .list, .map => return false,
|
||||
else => {},
|
||||
}
|
||||
|
||||
for (v[1..]) |item|
|
||||
if (item_type != std.meta.activeTag(item)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
56
zml/zml.zig
56
zml/zml.zig
@ -3,37 +3,20 @@
|
||||
//! compiling it for various accelerators and targets, and executing it.
|
||||
//!
|
||||
|
||||
// Namespaces
|
||||
const std = @import("std");
|
||||
|
||||
pub const tokenizer = @import("zml/tokenizer");
|
||||
|
||||
pub const aio = @import("aio.zig");
|
||||
pub const Buffer = @import("buffer.zig").Buffer;
|
||||
pub const Bufferized = @import("tensor.zig").Bufferized;
|
||||
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
|
||||
pub const context = @import("context.zig");
|
||||
pub const Context = @import("context.zig").Context;
|
||||
pub const Data = @import("dtype.zig").Data;
|
||||
pub const DataType = @import("dtype.zig").DataType;
|
||||
pub const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
pub const Platform = @import("platform.zig").Platform;
|
||||
pub const Shape = @import("shape.zig").Shape;
|
||||
pub const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
pub const Target = @import("platform.zig").Target;
|
||||
pub const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
// Namespaces
|
||||
pub const context = @import("context.zig");
|
||||
pub const exe = @import("exe.zig");
|
||||
pub const floats = @import("floats.zig");
|
||||
pub const helpers = @import("helpers.zig");
|
||||
pub const nn = @import("nn.zig");
|
||||
pub const module = @import("module.zig");
|
||||
pub const meta = @import("meta.zig");
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const mlir = @import("mlirx.zig");
|
||||
pub const pjrt = @import("pjrtx.zig");
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
// pub const tokenizer = @import("tokenizer.zig");
|
||||
pub const tokenizer = @import("zml/tokenizer");
|
||||
|
||||
pub const call = ops.call;
|
||||
pub const compile = exe.compile;
|
||||
pub const compileWithPrefix = exe.compileWithPrefix;
|
||||
pub const compileFn = exe.compileFn;
|
||||
@ -41,19 +24,32 @@ pub const compileModel = exe.compileModel;
|
||||
pub const FnExe = exe.FnExe;
|
||||
pub const ModuleExe = exe.ModuleExe;
|
||||
pub const ModuleSignature = exe.ModuleSignature;
|
||||
|
||||
pub const floats = @import("floats.zig");
|
||||
pub const helpers = @import("helpers.zig");
|
||||
pub const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
pub const meta = @import("meta.zig");
|
||||
pub const mlir = @import("mlirx.zig");
|
||||
pub const module = @import("module.zig");
|
||||
pub const nn = @import("nn.zig");
|
||||
pub const ops = @import("ops.zig");
|
||||
pub const call = ops.call;
|
||||
pub const pjrt = @import("pjrtx.zig");
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const Platform = @import("platform.zig").Platform;
|
||||
pub const Shape = @import("shape.zig").Shape;
|
||||
pub const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
pub const Target = @import("platform.zig").Target;
|
||||
pub const Tensor = @import("tensor.zig").Tensor;
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
// pub const tokenizer = @import("tokenizer.zig");
|
||||
pub const tools = struct {
|
||||
pub const Tracer = @import("tools/tracer.zig").Tracer;
|
||||
};
|
||||
|
||||
pub const aio = @import("aio.zig");
|
||||
pub const sentencepiece = @import("aio/sentencepiece.zig");
|
||||
|
||||
pub const log = std.log.scoped(.zml);
|
||||
|
||||
const std = @import("std");
|
||||
|
||||
test {
|
||||
// NOTE : testing entrypoint.
|
||||
// Don't forget to import your module if you want to declare tests declarations that will be run by //zml:test
|
||||
|
||||
Loading…
Reference in New Issue
Block a user