Fix empty JSON array handling in safetensor metadata loader and refactor torch loader (make ops slices const and improve readability).

This commit is contained in:
Tarry Singh 2023-03-28 16:17:00 +00:00
parent aae37738a5
commit ef922e3aea
7 changed files with 160 additions and 165 deletions

View File

@ -27,77 +27,86 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
return res;
}
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: std.json.Value) !void {
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
const metadata = &store._metadata;
switch (val) {
.null => try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .null = {} }),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .boolval = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = try allocator.dupe(u8, v) }),
.array => |v| switch (validSlice(v)) {
true => {
if (v.items.len == 0) return;
switch (v.items[0]) {
.bool => {
const key = prefix.items;
return switch (val) {
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
.array => |v| {
if (v.items.len == 0) return;
return if (validSlice(v)) |item_type| {
const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) {
.bool => blk: {
const values = try allocator.alloc(bool, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.bool;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .boolval, .data = std.mem.sliceAsBytes(values) } });
break :blk .{ std.mem.sliceAsBytes(values), .boolval };
},
.integer => {
.integer => blk: {
const values = try allocator.alloc(i64, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.integer;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } });
break :blk .{ std.mem.sliceAsBytes(values), .int64 };
},
.float => {
.float => blk: {
const values = try allocator.alloc(f64, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| values[i] = item.float;
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } });
break :blk .{ std.mem.sliceAsBytes(values), .float64 };
},
inline .string, .number_string => |_, tag| {
inline .string, .number_string => |tag| blk: {
const values = try allocator.alloc([]const u8, v.items.len);
errdefer allocator.free(values);
for (v.items, 0..) |item, i| {
values[i] = @field(item, @tagName(tag));
}
try metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } });
break :blk .{ std.mem.sliceAsBytes(values), .string };
},
else => unreachable,
.null, .array, .object => unreachable,
};
try metadata.put(
allocator,
try allocator.dupe(u8, key),
.{ .array = .{ .item_type = dtype, .data = data } },
);
} else {
for (v.items, 0..) |item, i| {
var new_prefix = prefix;
if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.');
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_prefix, item);
}
},
false => for (v.items, 0..) |item, i| {
var new_key = key;
if (key.items.len > 0)
new_key.appendAssumeCapacity('.');
new_key.items.len += std.fmt.formatIntBuf(new_key.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_key, item);
},
};
},
.object => |v| {
var obj_iter = v.iterator();
while (obj_iter.next()) |entry| {
var new_key = key;
if (key.items.len > 0)
new_key.appendAssumeCapacity('.');
new_key.appendSliceAssumeCapacity(entry.key_ptr.*);
try parseMetadata(allocator, store, new_key, entry.value_ptr.*);
var new_prefix = prefix;
if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.');
new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*);
try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*);
}
},
}
};
}
fn validSlice(v: std.json.Array) bool {
const item_type = std.meta.activeTag(v.items[0]);
/// We can only create a Zig slice out of json array, if all values
/// in the array have the same type.
fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) {
if (v.items.len == 0) return null;
const item_type: std.meta.Tag(std.json.Value) = v.items[0];
switch (item_type) {
.null, .array, .object => return false,
.null, .array, .object => return null,
else => {},
}
for (v.items[1..]) |item|
if (item_type != std.meta.activeTag(item)) return false;
for (v.items[1..]) |item| {
if (item != item_type)
return null;
}
return true;
return item_type;
}

View File

@ -11,30 +11,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io);
fn stringToDtype(v: []const u8) !zml.DataType {
const Case = enum { F64, F32, F16, BF16, F8_E4M3, I64, I32, I16, I8, U64, U32, U16, U8, BOOL };
if (std.meta.stringToEnum(Case, v)) |case| {
return switch (case) {
.F64 => .f64,
.F32 => .f32,
.F16 => .f16,
.BF16 => .bf16,
.F8_E4M3 => .f8e4m3fn,
.I64 => .i64,
.I32 => .i32,
.I16 => .i16,
.I8 => .i8,
.U64 => .u64,
.U32 => .u32,
.U16 => .u16,
.U8 => .u8,
.BOOL => .bool,
};
}
std.log.err("Unsupported type-string: {s}\n", .{v});
return error.UnsupportedDataType;
}
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
@ -93,9 +69,13 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little));
const json_data = try allocator.alloc(u8, json_header_length);
_ = try r.readAtLeast(json_data, json_header_length);
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
const n = try r.readAll(json_data);
if (n != json_header_length) {
log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
return error.CorruptedFile;
}
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{});
var buffer_file = try MemoryMappedFile.init(file);
errdefer buffer_file.deinit();
buffer_file.data_offset = 8 + json_header_length;
@ -138,3 +118,27 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
try store.buffers.put(allocator, try allocator.dupe(u8, key), buf);
}
}
fn stringToDtype(safetensor_type: []const u8) !zml.DataType {
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "F64", .f64 },
.{ "F32", .f32 },
.{ "F16", .f16 },
.{ "BF16", .bf16 },
.{ "F8_E4M3", .f8e4m3fn },
.{ "I64", .i64 },
.{ "I32", .i32 },
.{ "I16", .i16 },
.{ "I8", .i8 },
.{ "U64", .u64 },
.{ "U32", .u32 },
.{ "U16", .u16 },
.{ "U8", .u8 },
.{ "BOOL", .bool },
});
return map.get(safetensor_type) orelse {
log.err("Unsupported safetensor data type: {s}", .{safetensor_type});
return error.UnsupportedDataType;
};
}

View File

@ -20,41 +20,6 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io);
const TorchType = enum {
float64,
double,
float32,
float,
float16,
half,
bfloat16,
int64,
long,
int32,
int,
int16,
short,
int8,
char,
uint8,
byte,
};
fn dtypeFromStr(str: []const u8) !zml.DataType {
const case = std.meta.stringToEnum(TorchType, str) orelse return error.UnknownTensorType;
return switch (case) {
.float64, .double => .f64,
.float32, .float => .f32,
.float16, .half => .f16,
.bfloat16 => .bf16,
.int64, .long => .i64,
.int32, .int => .i32,
.int16, .short => .i16,
.int8, .char => .i8,
.uint8, .byte => .u8,
};
}
/// Opens and loads a BufferStore from the torch file at the given path.
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
const file = asynk.File.open(path, .{}) catch |err| {
@ -80,6 +45,37 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
return res;
}
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}
pub const PickleData = struct {
stack: PickleStack,
memo: PickleMemo,
@ -89,7 +85,7 @@ pub const PickleData = struct {
return switch (v) {
.global => |object| switch (object.member) {
.raw => |raw| {
if (std.mem.eql(u8, ns, raw.global[0]) and std.mem.eql(u8, name, raw.global[1]) and object.args[0] == .seq) {
if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) {
return true;
} else return false;
},
@ -179,7 +175,7 @@ pub const PickleData = struct {
const rank = raw_shape.values.len;
const shape = dimsFromValues(raw_shape.values);
var strides = dimsFromValues(raw_strides.values);
const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) {
const storage_type, const sfile = switch (pidval.ref) {
.seq => |seq| blk: {
const sargs = seq.values;
if (seq.type == .tuple and
@ -191,17 +187,16 @@ pub const PickleData = struct {
{
const op = sargs[1].raw.global;
const sfile = sargs[2].string;
const sdev = sargs[3].string;
const styp = op[1];
if (std.mem.eql(u8, "torch", op[0]) and std.mem.endsWith(u8, styp, "Storage")) {
break :blk .{ std.ascii.lowerString(styp[0 .. styp.len - 7], styp[0 .. styp.len - 7]), sfile, sdev };
// const sdev = sargs[3].string;
if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) {
break :blk .{ op.class, sfile };
} else @panic("Unexpected storage type part of persistant ID");
} else @panic("Unexpected value for persistant ID");
},
else => @panic("Unexpected value for persistant ID"),
};
_ = sdev;
const data_type = try dtypeFromStr(stype);
const data_type = try storageToDtype(storage_type);
for (strides[0..rank]) |*s| s.* *= data_type.sizeOf();
var sfile_buf = std.ArrayList(u8).init(allocator);
@ -296,10 +291,7 @@ pub const PickleData = struct {
log.err("Duplicate key: {s}", .{new_prefix.items});
allocator.free(key);
} else {
const val = try allocator.alloc(u8, global[0].len + 1 + global[1].len);
@memcpy(val[0..global[0].len], global[0]);
val[global[0].len] = '.';
@memcpy(val[global[0].len + 1 ..], global[1]);
const val = try std.mem.join(allocator, ".", &.{ global.module, global.class });
d.value_ptr.* = .{ .string = val };
}
},

View File

@ -398,7 +398,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
}),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } });
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
}),
.obj => try stack.values.append(blk: {

View File

@ -7,35 +7,35 @@ pub const PickleOp = union(RawPickleOp) {
pop,
pop_mark,
dup,
float: []u8,
int: []u8,
float: []const u8,
int: []const u8,
binint: i32,
binint1: u8,
long: []u8,
long: []const u8,
binint2: u16,
none,
persid: []u8,
persid: []const u8,
binpersid,
reduce,
string: []u8,
binstring: []u8,
short_binstring: []u8,
unicode: []u8,
binunicode: []u8,
string: []const u8,
binstring: []const u8,
short_binstring: []const u8,
unicode: []const u8,
binunicode: []const u8,
append,
build,
global: [2][]u8,
global: PyType,
dict,
empty_dict,
appends,
get: []u8,
get: []const u8,
binget: u8,
inst: [2][]u8,
inst: PyType,
long_binget: u32,
list,
empty_list,
obj,
put: []u8,
put: []const u8,
binput: u8,
long_binput: u32,
setitem,
@ -53,13 +53,13 @@ pub const PickleOp = union(RawPickleOp) {
tuple3,
newtrue,
newfalse,
long1: []u8,
long4: []u8,
binbytes: []u8,
short_binbytes: []u8,
short_binunicode: []u8,
binunicode8: []u8,
binbytes8: []u8,
long1: []const u8,
long4: []const u8,
binbytes: []const u8,
short_binbytes: []const u8,
short_binunicode: []const u8,
binunicode8: []const u8,
binbytes8: []const u8,
empty_set,
additems,
frozenset,
@ -67,10 +67,12 @@ pub const PickleOp = union(RawPickleOp) {
stack_global,
memoize,
frame: u64,
bytearray8: []u8,
bytearray8: []const u8,
next_buffer,
readonly_buffer,
pub const PyType = struct { module: []const u8, class: []const u8 };
pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void {
switch (self) {
.float,
@ -93,10 +95,9 @@ pub const PickleOp = union(RawPickleOp) {
.binbytes8,
.bytearray8,
=> |v| allocator.free(v),
.global, .inst => |fields| {
inline for (fields) |field| {
allocator.free(field);
}
.global, .inst => |py_type| {
allocator.free(py_type.module);
allocator.free(py_type.class);
},
else => {},
}
@ -131,12 +132,10 @@ pub const PickleOp = union(RawPickleOp) {
return res;
},
inline .global, .inst => |v, tag| {
var out: std.meta.Tuple(&.{ []u8, []u8 }) = undefined;
inline for (0..2) |i| {
out[i] = try allocator.alloc(u8, v[i].len);
@memcpy(out[i], v[i]);
}
@field(res, @tagName(tag)) = out;
@field(res, @tagName(tag)) = PyType{
.module = try allocator.dupe(u8, v.module),
.class = try allocator.dupe(u8, v.class),
};
return res;
},
else => self,

View File

@ -240,13 +240,12 @@ pub const Decoder = struct {
},
.append => try results.append(.{ .append = {} }),
.build => try results.append(.{ .build = {} }),
.global => {
const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf0);
const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf1);
_ = (buf1.len + 1);
try results.append(.{ .global = .{ buf0, buf1 } });
.global, .inst => {
const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(module);
const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(class);
try results.append(.{ .global = .{ .module = module, .class = class } });
},
.dict => try results.append(.{ .dict = {} }),
.empty_dict => try results.append(.{ .empty_dict = {} }),
@ -257,14 +256,6 @@ pub const Decoder = struct {
try results.append(.{ .get = buf });
},
.binget => try results.append(.{ .binget = try reader.readByte() }),
.inst => {
const buf0 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf0);
const buf1 = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf1);
_ = (buf1.len + 1);
try results.append(.{ .inst = .{ buf0, buf1 } });
},
.long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }),
.list => try results.append(.{ .list = {} }),
.empty_list => try results.append(.{ .empty_list = {} }),

View File

@ -265,7 +265,7 @@ pub const Value = union(ValueType) {
},
.string => |v| try writer.print("\"{s}\"", .{v}),
.raw => |v| switch (v) {
.global => |raw_global| try writer.print("\"{s}\", \"{s}\"", .{ raw_global[0], raw_global[1] }),
.global => |py_type| try writer.print("\"{s}\", \"{s}\"", .{ py_type.module, py_type.class }),
else => try writer.print("{any}", .{v}),
},
inline else => |v| {