From ef922e3aeaa84b3a77affd2734522c813535d8a0 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 28 Mar 2023 16:17:00 +0000 Subject: [PATCH] Fix empty JSON array handling in safetensor metadata loader and refactor torch loader (make ops slices const and improve readability). --- zml/aio/json.zig | 97 ++++++++++++++++++++++------------------ zml/aio/safetensors.zig | 56 ++++++++++++----------- zml/aio/torch.zig | 86 ++++++++++++++++------------------- zml/aio/torch/eval.zig | 2 +- zml/aio/torch/ops.zig | 61 +++++++++++++------------ zml/aio/torch/parser.zig | 21 +++------ zml/aio/torch/value.zig | 2 +- 7 files changed, 160 insertions(+), 165 deletions(-) diff --git a/zml/aio/json.zig b/zml/aio/json.zig index b08cb17..c8ba494 100644 --- a/zml/aio/json.zig +++ b/zml/aio/json.zig @@ -27,77 +27,86 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore return res; } -pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: std.json.Value) !void { +pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void { const metadata = &store._metadata; - switch (val) { - .null => try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .null = {} }), - .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .boolval = v }), - .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), - .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), - .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = try allocator.dupe(u8, v) }), - .array => |v| switch (validSlice(v)) { - true => { - if (v.items.len == 0) return; - switch (v.items[0]) { - .bool => { + const key = prefix.items; + return switch (val) { + .null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }), + .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }), + .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }), + .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }), + .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }), + .array => |v| { + if (v.items.len == 0) return; + return if (validSlice(v)) |item_type| { + const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) { + .bool => blk: { const values = try allocator.alloc(bool, v.items.len); - errdefer allocator.free(values); for (v.items, 0..) |item, i| values[i] = item.bool; - try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .boolval, .data = std.mem.sliceAsBytes(values) } }); + break :blk .{ std.mem.sliceAsBytes(values), .boolval }; }, - .integer => { + .integer => blk: { const values = try allocator.alloc(i64, v.items.len); - errdefer allocator.free(values); for (v.items, 0..) |item, i| values[i] = item.integer; - try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } }); + break :blk .{ std.mem.sliceAsBytes(values), .int64 }; }, - .float => { + .float => blk: { const values = try allocator.alloc(f64, v.items.len); - errdefer allocator.free(values); for (v.items, 0..) |item, i| values[i] = item.float; - try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } }); + break :blk .{ std.mem.sliceAsBytes(values), .float64 }; }, - inline .string, .number_string => |_, tag| { + inline .string, .number_string => |tag| blk: { const values = try allocator.alloc([]const u8, v.items.len); - errdefer allocator.free(values); for (v.items, 0..) |item, i| { values[i] = @field(item, @tagName(tag)); } - try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } }); + break :blk .{ std.mem.sliceAsBytes(values), .string }; }, - else => unreachable, + .null, .array, .object => unreachable, + }; + try metadata.put( + allocator, + try allocator.dupe(u8, key), + .{ .array = .{ .item_type = dtype, .data = data } }, + ); + } else { + for (v.items, 0..) |item, i| { + var new_prefix = prefix; + if (prefix.items.len > 0) + new_prefix.appendAssumeCapacity('.'); + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + try parseMetadata(allocator, store, new_prefix, item); } - }, - false => for (v.items, 0..) |item, i| { - var new_key = key; - if (key.items.len > 0) - new_key.appendAssumeCapacity('.'); - new_key.items.len += std.fmt.formatIntBuf(new_key.unusedCapacitySlice(), i, 10, .lower, .{}); - try parseMetadata(allocator, store, new_key, item); - }, + }; }, .object => |v| { var obj_iter = v.iterator(); while (obj_iter.next()) |entry| { - var new_key = key; - if (key.items.len > 0) - new_key.appendAssumeCapacity('.'); - new_key.appendSliceAssumeCapacity(entry.key_ptr.*); - try parseMetadata(allocator, store, new_key, entry.value_ptr.*); + var new_prefix = prefix; + if (prefix.items.len > 0) + new_prefix.appendAssumeCapacity('.'); + new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*); + try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*); } }, - } + }; } -fn validSlice(v: std.json.Array) bool { - const item_type = std.meta.activeTag(v.items[0]); +/// We can only create a Zig slice out of json array, if all values +/// in the array have the same type. +fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) { + if (v.items.len == 0) return null; + + const item_type: std.meta.Tag(std.json.Value) = v.items[0]; switch (item_type) { - .null, .array, .object => return false, + .null, .array, .object => return null, else => {}, } - for (v.items[1..]) |item| - if (item_type != std.meta.activeTag(item)) return false; + for (v.items[1..]) |item| { + if (item != item_type) + return null; + } - return true; + return item_type; } diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index c56e4b0..519be29 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -11,30 +11,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8); const Allocator = std.mem.Allocator; const log = std.log.scoped(.zml_io); -fn stringToDtype(v: []const u8) !zml.DataType { - const Case = enum { F64, F32, F16, BF16, F8_E4M3, I64, I32, I16, I8, U64, U32, U16, U8, BOOL }; - if (std.meta.stringToEnum(Case, v)) |case| { - return switch (case) { - .F64 => .f64, - .F32 => .f32, - .F16 => .f16, - .BF16 => .bf16, - .F8_E4M3 => .f8e4m3fn, - .I64 => .i64, - .I32 => .i32, - .I16 => .i16, - .I8 => .i8, - .U64 => .u64, - .U32 => .u32, - .U16 => .u16, - .U8 => .u8, - .BOOL => .bool, - }; - } - std.log.err("Unsupported type-string: {s}\n", .{v}); - return error.UnsupportedDataType; -} - pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { var res: zml.aio.BufferStore = .{ .arena = std.heap.ArenaAllocator.init(allocator), @@ -93,9 +69,13 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little)); const json_data = try allocator.alloc(u8, json_header_length); - _ = try r.readAtLeast(json_data, json_header_length); - const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); + const n = try r.readAll(json_data); + if (n != json_header_length) { + log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path }); + return error.CorruptedFile; + } + const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{}); var buffer_file = try MemoryMappedFile.init(file); errdefer buffer_file.deinit(); buffer_file.data_offset = 8 + json_header_length; @@ -138,3 +118,27 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array try store.buffers.put(allocator, try allocator.dupe(u8, key), buf); } } + +fn stringToDtype(safetensor_type: []const u8) !zml.DataType { + const map = std.StaticStringMap(zml.DataType).initComptime(.{ + .{ "F64", .f64 }, + .{ "F32", .f32 }, + .{ "F16", .f16 }, + .{ "BF16", .bf16 }, + .{ "F8_E4M3", .f8e4m3fn }, + .{ "I64", .i64 }, + .{ "I32", .i32 }, + .{ "I16", .i16 }, + .{ "I8", .i8 }, + .{ "U64", .u64 }, + .{ "U32", .u32 }, + .{ "U16", .u16 }, + .{ "U8", .u8 }, + .{ "BOOL", .bool }, + }); + + return map.get(safetensor_type) orelse { + log.err("Unsupported safetensor data type: {s}", .{safetensor_type}); + return error.UnsupportedDataType; + }; +} diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index b5e2404..5620a69 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -20,41 +20,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8); const Allocator = std.mem.Allocator; const log = std.log.scoped(.zml_io); -const TorchType = enum { - float64, - double, - float32, - float, - float16, - half, - bfloat16, - int64, - long, - int32, - int, - int16, - short, - int8, - char, - uint8, - byte, -}; - -fn dtypeFromStr(str: []const u8) !zml.DataType { - const case = std.meta.stringToEnum(TorchType, str) orelse return error.UnknownTensorType; - return switch (case) { - .float64, .double => .f64, - .float32, .float => .f32, - .float16, .half => .f16, - .bfloat16 => .bf16, - .int64, .long => .i64, - .int32, .int => .i32, - .int16, .short => .i16, - .int8, .char => .i8, - .uint8, .byte => .u8, - }; -} - /// Opens and loads a BufferStore from the torch file at the given path. pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { const file = asynk.File.open(path, .{}) catch |err| { @@ -80,6 +45,37 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { return res; } +/// Convert from a torch.Storage to a `zml.DataType`. +/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage +/// See https://pytorch.org/docs/stable/storage.html +fn storageToDtype(storage_type: []const u8) !zml.DataType { + const torch_type = storage_type[0 .. storage_type.len - "Storage".len]; + const map = std.StaticStringMap(zml.DataType).initComptime(.{ + .{ "Double", .f64 }, + .{ "Float", .f32 }, + .{ "Half", .f16 }, + .{ "Long", .i64 }, + .{ "Int", .i32 }, + .{ "Short", .i16 }, + .{ "Char", .i8 }, + .{ "Byte", .u8 }, + .{ "Bool", .bool }, + .{ "BFloat16", .bf16 }, + .{ "ComplexDouble", .c128 }, + .{ "ComplexFloat", .c64 }, + // QUInt8Storage + // QInt8Storage + // QInt32Storage + // QUInt4x2Storage + // QUInt2x4Storage + }); + + return map.get(torch_type) orelse { + log.err("Unsupported torch storage type: {s}", .{storage_type}); + return error.UnsupportedDataType; + }; +} + pub const PickleData = struct { stack: PickleStack, memo: PickleMemo, @@ -89,7 +85,7 @@ pub const PickleData = struct { return switch (v) { .global => |object| switch (object.member) { .raw => |raw| { - if (std.mem.eql(u8, ns, raw.global[0]) and std.mem.eql(u8, name, raw.global[1]) and object.args[0] == .seq) { + if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) { return true; } else return false; }, @@ -179,7 +175,7 @@ pub const PickleData = struct { const rank = raw_shape.values.len; const shape = dimsFromValues(raw_shape.values); var strides = dimsFromValues(raw_strides.values); - const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) { + const storage_type, const sfile = switch (pidval.ref) { .seq => |seq| blk: { const sargs = seq.values; if (seq.type == .tuple and @@ -191,17 +187,16 @@ pub const PickleData = struct { { const op = sargs[1].raw.global; const sfile = sargs[2].string; - const sdev = sargs[3].string; - const styp = op[1]; - if (std.mem.eql(u8, "torch", op[0]) and std.mem.endsWith(u8, styp, "Storage")) { - break :blk .{ std.ascii.lowerString(styp[0 .. styp.len - 7], styp[0 .. styp.len - 7]), sfile, sdev }; + // const sdev = sargs[3].string; + if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) { + break :blk .{ op.class, sfile }; } else @panic("Unexpected storage type part of persistant ID"); } else @panic("Unexpected value for persistant ID"); }, else => @panic("Unexpected value for persistant ID"), }; - _ = sdev; - const data_type = try dtypeFromStr(stype); + + const data_type = try storageToDtype(storage_type); for (strides[0..rank]) |*s| s.* *= data_type.sizeOf(); var sfile_buf = std.ArrayList(u8).init(allocator); @@ -296,10 +291,7 @@ pub const PickleData = struct { log.err("Duplicate key: {s}", .{new_prefix.items}); allocator.free(key); } else { - const val = try allocator.alloc(u8, global[0].len + 1 + global[1].len); - @memcpy(val[0..global[0].len], global[0]); - val[global[0].len] = '.'; - @memcpy(val[global[0].len + 1 ..], global[1]); + const val = try std.mem.join(allocator, ".", &.{ global.module, global.class }); d.value_ptr.* = .{ .string = val }; } }, diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index b04656a..d54f02d 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -398,7 +398,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: }), .list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }), .inst => |v| try stack.values.append(blk: { - const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } }); + const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) }; }), .obj => try stack.values.append(blk: { diff --git a/zml/aio/torch/ops.zig b/zml/aio/torch/ops.zig index a12a556..1a00f5d 100644 --- a/zml/aio/torch/ops.zig +++ b/zml/aio/torch/ops.zig @@ -7,35 +7,35 @@ pub const PickleOp = union(RawPickleOp) { pop, pop_mark, dup, - float: []u8, - int: []u8, + float: []const u8, + int: []const u8, binint: i32, binint1: u8, - long: []u8, + long: []const u8, binint2: u16, none, - persid: []u8, + persid: []const u8, binpersid, reduce, - string: []u8, - binstring: []u8, - short_binstring: []u8, - unicode: []u8, - binunicode: []u8, + string: []const u8, + binstring: []const u8, + short_binstring: []const u8, + unicode: []const u8, + binunicode: []const u8, append, build, - global: [2][]u8, + global: PyType, dict, empty_dict, appends, - get: []u8, + get: []const u8, binget: u8, - inst: [2][]u8, + inst: PyType, long_binget: u32, list, empty_list, obj, - put: []u8, + put: []const u8, binput: u8, long_binput: u32, setitem, @@ -53,13 +53,13 @@ pub const PickleOp = union(RawPickleOp) { tuple3, newtrue, newfalse, - long1: []u8, - long4: []u8, - binbytes: []u8, - short_binbytes: []u8, - short_binunicode: []u8, - binunicode8: []u8, - binbytes8: []u8, + long1: []const u8, + long4: []const u8, + binbytes: []const u8, + short_binbytes: []const u8, + short_binunicode: []const u8, + binunicode8: []const u8, + binbytes8: []const u8, empty_set, additems, frozenset, @@ -67,10 +67,12 @@ pub const PickleOp = union(RawPickleOp) { stack_global, memoize, frame: u64, - bytearray8: []u8, + bytearray8: []const u8, next_buffer, readonly_buffer, + pub const PyType = struct { module: []const u8, class: []const u8 }; + pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void { switch (self) { .float, @@ -93,10 +95,9 @@ pub const PickleOp = union(RawPickleOp) { .binbytes8, .bytearray8, => |v| allocator.free(v), - .global, .inst => |fields| { - inline for (fields) |field| { - allocator.free(field); - } + .global, .inst => |py_type| { + allocator.free(py_type.module); + allocator.free(py_type.class); }, else => {}, } @@ -131,12 +132,10 @@ pub const PickleOp = union(RawPickleOp) { return res; }, inline .global, .inst => |v, tag| { - var out: std.meta.Tuple(&.{ []u8, []u8 }) = undefined; - inline for (0..2) |i| { - out[i] = try allocator.alloc(u8, v[i].len); - @memcpy(out[i], v[i]); - } - @field(res, @tagName(tag)) = out; + @field(res, @tagName(tag)) = PyType{ + .module = try allocator.dupe(u8, v.module), + .class = try allocator.dupe(u8, v.class), + }; return res; }, else => self, diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index a17ac85..754b055 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -240,13 +240,12 @@ pub const Decoder = struct { }, .append => try results.append(.{ .append = {} }), .build => try results.append(.{ .build = {} }), - .global => { - const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf0); - const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf1); - _ = (buf1.len + 1); - try results.append(.{ .global = .{ buf0, buf1 } }); + .global, .inst => { + const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + errdefer allocator.free(module); + const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + errdefer allocator.free(class); + try results.append(.{ .global = .{ .module = module, .class = class } }); }, .dict => try results.append(.{ .dict = {} }), .empty_dict => try results.append(.{ .empty_dict = {} }), @@ -257,14 +256,6 @@ pub const Decoder = struct { try results.append(.{ .get = buf }); }, .binget => try results.append(.{ .binget = try reader.readByte() }), - .inst => { - const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf0); - const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf1); - _ = (buf1.len + 1); - try results.append(.{ .inst = .{ buf0, buf1 } }); - }, .long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }), .list => try results.append(.{ .list = {} }), .empty_list => try results.append(.{ .empty_list = {} }), diff --git a/zml/aio/torch/value.zig b/zml/aio/torch/value.zig index 9de3cc2..9c6acc4 100644 --- a/zml/aio/torch/value.zig +++ b/zml/aio/torch/value.zig @@ -265,7 +265,7 @@ pub const Value = union(ValueType) { }, .string => |v| try writer.print("\"{s}\"", .{v}), .raw => |v| switch (v) { - .global => |raw_global| try writer.print("\"{s}\", \"{s}\"", .{ raw_global[0], raw_global[1] }), + .global => |py_type| try writer.print("\"{s}\", \"{s}\"", .{ py_type.module, py_type.class }), else => try writer.print("{any}", .{v}), }, inline else => |v| {