aio: refactor PyTorch model parsing for better readability and optimize slice handling
This commit is contained in:
parent
ebdb8db213
commit
5e1688cbfd
@ -101,12 +101,12 @@ pub const PickleData = struct {
|
||||
|
||||
fn isTensor(v: Value) bool {
|
||||
if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) {
|
||||
const args = v.global.args[0].seq[1];
|
||||
const args = v.global.args[0].seq.values;
|
||||
if (args.len >= 5 and
|
||||
args[0] == .pers_id and
|
||||
args[1] == .int and
|
||||
args[2] == .seq and args[2].seq[0] == .tuple and
|
||||
args[3] == .seq and args[3].seq[0] == .tuple)
|
||||
args[1] == .int64 and
|
||||
args[2] == .seq and args[2].seq.type == .tuple and
|
||||
args[3] == .seq and args[3].seq.type == .tuple)
|
||||
{
|
||||
return true;
|
||||
} else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2");
|
||||
@ -119,7 +119,7 @@ pub const PickleData = struct {
|
||||
var result: [zml.Tensor.MAX_RANK]i64 = undefined;
|
||||
for (values, result[0..values.len]) |val, *elem| {
|
||||
switch (val) {
|
||||
.int => |int| elem.* = int,
|
||||
.int64 => |int| elem.* = int,
|
||||
else => @panic("Bad value for shape item"),
|
||||
}
|
||||
}
|
||||
@ -174,15 +174,15 @@ pub const PickleData = struct {
|
||||
return switch (v) {
|
||||
.global => |object| {
|
||||
if (isTensor(v)) {
|
||||
const args = object.args[0].seq[1];
|
||||
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int), args[2].seq, args[3].seq };
|
||||
const rank = raw_shape[1].len;
|
||||
const shape = dimsFromValues(raw_shape[1]);
|
||||
var strides = dimsFromValues(raw_strides[1]);
|
||||
const args = object.args[0].seq.values;
|
||||
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq };
|
||||
const rank = raw_shape.values.len;
|
||||
const shape = dimsFromValues(raw_shape.values);
|
||||
var strides = dimsFromValues(raw_strides.values);
|
||||
const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) {
|
||||
.seq => |seq| blk: {
|
||||
const sargs = seq[1];
|
||||
if (seq[0] == .tuple and
|
||||
const sargs = seq.values;
|
||||
if (seq.type == .tuple and
|
||||
sargs.len >= 5 and
|
||||
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
|
||||
sargs[1] == .raw and sargs[1].raw == .global and
|
||||
@ -231,7 +231,7 @@ pub const PickleData = struct {
|
||||
);
|
||||
return true;
|
||||
} else if (basicTypeCheck(v, "torch", "Size")) {
|
||||
const size = object.args[0].seq[1][0].seq[1];
|
||||
const size = object.args[0].seq.values[0].seq.values;
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const entry = try store._metadata.getOrPut(allocator, key);
|
||||
if (entry.found_existing) {
|
||||
@ -239,11 +239,11 @@ pub const PickleData = struct {
|
||||
allocator.free(key);
|
||||
}
|
||||
const d = try allocator.alloc(i64, size.len);
|
||||
for (d, 0..) |*di, i| di.* = size[i].int;
|
||||
for (d, 0..) |*di, i| di.* = size[i].int64;
|
||||
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
|
||||
return true;
|
||||
} else if (basicTypeCheck(v, "fractions", "Fraction")) {
|
||||
const fraction_str = object.args[0].seq[1][0].string;
|
||||
const fraction_str = object.args[0].seq.values[0].string;
|
||||
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
@ -271,8 +271,8 @@ pub const PickleData = struct {
|
||||
try self.parseValue(allocator, store, prefix, object.member);
|
||||
for (object.args) |item| {
|
||||
// if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix)
|
||||
if (item == .seq and item.seq[0] == .tuple and item.seq[1].len == 2 and item.seq[1][0] == .string) {
|
||||
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .kv_tuple, item.seq[1] } });
|
||||
if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) {
|
||||
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } });
|
||||
} else try self.parseValue(allocator, store, prefix, item);
|
||||
}
|
||||
}
|
||||
@ -312,59 +312,54 @@ pub const PickleData = struct {
|
||||
try self.parseValue(allocator, store, prefix, build.args);
|
||||
},
|
||||
.pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref),
|
||||
.seq => |*seq| switch (seq[0]) {
|
||||
.seq => |seq| {
|
||||
switch (seq.type) {
|
||||
.list, .tuple, .set, .frozen_set => {
|
||||
const elemCheck = struct {
|
||||
fn call(comptime T: ValueType) fn (v: Value) bool {
|
||||
return struct {
|
||||
fn call(val: Value) bool {
|
||||
return val == T;
|
||||
}
|
||||
}.call;
|
||||
}
|
||||
}.call;
|
||||
|
||||
if (seq[1].len > 0 and switch (seq[1][0]) {
|
||||
inline .int, .bool, .float => |_, tag| utils.allTrue(seq[1][1..], elemCheck(tag)),
|
||||
else => false,
|
||||
}) {
|
||||
const out: []u8 = switch (seq[1][0]) {
|
||||
.int => blk: {
|
||||
const d = try allocator.alloc(i64, seq[1].len);
|
||||
for (seq[1], 0..) |item, i| {
|
||||
d[i] = item.int;
|
||||
}
|
||||
break :blk std.mem.sliceAsBytes(d);
|
||||
},
|
||||
.float => blk: {
|
||||
const d = try allocator.alloc(f64, seq[1].len);
|
||||
for (seq[1], 0..) |item, i| {
|
||||
d[i] = item.float;
|
||||
}
|
||||
break :blk std.mem.sliceAsBytes(d);
|
||||
},
|
||||
else => blk: {
|
||||
const d = try allocator.alloc(bool, seq[1].len);
|
||||
for (seq[1], 0..) |item, i| {
|
||||
d[i] = item.bool;
|
||||
}
|
||||
break :blk std.mem.sliceAsBytes(d);
|
||||
},
|
||||
if (seq.values.len == 0) return;
|
||||
var valid_slice = true;
|
||||
switch (seq.values[0]) {
|
||||
inline .int64, .float64, .boolval => |val0, tag| {
|
||||
const ItemType = switch (tag) {
|
||||
.int64 => i64,
|
||||
.float64 => f64,
|
||||
.boolval => bool,
|
||||
else => unreachable,
|
||||
};
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const d = try store._metadata.getOrPut(allocator, key);
|
||||
if (d.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
allocator.free(out);
|
||||
} else d.value_ptr.* = @unionInit(zml.aio.Value, "array", .{ .item_type = switch (seq[1][0]) {
|
||||
.int => .int64,
|
||||
.float => .float64,
|
||||
.string => .string,
|
||||
else => .boolval,
|
||||
}, .data = out });
|
||||
var values: std.ArrayListUnmanaged(ItemType) = .{};
|
||||
try values.append(allocator, val0);
|
||||
for (seq.values[1..], 1..) |val, i| {
|
||||
if (std.meta.activeTag(val) != tag) valid_slice = false;
|
||||
if (valid_slice) {
|
||||
try values.append(allocator, @field(val, @tagName(tag)));
|
||||
} else {
|
||||
for (seq[1], 0..) |item, i| {
|
||||
var new_prefix = prefix;
|
||||
if (prefix.items.len > 0) {
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
}
|
||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||
try self.parseValue(allocator, store, new_prefix, val);
|
||||
}
|
||||
}
|
||||
|
||||
if (valid_slice) {
|
||||
try store._metadata.put(
|
||||
allocator,
|
||||
try allocator.dupe(u8, prefix.items),
|
||||
.{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } },
|
||||
);
|
||||
} else {
|
||||
for (values.items, 0..) |val, i| {
|
||||
var new_prefix = prefix;
|
||||
if (prefix.items.len > 0) {
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
}
|
||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val));
|
||||
}
|
||||
}
|
||||
},
|
||||
else => {
|
||||
for (seq.values, 0..) |item, i| {
|
||||
var new_prefix = prefix;
|
||||
if (v.isPrimitive()) {
|
||||
if (prefix.items.len > 0) {
|
||||
@ -374,18 +369,17 @@ pub const PickleData = struct {
|
||||
}
|
||||
try self.parseValue(allocator, store, new_prefix, item);
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
.dict => {
|
||||
for (seq[1]) |item| {
|
||||
.dict => for (seq.values) |item| {
|
||||
try self.parseValue(allocator, store, prefix, item);
|
||||
}
|
||||
},
|
||||
.kv_tuple => {
|
||||
const key = seq[1][0];
|
||||
const val = seq[1][1];
|
||||
const key, const val = seq.values[0..2].*;
|
||||
switch (key) {
|
||||
.string => |s| {
|
||||
// Handle Pytorch specific fields
|
||||
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
|
||||
try self.parseValue(allocator, store, prefix, val);
|
||||
} else {
|
||||
@ -397,7 +391,7 @@ pub const PickleData = struct {
|
||||
try self.parseValue(allocator, store, new_prefix, val);
|
||||
}
|
||||
},
|
||||
.int => |int| {
|
||||
.int64 => |int| {
|
||||
var new_prefix = prefix;
|
||||
if (prefix.items.len > 0) {
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
@ -408,6 +402,7 @@ pub const PickleData = struct {
|
||||
inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
.bytes => |val| {
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
@ -417,18 +412,13 @@ pub const PickleData = struct {
|
||||
allocator.free(key);
|
||||
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } };
|
||||
},
|
||||
inline .float, .int, .bool, .bigint, .string => |val, tag| {
|
||||
inline .float64, .int64, .boolval, .bigint, .string => |val, tag| {
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const d = try store._metadata.getOrPut(allocator, key);
|
||||
if (d.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
} else d.value_ptr.* = @unionInit(zml.aio.Value, switch (tag) {
|
||||
.int => "int64",
|
||||
.float => "float64",
|
||||
.bool => "boolval",
|
||||
else => @tagName(tag),
|
||||
}, val);
|
||||
} else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val);
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ pub const PickleMemo = struct {
|
||||
}
|
||||
},
|
||||
.seq => |*v| {
|
||||
for (v[1]) |*item| {
|
||||
for (v.values) |*item| {
|
||||
if (item.containsRef()) {
|
||||
item.* = try self.resolve(allocator, item.*, recursive);
|
||||
}
|
||||
@ -148,7 +148,7 @@ pub const PickleMemo = struct {
|
||||
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
|
||||
try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values),
|
||||
) },
|
||||
.seq => |v| .{ .seq = .{ v[0], try self.resolveAllRefsIter(allocator, depth + 1, v[1], fix_values) } },
|
||||
.seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } },
|
||||
.pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) },
|
||||
else => try val.clone(allocator),
|
||||
};
|
||||
@ -252,7 +252,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
const kv = try alloc.alloc(Value, 2);
|
||||
kv[0] = items[idx];
|
||||
kv[1] = items[idx + 1];
|
||||
kv_items.appendAssumeCapacity(.{ .seq = .{ .kv_tuple, kv } });
|
||||
kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } });
|
||||
}
|
||||
return kv_items.toOwnedSlice();
|
||||
}
|
||||
@ -283,19 +283,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
const member = try memo.resolve(allocator, try stack.pop(), true);
|
||||
break :blk .{ .build = try Build.init(allocator, member, args) };
|
||||
}),
|
||||
.empty_dict => try stack.values.append(.{ .seq = .{ .dict, &[_]Value{} } }),
|
||||
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
|
||||
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
|
||||
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
|
||||
.empty_list => try stack.values.append(.{ .seq = .{ .list, &[_]Value{} } }),
|
||||
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
|
||||
.binput, .long_binput => |v| {
|
||||
try memo.insert(v, try stack.pop());
|
||||
try stack.values.append(.{ .ref = v });
|
||||
},
|
||||
.tuple => try stack.values.append(blk: {
|
||||
const popped = try stack.popMark(allocator);
|
||||
break :blk .{ .seq = .{ .tuple, popped } };
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
|
||||
}),
|
||||
.empty_tuple => try stack.values.append(.{ .seq = .{ .tuple, &[_]Value{} } }),
|
||||
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
|
||||
.setitem => {
|
||||
const v, const k = .{ try stack.pop(), try stack.pop() };
|
||||
const top = try stack.lastMut();
|
||||
@ -303,11 +303,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } };
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
|
||||
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } };
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
|
||||
},
|
||||
else => {
|
||||
return error.BadStackTopForSetItem;
|
||||
@ -323,11 +323,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, kv_items } };
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
|
||||
tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, kv_items } };
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
|
||||
},
|
||||
else => {
|
||||
defer allocator.free(kv_items);
|
||||
@ -339,17 +339,17 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
.tuple1 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 1);
|
||||
tup_values[0] = try stack.pop();
|
||||
break :blk .{ .seq = .{ .tuple, tup_values } };
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.tuple2 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 2);
|
||||
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
|
||||
break :blk .{ .seq = .{ .tuple, tup_values } };
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.tuple3 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 3);
|
||||
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
|
||||
break :blk .{ .seq = .{ .tuple, tup_values } };
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.append => {
|
||||
const v = try stack.pop();
|
||||
@ -361,8 +361,8 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
obj.args[obj.args.len - 1] = v;
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1);
|
||||
tup[1][tup[1].len - 1] = v;
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = v;
|
||||
},
|
||||
else => {
|
||||
return error.BadStackTopForAppend;
|
||||
@ -381,9 +381,9 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
@memcpy(obj.args[obj_len..], postmark);
|
||||
},
|
||||
.seq => |*tup| {
|
||||
const tup_len = tup[1].len;
|
||||
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len);
|
||||
@memcpy(tup[1][tup_len..], postmark);
|
||||
const tup_len = tup.values.len;
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
|
||||
@memcpy(tup.values[tup_len..], postmark);
|
||||
},
|
||||
else => {
|
||||
return error.BadStackTopForAppends;
|
||||
@ -394,12 +394,12 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
const popped = try stack.popMark(allocator);
|
||||
defer allocator.free(popped);
|
||||
const kv_items = try makeKVList(allocator, popped);
|
||||
break :blk .{ .seq = .{ .dict, kv_items } };
|
||||
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
|
||||
}),
|
||||
.list => try stack.values.append(.{ .seq = .{ .list, try stack.popMark(allocator) } }),
|
||||
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
|
||||
.inst => |v| try stack.values.append(blk: {
|
||||
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } });
|
||||
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .tuple, tup_items } }, try stack.popMark(allocator)) };
|
||||
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
|
||||
}),
|
||||
.obj => try stack.values.append(blk: {
|
||||
const markidx = try stack.findMark();
|
||||
@ -418,7 +418,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
args[0] = try stack.pop();
|
||||
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
|
||||
}),
|
||||
.empty_set => try stack.values.append(.{ .seq = .{ .set, &[_]Value{} } }),
|
||||
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
|
||||
.additems => {
|
||||
const postmark = try stack.popMark(allocator);
|
||||
defer allocator.free(postmark);
|
||||
@ -431,19 +431,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
@memcpy(obj.args[obj_len..], postmark);
|
||||
},
|
||||
.seq => |*tup| {
|
||||
const tup_len = tup[1].len;
|
||||
tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len);
|
||||
@memcpy(tup[1][tup_len..], postmark);
|
||||
const tup_len = tup.values.len;
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
|
||||
@memcpy(tup.values[tup_len..], postmark);
|
||||
},
|
||||
else => {
|
||||
return error.BadStackTopForSetItem;
|
||||
},
|
||||
}
|
||||
},
|
||||
.frozenset => try stack.values.append(.{ .seq = .{ .frozen_set, try stack.popMark(allocator) } }),
|
||||
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
|
||||
.newobj_ex => try stack.values.append(blk: {
|
||||
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
|
||||
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ args, kwargs }) };
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
|
||||
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
|
||||
}),
|
||||
.stack_global => try stack.values.append(blk: {
|
||||
@ -451,7 +451,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs:
|
||||
try memo.resolve(allocator, try stack.pop(), true),
|
||||
try memo.resolve(allocator, try stack.pop(), true),
|
||||
};
|
||||
const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ gn, mn }) };
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
|
||||
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
|
||||
}),
|
||||
.memoize => {
|
||||
|
||||
@ -388,6 +388,7 @@ const TarStream = struct {
|
||||
};
|
||||
|
||||
test "Read pickle (simple)" {
|
||||
const Value = @import("value.zig").Value;
|
||||
var arena = std.heap.ArenaAllocator.init(testing.allocator);
|
||||
defer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
@ -402,64 +403,36 @@ test "Read pickle (simple)" {
|
||||
try testing.expect(vals.stack.len == 2);
|
||||
// skip first value (frame)
|
||||
try testing.expect(vals.stack[1] == .seq);
|
||||
try testing.expect(vals.stack[1].seq[0] == .dict);
|
||||
const entries = vals.stack[1].seq[1][0].seq[1];
|
||||
try testing.expect(vals.stack[1].seq.type == .dict);
|
||||
const entries = vals.stack[1].seq.values[0].seq.values;
|
||||
try testing.expect(entries.len == 5);
|
||||
for (entries, 0..) |kv, i| {
|
||||
try testing.expect(kv == .seq);
|
||||
try testing.expect(kv.seq[0] == .kv_tuple);
|
||||
switch (i) {
|
||||
0 => {
|
||||
const key = kv.seq[1][0];
|
||||
try testing.expect(key == .string);
|
||||
try testing.expectEqualStrings("hello", key.string);
|
||||
const value = kv.seq[1][1];
|
||||
try testing.expect(value == .string);
|
||||
try testing.expectEqualStrings("world", value.string);
|
||||
},
|
||||
1 => {
|
||||
const key = kv.seq[1][0];
|
||||
try testing.expect(key == .string);
|
||||
try testing.expectEqualStrings("int", key.string);
|
||||
const value = kv.seq[1][1];
|
||||
try testing.expect(value == .int);
|
||||
try testing.expect(value.int == 1);
|
||||
},
|
||||
2 => {
|
||||
const key = kv.seq[1][0];
|
||||
try testing.expect(key == .string);
|
||||
try testing.expectEqualStrings("float", key.string);
|
||||
const value = kv.seq[1][1];
|
||||
try testing.expect(value == .float);
|
||||
try testing.expectEqual(@as(f64, 3.141592), value.float);
|
||||
},
|
||||
3 => {
|
||||
const key = kv.seq[1][0];
|
||||
try testing.expect(key == .string);
|
||||
try testing.expectEqualStrings("list", key.string);
|
||||
const value = kv.seq[1][1];
|
||||
try testing.expect(value == .seq);
|
||||
try testing.expect(value.seq[0] == .list);
|
||||
for (value.seq[1], 0..) |item, j| {
|
||||
try testing.expect(item == .int);
|
||||
try testing.expect(item.int == @as(i64, @intCast(j)));
|
||||
}
|
||||
},
|
||||
4 => {
|
||||
const key = kv.seq[1][0];
|
||||
try testing.expect(key == .string);
|
||||
try testing.expectEqualStrings("tuple", key.string);
|
||||
const value = kv.seq[1][1];
|
||||
try testing.expect(value == .seq);
|
||||
try testing.expect(value.seq[0] == .tuple);
|
||||
try testing.expect(value.seq[1][0] == .string);
|
||||
try testing.expectEqualStrings("a", value.seq[1][0].string);
|
||||
try testing.expect(value.seq[1][1] == .int);
|
||||
try testing.expect(value.seq[1][1].int == 10);
|
||||
},
|
||||
else => unreachable,
|
||||
}
|
||||
}
|
||||
const expected: []const Value = &.{
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "list" },
|
||||
.{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .int64 = 0 },
|
||||
.{ .int64 = 1 },
|
||||
.{ .int64 = 2 },
|
||||
.{ .int64 = 3 },
|
||||
.{ .int64 = 4 },
|
||||
})) } },
|
||||
})) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "tuple" },
|
||||
.{ .seq = .{
|
||||
.type = .tuple,
|
||||
.values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "a" },
|
||||
.{ .int64 = 10 },
|
||||
})),
|
||||
} },
|
||||
})) } },
|
||||
};
|
||||
|
||||
try std.testing.expectEqualDeep(expected, entries);
|
||||
}
|
||||
|
||||
test "Read pickle (zipped)" {
|
||||
|
||||
@ -65,7 +65,10 @@ pub const Build = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const Sequence = struct { SequenceType, []Value };
|
||||
pub const Sequence = struct {
|
||||
type: SequenceType,
|
||||
values: []Value,
|
||||
};
|
||||
|
||||
pub const PersId = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
@ -100,11 +103,11 @@ pub const ValueType = enum {
|
||||
seq,
|
||||
string,
|
||||
bytes,
|
||||
int,
|
||||
int64,
|
||||
bigint,
|
||||
float,
|
||||
float64,
|
||||
raw_num,
|
||||
bool,
|
||||
boolval,
|
||||
none,
|
||||
};
|
||||
|
||||
@ -160,7 +163,7 @@ pub const Value = union(ValueType) {
|
||||
/// An integer, but not the crazy kind that comes as a string
|
||||
/// that has to be parsed. You can look in `Value.raw_num` for
|
||||
/// those.
|
||||
int: i64,
|
||||
int64: i64,
|
||||
|
||||
/// An integer that can't fit in i64.
|
||||
bigint: big_int.Managed,
|
||||
@ -168,13 +171,13 @@ pub const Value = union(ValueType) {
|
||||
/// An float, but not the crazy kind that comes as a string
|
||||
/// that has to be parsed. You can look in `Value.raw_num` for
|
||||
/// those.
|
||||
float: f64,
|
||||
float64: f64,
|
||||
|
||||
/// Some kind of weird number we can't handle.
|
||||
raw_num: PickleOp,
|
||||
|
||||
/// A boolean value.
|
||||
bool: bool,
|
||||
boolval: bool,
|
||||
|
||||
/// Python `None`.
|
||||
none: void,
|
||||
@ -184,8 +187,8 @@ pub const Value = union(ValueType) {
|
||||
.raw, .raw_num => |v| v.deinit(allocator),
|
||||
inline .app, .object, .global, .build, .pers_id => |v| v.deinit(),
|
||||
.seq => |v| {
|
||||
for (v[1]) |*val| val.deinit(allocator);
|
||||
allocator.free(v[1]);
|
||||
for (v.values) |*val| val.deinit(allocator);
|
||||
allocator.free(v.values);
|
||||
},
|
||||
.string, .bytes => |v| allocator.free(v),
|
||||
.bigint => self.bigint.deinit(),
|
||||
@ -205,7 +208,7 @@ pub const Value = union(ValueType) {
|
||||
try writeIndents(indents + 1, writer);
|
||||
try writer.print(".{s} = ", .{@tagName(std.meta.activeTag(value))});
|
||||
switch (value) {
|
||||
inline .ref, .int, .float => |v| try writer.print("{d} ", .{v}),
|
||||
inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}),
|
||||
.app, .object, .global => |v| {
|
||||
try writer.writeAll(".{\n");
|
||||
try internalFormat(v.member, indents + 2, writer);
|
||||
@ -242,13 +245,13 @@ pub const Value = union(ValueType) {
|
||||
.seq => |v| {
|
||||
try writer.writeAll(".{\n");
|
||||
try writeIndents(indents + 2, writer);
|
||||
try writer.print(".{s},\n", .{@tagName(v[0])});
|
||||
try writer.print(".{s},\n", .{@tagName(v.type)});
|
||||
try writeIndents(indents + 2, writer);
|
||||
if (v[1].len > 0) {
|
||||
if (v.values.len > 0) {
|
||||
try writer.writeAll(".{\n");
|
||||
for (v[1], 0..) |arg, i| {
|
||||
for (v.values, 0..) |arg, i| {
|
||||
try internalFormat(arg, indents + 3, writer);
|
||||
if (i < v[1].len - 1) try writer.writeAll(",");
|
||||
if (i < v.values.len - 1) try writer.writeAll(",");
|
||||
try writer.writeByte('\n');
|
||||
}
|
||||
try writeIndents(indents + 2, writer);
|
||||
@ -283,8 +286,8 @@ pub const Value = union(ValueType) {
|
||||
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
|
||||
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
|
||||
.seq => |seq| blk: {
|
||||
const new_val: Sequence = .{ seq[0], try allocator.alloc(Value, seq[1].len) };
|
||||
for (seq[1], 0..) |v, i| new_val[1][i] = try v.clone(allocator);
|
||||
const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) };
|
||||
for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator);
|
||||
break :blk .{ .seq = new_val };
|
||||
},
|
||||
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
|
||||
@ -295,8 +298,8 @@ pub const Value = union(ValueType) {
|
||||
|
||||
pub fn isPrimitive(self: Value) bool {
|
||||
return switch (self) {
|
||||
.int, .bigint, .float, .string, .bytes, .bool, .none => true,
|
||||
.seq => |seq| utils.allTrue(seq[1], Value.isPrimitive),
|
||||
.int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
|
||||
.seq => |seq| utils.allTrue(seq.values, Value.isPrimitive),
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
@ -316,7 +319,7 @@ pub const Value = union(ValueType) {
|
||||
},
|
||||
.pers_id => |v| return v.ref.containsRef(),
|
||||
.seq => |v| {
|
||||
for (v[1]) |val| if (val.containsRef()) return true;
|
||||
for (v.values) |val| if (val.containsRef()) return true;
|
||||
return false;
|
||||
},
|
||||
else => return false,
|
||||
@ -336,7 +339,7 @@ pub const Value = union(ValueType) {
|
||||
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
|
||||
return switch (self) {
|
||||
.raw => |raw_val| switch (raw_val) {
|
||||
.binint, .binint1, .binint2 => |val| .{ .int = val },
|
||||
.binint, .binint1, .binint2 => |val| .{ .int64 = val },
|
||||
.long1, .long4 => |b| if (b.len != 0) {
|
||||
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
|
||||
var mutable = bint.toMutable();
|
||||
@ -345,10 +348,10 @@ pub const Value = union(ValueType) {
|
||||
const max_comp = bint.toConst().order(BI64MAX);
|
||||
if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) {
|
||||
defer bint.deinit();
|
||||
return .{ .int = try bint.to(i64) };
|
||||
return .{ .int64 = try bint.to(i64) };
|
||||
} else return .{ .bigint = bint };
|
||||
} else .{ .raw_num = raw_val },
|
||||
.binfloat => |val| .{ .float = val },
|
||||
.binfloat => |val| .{ .float64 = val },
|
||||
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s },
|
||||
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b },
|
||||
// This isn't how Pickle actually works but we just try to UTF8 decode the
|
||||
@ -356,17 +359,17 @@ pub const Value = union(ValueType) {
|
||||
// actually cares they can just fix values themselves or recover the raw bytes
|
||||
// from the UTF8 string (it's guaranteed to be reversible, as far as I know).
|
||||
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
|
||||
.newtrue => .{ .bool = true },
|
||||
.newfalse => .{ .bool = false },
|
||||
.newtrue => .{ .boolval = true },
|
||||
.newfalse => .{ .boolval = false },
|
||||
.none => .{ .none = {} },
|
||||
inline .int,
|
||||
.float,
|
||||
.long,
|
||||
=> |v, tag| {
|
||||
if (tag == .int and std.mem.eql(u8, v, "01")) {
|
||||
return .{ .bool = true };
|
||||
return .{ .boolval = true };
|
||||
} else if (tag == .int and std.mem.eql(u8, v, "00")) {
|
||||
return .{ .bool = false };
|
||||
return .{ .boolval = false };
|
||||
} else {
|
||||
return .{ .raw_num = raw_val };
|
||||
}
|
||||
@ -389,8 +392,8 @@ pub const Value = union(ValueType) {
|
||||
v.ref = try v.ref.coerceFromRaw(allocator);
|
||||
break :blk self;
|
||||
},
|
||||
.seq => |*v| blk: {
|
||||
for (v[1]) |*val| {
|
||||
.seq => |v| blk: {
|
||||
for (v.values) |*val| {
|
||||
val.* = try val.coerceFromRaw(allocator);
|
||||
}
|
||||
break :blk self;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user