zml/aio: enable reading metadata from index.json for sharded safetensor files, allowing metadata storage alongside model config.

This commit is contained in:
Tarry Singh 2023-05-23 15:06:59 +00:00
parent 2f54e2a5f3
commit 89cf2233d3

View File

@ -38,10 +38,10 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
const json_data = try allocator.alloc(u8, (try file.stat()).size); const json_data = try allocator.alloc(u8, (try file.stat()).size);
_ = try r.readAtLeast(json_data, json_data.len); _ = try r.readAtLeast(json_data, json_data.len);
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
var loaded_files = std.StringHashMap(void).init(allocator); var loaded_files = std.StringHashMap(void).init(allocator);
const weight_map = metadata.object.get("weight_map").?.object; const weight_map = index.object.get("weight_map").?.object;
var it = weight_map.iterator(); var it = weight_map.iterator();
while (it.next()) |entry| { while (it.next()) |entry| {
const filename = entry.value_ptr.string; const filename = entry.value_ptr.string;
@ -55,6 +55,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename }); const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename });
try loadFile(allocator, store, files, full_filename); try loadFile(allocator, store, files, full_filename);
} }
if (index.object.get("__metadata__")) |metadata| {
var prefix_buf: [1024]u8 = undefined;
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata);
}
} }
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {