aio: refactor PyTorch model parsing for better readability and optimize slice handling

This commit is contained in:
Tarry Singh 2023-01-25 12:16:27 +00:00
parent ebdb8db213
commit 5e1688cbfd
4 changed files with 189 additions and 223 deletions

View File

@ -101,12 +101,12 @@ pub const PickleData = struct {
fn isTensor(v: Value) bool { fn isTensor(v: Value) bool {
if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) { if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) {
const args = v.global.args[0].seq[1]; const args = v.global.args[0].seq.values;
if (args.len >= 5 and if (args.len >= 5 and
args[0] == .pers_id and args[0] == .pers_id and
args[1] == .int and args[1] == .int64 and
args[2] == .seq and args[2].seq[0] == .tuple and args[2] == .seq and args[2].seq.type == .tuple and
args[3] == .seq and args[3].seq[0] == .tuple) args[3] == .seq and args[3].seq.type == .tuple)
{ {
return true; return true;
} else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2"); } else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2");
@ -119,7 +119,7 @@ pub const PickleData = struct {
var result: [zml.Tensor.MAX_RANK]i64 = undefined; var result: [zml.Tensor.MAX_RANK]i64 = undefined;
for (values, result[0..values.len]) |val, *elem| { for (values, result[0..values.len]) |val, *elem| {
switch (val) { switch (val) {
.int => |int| elem.* = int, .int64 => |int| elem.* = int,
else => @panic("Bad value for shape item"), else => @panic("Bad value for shape item"),
} }
} }
@ -174,15 +174,15 @@ pub const PickleData = struct {
return switch (v) { return switch (v) {
.global => |object| { .global => |object| {
if (isTensor(v)) { if (isTensor(v)) {
const args = object.args[0].seq[1]; const args = object.args[0].seq.values;
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int), args[2].seq, args[3].seq }; const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq };
const rank = raw_shape[1].len; const rank = raw_shape.values.len;
const shape = dimsFromValues(raw_shape[1]); const shape = dimsFromValues(raw_shape.values);
var strides = dimsFromValues(raw_strides[1]); var strides = dimsFromValues(raw_strides.values);
const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) { const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) {
.seq => |seq| blk: { .seq => |seq| blk: {
const sargs = seq[1]; const sargs = seq.values;
if (seq[0] == .tuple and if (seq.type == .tuple and
sargs.len >= 5 and sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and sargs[1] == .raw and sargs[1].raw == .global and
@ -231,7 +231,7 @@ pub const PickleData = struct {
); );
return true; return true;
} else if (basicTypeCheck(v, "torch", "Size")) { } else if (basicTypeCheck(v, "torch", "Size")) {
const size = object.args[0].seq[1][0].seq[1]; const size = object.args[0].seq.values[0].seq.values;
const key = try allocator.dupe(u8, prefix.items); const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key); const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) { if (entry.found_existing) {
@ -239,11 +239,11 @@ pub const PickleData = struct {
allocator.free(key); allocator.free(key);
} }
const d = try allocator.alloc(i64, size.len); const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int; for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
return true; return true;
} else if (basicTypeCheck(v, "fractions", "Fraction")) { } else if (basicTypeCheck(v, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq[1][0].string; const fraction_str = object.args[0].seq.values[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{ {
var new_prefix = prefix; var new_prefix = prefix;
@ -271,8 +271,8 @@ pub const PickleData = struct {
try self.parseValue(allocator, store, prefix, object.member); try self.parseValue(allocator, store, prefix, object.member);
for (object.args) |item| { for (object.args) |item| {
// if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix) // if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix)
if (item == .seq and item.seq[0] == .tuple and item.seq[1].len == 2 and item.seq[1][0] == .string) { if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) {
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .kv_tuple, item.seq[1] } }); try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } });
} else try self.parseValue(allocator, store, prefix, item); } else try self.parseValue(allocator, store, prefix, item);
} }
} }
@ -312,102 +312,97 @@ pub const PickleData = struct {
try self.parseValue(allocator, store, prefix, build.args); try self.parseValue(allocator, store, prefix, build.args);
}, },
.pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref), .pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref),
.seq => |*seq| switch (seq[0]) { .seq => |seq| {
.list, .tuple, .set, .frozen_set => { switch (seq.type) {
const elemCheck = struct { .list, .tuple, .set, .frozen_set => {
fn call(comptime T: ValueType) fn (v: Value) bool { if (seq.values.len == 0) return;
return struct { var valid_slice = true;
fn call(val: Value) bool { switch (seq.values[0]) {
return val == T; inline .int64, .float64, .boolval => |val0, tag| {
const ItemType = switch (tag) {
.int64 => i64,
.float64 => f64,
.boolval => bool,
else => unreachable,
};
var values: std.ArrayListUnmanaged(ItemType) = .{};
try values.append(allocator, val0);
for (seq.values[1..], 1..) |val, i| {
if (std.meta.activeTag(val) != tag) valid_slice = false;
if (valid_slice) {
try values.append(allocator, @field(val, @tagName(tag)));
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
} }
}.call;
}
}.call;
if (seq[1].len > 0 and switch (seq[1][0]) { if (valid_slice) {
inline .int, .bool, .float => |_, tag| utils.allTrue(seq[1][1..], elemCheck(tag)), try store._metadata.put(
else => false, allocator,
}) { try allocator.dupe(u8, prefix.items),
const out: []u8 = switch (seq[1][0]) { .{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } },
.int => blk: { );
const d = try allocator.alloc(i64, seq[1].len); } else {
for (seq[1], 0..) |item, i| { for (values.items, 0..) |val, i| {
d[i] = item.int; var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val));
}
} }
break :blk std.mem.sliceAsBytes(d);
}, },
.float => blk: { else => {
const d = try allocator.alloc(f64, seq[1].len); for (seq.values, 0..) |item, i| {
for (seq[1], 0..) |item, i| { var new_prefix = prefix;
d[i] = item.float; if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
} }
break :blk std.mem.sliceAsBytes(d);
}, },
else => blk: {
const d = try allocator.alloc(bool, seq[1].len);
for (seq[1], 0..) |item, i| {
d[i] = item.bool;
}
break :blk std.mem.sliceAsBytes(d);
},
};
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
allocator.free(out);
} else d.value_ptr.* = @unionInit(zml.aio.Value, "array", .{ .item_type = switch (seq[1][0]) {
.int => .int64,
.float => .float64,
.string => .string,
else => .boolval,
}, .data = out });
} else {
for (seq[1], 0..) |item, i| {
var new_prefix = prefix;
if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
} }
} },
}, .dict => for (seq.values) |item| {
.dict => {
for (seq[1]) |item| {
try self.parseValue(allocator, store, prefix, item); try self.parseValue(allocator, store, prefix, item);
} },
}, .kv_tuple => {
.kv_tuple => { const key, const val = seq.values[0..2].*;
const key = seq[1][0]; switch (key) {
const val = seq[1][1]; .string => |s| {
switch (key) { // Handle Pytorch specific fields
.string => |s| { if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { try self.parseValue(allocator, store, prefix, val);
try self.parseValue(allocator, store, prefix, val); } else {
} else { var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
try self.parseValue(allocator, store, new_prefix, val);
}
},
.int64 => |int| {
var new_prefix = prefix; var new_prefix = prefix;
if (prefix.items.len > 0) { if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.appendSliceAssumeCapacity(s); new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val); try self.parseValue(allocator, store, new_prefix, val);
} },
}, inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
.int => |int| { }
var new_prefix = prefix; },
if (prefix.items.len > 0) { }
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
},
inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
}
},
}, },
.bytes => |val| { .bytes => |val| {
const key = try allocator.dupe(u8, prefix.items); const key = try allocator.dupe(u8, prefix.items);
@ -417,18 +412,13 @@ pub const PickleData = struct {
allocator.free(key); allocator.free(key);
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } }; } else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } };
}, },
inline .float, .int, .bool, .bigint, .string => |val, tag| { inline .float64, .int64, .boolval, .bigint, .string => |val, tag| {
const key = try allocator.dupe(u8, prefix.items); const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key); const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) { if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items}); log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key); allocator.free(key);
} else d.value_ptr.* = @unionInit(zml.aio.Value, switch (tag) { } else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val);
.int => "int64",
.float => "float64",
.bool => "boolval",
else => @tagName(tag),
}, val);
}, },
else => {}, else => {},
} }

