diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 2de206e..20434d4 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -101,12 +101,12 @@ pub const PickleData = struct { fn isTensor(v: Value) bool { if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) { - const args = v.global.args[0].seq[1]; + const args = v.global.args[0].seq.values; if (args.len >= 5 and args[0] == .pers_id and - args[1] == .int and - args[2] == .seq and args[2].seq[0] == .tuple and - args[3] == .seq and args[3].seq[0] == .tuple) + args[1] == .int64 and + args[2] == .seq and args[2].seq.type == .tuple and + args[3] == .seq and args[3].seq.type == .tuple) { return true; } else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2"); @@ -119,7 +119,7 @@ pub const PickleData = struct { var result: [zml.Tensor.MAX_RANK]i64 = undefined; for (values, result[0..values.len]) |val, *elem| { switch (val) { - .int => |int| elem.* = int, + .int64 => |int| elem.* = int, else => @panic("Bad value for shape item"), } } @@ -174,15 +174,15 @@ pub const PickleData = struct { return switch (v) { .global => |object| { if (isTensor(v)) { - const args = object.args[0].seq[1]; - const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int), args[2].seq, args[3].seq }; - const rank = raw_shape[1].len; - const shape = dimsFromValues(raw_shape[1]); - var strides = dimsFromValues(raw_strides[1]); + const args = object.args[0].seq.values; + const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq }; + const rank = raw_shape.values.len; + const shape = dimsFromValues(raw_shape.values); + var strides = dimsFromValues(raw_strides.values); const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) { .seq => |seq| blk: { - const sargs = seq[1]; - if (seq[0] == .tuple and + const sargs = seq.values; + if (seq.type == .tuple and sargs.len >= 5 and sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and sargs[1] == .raw and sargs[1].raw == .global and @@ -231,7 +231,7 @@ pub const PickleData = struct { ); return true; } else if (basicTypeCheck(v, "torch", "Size")) { - const size = object.args[0].seq[1][0].seq[1]; + const size = object.args[0].seq.values[0].seq.values; const key = try allocator.dupe(u8, prefix.items); const entry = try store._metadata.getOrPut(allocator, key); if (entry.found_existing) { @@ -239,11 +239,11 @@ pub const PickleData = struct { allocator.free(key); } const d = try allocator.alloc(i64, size.len); - for (d, 0..) |*di, i| di.* = size[i].int; + for (d, 0..) |*di, i| di.* = size[i].int64; entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; return true; } else if (basicTypeCheck(v, "fractions", "Fraction")) { - const fraction_str = object.args[0].seq[1][0].string; + const fraction_str = object.args[0].seq.values[0].string; if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { { var new_prefix = prefix; @@ -271,8 +271,8 @@ pub const PickleData = struct { try self.parseValue(allocator, store, prefix, object.member); for (object.args) |item| { // if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix) - if (item == .seq and item.seq[0] == .tuple and item.seq[1].len == 2 and item.seq[1][0] == .string) { - try self.parseValue(allocator, store, prefix, .{ .seq = .{ .kv_tuple, item.seq[1] } }); + if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) { + try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } }); } else try self.parseValue(allocator, store, prefix, item); } } @@ -312,102 +312,97 @@ pub const PickleData = struct { try self.parseValue(allocator, store, prefix, build.args); }, .pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref), - .seq => |*seq| switch (seq[0]) { - .list, .tuple, .set, .frozen_set => { - const elemCheck = struct { - fn call(comptime T: ValueType) fn (v: Value) bool { - return struct { - fn call(val: Value) bool { - return val == T; + .seq => |seq| { + switch (seq.type) { + .list, .tuple, .set, .frozen_set => { + if (seq.values.len == 0) return; + var valid_slice = true; + switch (seq.values[0]) { + inline .int64, .float64, .boolval => |val0, tag| { + const ItemType = switch (tag) { + .int64 => i64, + .float64 => f64, + .boolval => bool, + else => unreachable, + }; + var values: std.ArrayListUnmanaged(ItemType) = .{}; + try values.append(allocator, val0); + for (seq.values[1..], 1..) |val, i| { + if (std.meta.activeTag(val) != tag) valid_slice = false; + if (valid_slice) { + try values.append(allocator, @field(val, @tagName(tag))); + } else { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + try self.parseValue(allocator, store, new_prefix, val); + } } - }.call; - } - }.call; - if (seq[1].len > 0 and switch (seq[1][0]) { - inline .int, .bool, .float => |_, tag| utils.allTrue(seq[1][1..], elemCheck(tag)), - else => false, - }) { - const out: []u8 = switch (seq[1][0]) { - .int => blk: { - const d = try allocator.alloc(i64, seq[1].len); - for (seq[1], 0..) |item, i| { - d[i] = item.int; + if (valid_slice) { + try store._metadata.put( + allocator, + try allocator.dupe(u8, prefix.items), + .{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } }, + ); + } else { + for (values.items, 0..) |val, i| { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val)); + } } - break :blk std.mem.sliceAsBytes(d); }, - .float => blk: { - const d = try allocator.alloc(f64, seq[1].len); - for (seq[1], 0..) |item, i| { - d[i] = item.float; + else => { + for (seq.values, 0..) |item, i| { + var new_prefix = prefix; + if (v.isPrimitive()) { + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + } + try self.parseValue(allocator, store, new_prefix, item); } - break :blk std.mem.sliceAsBytes(d); }, - else => blk: { - const d = try allocator.alloc(bool, seq[1].len); - for (seq[1], 0..) |item, i| { - d[i] = item.bool; - } - break :blk std.mem.sliceAsBytes(d); - }, - }; - const key = try allocator.dupe(u8, prefix.items); - const d = try store._metadata.getOrPut(allocator, key); - if (d.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - allocator.free(out); - } else d.value_ptr.* = @unionInit(zml.aio.Value, "array", .{ .item_type = switch (seq[1][0]) { - .int => .int64, - .float => .float64, - .string => .string, - else => .boolval, - }, .data = out }); - } else { - for (seq[1], 0..) |item, i| { - var new_prefix = prefix; - if (v.isPrimitive()) { - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - } - try self.parseValue(allocator, store, new_prefix, item); } - } - }, - .dict => { - for (seq[1]) |item| { + }, + .dict => for (seq.values) |item| { try self.parseValue(allocator, store, prefix, item); - } - }, - .kv_tuple => { - const key = seq[1][0]; - const val = seq[1][1]; - switch (key) { - .string => |s| { - if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { - try self.parseValue(allocator, store, prefix, val); - } else { + }, + .kv_tuple => { + const key, const val = seq.values[0..2].*; + switch (key) { + .string => |s| { + // Handle Pytorch specific fields + if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { + try self.parseValue(allocator, store, prefix, val); + } else { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.appendSliceAssumeCapacity(s); + try self.parseValue(allocator, store, new_prefix, val); + } + }, + .int64 => |int| { var new_prefix = prefix; if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } - new_prefix.appendSliceAssumeCapacity(s); + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); try self.parseValue(allocator, store, new_prefix, val); - } - }, - .int => |int| { - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); - try self.parseValue(allocator, store, new_prefix, val); - }, - inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}), - } - }, + }, + inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}), + } + }, + } }, .bytes => |val| { const key = try allocator.dupe(u8, prefix.items); @@ -417,18 +412,13 @@ pub const PickleData = struct { allocator.free(key); } else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } }; }, - inline .float, .int, .bool, .bigint, .string => |val, tag| { + inline .float64, .int64, .boolval, .bigint, .string => |val, tag| { const key = try allocator.dupe(u8, prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); - } else d.value_ptr.* = @unionInit(zml.aio.Value, switch (tag) { - .int => "int64", - .float => "float64", - .bool => "boolval", - else => @tagName(tag), - }, val); + } else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val); }, else => {}, } diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index 246a714..b04656a 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -81,7 +81,7 @@ pub const PickleMemo = struct { } }, .seq => |*v| { - for (v[1]) |*item| { + for (v.values) |*item| { if (item.containsRef()) { item.* = try self.resolve(allocator, item.*, recursive); } @@ -148,7 +148,7 @@ pub const PickleMemo = struct { try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values), ) }, - .seq => |v| .{ .seq = .{ v[0], try self.resolveAllRefsIter(allocator, depth + 1, v[1], fix_values) } }, + .seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } }, .pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) }, else => try val.clone(allocator), }; @@ -252,7 +252,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: const kv = try alloc.alloc(Value, 2); kv[0] = items[idx]; kv[1] = items[idx + 1]; - kv_items.appendAssumeCapacity(.{ .seq = .{ .kv_tuple, kv } }); + kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } }); } return kv_items.toOwnedSlice(); } @@ -283,19 +283,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: const member = try memo.resolve(allocator, try stack.pop(), true); break :blk .{ .build = try Build.init(allocator, member, args) }; }), - .empty_dict => try stack.values.append(.{ .seq = .{ .dict, &[_]Value{} } }), + .empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), .get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }), inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }), - .empty_list => try stack.values.append(.{ .seq = .{ .list, &[_]Value{} } }), + .empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), .binput, .long_binput => |v| { try memo.insert(v, try stack.pop()); try stack.values.append(.{ .ref = v }); }, .tuple => try stack.values.append(blk: { const popped = try stack.popMark(allocator); - break :blk .{ .seq = .{ .tuple, popped } }; + break :blk .{ .seq = .{ .type = .tuple, .values = popped } }; }), - .empty_tuple => try stack.values.append(.{ .seq = .{ .tuple, &[_]Value{} } }), + .empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), .setitem => { const v, const k = .{ try stack.pop(), try stack.pop() }; const top = try stack.lastMut(); @@ -303,11 +303,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: switch (rtop.*) { .global => |obj| { obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } }; + obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; }, .seq => |*tup| { - tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); - tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, try allocator.dupe(Value, &.{ k, v }) } }; + tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); + tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; }, else => { return error.BadStackTopForSetItem; @@ -323,11 +323,11 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: switch (rtop.*) { .global => |obj| { obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = .{ .seq = .{ .tuple, kv_items } }; + obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; }, .seq => |*tup| { - tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); - tup[1][tup[1].len - 1] = .{ .seq = .{ .tuple, kv_items } }; + tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); + tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; }, else => { defer allocator.free(kv_items); @@ -339,17 +339,17 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: .tuple1 => try stack.values.append(blk: { const tup_values = try allocator.alloc(Value, 1); tup_values[0] = try stack.pop(); - break :blk .{ .seq = .{ .tuple, tup_values } }; + break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .tuple2 => try stack.values.append(blk: { const tup_values = try allocator.alloc(Value, 2); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); - break :blk .{ .seq = .{ .tuple, tup_values } }; + break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .tuple3 => try stack.values.append(blk: { const tup_values = try allocator.alloc(Value, 3); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); - break :blk .{ .seq = .{ .tuple, tup_values } }; + break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .append => { const v = try stack.pop(); @@ -361,8 +361,8 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: obj.args[obj.args.len - 1] = v; }, .seq => |*tup| { - tup[1] = try assuredResize(Value, allocator, tup[1], tup[1].len + 1); - tup[1][tup[1].len - 1] = v; + tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); + tup.values[tup.values.len - 1] = v; }, else => { return error.BadStackTopForAppend; @@ -381,9 +381,9 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: @memcpy(obj.args[obj_len..], postmark); }, .seq => |*tup| { - const tup_len = tup[1].len; - tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len); - @memcpy(tup[1][tup_len..], postmark); + const tup_len = tup.values.len; + tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); + @memcpy(tup.values[tup_len..], postmark); }, else => { return error.BadStackTopForAppends; @@ -394,12 +394,12 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: const popped = try stack.popMark(allocator); defer allocator.free(popped); const kv_items = try makeKVList(allocator, popped); - break :blk .{ .seq = .{ .dict, kv_items } }; + break :blk .{ .seq = .{ .type = .dict, .values = kv_items } }; }), - .list => try stack.values.append(.{ .seq = .{ .list, try stack.popMark(allocator) } }), + .list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }), .inst => |v| try stack.values.append(blk: { const tup_items = try allocator.dupe(Value, &.{ .{ .string = v[0] }, .{ .string = v[1] } }); - break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .tuple, tup_items } }, try stack.popMark(allocator)) }; + break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) }; }), .obj => try stack.values.append(blk: { const markidx = try stack.findMark(); @@ -418,7 +418,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: args[0] = try stack.pop(); break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) }; }), - .empty_set => try stack.values.append(.{ .seq = .{ .set, &[_]Value{} } }), + .empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), .additems => { const postmark = try stack.popMark(allocator); defer allocator.free(postmark); @@ -431,19 +431,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: @memcpy(obj.args[obj_len..], postmark); }, .seq => |*tup| { - const tup_len = tup[1].len; - tup[1] = try assuredResize(Value, allocator, tup[1], tup_len + postmark.len); - @memcpy(tup[1][tup_len..], postmark); + const tup_len = tup.values.len; + tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); + @memcpy(tup.values[tup_len..], postmark); }, else => { return error.BadStackTopForSetItem; }, } }, - .frozenset => try stack.values.append(.{ .seq = .{ .frozen_set, try stack.popMark(allocator) } }), + .frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }), .newobj_ex => try stack.values.append(blk: { const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() }; - const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ args, kwargs }) }; + const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) }; break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) }; }), .stack_global => try stack.values.append(blk: { @@ -451,7 +451,7 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: try memo.resolve(allocator, try stack.pop(), true), try memo.resolve(allocator, try stack.pop(), true), }; - const new_seq: Sequence = .{ .tuple, try allocator.dupe(Value, &.{ gn, mn }) }; + const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) }; break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) }; }), .memoize => { diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index 5075299..e7a574e 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -388,6 +388,7 @@ const TarStream = struct { }; test "Read pickle (simple)" { + const Value = @import("value.zig").Value; var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); const allocator = arena.allocator(); @@ -402,64 +403,36 @@ test "Read pickle (simple)" { try testing.expect(vals.stack.len == 2); // skip first value (frame) try testing.expect(vals.stack[1] == .seq); - try testing.expect(vals.stack[1].seq[0] == .dict); - const entries = vals.stack[1].seq[1][0].seq[1]; + try testing.expect(vals.stack[1].seq.type == .dict); + const entries = vals.stack[1].seq.values[0].seq.values; try testing.expect(entries.len == 5); - for (entries, 0..) |kv, i| { - try testing.expect(kv == .seq); - try testing.expect(kv.seq[0] == .kv_tuple); - switch (i) { - 0 => { - const key = kv.seq[1][0]; - try testing.expect(key == .string); - try testing.expectEqualStrings("hello", key.string); - const value = kv.seq[1][1]; - try testing.expect(value == .string); - try testing.expectEqualStrings("world", value.string); - }, - 1 => { - const key = kv.seq[1][0]; - try testing.expect(key == .string); - try testing.expectEqualStrings("int", key.string); - const value = kv.seq[1][1]; - try testing.expect(value == .int); - try testing.expect(value.int == 1); - }, - 2 => { - const key = kv.seq[1][0]; - try testing.expect(key == .string); - try testing.expectEqualStrings("float", key.string); - const value = kv.seq[1][1]; - try testing.expect(value == .float); - try testing.expectEqual(@as(f64, 3.141592), value.float); - }, - 3 => { - const key = kv.seq[1][0]; - try testing.expect(key == .string); - try testing.expectEqualStrings("list", key.string); - const value = kv.seq[1][1]; - try testing.expect(value == .seq); - try testing.expect(value.seq[0] == .list); - for (value.seq[1], 0..) |item, j| { - try testing.expect(item == .int); - try testing.expect(item.int == @as(i64, @intCast(j))); - } - }, - 4 => { - const key = kv.seq[1][0]; - try testing.expect(key == .string); - try testing.expectEqualStrings("tuple", key.string); - const value = kv.seq[1][1]; - try testing.expect(value == .seq); - try testing.expect(value.seq[0] == .tuple); - try testing.expect(value.seq[1][0] == .string); - try testing.expectEqualStrings("a", value.seq[1][0].string); - try testing.expect(value.seq[1][1] == .int); - try testing.expect(value.seq[1][1].int == 10); - }, - else => unreachable, - } - } + const expected: []const Value = &.{ + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ + .{ .string = "list" }, + .{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{ + .{ .int64 = 0 }, + .{ .int64 = 1 }, + .{ .int64 = 2 }, + .{ .int64 = 3 }, + .{ .int64 = 4 }, + })) } }, + })) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ + .{ .string = "tuple" }, + .{ .seq = .{ + .type = .tuple, + .values = @constCast(@as([]const Value, &.{ + .{ .string = "a" }, + .{ .int64 = 10 }, + })), + } }, + })) } }, + }; + + try std.testing.expectEqualDeep(expected, entries); } test "Read pickle (zipped)" { diff --git a/zml/aio/torch/value.zig b/zml/aio/torch/value.zig index 302f85e..9de3cc2 100644 --- a/zml/aio/torch/value.zig +++ b/zml/aio/torch/value.zig @@ -65,7 +65,10 @@ pub const Build = struct { } }; -pub const Sequence = struct { SequenceType, []Value }; +pub const Sequence = struct { + type: SequenceType, + values: []Value, +}; pub const PersId = struct { allocator: std.mem.Allocator, @@ -100,11 +103,11 @@ pub const ValueType = enum { seq, string, bytes, - int, + int64, bigint, - float, + float64, raw_num, - bool, + boolval, none, }; @@ -160,7 +163,7 @@ pub const Value = union(ValueType) { /// An integer, but not the crazy kind that comes as a string /// that has to be parsed. You can look in `Value.raw_num` for /// those. - int: i64, + int64: i64, /// An integer that can't fit in i64. bigint: big_int.Managed, @@ -168,13 +171,13 @@ pub const Value = union(ValueType) { /// An float, but not the crazy kind that comes as a string /// that has to be parsed. You can look in `Value.raw_num` for /// those. - float: f64, + float64: f64, /// Some kind of weird number we can't handle. raw_num: PickleOp, /// A boolean value. - bool: bool, + boolval: bool, /// Python `None`. none: void, @@ -184,8 +187,8 @@ pub const Value = union(ValueType) { .raw, .raw_num => |v| v.deinit(allocator), inline .app, .object, .global, .build, .pers_id => |v| v.deinit(), .seq => |v| { - for (v[1]) |*val| val.deinit(allocator); - allocator.free(v[1]); + for (v.values) |*val| val.deinit(allocator); + allocator.free(v.values); }, .string, .bytes => |v| allocator.free(v), .bigint => self.bigint.deinit(), @@ -205,7 +208,7 @@ pub const Value = union(ValueType) { try writeIndents(indents + 1, writer); try writer.print(".{s} = ", .{@tagName(std.meta.activeTag(value))}); switch (value) { - inline .ref, .int, .float => |v| try writer.print("{d} ", .{v}), + inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}), .app, .object, .global => |v| { try writer.writeAll(".{\n"); try internalFormat(v.member, indents + 2, writer); @@ -242,13 +245,13 @@ pub const Value = union(ValueType) { .seq => |v| { try writer.writeAll(".{\n"); try writeIndents(indents + 2, writer); - try writer.print(".{s},\n", .{@tagName(v[0])}); + try writer.print(".{s},\n", .{@tagName(v.type)}); try writeIndents(indents + 2, writer); - if (v[1].len > 0) { + if (v.values.len > 0) { try writer.writeAll(".{\n"); - for (v[1], 0..) |arg, i| { + for (v.values, 0..) |arg, i| { try internalFormat(arg, indents + 3, writer); - if (i < v[1].len - 1) try writer.writeAll(","); + if (i < v.values.len - 1) try writer.writeAll(","); try writer.writeByte('\n'); } try writeIndents(indents + 2, writer); @@ -283,8 +286,8 @@ pub const Value = union(ValueType) { inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), .seq => |seq| blk: { - const new_val: Sequence = .{ seq[0], try allocator.alloc(Value, seq[1].len) }; - for (seq[1], 0..) |v, i| new_val[1][i] = try v.clone(allocator); + const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) }; + for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator); break :blk .{ .seq = new_val }; }, inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), @@ -295,8 +298,8 @@ pub const Value = union(ValueType) { pub fn isPrimitive(self: Value) bool { return switch (self) { - .int, .bigint, .float, .string, .bytes, .bool, .none => true, - .seq => |seq| utils.allTrue(seq[1], Value.isPrimitive), + .int64, .bigint, .float64, .string, .bytes, .boolval, .none => true, + .seq => |seq| utils.allTrue(seq.values, Value.isPrimitive), else => false, }; } @@ -316,7 +319,7 @@ pub const Value = union(ValueType) { }, .pers_id => |v| return v.ref.containsRef(), .seq => |v| { - for (v[1]) |val| if (val.containsRef()) return true; + for (v.values) |val| if (val.containsRef()) return true; return false; }, else => return false, @@ -336,7 +339,7 @@ pub const Value = union(ValueType) { pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value { return switch (self) { .raw => |raw_val| switch (raw_val) { - .binint, .binint1, .binint2 => |val| .{ .int = val }, + .binint, .binint1, .binint2 => |val| .{ .int64 = val }, .long1, .long4 => |b| if (b.len != 0) { var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); var mutable = bint.toMutable(); @@ -345,10 +348,10 @@ pub const Value = union(ValueType) { const max_comp = bint.toConst().order(BI64MAX); if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) { defer bint.deinit(); - return .{ .int = try bint.to(i64) }; + return .{ .int64 = try bint.to(i64) }; } else return .{ .bigint = bint }; } else .{ .raw_num = raw_val }, - .binfloat => |val| .{ .float = val }, + .binfloat => |val| .{ .float64 = val }, .binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s }, .binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b }, // This isn't how Pickle actually works but we just try to UTF8 decode the @@ -356,17 +359,17 @@ pub const Value = union(ValueType) { // actually cares they can just fix values themselves or recover the raw bytes // from the UTF8 string (it's guaranteed to be reversible, as far as I know). .binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, - .newtrue => .{ .bool = true }, - .newfalse => .{ .bool = false }, + .newtrue => .{ .boolval = true }, + .newfalse => .{ .boolval = false }, .none => .{ .none = {} }, inline .int, .float, .long, => |v, tag| { if (tag == .int and std.mem.eql(u8, v, "01")) { - return .{ .bool = true }; + return .{ .boolval = true }; } else if (tag == .int and std.mem.eql(u8, v, "00")) { - return .{ .bool = false }; + return .{ .boolval = false }; } else { return .{ .raw_num = raw_val }; } @@ -389,8 +392,8 @@ pub const Value = union(ValueType) { v.ref = try v.ref.coerceFromRaw(allocator); break :blk self; }, - .seq => |*v| blk: { - for (v[1]) |*val| { + .seq => |v| blk: { + for (v.values) |*val| { val.* = try val.coerceFromRaw(allocator); } break :blk self;