2023-01-02 14:28:25 +00:00
|
|
|
const std = @import("std");
|
2023-04-20 15:43:18 +00:00
|
|
|
const log = std.log.scoped(.zml_aio);
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
const asynk = @import("async");
|
|
|
|
|
const yaml = @import("zig-yaml");
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
const eval = @import("torch/eval.zig");
|
|
|
|
|
const zml = @import("../zml.zig");
|
|
|
|
|
const File = @import("torch/file.zig").File;
|
2023-01-02 14:28:25 +00:00
|
|
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
|
|
|
|
|
|
|
|
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
|
|
|
|
var res: zml.aio.BufferStore = .{
|
|
|
|
|
.arena = std.heap.ArenaAllocator.init(allocator),
|
|
|
|
|
};
|
|
|
|
|
errdefer res.arena.deinit();
|
|
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
// TODO(cryptodeal): this is incorrect, you should use a temporary arena for all intermediary allocations.
|
2023-01-02 14:28:25 +00:00
|
|
|
const arena = res.arena.allocator();
|
|
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
// TODO(cryptodeal): mapped_file will never be close in case of success.
|
|
|
|
|
// You need to store it inside the result.
|
2023-01-02 14:28:25 +00:00
|
|
|
var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{}));
|
|
|
|
|
errdefer mapped_file.deinit();
|
|
|
|
|
|
|
|
|
|
var file_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
|
|
|
|
|
var link_name_buffer: [std.fs.max_path_bytes]u8 = undefined;
|
|
|
|
|
var tar_iter = std.tar.iterator(
|
|
|
|
|
mapped_file.file.reader(),
|
|
|
|
|
.{
|
|
|
|
|
.file_name_buffer = &file_name_buffer,
|
|
|
|
|
.link_name_buffer = &link_name_buffer,
|
|
|
|
|
},
|
|
|
|
|
);
|
|
|
|
|
while (try tar_iter.next()) |file| {
|
|
|
|
|
if (std.mem.endsWith(u8, file.name, ".yaml")) {
|
|
|
|
|
const yaml_data = try file.reader().readAllAlloc(arena, file.size);
|
|
|
|
|
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
|
|
|
|
|
|
|
|
|
var prefix_buf: [1024]u8 = undefined;
|
2023-01-31 11:58:58 +00:00
|
|
|
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
|
2023-01-02 14:28:25 +00:00
|
|
|
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
|
|
|
|
|
const start = try mapped_file.file.getPos();
|
2023-04-20 15:43:18 +00:00
|
|
|
var torch_file = try File.fromTarFile(arena, mapped_file, file);
|
|
|
|
|
const ops = try torch_file.parsePickle(arena);
|
|
|
|
|
const values = try eval.evaluate(arena, ops, true);
|
2023-04-07 16:45:58 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
try torch_file.parseModel(values, &res);
|
2023-01-02 14:28:25 +00:00
|
|
|
// Since we directly manipulate the file handle pointer,
|
|
|
|
|
// reset to the end of file so iterator does not error
|
|
|
|
|
// and avoid `skipBytes`.
|
|
|
|
|
try mapped_file.file.seekTo(start + file.size);
|
|
|
|
|
file.unread_bytes.* = 0;
|
|
|
|
|
} else if (std.mem.eql(u8, file.name, "./model_weights/")) @panic(".NeMo sharded weights are not yet supported") else continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
}
|