Radix/zml/aio/nemo.zig

59 lines
2.5 KiB
Zig

const std = @import("std");
const log = std.log.scoped(.@"zml/aio");
const asynk = @import("async");
const yaml = @import("zig-yaml");
const eval = @import("torch/eval.zig");
const zml = @import("../zml.zig");
const File = @import("torch/file.zig").File;
const StringBuilder = std.ArrayListUnmanaged(u8);
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit();
// TODO(cryptodeal): this is incorrect, you should use a temporary arena for all intermediary allocations.
const arena = res.arena.allocator();
// TODO(cryptodeal): mapped_file will never be close in case of success.
// You need to store it inside the result.
var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{}));
errdefer mapped_file.deinit();
var file_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
var link_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
var tar_iter = std.tar.iterator(
mapped_file.file.reader(),
.{
.file_name_buffer = &file_name_buffer,
.link_name_buffer = &link_name_buffer,
},
);
while (try tar_iter.next()) |file| {
if (std.mem.endsWith(u8, file.name, ".yaml")) {
const yaml_data = try file.reader().readAllAlloc(arena, file.size);
const parsed = try yaml.Yaml.load(arena, yaml_data);
var prefix_buf: [1024]u8 = undefined;
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
const start = try mapped_file.file.getPos();
var torch_file = try File.fromTarFile(arena, mapped_file, file);
const ops = try torch_file.parsePickle(arena);
const values = try eval.evaluate(arena, ops, true);
try torch_file.parseModel(values, &res);
// Since we directly manipulate the file handle pointer,
// reset to the end of file so iterator does not error
// and avoid `skipBytes`.
try mapped_file.file.seekTo(start + file.size);
file.unread_bytes.* = 0;
} else if (std.mem.eql(u8, file.name, "./model_weights/")) @panic(".NeMo sharded weights are not yet supported") else continue;
}
return res;
}