View File

@ -81,7 +81,7 @@ pub const PickleMemo = struct {
} }
}, },
.seq => |*v| { .seq => |*v| {
for (v[1]) |*item| { for (v.values) |*item| {
if (item.containsRef()) { if (item.containsRef()) {
item.* = try self.resolve(allocator, item.*, recursive); item.* = try self.resolve(allocator, item.*, recursive);
} }
@ -148,7 +148,7 @@ pub const PickleMemo = struct {
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values),
) }, ) },
.seq => |v| .{ .seq = .{ v[0], try self.resolveAllRefsIter(allocator, depth + 1, v[1], fix_values) } }, .seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } },
.pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) }, .pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) },
else => try val.clone(allocator), else => try val.clone(allocator),
}; };
@ -252,7 +252,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const kv = try alloc.alloc(Value, 2); const kv = try alloc.alloc(Value, 2);
kv[0] = items[idx]; kv[0] = items[idx];
kv[1] = items[idx + 1]; kv[1] = items[idx + 1];
kv_items.appendAssumeCapacity(.{ .seq = .{ .kv_tuple, kv } }); kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } });
} }
return kv_items.toOwnedSlice(); return kv_items.toOwnedSlice();
} }
@ -283,19 +283,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const member = try memo.resolve(allocator, try stack.pop(), true); const member = try memo.resolve(allocator, try stack.pop(), true);
break :blk .{ .build = try Build.init(allocator, member, args) }; break :blk .{ .build = try Build.init(allocator, member, args) };
}), }),
.empty_dict => try stack.values.append(.{ .seq = .{ .dict, &[_]Value{} } }), .empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }), .get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }), inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
.empty_list => try stack.values.append(.{ .seq = .{ .list, &[_]Value{} } }), .empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.binput, .long_binput => |v| { .binput, .long_binput => |v| {
try memo.insert(v, try stack.pop()); try memo.insert(v, try stack.pop());
try stack.values.append(.{ .ref = v }); try stack.values.append(.{ .ref = v });
}, },
.tuple => try stack.values.append(blk: { .tuple => try stack.values.append(blk: {
const popped = try stack.popMark(allocator); const popped = try stack.popMark(allocator);
break :blk .{ .seq = .{ .tuple, popped } }; break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
}), }),
.empty_tuple => try stack.values.append(.{ .seq = .{ .tuple, &[_]Value{} } }), .empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.setitem => { .setitem => {
const v, const k = .{ try stack.pop(), try stack.pop() }; const v, const k = .{ try stack.pop(), try stack.pop() };
const top = try stack.lastMut(); const top = try stack.lastMut();
@ -303,11 +303,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } }; obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
}, },
.seq => |*tup| { .seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } }; tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
}, },
else => { else => {
return error.BadStackTopForSetItem; return error.BadStackTopForSetItem;
@ -323,11 +323,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, kv_items } }; obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
}, },
.seq => |*tup| { .seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, kv_items } }; tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
}, },
else => { else => {
defer allocator.free(kv_items); defer allocator.free(kv_items);
@ -339,17 +339,17 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
.tuple1 => try stack.values.append(blk: { .tuple1 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 1); const tup_values = try allocator.alloc(Value, 1);
tup_values[0] = try stack.pop(); tup_values[0] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple2 => try stack.values.append(blk: { .tuple2 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 2); const tup_values = try allocator.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple3 => try stack.values.append(blk: { .tuple3 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 3); const tup_values = try allocator.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .tuple, tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.append => { .append => {
const v = try stack.pop(); const v = try stack.pop();
@ -361,8 +361,8 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
obj.args[obj.args.len - 1] = v; obj.args[obj.args.len - 1] = v;
}, },
.seq => |*tup| { .seq => |*tup| {
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup[1][tup[1].len - 1] = v; tup.values[tup.values.len - 1] = v;
}, },
else => { else => {
return error.BadStackTopForAppend; return error.BadStackTopForAppend;
@ -381,9 +381,9 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
@memcpy(obj.args[obj_len..], postmark); @memcpy(obj.args[obj_len..], postmark);
}, },
.seq => |*tup| { .seq => |*tup| {
const tup_len = tup[1].len; const tup_len = tup.values.len;
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len); tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup[1][tup_len..], postmark); @memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
return error.BadStackTopForAppends; return error.BadStackTopForAppends;
@ -394,12 +394,12 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
const popped = try stack.popMark(allocator); const popped = try stack.popMark(allocator);
defer allocator.free(popped); defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped); const kv_items = try makeKVList(allocator, popped);
break :blk .{ .seq = .{ .dict, kv_items } }; break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
}), }),
.list => try stack.values.append(.{ .seq = .{ .list, try stack.popMark(allocator) } }), .list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: { .inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } }); const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .tuple, tup_items } }, try stack.popMark(allocator)) }; break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
}), }),
.obj => try stack.values.append(blk: { .obj => try stack.values.append(blk: {
const markidx = try stack.findMark(); const markidx = try stack.findMark();
@ -418,7 +418,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
args[0] = try stack.pop(); args[0] = try stack.pop();
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) }; break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
}), }),
.empty_set => try stack.values.append(.{ .seq = .{ .set, &[_]Value{} } }), .empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.additems => { .additems => {
const postmark = try stack.popMark(allocator); const postmark = try stack.popMark(allocator);
defer allocator.free(postmark); defer allocator.free(postmark);
@ -431,19 +431,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
@memcpy(obj.args[obj_len..], postmark); @memcpy(obj.args[obj_len..], postmark);
}, },
.seq => |*tup| { .seq => |*tup| {
const tup_len = tup[1].len; const tup_len = tup.values.len;
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len); tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup[1][tup_len..], postmark); @memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
return error.BadStackTopForSetItem; return error.BadStackTopForSetItem;
}, },
} }
}, },
.frozenset => try stack.values.append(.{ .seq = .{ .frozen_set, try stack.popMark(allocator) } }), .frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
.newobj_ex => try stack.values.append(blk: { .newobj_ex => try stack.values.append(blk: {
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() }; const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ args, kwargs }) }; const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) }; break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
}), }),
.stack_global => try stack.values.append(blk: { .stack_global => try stack.values.append(blk: {
@ -451,7 +451,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
try memo.resolve(allocator, try stack.pop(), true), try memo.resolve(allocator, try stack.pop(), true),
try memo.resolve(allocator, try stack.pop(), true), try memo.resolve(allocator, try stack.pop(), true),
}; };
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ gn, mn }) }; const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) }; break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
}), }),
.memoize => { .memoize => {

View File

@ -388,6 +388,7 @@ const TarStream = struct {
}; };
test "Read pickle (simple)" { test "Read pickle (simple)" {
const Value = @import("value.zig").Value;
var arena = std.heap.ArenaAllocator.init(testing.allocator); var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit(); defer arena.deinit();
const allocator = arena.allocator(); const allocator = arena.allocator();
@ -402,64 +403,36 @@ test "Read pickle (simple)" {
try testing.expect(vals.stack.len == 2); try testing.expect(vals.stack.len == 2);
// skip first value (frame) // skip first value (frame)
try testing.expect(vals.stack[1] == .seq); try testing.expect(vals.stack[1] == .seq);
try testing.expect(vals.stack[1].seq[0] == .dict); try testing.expect(vals.stack[1].seq.type == .dict);
const entries = vals.stack[1].seq[1][0].seq[1]; const entries = vals.stack[1].seq.values[0].seq.values;
try testing.expect(entries.len == 5); try testing.expect(entries.len == 5);
for (entries, 0..) |kv, i| { const expected: []const Value = &.{
try testing.expect(kv == .seq); .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
try testing.expect(kv.seq[0] == .kv_tuple); .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
switch (i) { .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
0 => { .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
const key = kv.seq[1][0]; .{ .string = "list" },
try testing.expect(key == .string); .{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
try testing.expectEqualStrings("hello", key.string); .{ .int64 = 0 },
const value = kv.seq[1][1]; .{ .int64 = 1 },
try testing.expect(value == .string); .{ .int64 = 2 },
try testing.expectEqualStrings("world", value.string); .{ .int64 = 3 },
}, .{ .int64 = 4 },
1 => { })) } },
const key = kv.seq[1][0]; })) } },
try testing.expect(key == .string); .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
try testing.expectEqualStrings("int", key.string); .{ .string = "tuple" },
const value = kv.seq[1][1]; .{ .seq = .{
try testing.expect(value == .int); .type = .tuple,
try testing.expect(value.int == 1); .values = @constCast(@as([]const Value, &.{
}, .{ .string = "a" },
2 => { .{ .int64 = 10 },
const key = kv.seq[1][0]; })),
try testing.expect(key == .string); } },
try testing.expectEqualStrings("float", key.string); })) } },
const value = kv.seq[1][1]; };
try testing.expect(value == .float);
try testing.expectEqual(@as(f64, 3.141592), value.float); try std.testing.expectEqualDeep(expected, entries);
},
3 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("list", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .seq);
try testing.expect(value.seq[0] == .list);
for (value.seq[1], 0..) |item, j| {
try testing.expect(item == .int);
try testing.expect(item.int == @as(i64, @intCast(j)));
}
},
4 => {
const key = kv.seq[1][0];
try testing.expect(key == .string);
try testing.expectEqualStrings("tuple", key.string);
const value = kv.seq[1][1];
try testing.expect(value == .seq);
try testing.expect(value.seq[0] == .tuple);
try testing.expect(value.seq[1][0] == .string);
try testing.expectEqualStrings("a", value.seq[1][0].string);
try testing.expect(value.seq[1][1] == .int);
try testing.expect(value.seq[1][1].int == 10);
},
else => unreachable,
}
}
} }
test "Read pickle (zipped)" { test "Read pickle (zipped)" {

View File

@ -65,7 +65,10 @@ pub const Build = struct {
} }
}; };
pub const Sequence = struct { SequenceType, []Value }; pub const Sequence = struct {
type: SequenceType,
values: []Value,
};
pub const PersId = struct { pub const PersId = struct {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
@ -100,11 +103,11 @@ pub const ValueType = enum {
seq, seq,
string, string,
bytes, bytes,
int, int64,
bigint, bigint,
float, float64,
raw_num, raw_num,
bool, boolval,
none, none,
}; };
@ -160,7 +163,7 @@ pub const Value = union(ValueType) {
/// An integer, but not the crazy kind that comes as a string /// An integer, but not the crazy kind that comes as a string
/// that has to be parsed. You can look in `Value.raw_num` for /// that has to be parsed. You can look in `Value.raw_num` for
/// those. /// those.
int: i64, int64: i64,
/// An integer that can't fit in i64. /// An integer that can't fit in i64.
bigint: big_int.Managed, bigint: big_int.Managed,
@ -168,13 +171,13 @@ pub const Value = union(ValueType) {
/// An float, but not the crazy kind that comes as a string /// An float, but not the crazy kind that comes as a string
/// that has to be parsed. You can look in `Value.raw_num` for /// that has to be parsed. You can look in `Value.raw_num` for
/// those. /// those.
float: f64, float64: f64,
/// Some kind of weird number we can't handle. /// Some kind of weird number we can't handle.
raw_num: PickleOp, raw_num: PickleOp,
/// A boolean value. /// A boolean value.
bool: bool, boolval: bool,
/// Python `None`. /// Python `None`.
none: void, none: void,
@ -184,8 +187,8 @@ pub const Value = union(ValueType) {
.raw, .raw_num => |v| v.deinit(allocator), .raw, .raw_num => |v| v.deinit(allocator),
inline .app, .object, .global, .build, .pers_id => |v| v.deinit(), inline .app, .object, .global, .build, .pers_id => |v| v.deinit(),
.seq => |v| { .seq => |v| {
for (v[1]) |*val| val.deinit(allocator); for (v.values) |*val| val.deinit(allocator);
allocator.free(v[1]); allocator.free(v.values);
}, },
.string, .bytes => |v| allocator.free(v), .string, .bytes => |v| allocator.free(v),
.bigint => self.bigint.deinit(), .bigint => self.bigint.deinit(),
@ -205,7 +208,7 @@ pub const Value = union(ValueType) {
try writeIndents(indents + 1, writer); try writeIndents(indents + 1, writer);
try writer.print(".{s} = ", .{@tagName(std.meta.activeTag(value))}); try writer.print(".{s} = ", .{@tagName(std.meta.activeTag(value))});
switch (value) { switch (value) {
inline .ref, .int, .float => |v| try writer.print("{d} ", .{v}), inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}),
.app, .object, .global => |v| { .app, .object, .global => |v| {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try internalFormat(v.member, indents + 2, writer); try internalFormat(v.member, indents + 2, writer);
@ -242,13 +245,13 @@ pub const Value = union(ValueType) {
.seq => |v| { .seq => |v| {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try writeIndents(indents + 2, writer); try writeIndents(indents + 2, writer);
try writer.print(".{s},\n", .{@tagName(v[0])}); try writer.print(".{s},\n", .{@tagName(v.type)});
try writeIndents(indents + 2, writer); try writeIndents(indents + 2, writer);
if (v[1].len > 0) { if (v.values.len > 0) {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
for (v[1], 0..) |arg, i| { for (v.values, 0..) |arg, i| {
try internalFormat(arg, indents + 3, writer); try internalFormat(arg, indents + 3, writer);
if (i < v[1].len - 1) try writer.writeAll(","); if (i < v.values.len - 1) try writer.writeAll(",");
try writer.writeByte('\n'); try writer.writeByte('\n');
} }
try writeIndents(indents + 2, writer); try writeIndents(indents + 2, writer);
@ -283,8 +286,8 @@ pub const Value = union(ValueType) {
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
.seq => |seq| blk: { .seq => |seq| blk: {
const new_val: Sequence = .{ seq[0], try allocator.alloc(Value, seq[1].len) }; const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) };
for (seq[1], 0..) |v, i| new_val[1][i] = try v.clone(allocator); for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator);
break :blk .{ .seq = new_val }; break :blk .{ .seq = new_val };
}, },
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
@ -295,8 +298,8 @@ pub const Value = union(ValueType) {
pub fn isPrimitive(self: Value) bool { pub fn isPrimitive(self: Value) bool {
return switch (self) { return switch (self) {
.int, .bigint, .float, .string, .bytes, .bool, .none => true, .int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
.seq => |seq| utils.allTrue(seq[1], Value.isPrimitive), .seq => |seq| utils.allTrue(seq.values, Value.isPrimitive),
else => false, else => false,
}; };
} }
@ -316,7 +319,7 @@ pub const Value = union(ValueType) {
}, },
.pers_id => |v| return v.ref.containsRef(), .pers_id => |v| return v.ref.containsRef(),
.seq => |v| { .seq => |v| {
for (v[1]) |val| if (val.containsRef()) return true; for (v.values) |val| if (val.containsRef()) return true;
return false; return false;
}, },
else => return false, else => return false,
@ -336,7 +339,7 @@ pub const Value = union(ValueType) {
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value { pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
return switch (self) { return switch (self) {
.raw => |raw_val| switch (raw_val) { .raw => |raw_val| switch (raw_val) {
.binint, .binint1, .binint2 => |val| .{ .int = val }, .binint, .binint1, .binint2 => |val| .{ .int64 = val },
.long1, .long4 => |b| if (b.len != 0) { .long1, .long4 => |b| if (b.len != 0) {
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
var mutable = bint.toMutable(); var mutable = bint.toMutable();
@ -345,10 +348,10 @@ pub const Value = union(ValueType) {
const max_comp = bint.toConst().order(BI64MAX); const max_comp = bint.toConst().order(BI64MAX);
if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) { if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) {
defer bint.deinit(); defer bint.deinit();
return .{ .int = try bint.to(i64) }; return .{ .int64 = try bint.to(i64) };
} else return .{ .bigint = bint }; } else return .{ .bigint = bint };
} else .{ .raw_num = raw_val }, } else .{ .raw_num = raw_val },
.binfloat => |val| .{ .float = val }, .binfloat => |val| .{ .float64 = val },
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s }, .binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s },
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b }, .binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b },
// This isn't how Pickle actually works but we just try to UTF8 decode the // This isn't how Pickle actually works but we just try to UTF8 decode the
@ -356,17 +359,17 @@ pub const Value = union(ValueType) {
// actually cares they can just fix values themselves or recover the raw bytes // actually cares they can just fix values themselves or recover the raw bytes
// from the UTF8 string (it's guaranteed to be reversible, as far as I know). // from the UTF8 string (it's guaranteed to be reversible, as far as I know).
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, .binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
.newtrue => .{ .bool = true }, .newtrue => .{ .boolval = true },
.newfalse => .{ .bool = false }, .newfalse => .{ .boolval = false },
.none => .{ .none = {} }, .none => .{ .none = {} },
inline .int, inline .int,
.float, .float,
.long, .long,
=> |v, tag| { => |v, tag| {
if (tag == .int and std.mem.eql(u8, v, "01")) { if (tag == .int and std.mem.eql(u8, v, "01")) {
return .{ .bool = true }; return .{ .boolval = true };
} else if (tag == .int and std.mem.eql(u8, v, "00")) { } else if (tag == .int and std.mem.eql(u8, v, "00")) {
return .{ .bool = false }; return .{ .boolval = false };
} else { } else {
return .{ .raw_num = raw_val }; return .{ .raw_num = raw_val };
} }
@ -389,8 +392,8 @@ pub const Value = union(ValueType) {
v.ref = try v.ref.coerceFromRaw(allocator); v.ref = try v.ref.coerceFromRaw(allocator);
break :blk self; break :blk self;
}, },
.seq => |*v| blk: { .seq => |v| blk: {
for (v[1]) |*val| { for (v.values) |*val| {
val.* = try val.coerceFromRaw(allocator); val.* = try val.coerceFromRaw(allocator);
} }
break :blk self; break :blk self;