From 89cf2233d3b97b59ec39365490375ac315da96f2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 23 May 2023 15:06:59 +0000 Subject: [PATCH] zml/aio: enable reading metadata from index.json for sharded safetensor files, allowing metadata storage alongside model config. --- zml/aio/safetensors.zig | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index c87eb0f..b5f3c80 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -38,10 +38,10 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std. const json_data = try allocator.alloc(u8, (try file.stat()).size); _ = try r.readAtLeast(json_data, json_data.len); - const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); + const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); var loaded_files = std.StringHashMap(void).init(allocator); - const weight_map = metadata.object.get("weight_map").?.object; + const weight_map = index.object.get("weight_map").?.object; var it = weight_map.iterator(); while (it.next()) |entry| { const filename = entry.value_ptr.string; @@ -55,6 +55,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std. const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename }); try loadFile(allocator, store, files, full_filename); } + + if (index.object.get("__metadata__")) |metadata| { + var prefix_buf: [1024]u8 = undefined; + try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata); + } } fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {