aio: refactor PyTorch model parsing for better readability and optimize slice handling

This commit is contained in:
Tarry Singh 2023-01-25 12:16:27 +00:00
parent ebdb8db213
commit 5e1688cbfd
4 changed files with 189 additions and 223 deletions

View File

@ -101,12 +101,12 @@ pub const PickleData = struct {
fn isTensor(v: Value) bool {
if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) {
const args = v.global.args[0].seq[1];
const args = v.global.args[0].seq.values;
if (args.len >= 5 and
args[0] == .pers_id and
args[1] == .int and
args[2] == .seq and args[2].seq[0] == .tuple and
args[3] == .seq and args[3].seq[0] == .tuple)
args[1] == .int64 and
args[2] == .seq and args[2].seq.type == .tuple and
args[3] == .seq and args[3].seq.type == .tuple)
{
return true;
} else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2");
@ -119,7 +119,7 @@ pub const PickleData = struct {
var result: [zml.Tensor.MAX_RANK]i64 = undefined;
for (values, result[0..values.len]) |val, *elem| {
switch (val) {
.int => |int| elem.* = int,
.int64 => |int| elem.* = int,
else => @panic("Bad value for shape item"),
}
}
@ -174,15 +174,15 @@ pub const PickleData = struct {
return switch (v) {
.global => |object| {
if (isTensor(v)) {
const args = object.args[0].seq[1];
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int), args[2].seq, args[3].seq };
const rank = raw_shape[1].len;
const shape = dimsFromValues(raw_shape[1]);
var strides = dimsFromValues(raw_strides[1]);
const args = object.args[0].seq.values;
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq };
const rank = raw_shape.values.len;
const shape = dimsFromValues(raw_shape.values);
var strides = dimsFromValues(raw_strides.values);
const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) {
.seq => |seq| blk: {
const sargs = seq[1];
if (seq[0] == .tuple and
const sargs = seq.values;
if (seq.type == .tuple and
sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and
@ -231,7 +231,7 @@ pub const PickleData = struct {
);
return true;
} else if (basicTypeCheck(v, "torch", "Size")) {
const size = object.args[0].seq[1][0].seq[1];
const size = object.args[0].seq.values[0].seq.values;
const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) {
@ -239,11 +239,11 @@ pub const PickleData = struct {
allocator.free(key);
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int;
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
return true;
} else if (basicTypeCheck(v, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq[1][0].string;
const fraction_str = object.args[0].seq.values[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{
var new_prefix = prefix;
@ -271,8 +271,8 @@ pub const PickleData = struct {
try self.parseValue(allocator, store, prefix, object.member);
for (object.args) |item| {
// if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix)
if (item == .seq and item.seq[0] == .tuple and item.seq[1].len == 2 and item.seq[1][0] == .string) {
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .kv_tuple, item.seq[1] } });
if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) {
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } });
} else try self.parseValue(allocator, store, prefix, item);
}
}
@ -312,102 +312,97 @@ pub const PickleData = struct {
try self.parseValue(allocator, store, prefix, build.args);
},
.pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref),
.seq => |*seq| switch (seq[0]) {
.list, .tuple, .set, .frozen_set => {
const elemCheck = struct {
fn call(comptime T: ValueType) fn (v: Value) bool {
return struct {
fn call(val: Value) bool {
return val == T;
.seq => |seq| {
switch (seq.type) {
.list, .tuple, .set, .frozen_set => {
if (seq.values.len == 0) return;
var valid_slice = true;
switch (seq.values[0]) {
inline .int64, .float64, .boolval => |val0, tag| {
const ItemType = switch (tag) {
.int64 => i64,
.float64 => f64,
.boolval => bool,
else => unreachable,
};
var values: std.ArrayListUnmanaged(ItemType) = .{};
try values.append(allocator, val0);
for (seq.values[1..], 1..) |val, i| {
if (std.meta.activeTag(val) != tag) valid_slice = false;
if (valid_slice) {
try values.append(allocator, @field(val, @tagName(tag)));
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
}
}.call;
}
}.call;
if (seq[1].len > 0 and switch (seq[1][0]) {
inline .int, .bool, .float => |_, tag| utils.allTrue(seq[1][1..], elemCheck(tag)),
else => false,
}) {
const out: []u8 = switch (seq[1][0]) {
.int => blk: {
const d = try allocator.alloc(i64, seq[1].len);
for (seq[1], 0..) |item, i| {
d[i] = item.int;
if (valid_slice) {
try store._metadata.put(
allocator,
try allocator.dupe(u8, prefix.items),
.{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } },
);
} else {
for (values.items, 0..) |val, i| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val));
}
}
break :blk std.mem.sliceAsBytes(d);
},
.float => blk: {
const d = try allocator.alloc(f64, seq[1].len);
for (seq[1], 0..) |item, i| {
d[i] = item.float;
else => {
for (seq.values, 0..) |item, i| {
var new_prefix = prefix;
if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
}
break :blk std.mem.sliceAsBytes(d);
},
else => blk: {
const d = try allocator.alloc(bool, seq[1].len);
for (seq[1], 0..) |item, i| {
d[i] = item.bool;
}
break :blk std.mem.sliceAsBytes(d);
},
};
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
allocator.free(out);
} else d.value_ptr.* = @unionInit(zml.aio.Value, "array", .{ .item_type = switch (seq[1][0]) {
.int => .int64,
.float => .float64,
.string => .string,
else => .boolval,
}, .data = out });
} else {
for (seq[1], 0..) |item, i| {
var new_prefix = prefix;
if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
}
}
},
.dict => {
for (seq[1]) |item| {
},
.dict => for (seq.values) |item| {
try self.parseValue(allocator, store, prefix, item);
}
},
.kv_tuple => {
const key = seq[1][0];
const val = seq[1][1];
switch (key) {
.string => |s| {
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
try self.parseValue(allocator, store, prefix, val);
} else {
},
.kv_tuple => {
const key, const val = seq.values[0..2].*;
switch (key) {
.string => |s| {
// Handle Pytorch specific fields
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
try self.parseValue(allocator, store, prefix, val);
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
try self.parseValue(allocator, store, new_prefix, val);
}
},
.int64 => |int| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
},
.int => |int| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
},
inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
}
},
},
inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
}
},
}
},
.bytes => |val| {
const key = try allocator.dupe(u8, prefix.items);
@ -417,18 +412,13 @@ pub const PickleData = struct {
allocator.free(key);
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } };
},
inline .float, .int, .bool, .bigint, .string => |val, tag| {
inline .float64, .int64, .boolval, .bigint, .string => |val, tag| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else d.value_ptr.* = @unionInit(zml.aio.Value, switch (tag) {
.int => "int64",
.float => "float64",
.bool => "boolval",
else => @tagName(tag),
}, val);
} else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val);
},
else => {},
}

View File

@ -81,7 +81,7 @@ pub const PickleMemo = struct {
}
},
.seq => |*v| {
for (v[1]) |*item| {
for (v.values) |*item| {
if (item.containsRef()) {
item.* = try self.resolve(allocator, item.*, recursive);
}
@ -148,7 +148,7 @@ pub const PickleMemo = struct {
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values),
) },
.seq => |v| .{ .seq = .{ v[0], try self.resolveAllRefsIter(allocator, depth + 1, v[1], fix_values) } },
.seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } },
.pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) },
else => try val.clone(allocator),
};
@ -252,7 +252,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const kv = try alloc.alloc(Value, 2);
kv[0] = items[idx];
kv[1] = items[idx + 1];
kv_items.appendAssumeCapacity(.{ .seq = .{ .kv_tuple, kv } });
kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } });
}
return kv_items.toOwnedSlice();
}
@ -283,19 +283,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const member = try memo.resolve(allocator, try stack.pop(), true);
break :blk .{ .build = try Build.init(allocator, member, args) };
}),
.empty_dict => try stack.values.append(.{ .seq = .{ .dict, &[_]Value{} } }),
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
.empty_list => try stack.values.append(.{ .seq = .{ .list, &[_]Value{} } }),
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.binput, .long_binput => |v| {
try memo.insert(v, try stack.pop());
try stack.values.append(.{ .ref = v });
},
.tuple => try stack.values.append(blk: {
const popped = try stack.popMark(allocator);
break :blk .{ .seq = .{ .tuple, popped } };
break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
}),
.empty_tuple => try stack.values.append(.{ .seq = .{ .tuple, &[_]Value{} } }),
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.setitem => {
const v, const k = .{ try stack.pop(), try stack.pop() };
const top = try stack.lastMut();
@ -303,11 +303,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } };
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
},
.seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } };
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
},
else => {
return error.BadStackTopForSetItem;
@ -323,11 +323,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, kv_items } };
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
.seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, kv_items } };
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
else => {
defer allocator.free(kv_items);
@ -339,17 +339,17 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
.tuple1 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 1);
tup_values[0] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } };
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple2 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } };
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple3 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } };
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.append => {
const v = try stack.pop();
@ -361,8 +361,8 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
obj.args[obj.args.len - 1] = v;
},
.seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
tup[1][tup[1].len - 1] = v;
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = v;
},
else => {
return error.BadStackTopForAppend;
@ -381,9 +381,9 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup[1].len;
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len);
@memcpy(tup[1][tup_len..], postmark);
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
return error.BadStackTopForAppends;
@ -394,12 +394,12 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const popped = try stack.popMark(allocator);
defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped);
break :blk .{ .seq = .{ .dict, kv_items } };
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
}),
.list => try stack.values.append(.{ .seq = .{ .list, try stack.popMark(allocator) } }),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .tuple, tup_items } }, try stack.popMark(allocator)) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
}),
.obj => try stack.values.append(blk: {
const markidx = try stack.findMark();
@ -418,7 +418,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
args[0] = try stack.pop();
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
}),
.empty_set => try stack.values.append(.{ .seq = .{ .set, &[_]Value{} } }),
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.additems => {
const postmark = try stack.popMark(allocator);
defer allocator.free(postmark);
@ -431,19 +431,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup[1].len;
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len);
@memcpy(tup[1][tup_len..], postmark);
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
return error.BadStackTopForSetItem;
},
}
},
.frozenset => try stack.values.append(.{ .seq = .{ .frozen_set, try stack.popMark(allocator) } }),
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
.newobj_ex => try stack.values.append(blk: {
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ args, kwargs }) };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
}),
.stack_global => try stack.values.append(blk: {
@ -451,7 +451,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
try memo.resolve(allocator, try stack.pop(), true),
try memo.resolve(allocator, try stack.pop(), true),
};
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ gn, mn }) };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
}),
.memoize => {

View File

@ -388,6 +388,7 @@ const TarStream = struct {
};
test "Read pickle (simple)" {
const Value = @import("value.zig").Value;
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
@ -402,64 +403,36 @@ test "Read pickle (simple)" {
try testing.expect(vals.stack.len == 2);
// skip first value (frame)
try testing.expect(vals.stack[1] == .seq);
try testing.expect(vals.stack[1].seq[0] == .dict);
const entries = vals.stack[1].seq[1][0].seq[1];
try testing.expect(vals.stack[1].seq.type == .dict);
const entries = vals.stack[1].seq.values[0].seq.values;
try testing.expect(entries.len == 5);
for (entries, 0..) |kv, i| {
try testing.expect(kv == .seq);
try testing.expect(kv.seq[0] == .kv_tuple);
switch (i) {
0 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("hello", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .string);
try testing.expectEqualStrings("world", value.string);
},
1 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("int", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .int);
try testing.expect(value.int == 1);
},
2 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("float", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .float);
try testing.expectEqual(@as(f64, 3.141592), value.float);
},
3 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("list", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .seq);
try testing.expect(value.seq[0] == .list);
for (value.seq[1], 0..) |item, j| {
try testing.expect(item == .int);
try testing.expect(item.int == @as(i64, @intCast(j)));
}
},
4 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("tuple", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .seq);
try testing.expect(value.seq[0] == .tuple);
try testing.expect(value.seq[1][0] == .string);
try testing.expectEqualStrings("a", value.seq[1][0].string);
try testing.expect(value.seq[1][1] == .int);
try testing.expect(value.seq[1][1].int == 10);
},
else => unreachable,
}
}
const expected: []const Value = &.{
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
.{ .int64 = 0 },
.{ .int64 = 1 },
.{ .int64 = 2 },
.{ .int64 = 3 },
.{ .int64 = 4 },
})) } },
})) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "tuple" },
.{ .seq = .{
.type = .tuple,
.values = @constCast(@as([]const Value, &.{
.{ .string = "a" },
.{ .int64 = 10 },
})),
} },
})) } },
};
try std.testing.expectEqualDeep(expected, entries);
}
test "Read pickle (zipped)" {

View File

@ -65,7 +65,10 @@ pub const Build = struct {
}
};
pub const Sequence = struct { SequenceType, []Value };
pub const Sequence = struct {
type: SequenceType,
values: []Value,
};
pub const PersId = struct {
allocator: std.mem.Allocator,
@ -100,11 +103,11 @@ pub const ValueType = enum {
seq,
string,
bytes,
int,
int64,
bigint,
float,
float64,
raw_num,
bool,
boolval,
none,
};
@ -160,7 +163,7 @@ pub const Value = union(ValueType) {
/// An integer, but not the crazy kind that comes as a string
/// that has to be parsed. You can look in `Value.raw_num` for
/// those.
int: i64,
int64: i64,
/// An integer that can't fit in i64.
bigint: big_int.Managed,
@ -168,13 +171,13 @@ pub const Value = union(ValueType) {
/// An float, but not the crazy kind that comes as a string
/// that has to be parsed. You can look in `Value.raw_num` for
/// those.
float: f64,
float64: f64,
/// Some kind of weird number we can't handle.
raw_num: PickleOp,
/// A boolean value.
bool: bool,
boolval: bool,
/// Python `None`.
none: void,
@ -184,8 +187,8 @@ pub const Value = union(ValueType) {
.raw, .raw_num => |v| v.deinit(allocator),
inline .app, .object, .global, .build, .pers_id => |v| v.deinit(),
.seq => |v| {
for (v[1]) |*val| val.deinit(allocator);
allocator.free(v[1]);
for (v.values) |*val| val.deinit(allocator);
allocator.free(v.values);
},
.string, .bytes => |v| allocator.free(v),
.bigint => self.bigint.deinit(),
@ -205,7 +208,7 @@ pub const Value = union(ValueType) {
try writeIndents(indents + 1, writer);
try writer.print(".{s} = ", .{@tagName(std.meta.activeTag(value))});
switch (value) {
inline .ref, .int, .float => |v| try writer.print("{d} ", .{v}),
inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}),
.app, .object, .global => |v| {
try writer.writeAll(".{\n");
try internalFormat(v.member, indents + 2, writer);
@ -242,13 +245,13 @@ pub const Value = union(ValueType) {
.seq => |v| {
try writer.writeAll(".{\n");
try writeIndents(indents + 2, writer);
try writer.print(".{s},\n", .{@tagName(v[0])});
try writer.print(".{s},\n", .{@tagName(v.type)});
try writeIndents(indents + 2, writer);
if (v[1].len > 0) {
if (v.values.len > 0) {
try writer.writeAll(".{\n");
for (v[1], 0..) |arg, i| {
for (v.values, 0..) |arg, i| {
try internalFormat(arg, indents + 3, writer);
if (i < v[1].len - 1) try writer.writeAll(",");
if (i < v.values.len - 1) try writer.writeAll(",");
try writer.writeByte('\n');
}
try writeIndents(indents + 2, writer);
@ -283,8 +286,8 @@ pub const Value = union(ValueType) {
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
.seq => |seq| blk: {
const new_val: Sequence = .{ seq[0], try allocator.alloc(Value, seq[1].len) };
for (seq[1], 0..) |v, i| new_val[1][i] = try v.clone(allocator);
const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) };
for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator);
break :blk .{ .seq = new_val };
},
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
@ -295,8 +298,8 @@ pub const Value = union(ValueType) {
pub fn isPrimitive(self: Value) bool {
return switch (self) {
.int, .bigint, .float, .string, .bytes, .bool, .none => true,
.seq => |seq| utils.allTrue(seq[1], Value.isPrimitive),
.int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
.seq => |seq| utils.allTrue(seq.values, Value.isPrimitive),
else => false,
};
}
@ -316,7 +319,7 @@ pub const Value = union(ValueType) {
},
.pers_id => |v| return v.ref.containsRef(),
.seq => |v| {
for (v[1]) |val| if (val.containsRef()) return true;
for (v.values) |val| if (val.containsRef()) return true;
return false;
},
else => return false,
@ -336,7 +339,7 @@ pub const Value = union(ValueType) {
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
return switch (self) {
.raw => |raw_val| switch (raw_val) {
.binint, .binint1, .binint2 => |val| .{ .int = val },
.binint, .binint1, .binint2 => |val| .{ .int64 = val },
.long1, .long4 => |b| if (b.len != 0) {
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
var mutable = bint.toMutable();
@ -345,10 +348,10 @@ pub const Value = union(ValueType) {
const max_comp = bint.toConst().order(BI64MAX);
if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) {
defer bint.deinit();
return .{ .int = try bint.to(i64) };
return .{ .int64 = try bint.to(i64) };
} else return .{ .bigint = bint };
} else .{ .raw_num = raw_val },
.binfloat => |val| .{ .float = val },
.binfloat => |val| .{ .float64 = val },
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s },
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b },
// This isn't how Pickle actually works but we just try to UTF8 decode the
@ -356,17 +359,17 @@ pub const Value = union(ValueType) {
// actually cares they can just fix values themselves or recover the raw bytes
// from the UTF8 string (it's guaranteed to be reversible, as far as I know).
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
.newtrue => .{ .bool = true },
.newfalse => .{ .bool = false },
.newtrue => .{ .boolval = true },
.newfalse => .{ .boolval = false },
.none => .{ .none = {} },
inline .int,
.float,
.long,
=> |v, tag| {
if (tag == .int and std.mem.eql(u8, v, "01")) {
return .{ .bool = true };
return .{ .boolval = true };
} else if (tag == .int and std.mem.eql(u8, v, "00")) {
return .{ .bool = false };
return .{ .boolval = false };
} else {
return .{ .raw_num = raw_val };
}
@ -389,8 +392,8 @@ pub const Value = union(ValueType) {
v.ref = try v.ref.coerceFromRaw(allocator);
break :blk self;
},
.seq => |*v| blk: {
for (v[1]) |*val| {
.seq => |v| blk: {
for (v.values) |*val| {
val.* = try val.coerceFromRaw(allocator);
}
break :blk self;