Fix empty JSON array handling in safetensor metadata loader and refactor torch loader (make ops slices const and improve readability).

This commit is contained in:
Tarry Singh 2023-03-28 16:17:00 +00:00
parent aae37738a5
commit ef922e3aea
7 changed files with 160 additions and 165 deletions

View File

@ -27,77 +27,86 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
return res; return res;
} }
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: std.json.Value) !void { pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
const metadata = &store._metadata; const metadata = &store._metadata;
switch (val) { const key = prefix.items;
.null => try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .null = {} }), return switch (val) {
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .boolval = v }), .null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = try allocator.dupe(u8, v) }), .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }),
.array => |v| switch (validSlice(v)) { .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
true => { .array => |v| {
if (v.items.len == 0) return; if (v.items.len == 0) return;
switch (v.items[0]) { return if (validSlice(v)) |item_type| {
.bool => { const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) {
.bool => blk: {
const values = try allocator.alloc(bool, v.items.len); const values = try allocator.alloc(bool, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.bool; for (v.items, 0..) |item, i| values[i] = item.bool;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .boolval, .data = std.mem.sliceAsBytes(values) } }); break :blk .{ std.mem.sliceAsBytes(values), .boolval };
}, },
.integer => { .integer => blk: {
const values = try allocator.alloc(i64, v.items.len); const values = try allocator.alloc(i64, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.integer; for (v.items, 0..) |item, i| values[i] = item.integer;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } }); break :blk .{ std.mem.sliceAsBytes(values), .int64 };
}, },
.float => { .float => blk: {
const values = try allocator.alloc(f64, v.items.len); const values = try allocator.alloc(f64, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.float; for (v.items, 0..) |item, i| values[i] = item.float;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } }); break :blk .{ std.mem.sliceAsBytes(values), .float64 };
}, },
inline .string, .number_string => |_, tag| { inline .string, .number_string => |tag| blk: {
const values = try allocator.alloc([]const u8, v.items.len); const values = try allocator.alloc([]const u8, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| { for (v.items, 0..) |item, i| {
values[i] = @field(item, @tagName(tag)); values[i] = @field(item, @tagName(tag));
} }
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } }); break :blk .{ std.mem.sliceAsBytes(values), .string };
}, },
else => unreachable, .null, .array, .object => unreachable,
};
try metadata.put(
allocator,
try allocator.dupe(u8, key),
.{ .array = .{ .item_type = dtype, .data = data } },
);
} else {
for (v.items, 0..) |item, i| {
var new_prefix = prefix;
if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.');
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_prefix, item);
} }
}, };
false => for (v.items, 0..) |item, i| {
var new_key = key;
if (key.items.len > 0)
new_key.appendAssumeCapacity('.');
new_key.items.len += std.fmt.formatIntBuf(new_key.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_key, item);
},
}, },
.object => |v| { .object => |v| {
var obj_iter = v.iterator(); var obj_iter = v.iterator();
while (obj_iter.next()) |entry| { while (obj_iter.next()) |entry| {
var new_key = key; var new_prefix = prefix;
if (key.items.len > 0) if (prefix.items.len > 0)
new_key.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
new_key.appendSliceAssumeCapacity(entry.key_ptr.*); new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*);
try parseMetadata(allocator, store, new_key, entry.value_ptr.*); try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*);
} }
}, },
} };
} }
fn validSlice(v: std.json.Array) bool { /// We can only create a Zig slice out of json array, if all values
const item_type = std.meta.activeTag(v.items[0]); /// in the array have the same type.
fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) {
if (v.items.len == 0) return null;
const item_type: std.meta.Tag(std.json.Value) = v.items[0];
switch (item_type) { switch (item_type) {
.null, .array, .object => return false, .null, .array, .object => return null,
else => {}, else => {},
} }
for (v.items[1..]) |item| for (v.items[1..]) |item| {
if (item_type != std.meta.activeTag(item)) return false; if (item != item_type)
return null;
return true; }
return item_type;
} }

View File

@ -11,30 +11,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.zml_io);
fn stringToDtype(v: []const u8) !zml.DataType {
const Case = enum { F64, F32, F16, BF16, F8_E4M3, I64, I32, I16, I8, U64, U32, U16, U8, BOOL };
if (std.meta.stringToEnum(Case, v)) |case| {
return switch (case) {
.F64 => .f64,
.F32 => .f32,
.F16 => .f16,
.BF16 => .bf16,
.F8_E4M3 => .f8e4m3fn,
.I64 => .i64,
.I32 => .i32,
.I16 => .i16,
.I8 => .i8,
.U64 => .u64,
.U32 => .u32,
.U16 => .u16,
.U8 => .u8,
.BOOL => .bool,
};
}
std.log.err("Unsupported type-string: {s}\n", .{v});
return error.UnsupportedDataType;
}
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{ var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator), .arena = std.heap.ArenaAllocator.init(allocator),
@ -93,9 +69,13 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little)); const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little));
const json_data = try allocator.alloc(u8, json_header_length); const json_data = try allocator.alloc(u8, json_header_length);
_ = try r.readAtLeast(json_data, json_header_length); const n = try r.readAll(json_data);
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); if (n != json_header_length) {
log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
return error.CorruptedFile;
}
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{});
var buffer_file = try MemoryMappedFile.init(file); var buffer_file = try MemoryMappedFile.init(file);
errdefer buffer_file.deinit(); errdefer buffer_file.deinit();
buffer_file.data_offset = 8 + json_header_length; buffer_file.data_offset = 8 + json_header_length;
@ -138,3 +118,27 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
try store.buffers.put(allocator, try allocator.dupe(u8, key), buf); try store.buffers.put(allocator, try allocator.dupe(u8, key), buf);
} }
} }
fn stringToDtype(safetensor_type: []const u8) !zml.DataType {
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "F64", .f64 },
.{ "F32", .f32 },
.{ "F16", .f16 },
.{ "BF16", .bf16 },
.{ "F8_E4M3", .f8e4m3fn },
.{ "I64", .i64 },
.{ "I32", .i32 },
.{ "I16", .i16 },
.{ "I8", .i8 },
.{ "U64", .u64 },
.{ "U32", .u32 },
.{ "U16", .u16 },
.{ "U8", .u8 },
.{ "BOOL", .bool },
});
return map.get(safetensor_type) orelse {
log.err("Unsupported safetensor data type: {s}", .{safetensor_type});
return error.UnsupportedDataType;
};
}

View File

@ -20,41 +20,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.zml_io);
const TorchType = enum {
float64,
double,
float32,
float,
float16,
half,
bfloat16,
int64,
long,
int32,
int,
int16,
short,
int8,
char,
uint8,
byte,
};
fn dtypeFromStr(str: []const u8) !zml.DataType {
const case = std.meta.stringToEnum(TorchType, str) orelse return error.UnknownTensorType;
return switch (case) {
.float64, .double => .f64,
.float32, .float => .f32,
.float16, .half => .f16,
.bfloat16 => .bf16,
.int64, .long => .i64,
.int32, .int => .i32,
.int16, .short => .i16,
.int8, .char => .i8,
.uint8, .byte => .u8,
};
}
/// Opens and loads a BufferStore from the torch file at the given path. /// Opens and loads a BufferStore from the torch file at the given path.
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
const file = asynk.File.open(path, .{}) catch |err| { const file = asynk.File.open(path, .{}) catch |err| {
@ -80,6 +45,37 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
return res; return res;
} }
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}
pub const PickleData = struct { pub const PickleData = struct {
stack: PickleStack, stack: PickleStack,
memo: PickleMemo, memo: PickleMemo,
@ -89,7 +85,7 @@ pub const PickleData = struct {
return switch (v) { return switch (v) {
.global => |object| switch (object.member) { .global => |object| switch (object.member) {
.raw => |raw| { .raw => |raw| {
if (std.mem.eql(u8, ns, raw.global[0]) and std.mem.eql(u8, name, raw.global[1]) and object.args[0] == .seq) { if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) {
return true; return true;
} else return false; } else return false;
}, },
@ -179,7 +175,7 @@ pub const PickleData = struct {
const rank = raw_shape.values.len; const rank = raw_shape.values.len;
const shape = dimsFromValues(raw_shape.values); const shape = dimsFromValues(raw_shape.values);
var strides = dimsFromValues(raw_strides.values); var strides = dimsFromValues(raw_strides.values);
const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) { const storage_type, const sfile = switch (pidval.ref) {
.seq => |seq| blk: { .seq => |seq| blk: {
const sargs = seq.values; const sargs = seq.values;
if (seq.type == .tuple and if (seq.type == .tuple and
@ -191,17 +187,16 @@ pub const PickleData = struct {
{ {
const op = sargs[1].raw.global; const op = sargs[1].raw.global;
const sfile = sargs[2].string; const sfile = sargs[2].string;
const sdev = sargs[3].string; // const sdev = sargs[3].string;
const styp = op[1]; if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) {
if (std.mem.eql(u8, "torch", op[0]) and std.mem.endsWith(u8, styp, "Storage")) { break :blk .{ op.class, sfile };
break :blk .{ std.ascii.lowerString(styp[0 .. styp.len - 7], styp[0 .. styp.len - 7]), sfile, sdev };
} else @panic("Unexpected storage type part of persistant ID"); } else @panic("Unexpected storage type part of persistant ID");
} else @panic("Unexpected value for persistant ID"); } else @panic("Unexpected value for persistant ID");
}, },
else => @panic("Unexpected value for persistant ID"), else => @panic("Unexpected value for persistant ID"),
}; };
_ = sdev;
const data_type = try dtypeFromStr(stype); const data_type = try storageToDtype(storage_type);
for (strides[0..rank]) |*s| s.* *= data_type.sizeOf(); for (strides[0..rank]) |*s| s.* *= data_type.sizeOf();
var sfile_buf = std.ArrayList(u8).init(allocator); var sfile_buf = std.ArrayList(u8).init(allocator);
@ -296,10 +291,7 @@ pub const PickleData = struct {
log.err("Duplicate key: {s}", .{new_prefix.items}); log.err("Duplicate key: {s}", .{new_prefix.items});
allocator.free(key); allocator.free(key);
} else { } else {
const val = try allocator.alloc(u8, global[0].len + 1 + global[1].len); const val = try std.mem.join(allocator, ".", &.{ global.module, global.class });
@memcpy(val[0..global[0].len], global[0]);
val[global[0].len] = '.';
@memcpy(val[global[0].len + 1 ..], global[1]);
d.value_ptr.* = .{ .string = val }; d.value_ptr.* = .{ .string = val };
} }
}, },

View File

@ -398,7 +398,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
}), }),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }), .list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: { .inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } }); const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) }; break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
}), }),
.obj => try stack.values.append(blk: { .obj => try stack.values.append(blk: {

View File

@ -7,35 +7,35 @@ pub const PickleOp = union(RawPickleOp) {
pop, pop,
pop_mark, pop_mark,
dup, dup,
float: []u8, float: []const u8,
int: []u8, int: []const u8,
binint: i32, binint: i32,
binint1: u8, binint1: u8,
long: []u8, long: []const u8,
binint2: u16, binint2: u16,
none, none,
persid: []u8, persid: []const u8,
binpersid, binpersid,
reduce, reduce,
string: []u8, string: []const u8,
binstring: []u8, binstring: []const u8,
short_binstring: []u8, short_binstring: []const u8,
unicode: []u8, unicode: []const u8,
binunicode: []u8, binunicode: []const u8,
append, append,
build, build,
global: [2][]u8, global: PyType,
dict, dict,
empty_dict, empty_dict,
appends, appends,
get: []u8, get: []const u8,
binget: u8, binget: u8,
inst: [2][]u8, inst: PyType,
long_binget: u32, long_binget: u32,
list, list,
empty_list, empty_list,
obj, obj,
put: []u8, put: []const u8,
binput: u8, binput: u8,
long_binput: u32, long_binput: u32,
setitem, setitem,
@ -53,13 +53,13 @@ pub const PickleOp = union(RawPickleOp) {
tuple3, tuple3,
newtrue, newtrue,
newfalse, newfalse,
long1: []u8, long1: []const u8,
long4: []u8, long4: []const u8,
binbytes: []u8, binbytes: []const u8,
short_binbytes: []u8, short_binbytes: []const u8,
short_binunicode: []u8, short_binunicode: []const u8,
binunicode8: []u8, binunicode8: []const u8,
binbytes8: []u8, binbytes8: []const u8,
empty_set, empty_set,
additems, additems,
frozenset, frozenset,
@ -67,10 +67,12 @@ pub const PickleOp = union(RawPickleOp) {
stack_global, stack_global,
memoize, memoize,
frame: u64, frame: u64,
bytearray8: []u8, bytearray8: []const u8,
next_buffer, next_buffer,
readonly_buffer, readonly_buffer,
pub const PyType = struct { module: []const u8, class: []const u8 };
pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void { pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void {
switch (self) { switch (self) {
.float, .float,
@ -93,10 +95,9 @@ pub const PickleOp = union(RawPickleOp) {
.binbytes8, .binbytes8,
.bytearray8, .bytearray8,
=> |v| allocator.free(v), => |v| allocator.free(v),
.global, .inst => |fields| { .global, .inst => |py_type| {
inline for (fields) |field| { allocator.free(py_type.module);
allocator.free(field); allocator.free(py_type.class);
}
}, },
else => {}, else => {},
} }
@ -131,12 +132,10 @@ pub const PickleOp = union(RawPickleOp) {
return res; return res;
}, },
inline .global, .inst => |v, tag| { inline .global, .inst => |v, tag| {
var out: std.meta.Tuple(&.{ []u8, []u8 }) = undefined; @field(res, @tagName(tag)) = PyType{
inline for (0..2) |i| { .module = try allocator.dupe(u8, v.module),
out[i] = try allocator.alloc(u8, v[i].len); .class = try allocator.dupe(u8, v.class),
@memcpy(out[i], v[i]); };
}
@field(res, @tagName(tag)) = out;
return res; return res;
}, },
else => self, else => self,

View File

@ -240,13 +240,12 @@ pub const Decoder = struct {
}, },
.append => try results.append(.{ .append = {} }), .append => try results.append(.{ .append = {} }),
.build => try results.append(.{ .build = {} }), .build => try results.append(.{ .build = {} }),
.global => { .global, .inst => {
const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf0); errdefer allocator.free(module);
const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len); const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf1); errdefer allocator.free(class);
_ = (buf1.len + 1); try results.append(.{ .global = .{ .module = module, .class = class } });
try results.append(.{ .global = .{ buf0, buf1 } });
}, },
.dict => try results.append(.{ .dict = {} }), .dict => try results.append(.{ .dict = {} }),
.empty_dict => try results.append(.{ .empty_dict = {} }), .empty_dict => try results.append(.{ .empty_dict = {} }),
@ -257,14 +256,6 @@ pub const Decoder = struct {
try results.append(.{ .get = buf }); try results.append(.{ .get = buf });
}, },
.binget => try results.append(.{ .binget = try reader.readByte() }), .binget => try results.append(.{ .binget = try reader.readByte() }),
.inst => {
const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf0);
const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf1);
_ = (buf1.len + 1);
try results.append(.{ .inst = .{ buf0, buf1 } });
},
.long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }), .long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }),
.list => try results.append(.{ .list = {} }), .list => try results.append(.{ .list = {} }),
.empty_list => try results.append(.{ .empty_list = {} }), .empty_list => try results.append(.{ .empty_list = {} }),

View File

@ -265,7 +265,7 @@ pub const Value = union(ValueType) {
}, },
.string => |v| try writer.print("\"{s}\"", .{v}), .string => |v| try writer.print("\"{s}\"", .{v}),
.raw => |v| switch (v) { .raw => |v| switch (v) {
.global => |raw_global| try writer.print("\"{s}\", \"{s}\"", .{ raw_global[0], raw_global[1] }), .global => |py_type| try writer.print("\"{s}\", \"{s}\"", .{ py_type.module, py_type.class }),
else => try writer.print("{any}", .{v}), else => try writer.print("{any}", .{v}),
}, },
inline else => |v| { inline else => |v| {