2023-01-02 14:28:25 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
|
2025-08-29 11:03:59 +00:00
|
|
|
const async = @import("async");
|
2025-08-07 15:09:27 +00:00
|
|
|
|
|
|
|
|
const zml = @import("../zml.zig");
|
2023-01-02 14:28:25 +00:00
|
|
|
const eval = @import("torch/eval.zig");
|
2023-04-20 15:43:18 +00:00
|
|
|
const File = @import("torch/file.zig").File;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
2023-06-21 14:45:14 +00:00
|
|
|
const log = std.log.scoped(.@"zml/aio");
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-04-04 17:20:53 +00:00
|
|
|
test {
|
2023-04-07 16:45:58 +00:00
|
|
|
std.testing.refAllDecls(@This());
|
2023-04-04 17:20:53 +00:00
|
|
|
std.testing.refAllDecls(eval);
|
2025-08-07 15:09:27 +00:00
|
|
|
std.testing.refAllDecls(@import("torch/py.zig"));
|
2023-04-20 15:43:18 +00:00
|
|
|
std.testing.refAllDecls(File);
|
2023-04-04 17:20:53 +00:00
|
|
|
}
|
|
|
|
|
|
2023-01-02 14:28:25 +00:00
|
|
|
/// Opens and loads a BufferStore from the torch file at the given path.
|
2023-04-04 17:20:53 +00:00
|
|
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
2025-08-29 11:03:59 +00:00
|
|
|
const file = async.File.open(path, .{}) catch |err| {
|
2023-01-02 14:28:25 +00:00
|
|
|
log.err("Failed to open {s}: {}", .{ path, err });
|
|
|
|
|
return err;
|
|
|
|
|
};
|
|
|
|
|
errdefer file.close() catch unreachable;
|
|
|
|
|
|
2023-04-04 17:20:53 +00:00
|
|
|
// Temporary memory needed to parse the pytorch file.
|
|
|
|
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
|
|
|
|
defer arena.deinit();
|
|
|
|
|
const tmp_alloc = arena.allocator();
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
|
2025-08-29 11:03:59 +00:00
|
|
|
var torch_file = try async.callBlocking(File.init, .{ tmp_alloc, mmap_file });
|
2023-04-04 17:20:53 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
const ops = try torch_file.parsePickle(tmp_alloc);
|
|
|
|
|
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
2023-04-04 17:20:53 +00:00
|
|
|
|
2023-04-20 15:43:18 +00:00
|
|
|
// file ownership is transferred to the BufferStore
|
2025-08-28 14:39:21 +00:00
|
|
|
var res = try zml.aio.BufferStore.initWithFiles(allocator, &.{mmap_file});
|
2023-04-20 15:43:18 +00:00
|
|
|
try torch_file.parseModel(py_values, &res);
|
|
|
|
|
return res;
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|