From 0189b7107049c2502538507bfdb6b57f94160deb Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 7 Apr 2023 16:45:58 +0000 Subject: [PATCH] Rename `zml.aio.Value` to `zml.aio.Metadata`, simplify its type variants, and update torch pickle/eval APIs accordingly. --- zml/aio.zig | 116 +++- zml/aio/gguf.zig | 36 +- zml/aio/gguf/core.zig | 70 +-- zml/aio/json.zig | 24 +- zml/aio/nemo.zig | 4 +- zml/aio/tinyllama.zig | 20 +- zml/aio/torch.zig | 34 +- zml/aio/torch/eval.zig | 374 +++++++------ zml/aio/torch/parser.zig | 248 +-------- zml/aio/torch/pickle.zig | 1100 +++++++++++++++++++++++++++++++------- zml/aio/torch/value.zig | 38 +- zml/aio/yaml.zig | 10 +- zml/meta.zig | 2 + 13 files changed, 1307 insertions(+), 769 deletions(-) diff --git a/zml/aio.zig b/zml/aio.zig index 61d928d..50ccc22 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -14,7 +14,6 @@ pub const torch = @import("aio/torch.zig"); pub const yaml = @import("aio/yaml.zig"); pub const log = std.log.scoped(.zml_aio); -pub const Value = @import("aio/value.zig").Value; const HostBuffer = @import("hostbuffer.zig").HostBuffer; test { @@ -56,7 +55,7 @@ pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama")) try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000) else { - zml.log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path}); + log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path}); return error.FormatNotRecognized; }; } @@ -87,8 +86,8 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato try prefix_builder.push(allocator, prefix); defer prefix_builder.deinit(allocator); - var unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count())); - const ok = _populateStruct(allocator, &prefix_builder, &unique_id, store, &model, true) catch |err| { + const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count())); + const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| { std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) }); }; if (!ok) return error.TensorNotFound; @@ -98,12 +97,12 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato /// A struct containing all the buffers and metadata found in a model file. pub const BufferStore = struct { pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer); - pub const Metadata = std.StringArrayHashMapUnmanaged(Value); + pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata); arena: std.heap.ArenaAllocator, files: []MemoryMappedFile = &.{}, buffers: Buffers = .{}, - _metadata: Metadata = .{}, + _metadata: Metadatas = .{}, pub fn deinit(self: BufferStore) void { for (self.files) |*file| file.deinit(); @@ -135,7 +134,7 @@ pub const BufferStore = struct { return if (maybe_max_index) |index| index + 1 else 0; } - pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Value)) ?std.meta.FieldType(Value, tag) { + pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Metadata)) ?std.meta.FieldType(Metadata, tag) { const wrapped_value = self._metadata.get(key) orelse return null; if (wrapped_value != tag) { @@ -145,14 +144,86 @@ pub const BufferStore = struct { return @field(wrapped_value, @tagName(tag)); } - pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Value.Slice.ItemType) ?[]const Value.Slice.toZigType(tag) { + pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Metadata.ItemType) ?[]const tag.toZigType() { const wrapped_value = self._metadata.get(key) orelse return null; - - if (wrapped_value != .array or wrapped_value.array.item_type != tag) { - return null; + const true_tag = std.meta.stringToEnum(std.meta.FieldEnum(Metadata), @tagName(tag)).?; + if (wrapped_value == true_tag) { + return @field(wrapped_value, "array_" ++ @tagName(tag)); + } + + return null; + } +}; + +pub const Metadata = union(enum) { + null: void, + int: i64, + float: f64, + bool: bool, + string: []const u8, + + array_bool: []const bool, + array_int: []const i64, + array_float: []const f64, + array_string: []const []const u8, + + pub const ItemType = enum { + int, + float, + bool, + string, + + pub fn toZigType(comptime kind: ItemType) type { + return switch (kind) { + .int => i64, + .float => f64, + .bool => bool, + .string => []const u8, + }; + } + }; + + pub fn wrap(x: anytype) Metadata { + return switch (@TypeOf(x)) { + inline u8, i8, u16, i16, u32, i32, u64, i64 => .{ .int = @intCast(x) }, + inline f16, f32, f64 => .{ .float = @floatCast(x) }, + bool => .{ .bool = x }, + []const u8 => .{ .string = x }, + else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(x))), + }; + } + + pub fn copySlice(allocator: std.mem.Allocator, any_slice: anytype) !Metadata { + return switch (@TypeOf(any_slice[0])) { + inline u8, i8, u16, i16, u32, i32, u64, i64 => { + const res = try allocator.alloc(i64, any_slice.len); + for (res, any_slice) |*r, val| r.* = @intCast(val); + return .{ .array_int = res }; + }, + inline f16, f32, f64 => { + const res = try allocator.alloc(f64, any_slice.len); + for (res, any_slice) |*r, val| r.* = @floatCast(val); + return .{ .array_float = res }; + }, + bool => .{ .array_bool = try allocator.dupe(bool, any_slice) }, + []const u8 => .{ .array_string = try allocator.dupe([]const u8, @alignCast(any_slice)) }, + else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(any_slice))), + }; + } + + pub fn format( + self: Metadata, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + switch (self) { + .null => _ = try writer.write("null"), + inline .bool, .array_bool => |b| try writer.print("{any}", .{b}), + inline else => |v| try writer.print("{d}", .{v}), } - const T = Value.Slice.toZigType(tag); - return @alignCast(std.mem.bytesAsSlice(T, wrapped_value.array.data)); } }; @@ -244,7 +315,7 @@ const PrefixBuilder = struct { fn _populateStruct( allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, - unique_id: *u64, + unique_id: u64, buffer_store: BufferStore, obj: anytype, required: bool, @@ -260,17 +331,17 @@ fn _populateStruct( const prefix = prefix_builder.data.items; if (T == zml.Tensor) { - return if (buffer_store.get(prefix)) |buffer| { + return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| { + const buffer = buffer_store.get(prefix).?; obj.* = zml.Tensor{ ._shape = buffer.shape(), - ._id = .{ .buffer_id = unique_id.* }, + ._id = .{ .buffer_id = unique_id + entry_idx }, ._donation = .input_buffer, }; - unique_id.* += 1; return true; } else { if (required) { - std.log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() }); + log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() }); } return false; }; @@ -290,7 +361,7 @@ fn _populateStruct( defer prefix_builder.pop(); const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); if (!found) { - std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) }); + log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) }); return false; } } @@ -299,7 +370,7 @@ fn _populateStruct( } return true; } else { - std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) }); + log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) }); return false; } }, @@ -346,7 +417,7 @@ fn _populateStruct( }, .Void => true, else => if (required) { - std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) }); + log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) }); return error.UnsupportedMetadataType; } else return false, }; @@ -431,7 +502,7 @@ pub fn loadBuffers( zml.meta.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model}); } - return loadModelBuffers(Model, model, buffer_store, allocator, platform); + return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); } /// Creates a bufferized version of a Model from the given BufferStore. For details about @@ -449,6 +520,7 @@ pub fn loadModelBuffers( ) !zml.Bufferized(Model) { return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); } + /// Creates a bufferized version of a Model from the given BufferStore and the given prefix. /// For details about bufferization, see the documentation of Bufferized(T). /// diff --git a/zml/aio/gguf.zig b/zml/aio/gguf.zig index 05a6c87..9740871 100644 --- a/zml/aio/gguf.zig +++ b/zml/aio/gguf.zig @@ -36,24 +36,24 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) log.err("GGUF File: Tokens not found", .{}); return error.TokensNotFound; }; - const scores = self.metadataSlice("tokenizer.ggml.scores", .float32) orelse { + const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse { log.err("GGUF File: Scores not found", .{}); return error.ScoresNotFound; }; assert(tokens.len == scores.len); const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama"; const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece; - const bos = self.metadata("tokenizer.ggml.bos_token_id", .uint32); - const eos = self.metadata("tokenizer.ggml.eos_token_id", .uint32); - const unk = self.metadata("tokenizer.ggml.unknown_token_id", .uint32); - const pad = self.metadata("tokenizer.ggml.padding_token_id", .uint32); + const bos = self.metadata("tokenizer.ggml.bos_token_id", .int); + const eos = self.metadata("tokenizer.ggml.eos_token_id", .int); + const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int); + const pad = self.metadata("tokenizer.ggml.padding_token_id", .int); const NOT_FOUND = std.math.maxInt(u32); const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{ - .bos = bos.?, - .eos = eos.?, - .unk = unk orelse NOT_FOUND, - .pad = pad orelse NOT_FOUND, + .bos = @intCast(bos.?), + .eos = @intCast(eos.?), + .unk = @intCast(unk orelse NOT_FOUND), + .pad = @intCast(pad orelse NOT_FOUND), }; const gguf_normalizer = if (tokenizer_impl == .gpt2) @@ -85,10 +85,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) for (tokens, 0..tokens.len) |t, i| { if (tokenizer_impl == .gpt2) { decoded.clearRetainingCapacity(); - try tokenizer.addToken(scores[i], try gpt2_unicode.?.decode(&decoded, t)); + try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t)); // log.debug("token: {s} -> {s}", .{t, decoded.items}); } else { - try tokenizer.addToken(scores[i], t); + try tokenizer.addToken(@floatCast(scores[i]), t); } } @@ -112,7 +112,19 @@ fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.G log.warn("Found duplicated metadata key: {s}", .{entry.name}); continue; } - res.value_ptr.* = entry.val.asLoaderValue(); + res.value_ptr.* = switch (entry.val) { + .array => |arr| switch (arr.child) { + inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: { + const T = std.meta.FieldType(core.GgufValue, tag); + break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data)); + }, + else => blk: { + log.warn("ignoring array metadata", .{}); + break :blk .null; + }, + }, + inline else => |v| zml.aio.Metadata.wrap(v), + }; } else |err| switch (err) { error.EndOfMetadata => {}, else => return err, diff --git a/zml/aio/gguf/core.zig b/zml/aio/gguf/core.zig index a7feadf..c8ed6a2 100644 --- a/zml/aio/gguf/core.zig +++ b/zml/aio/gguf/core.zig @@ -176,20 +176,20 @@ pub const GgufValueType = enum(u32) { } }; -pub const ValueType = enum { - uint8, - int8, - uint16, - int16, - uint32, - int32, - float32, - uint64, - int64, - float64, - boolval, - string, - array, +pub const ValueType = enum(u8) { + uint8 = 0, + int8 = 1, + uint16 = 2, + int16 = 3, + uint32 = 4, + int32 = 5, + float32 = 6, + bool = 7, + string = 8, + array = 9, + uint64 = 10, + int64 = 11, + float64 = 12, }; // Union of possible values. @@ -201,47 +201,20 @@ pub const GgufValue = union(ValueType) { uint32: u32, int32: i32, float32: f32, + bool: bool, + string: []const u8, + array: Array, uint64: u64, int64: i64, float64: f64, - boolval: bool, - string: []const u8, - array: Array, pub const Array = struct { // Any value type is valid, including arrays. - child: GgufValueType, + child: ValueType, // Number of elements, not bytes len: usize, data: []u8, }; - - pub fn asLoaderValue(self: GgufValue) zml.aio.Value { - return switch (self) { - .array => |v| .{ - .array = .{ - .item_type = switch (v.child) { - .bool => .boolval, - .uint8 => .uint8, - .int8 => .int8, - .uint16 => .uint16, - .int16 => .int16, - .uint32 => .uint32, - .int32 => .int32, - .float32 => .float32, - .uint64 => .uint64, - .int64 => .int64, - .float64 => .float64, - .string => .string, - // TODO: .array => .array, - else => unreachable, - }, - .data = v.data, - }, - }, - inline else => |v, tag| @unionInit(zml.aio.Value, @tagName(tag), v), - }; - } }; // Header @@ -403,6 +376,9 @@ pub const GgufFile = struct { fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array { const child = try self.readValueType(); + if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) { + return error.UnsupportedGgufType; + } const len: usize = try self.readInt(u64); const data = switch (child) { // Since strings have variable lenghts, we need to read them one by one @@ -414,7 +390,7 @@ pub const GgufFile = struct { else => try self.readAlloc(allocator, len * child.sizeOf()), }; return .{ - .child = child, + .child = @enumFromInt(@intFromEnum(child)), .len = len, .data = data, }; @@ -429,7 +405,7 @@ pub const GgufFile = struct { .uint32 => .{ .uint32 = try self.readInt(u32) }, .int32 => .{ .int32 = try self.readInt(i32) }, .float32 => .{ .float32 = @bitCast(try self.readInt(u32)) }, - .bool => .{ .boolval = try self.readInt(u8) != 0 }, + .bool => .{ .bool = try self.readInt(u8) != 0 }, .string => .{ .string = try self.readString(allocator) }, .array => .{ .array = try self.readArrayHeader(allocator) }, .uint64 => .{ .uint64 = try self.readInt(u64) }, diff --git a/zml/aio/json.zig b/zml/aio/json.zig index 0c1634a..9c02a5d 100644 --- a/zml/aio/json.zig +++ b/zml/aio/json.zig @@ -30,44 +30,40 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: const metadata = &store._metadata; const key = prefix.items; return switch (val) { - .null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }), - .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }), - .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }), - .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }), + .null => try metadata.put(allocator, try allocator.dupe(u8, key), .null), + .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }), + .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }), + .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }), .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }), .array => |v| { if (v.items.len == 0) return; return if (validSlice(v)) |item_type| { - const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) { + const data: zml.aio.Metadata = switch (item_type) { .bool => blk: { const values = try allocator.alloc(bool, v.items.len); for (v.items, 0..) |item, i| values[i] = item.bool; - break :blk .{ std.mem.sliceAsBytes(values), .boolval }; + break :blk .{ .array_bool = values }; }, .integer => blk: { const values = try allocator.alloc(i64, v.items.len); for (v.items, 0..) |item, i| values[i] = item.integer; - break :blk .{ std.mem.sliceAsBytes(values), .int64 }; + break :blk .{ .array_int = values }; }, .float => blk: { const values = try allocator.alloc(f64, v.items.len); for (v.items, 0..) |item, i| values[i] = item.float; - break :blk .{ std.mem.sliceAsBytes(values), .float64 }; + break :blk .{ .array_float = values }; }, inline .string, .number_string => |tag| blk: { const values = try allocator.alloc([]const u8, v.items.len); for (v.items, 0..) |item, i| { values[i] = @field(item, @tagName(tag)); } - break :blk .{ std.mem.sliceAsBytes(values), .string }; + break :blk .{ .array_string = values }; }, .null, .array, .object => unreachable, }; - try metadata.put( - allocator, - try allocator.dupe(u8, key), - .{ .array = .{ .item_type = dtype, .data = data } }, - ); + try metadata.put(allocator, try allocator.dupe(u8, key), data); } else { for (v.items, 0..) |item, i| { var new_prefix = prefix; diff --git a/zml/aio/nemo.zig b/zml/aio/nemo.zig index feacc63..3e7039d 100644 --- a/zml/aio/nemo.zig +++ b/zml/aio/nemo.zig @@ -39,10 +39,10 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore const start = try mapped_file.file.getPos(); var tmp: zml.aio.torch.PickleData = .{ .data = try parser.Parser.fromTarFile(arena, mapped_file, file), - .memo = undefined, .stack = undefined, }; - tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true); + tmp.stack = try eval.evaluate(arena, tmp.data.ops, true); + try tmp.parseModel(arena, &res); // Since we directly manipulate the file handle pointer, // reset to the end of file so iterator does not error diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index 3f65149..499aff6 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -91,17 +91,17 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe { try res._metadata.ensureUnusedCapacity(arena, 11); - res._metadata.putAssumeCapacityNoClobber("dim", .{ .int64 = c.dim }); - res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int64 = c.hidden_dim }); - res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int64 = c.n_layers }); - res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int64 = c.n_heads }); - res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int64 = c.n_kv_heads }); - res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int64 = c.vocab.size }); - res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .boolval = c.vocab.has_lm_head }); - res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int64 = c.seq_len }); + res._metadata.putAssumeCapacityNoClobber("dim", .{ .int = c.dim }); + res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int = c.hidden_dim }); + res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int = c.n_layers }); + res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int = c.n_heads }); + res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int = c.n_kv_heads }); + res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int = c.vocab.size }); + res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .bool = c.vocab.has_lm_head }); + res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int = c.seq_len }); res._metadata.putAssumeCapacityNoClobber("rope_impl", .{ .string = "interleaved" }); - res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float64 = 10_000 }); - res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float64 = 1e-6 }); + res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float = 10_000 }); + res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float = 1e-6 }); } return res; diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 395ae60..232513f 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -16,6 +16,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8); const log = std.log.scoped(.zml_io); test { + std.testing.refAllDecls(@This()); std.testing.refAllDecls(eval); std.testing.refAllDecls(value); std.testing.refAllDecls(parser); @@ -35,22 +36,21 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore const tmp_alloc = arena.allocator(); const _parser = try parser.Parser.init(tmp_alloc, file); - const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true); + const stack = try eval.evaluate(tmp_alloc, _parser.ops, true); // But we create the HostBuffer objects inside the result BufferStore arena. var res: zml.aio.BufferStore = .{ .arena = std.heap.ArenaAllocator.init(allocator), }; res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file}); - var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack }; + var tmp: PickleData = .{ .data = _parser, .stack = stack }; try tmp.parseModel(res.arena.allocator(), &res); return res; } // TODO: rename me to PytorchFile pub const PickleData = struct { - stack: eval.PickleStack, - memo: eval.PickleMemo, + stack: []const Value, data: parser.Parser, fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool { @@ -63,7 +63,7 @@ pub const PickleData = struct { } pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void { - for (self.stack.stack) |item| { + for (self.stack) |item| { var prefix_buf: [1024]u8 = undefined; try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); } @@ -147,7 +147,7 @@ pub const PickleData = struct { try store._metadata.put( allocator, try allocator.dupe(u8, prefix.items), - .{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } }, + try zml.aio.Metadata.copySlice(allocator, values.items), ); } else { for (values.items, 0..) |val, i| { @@ -156,7 +156,13 @@ pub const PickleData = struct { new_prefix.appendAssumeCapacity('.'); } new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val)); + const new_tag = switch (tag) { + .int64 => "int", + .float64 => "float", + .boolval => "bool", + else => unreachable, // we are already inside a switch + }; + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val)); } } }, @@ -212,15 +218,17 @@ pub const PickleData = struct { if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); - } else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } }; + } else d.value_ptr.* = .{ .string = val }; }, - inline .float64, .int64, .boolval, .bigint, .string => |val, tag| { + inline .float64, .int64, .boolval, .bigint, .string => |val| { const key = try allocator.dupe(u8, prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); - } else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val); + } else { + d.value_ptr.* = zml.aio.Metadata.wrap(val); + } }, else => {}, } @@ -248,7 +256,7 @@ pub const PickleData = struct { } const d = try allocator.alloc(i64, size.len); for (d, 0..) |*di, i| di.* = size[i].int64; - entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; + entry.value_ptr.* = .{ .array_int = d }; return true; } else if (basicTypeCheck(object, "fractions", "Fraction")) { const fraction_str = object.args[0].seq.values[0].string; @@ -256,12 +264,12 @@ pub const PickleData = struct { { var new_prefix = prefix; new_prefix.appendSliceAssumeCapacity(".numerator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); } { var new_prefix = prefix; new_prefix.appendSliceAssumeCapacity(".denominator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); } return true; } diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index d5070c8..175febe 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -159,84 +159,9 @@ pub const PickleMemo = struct { } }; -pub const InternalStack = struct { - allocator: std.mem.Allocator, - values: std.ArrayList(Value), - - pub fn init(allocator: std.mem.Allocator) InternalStack { - return .{ - .allocator = allocator, - .values = std.ArrayList(Value).init(allocator), - }; - } - - pub fn deinit(self: *InternalStack) void { - for (0..self.values.items.len) |i| self.values.items[i].deinit(self.allocator); - self.values.deinit(); - self.* = undefined; - } - - pub fn pop(self: *InternalStack) !Value { - if (self.values.items.len == 0) { - return error.StackUnderrun; - } - return self.values.pop(); - } - - pub fn popMark(self: *InternalStack, allocator: ?std.mem.Allocator) ![]Value { - const markidx = try self.findMark(); - var postmark: []Value = &[_]Value{}; - if (allocator) |a| { - postmark = try a.alloc(Value, self.values.items.len - (markidx + 1)); - @memcpy(postmark, self.values.items[markidx + 1 ..]); - } - self.values.shrinkAndFree(markidx); - return postmark; - } - - pub fn lastMut(self: *InternalStack) !*Value { - if (self.values.items.len == 0) { - return error.UnexpectedEmptyStack; - } - return &self.values.items[self.values.items.len - 1]; - } - - pub fn findMark(self: *InternalStack) !usize { - const len = self.values.items.len; - for (0..len) |i| { - const idx = (len - 1) - i; - const val = self.values.items[idx]; - if (val == .raw and val.raw == .mark) { - return idx; - } - } - zml.log.warn("pytorch loader: missing mark", .{}); - return 0; - } - - pub fn toPickleStack(self: *InternalStack) !PickleStack { - return .{ .stack = try self.values.toOwnedSlice(), .allocator = self.allocator }; - } -}; - -pub const PickleStack = struct { - stack: []Value, - allocator: std.mem.Allocator, - - pub fn init(allocator: std.mem.Allocator, values: []Value) PickleStack { - return .{ .allocator = allocator, .stack = values }; - } - - pub fn deinit(self: *PickleStack) void { - for (self.stack) |*v| v.deinit(self.allocator); - self.allocator.free(self.stack); - } -}; - -pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } { - var stack = InternalStack.init(allocator); - defer stack.deinit(); - var memo = PickleMemo.init(allocator); +pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value { + var stack = std.ArrayList(Value).init(arena); + var memo = PickleMemo.init(arena); errdefer memo.deinit(); const makeKVList = (struct { @@ -258,56 +183,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs } }).call; - outer: for (x) |op| { + for (x) |op| { switch (op) { - .mark => try stack.values.append(.{ .raw = op }), - .stop => break :outer, - .pop => _ = try stack.pop(), - .pop_mark => _ = try stack.popMark(allocator), - .dup => { - if (stack.values.getLastOrNull()) |item| { - try stack.values.append(try item.clone(allocator)); - } else { - return error.CannotDupEmptyStack; - } - }, - .persid => |v| try stack.values.append(.{ .pers_id = try PersId.init(allocator, .{ .string = try allocator.dupe(u8, v) }) }), - .binpersid => try stack.values.append(.{ .pers_id = try PersId.init(allocator, try stack.pop()) }), - .reduce => try stack.values.append(.{ .global = blk: { - const values = try allocator.alloc(Value, 1); - values[0] = try memo.resolve(allocator, try stack.pop(), true); - break :blk try Object.init(allocator, try memo.resolve(allocator, try stack.pop(), true), values); + .mark => try stack.append(.{ .raw = op }), + .frame => {}, + .stop => break, + .pop => _ = try pop(&stack), + .pop_mark => try popMarkDiscard(&stack), + .dup => if (stack.getLastOrNull()) |item| + try stack.append(try item.clone(arena)) + else + return error.CannotDupEmptyStack, + .persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), + .binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }), + .reduce => try stack.append(.{ .global = blk: { + const values = try arena.alloc(Value, 1); + values[0] = try memo.resolve(arena, try pop(&stack), true); + break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values); } }), - .build => try stack.values.append(blk: { - const args = try memo.resolve(allocator, try stack.pop(), true); - const member = try memo.resolve(allocator, try stack.pop(), true); - break :blk .{ .build = try Build.init(allocator, member, args) }; + .build => try stack.append(blk: { + const args = try memo.resolve(arena, try pop(&stack), true); + const member = try memo.resolve(arena, try pop(&stack), true); + break :blk .{ .build = try Build.init(arena, member, args) }; }), - .empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), - .get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }), - inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }), - .empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), - .binput, .long_binput => |v| { - try memo.insert(v, try stack.pop()); - try stack.values.append(.{ .ref = v }); + .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), + .get => |v| try stack.append(.{ .ref = v }), + .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), + .put => |v| { + try memo.insert(v, try pop(&stack)); + try stack.append(.{ .ref = v }); }, - .tuple => try stack.values.append(blk: { - const popped = try stack.popMark(allocator); + .tuple => try stack.append(blk: { + const popped = try popMark(&stack, arena); break :blk .{ .seq = .{ .type = .tuple, .values = popped } }; }), - .empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), + .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), .setitem => { - const v, const k = .{ try stack.pop(), try stack.pop() }; - const top = try stack.lastMut(); + const v, const k = .{ try pop(&stack), try pop(&stack) }; + const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; + obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); + obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } }; }, .seq => |*tup| { - tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); - tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; + tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); + tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } }; }, else => { return error.BadStackTopForSetItem; @@ -315,53 +237,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs } }, .setitems => { - const popped = try stack.popMark(allocator); - defer allocator.free(popped); - const kv_items = try makeKVList(allocator, popped); - const top = try stack.lastMut(); + const popped = try popMark(&stack, arena); + defer arena.free(popped); + const kv_items = try makeKVList(arena, popped); + const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); + obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; }, .seq => |*tup| { - tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); + tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; }, else => { - defer allocator.free(kv_items); + defer arena.free(kv_items); return error.BadStackTopForSetItems; }, } }, .proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), - .tuple1 => try stack.values.append(blk: { - const tup_values = try allocator.alloc(Value, 1); - tup_values[0] = try stack.pop(); + .tuple1 => try stack.append(blk: { + const tup_values = try arena.alloc(Value, 1); + tup_values[0] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), - .tuple2 => try stack.values.append(blk: { - const tup_values = try allocator.alloc(Value, 2); - inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); + .tuple2 => try stack.append(blk: { + const tup_values = try arena.alloc(Value, 2); + inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), - .tuple3 => try stack.values.append(blk: { - const tup_values = try allocator.alloc(Value, 3); - inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); + .tuple3 => try stack.append(blk: { + const tup_values = try arena.alloc(Value, 3); + inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .append => { - const v = try stack.pop(); - const top = try stack.lastMut(); + const v = try pop(&stack); + const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); + obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); obj.args[obj.args.len - 1] = v; }, .seq => |*tup| { - tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); + tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); tup.values[tup.values.len - 1] = v; }, else => { @@ -370,19 +292,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs } }, .appends => { - const postmark = try stack.popMark(allocator); - defer allocator.free(postmark); - const top = try stack.lastMut(); + const postmark = try popMark(&stack, arena); + defer arena.free(postmark); + const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { const obj_len = obj.args.len; - obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len); + obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); @memcpy(obj.args[obj_len..], postmark); }, .seq => |*tup| { const tup_len = tup.values.len; - tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); + tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len); @memcpy(tup.values[tup_len..], postmark); }, else => { @@ -390,49 +312,43 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs }, } }, - .dict => try stack.values.append(blk: { - const popped = try stack.popMark(allocator); - defer allocator.free(popped); - const kv_items = try makeKVList(allocator, popped); + .dict => try stack.append(blk: { + const popped = try popMark(&stack, arena); + defer arena.free(popped); + const kv_items = try makeKVList(arena, popped); break :blk .{ .seq = .{ .type = .dict, .values = kv_items } }; }), - .list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }), - .inst => |v| try stack.values.append(blk: { - const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); - break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) }; + .list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }), + .inst => |v| try stack.append(blk: { + const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); + break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) }; }), - .obj => try stack.values.append(blk: { - const markidx = try stack.findMark(); - const args = try allocator.alloc(Value, stack.values.items.len - (markidx + 2)); - @memcpy(args, stack.values.items[markidx + 2 ..]); - const member = stack.values.items[markidx + 1]; - break :blk .{ .object = try Object.init(allocator, member, args) }; + .obj => try stack.append(blk: { + const mark = try findMark(&stack); + const args = try arena.dupe(Value, stack.items[mark + 2 ..]); + const member = stack.items[mark + 1]; + break :blk .{ .object = try Object.init(arena, member, args) }; }), - .put => |v| { - const mid = try std.fmt.parseInt(u32, v, 10); - try memo.insert(mid, try stack.pop()); - try stack.values.append(.{ .ref = mid }); - }, - .newobj => try stack.values.append(blk: { - const args = try allocator.alloc(Value, 1); - args[0] = try stack.pop(); - break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) }; + .newobj => try stack.append(blk: { + const args = try arena.alloc(Value, 1); + args[0] = try pop(&stack); + break :blk .{ .object = try Object.init(arena, try pop(&stack), args) }; }), - .empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), + .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), .additems => { - const postmark = try stack.popMark(allocator); - defer allocator.free(postmark); - const top = try stack.lastMut(); + const postmark = try popMark(&stack, arena); + defer arena.free(postmark); + const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { const obj_len = obj.args.len; - obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len); + obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); @memcpy(obj.args[obj_len..], postmark); }, .seq => |*tup| { const tup_len = tup.values.len; - tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); + tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len); @memcpy(tup.values[tup_len..], postmark); }, else => { @@ -440,36 +356,33 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs }, } }, - .frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }), - .newobj_ex => try stack.values.append(blk: { - const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() }; - const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) }; - break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) }; + .frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }), + .newobj_ex => try stack.append(blk: { + const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) }; + const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) }; + break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) }; }), - .stack_global => try stack.values.append(blk: { + .stack_global => try stack.append(blk: { const gn, const mn = .{ - try memo.resolve(allocator, try stack.pop(), true), - try memo.resolve(allocator, try stack.pop(), true), + try memo.resolve(arena, try pop(&stack), true), + try memo.resolve(arena, try pop(&stack), true), }; - const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) }; - break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) }; + const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) }; + break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) }; }), .memoize => { - const item = stack.values.getLastOrNull() orelse { + const item = stack.getLastOrNull() orelse { return error.StackUnderrun; }; - try memo.insert(@intCast(memo.map.count()), try item.clone(allocator)); + try memo.insert(@intCast(memo.map.count()), try item.clone(arena)); }, - else => try stack.values.append(.{ .raw = try op.clone(allocator) }), + else => try stack.append(.{ .raw = try op.clone(arena) }), } } - if (!resolve_refs) { - return .{ try stack.toPickleStack(), memo }; + if (resolve_refs) { + return try memo.resolveAllRefsIter(arena, 0, stack.items, true); } - return .{ - PickleStack.init(allocator, try memo.resolveAllRefsIter(allocator, 0, stack.values.items, true)), - memo, - }; + return stack.toOwnedSlice(); } // TODO: this is a unmanaged array list, minus the optimisation. We should use that instead @@ -483,3 +396,86 @@ fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_l return new; } } + +test evaluate { + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const allocator = arena.allocator(); + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); + var buffered_reader = std.io.bufferedReader(file.reader()); + const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096); + + const vals = try evaluate(allocator, ops, true); + defer allocator.free(vals); + + try std.testing.expect(vals.len == 1); + try std.testing.expect(vals[0] == .seq); + try std.testing.expect(vals[0].seq.type == .dict); + const entries = vals[0].seq.values[0].seq.values; + try std.testing.expect(entries.len == 5); + const expected: []const Value = &.{ + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ + .{ .string = "list" }, + .{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{ + .{ .int64 = 0 }, + .{ .int64 = 1 }, + .{ .int64 = 2 }, + .{ .int64 = 3 }, + .{ .int64 = 4 }, + }) } }, + }) } }, + .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ + .{ .string = "tuple" }, + .{ .seq = .{ + .type = .tuple, + .values = @constCast(&[_]Value{ + .{ .string = "a" }, + .{ .int64 = 10 }, + }), + } }, + }) } }, + }; + + try std.testing.expectEqualDeep(expected, entries); +} + +pub fn pop(values: *std.ArrayList(Value)) !Value { + if (values.items.len == 0) { + return error.StackUnderrun; + } + return values.pop(); +} + +fn popMarkDiscard(values: *std.ArrayList(Value)) !void { + const mark = try findMark(values); + values.shrinkRetainingCapacity(mark); +} + +fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value { + const mark = try findMark(values); + const popping = values.items[mark + 1 ..]; + values.shrinkRetainingCapacity(mark); + return try allocator.dupe(Value, popping); +} + +fn lastMut(values: *std.ArrayList(Value)) !*Value { + if (values.items.len == 0) { + return error.UnexpectedEmptyStack; + } + return &values.items[values.items.len - 1]; +} + +fn findMark(values: *std.ArrayList(Value)) !usize { + const len = values.items.len; + for (0..len) |i| { + const idx = (len - 1) - i; + const val = values.items[idx]; + if (val == .raw and val.raw == .mark) { + return idx; + } + } + return error.MarkNotFound; +} diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index 95341bd..6d9485a 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -17,7 +17,7 @@ pub const Parser = struct { buffer_file: zml.aio.MemoryMappedFile, file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, tar_file: ?TarStream = null, - ops: []pickle.Op, + ops: []const pickle.Op, is_zip_file: bool, zip_prefix: []const u8 = &[_]u8{}, @@ -65,7 +65,7 @@ pub const Parser = struct { }; if (!self.is_zip_file) { const reader = tar_stream.reader(); - self.ops = try parse(allocator, reader, try tar_stream.getEndPos()); + self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos()); } else { self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream()); } @@ -82,7 +82,7 @@ pub const Parser = struct { }; if (!self.is_zip_file) { const reader = self.buffer_file.file.reader(); - self.ops = try parse(allocator, reader, try reader.context.getEndPos()); + self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos()); } else { self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream()); } @@ -94,7 +94,7 @@ pub const Parser = struct { self.* = undefined; } - fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op { + fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op { var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); var filename_buf: [std.fs.max_path_bytes]u8 = undefined; while (try iter.next()) |entry| { @@ -152,7 +152,7 @@ pub const Parser = struct { switch (entry.compression_method) { .store => { - return parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size); + return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size); }, .deflate => { // TODO(cryptodeal): handle decompress @@ -166,196 +166,6 @@ pub const Parser = struct { std.log.err("Could not find file ending in `data.pkl` in archive", .{}); return error.PickleNotFound; } - - fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op { - var results = std.ArrayList(pickle.Op).init(allocator); - errdefer results.deinit(); - outer: while (true) { - const b = try reader.readByte(); - switch (@as(pickle.OpCode, @enumFromInt(b))) { - .mark => try results.append(.{ .mark = {} }), - .stop => { - try results.append(.{ .stop = {} }); - break :outer; - }, - .pop => try results.append(.{ .pop = {} }), - .pop_mark => try results.append(.{ .pop_mark = {} }), - .dup => try results.append(.{ .dup = {} }), - .float => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .float = buf }); - }, - .int => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .int = buf }); - }, - .binint => try results.append(.{ .binint = try reader.readInt(i32, .little) }), - .binint1 => try results.append(.{ .binint1 = try reader.readByte() }), - .long => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .long = buf }); - }, - .binint2 => try results.append(.{ .binint2 = try reader.readInt(u16, .little) }), - .none => try results.append(.{ .none = {} }), - .persid => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .persid = buf }); - }, - .binpersid => try results.append(.{ .binpersid = {} }), - .reduce => try results.append(.{ .reduce = {} }), - .string => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .string = buf }); - }, - .binstring => { - const str_len = try reader.readInt(u32, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binstring = buf }); - }, - .short_binstring => { - const str_len = try reader.readByte(); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .short_binstring = buf }); - }, - .unicode => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .unicode = buf }); - }, - .binunicode => { - const str_len = try reader.readInt(u32, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binunicode = buf }); - }, - .append => try results.append(.{ .append = {} }), - .build => try results.append(.{ .build = {} }), - .global, .inst => { - const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(module); - const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(class); - try results.append(.{ .global = .{ .module = module, .class = class } }); - }, - .dict => try results.append(.{ .dict = {} }), - .empty_dict => try results.append(.{ .empty_dict = {} }), - .appends => try results.append(.{ .appends = {} }), - .get => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .get = buf }); - }, - .binget => try results.append(.{ .binget = try reader.readByte() }), - .long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }), - .list => try results.append(.{ .list = {} }), - .empty_list => try results.append(.{ .empty_list = {} }), - .obj => try results.append(.{ .obj = {} }), - .put => { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - errdefer allocator.free(buf); - try results.append(.{ .put = buf }); - }, - .binput => { - try results.append(.{ .binput = try reader.readByte() }); - }, - .long_binput => { - try results.append(.{ .long_binput = try reader.readInt(u32, .little) }); - }, - .setitem => try results.append(.{ .setitem = {} }), - .tuple => try results.append(.{ .tuple = {} }), - .empty_tuple => try results.append(.{ .empty_tuple = {} }), - .setitems => try results.append(.{ .setitems = {} }), - .binfloat => try results.append(.{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }), - .proto => try results.append(.{ .proto = try reader.readByte() }), - .newobj => try results.append(.{ .newobj = {} }), - .ext1 => try results.append(.{ .ext1 = try reader.readByte() }), - .ext2 => try results.append(.{ .ext2 = try reader.readInt(i16, .little) }), - .ext4 => try results.append(.{ .ext4 = try reader.readInt(i32, .little) }), - .tuple1 => try results.append(.{ .tuple1 = {} }), - .tuple2 => try results.append(.{ .tuple2 = {} }), - .tuple3 => try results.append(.{ .tuple3 = {} }), - .newtrue => try results.append(.{ .newtrue = {} }), - .newfalse => try results.append(.{ .newfalse = {} }), - .long1 => { - const str_len = try reader.readByte(); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .long1 = buf }); - }, - .long4 => { - const str_len = try reader.readInt(u32, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .long4 = buf }); - }, - .binbytes => { - const str_len = try reader.readInt(u32, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binbytes = buf }); - }, - .binbytes8 => { - const str_len = try reader.readInt(u64, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binbytes8 = buf }); - }, - .short_binbytes => { - const str_len = try reader.readByte(); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .short_binbytes = buf }); - }, - .binunicode8 => { - const str_len = try reader.readInt(u64, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binunicode8 = buf }); - }, - .short_binunicode => { - const str_len = try reader.readByte(); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .binunicode8 = buf }); - }, - .empty_set => try results.append(.{ .empty_set = {} }), - .additems => try results.append(.{ .additems = {} }), - .frozenset => try results.append(.{ .frozenset = {} }), - .newobj_ex => try results.append(.{ .newobj_ex = {} }), - .stack_global => try results.append(.{ .stack_global = {} }), - .memoize => try results.append(.{ .memoize = {} }), - .frame => try results.append(.{ .frame = try reader.readInt(u64, .little) }), - .bytearray8 => { - const str_len = try reader.readInt(u64, .little); - const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); - try results.append(.{ .bytearray8 = buf }); - }, - .next_buffer => try results.append(.{ .next_buffer = {} }), - .readonly_buffer => try results.append(.{ .readonly_buffer = {} }), - _ => {}, - } - } - return results.toOwnedSlice(); - } }; const TarStream = struct { @@ -404,54 +214,6 @@ const TarStream = struct { } }; -test "Read pickle (simple)" { - const Value = @import("value.zig").Value; - var arena = std.heap.ArenaAllocator.init(testing.allocator); - defer arena.deinit(); - const allocator = arena.allocator(); - const eval = @import("eval.zig"); - const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); - var data = try Parser.init(allocator, file); - defer data.deinit(); - var vals, var memo = try eval.evaluate(allocator, data.ops, true); - defer vals.deinit(); - defer memo.deinit(); - - try testing.expect(vals.stack.len == 2); - // skip first value (frame) - try testing.expect(vals.stack[1] == .seq); - try testing.expect(vals.stack[1].seq.type == .dict); - const entries = vals.stack[1].seq.values[0].seq.values; - try testing.expect(entries.len == 5); - const expected: []const Value = &.{ - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ - .{ .string = "list" }, - .{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{ - .{ .int64 = 0 }, - .{ .int64 = 1 }, - .{ .int64 = 2 }, - .{ .int64 = 3 }, - .{ .int64 = 4 }, - })) } }, - })) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ - .{ .string = "tuple" }, - .{ .seq = .{ - .type = .tuple, - .values = @constCast(@as([]const Value, &.{ - .{ .string = "a" }, - .{ .int64 = 10 }, - })), - } }, - })) } }, - }; - - try std.testing.expectEqualDeep(expected, entries); -} - test "Read pickle (zipped)" { var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 3d77508..22d80a8 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -1,226 +1,950 @@ const std = @import("std"); +const log = std.log.scoped(.zml_aio); + +/// All possible pickle operators. +/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py +pub const OpCode = enum(u8) { + /// Push an integer or bool. + /// + /// The argument is a newline-terminated decimal literal string. + /// + /// The intent may have been that this always fit in a short Python int, + /// but INT can be generated in pickles written on a 64-bit box that + /// require a Python long on a 32-bit box. The difference between this + /// and LONG then is that INT skips a trailing 'L', and produces a short + /// int whenever possible. + /// + /// Another difference is due to that, when bool was introduced as a + /// distinct type in 2.3, builtin names True and False were also added to + /// 2.2.2, mapping to ints 1 and 0. For compatibility in both directions, + /// True gets pickled as INT + "I01\n", and False as INT + "I00\n". + /// Leading zeroes are never produced for a genuine integer. The 2.3 + /// (and later) unpicklers special-case these and return bool instead; + /// earlier unpicklers ignore the leading "0" and return the int. + int = 'I', + + /// Push a four-byte signed integer. + /// Introduced in protocol 1. + /// + /// This handles the full range of Python (short) integers on a 32-bit + /// box, directly as binary bytes (1 for the opcode and 4 for the integer). + /// If the integer is non-negative and fits in 1 or 2 bytes, pickling via + /// BININT1 or BININT2 saves space. + binint = 'J', + + /// Push a one-byte unsigned integer. + /// Introduced in protocol 1. + /// + /// This is a space optimization for pickling very small non-negative ints, + /// in range(256). + binint1 = 'K', + + /// Push a two-byte unsigned integer. + /// Introduced in protocol 1. + /// + /// This is a space optimization for pickling small positive ints, in + /// range(256, 2**16). Integers in range(256) can also be pickled via + /// BININT2, but BININT1 instead saves a byte. + binint2 = 'M', + + /// Push a long integer. + /// + /// The same as INT, except that the literal ends with 'L', and always + /// unpickles to a Python long. There doesn't seem a real purpose to the + /// trailing 'L'. + /// + /// Note that LONG takes time quadratic in the number of digits when + /// unpickling (this is simply due to the nature of decimal->binary + /// conversion). Proto 2 added linear-time (in C; still quadratic-time + /// in Python) LONG1 and LONG4 opcodes. + long = 'L', + + /// Long integer using one-byte length. + /// Introduced in protocol 2. + /// + /// A more efficient encoding of a Python long; the long1 encoding + long1 = 0x8a, + + /// Long integer using four-byte length. + /// Introduced in protocol 2. + /// + /// A more efficient encoding of a Python long; the long4 encoding + long4 = 0x8b, + + /// Push a Python string object. + /// + /// The argument is a repr-style string, with bracketing quote characters, + /// and perhaps embedded escapes. The argument extends until the next + /// newline character. These are usually decoded into a str instance + /// using the encoding given to the Unpickler constructor. or the default, + /// 'ASCII'. If the encoding given was 'bytes' however, they will be + /// decoded as bytes object instead. + string = 'S', + + /// Push a Python string object. + /// Introduced in protocol 1. + /// + /// There are two arguments: the first is a 4-byte little-endian + /// signed int giving the number of bytes in the string, and the + /// second is that many bytes, which are taken literally as the string + /// content. These are usually decoded into a str instance using the + /// encoding given to the Unpickler constructor. or the default, + /// 'ASCII'. If the encoding given was 'bytes' however, they will be + /// decoded as bytes object instead. + binstring = 'T', + + /// Push a Python string object. + /// Introduced in protocol 1. + /// + /// There are two arguments: the first is a 1-byte unsigned int giving + /// the number of bytes in the string, and the second is that many + /// bytes, which are taken literally as the string content. These are + /// usually decoded into a str instance using the encoding given to + /// the Unpickler constructor. or the default, 'ASCII'. If the + /// encoding given was 'bytes' however, they will be decoded as bytes + /// object instead. + short_binstring = 'U', + + /// Push a Python bytes object. + /// Introduced in protocol 3. + /// + /// There are two arguments: the first is a 4-byte little-endian unsigned int + /// giving the number of bytes, and the second is that many bytes, which are + /// taken literally as the bytes content. + binbytes = 'B', + + /// Push a Python bytes object. + /// Introduced in protocol 3. + /// + /// There are two arguments: the first is a 1-byte unsigned int giving + /// the number of bytes, and the second is that many bytes, which are taken + /// literally as the string content. + short_binbytes = 'C', + + /// Push a Python bytes object. + /// Introduced in protocol 4. + /// + /// There are two arguments: the first is an 8-byte unsigned int giving + /// the number of bytes in the string, and the second is that many bytes, + /// which are taken literally as the string content. + binbytes8 = 0x8e, + + /// Push a Python bytearray object. + /// Introduced in protocol 5. + /// + /// There are two arguments: the first is an 8-byte unsigned int giving + /// the number of bytes in the bytearray, and the second is that many bytes, + /// which are taken literally as the bytearray content. + bytearray8 = 0x96, + + /// Introduced in protocol 5. + next_buffer = 0x97, + + /// Introduced in protocol 5. + readonly_buffer = 0x98, + + none = 'N', + + /// Introduced in protocol 2. + newtrue = 0x88, + + /// Introduced in protocol 2. + newfalse = 0x89, + + /// Push a Python Unicode string object. + /// + /// The argument is a raw-unicode-escape encoding of a Unicode string, + /// and so may contain embedded escape sequences. The argument extends + /// until the next newline character. + unicode = 'V', + + /// Push a Python Unicode string object. + /// Introduced in protocol 4. + /// + /// There are two arguments: the first is a 1-byte little-endian signed int + /// giving the number of bytes in the string. The second is that many + /// bytes, and is the UTF-8 encoding of the Unicode string. + short_binunicode = 0x8c, + + /// Push a Python Unicode string object. + /// Introduced in protocol 1. + /// + /// There are two arguments: the first is a 4-byte little-endian unsigned int + /// giving the number of bytes in the string. The second is that many + /// bytes, and is the UTF-8 encoding of the Unicode string. + binunicode = 'X', + + /// Push a Python Unicode string object. + /// Introduced in protocol 4. + /// + /// There are two arguments: the first is an 8-byte little-endian signed int + /// giving the number of bytes in the string. The second is that many + /// bytes, and is the UTF-8 encoding of the Unicode string. + binunicode8 = 0x8d, + + /// Newline-terminated decimal float literal. + /// + /// The argument is repr(a_float), and in general requires 17 significant + /// digits for roundtrip conversion to be an identity (this is so for + /// IEEE-754 double precision values, which is what Python float maps to + /// on most boxes). + /// + /// In general, FLOAT cannot be used to transport infinities, NaNs, or + /// minus zero across boxes (or even on a single box, if the platform C + /// library can't read the strings it produces for such things -- Windows + /// is like that), but may do less damage than BINFLOAT on boxes with + /// greater precision or dynamic range than IEEE-754 double. + float = 'F', + + /// Float stored in binary form, with 8 bytes of data. + /// Introduced in protocol 1. + /// + /// This generally requires less than half the space of FLOAT encoding. + /// In general, BINFLOAT cannot be used to transport infinities, NaNs, or + /// minus zero, raises an exception if the exponent exceeds the range of + /// an IEEE-754 double, and retains no more than 53 bits of precision (if + /// there are more than that, "add a half and chop" rounding is used to + /// cut it back to 53 significant bits). + binfloat = 'G', + + /// Introduced in protocol 1. + empty_list = ']', + + /// Append an object to a list. + /// + /// Stack before: ... pylist anyobject + /// Stack after: ... pylist+[anyobject] + /// + /// although pylist is really extended in-place. + append = 'a', + + /// Extend a list by a slice of stack objects. + /// Introduced in protocol 1. + /// + /// Stack before: ... pylist markobject stackslice + /// Stack after: ... pylist+stackslice + /// + /// although pylist is really extended in-place. + appends = 'e', + + /// Build a list out of the topmost stack slice, after markobject. + /// + /// All the stack entries following the topmost markobject are placed into + /// a single Python list, which single list object replaces all of the + /// stack from the topmost markobject onward. For example, + /// + /// Stack before: ... markobject 1 2 3 'abc' + /// Stack after: ... [1, 2, 3, 'abc'] + list = 'l', + + /// Introduced in protocol 1. + empty_tuple = ')', + + /// Build a tuple out of the topmost stack slice, after markobject. + /// + /// All the stack entries following the topmost markobject are placed into + /// a single Python tuple, which single tuple object replaces all of the + /// stack from the topmost markobject onward. For example, + /// + /// Stack before: ... markobject 1 2 3 'abc' + /// Stack after: ... (1, 2, 3, 'abc') + tuple = 't', + + /// Build a one-tuple out of the topmost item on the stack. + /// Introduced in protocol 2. + /// + /// This code pops one value off the stack and pushes a tuple of + /// length 1 whose one item is that value back onto it. In other + /// words: + /// + /// stack[-1] = tuple(stack[-1:]) + tuple1 = 0x85, + + /// Build a two-tuple out of the top two items on the stack. + /// Introduced in protocol 2. + /// + /// This code pops two values off the stack and pushes a tuple of + /// length 2 whose items are those values back onto it. In other + /// words: + /// + /// stack[-2:] = [tuple(stack[-2:])] + tuple2 = 0x86, + + /// Build a three-tuple out of the top three items on the stack. + /// Introduced in protocol 2. + /// + /// This code pops three values off the stack and pushes a tuple of + /// length 3 whose items are those values back onto it. In other + /// words: + /// + /// stack[-3:] = [tuple(stack[-3:])] + tuple3 = 0x87, + + /// Introduced in protocol 1. + empty_dict = '}', + + /// Build a dict out of the topmost stack slice, after markobject. + /// + /// All the stack entries following the topmost markobject are placed into + /// a single Python dict, which single dict object replaces all of the + /// stack from the topmost markobject onward. The stack slice alternates + /// key, value, key, value, .... For example, + /// + /// Stack before: ... markobject 1 2 3 'abc' + /// Stack after: ... {1: 2, 3: 'abc'} + dict = 'd', + + /// Add a key+value pair to an existing dict. + /// + /// Stack before: ... pydict key value + /// Stack after: ... pydict + /// + /// where pydict has been modified via pydict[key] = value. + setitem = 's', + + /// Add an arbitrary number of key+value pairs to an existing dict. + /// Introduced in protocol 1. + /// + /// The slice of the stack following the topmost markobject is taken as + /// an alternating sequence of keys and values, added to the dict + /// immediately under the topmost markobject. Everything at and after the + /// topmost markobject is popped, leaving the mutated dict at the top + /// of the stack. + /// + /// Stack before: ... pydict markobject key_1 value_1 ... key_n value_n + /// Stack after: ... pydict + /// + /// where pydict has been modified via pydict[key_i] = value_i for i in + /// 1, 2, ..., n, and in that order. + setitems = 'u', + + /// Introduced in protocol 4. + empty_set = 0x8f, + + /// Add an arbitrary number of items to an existing set. + /// Introduced in protocol 4. + /// + /// The slice of the stack following the topmost markobject is taken as + /// a sequence of items, added to the set immediately under the topmost + /// markobject. Everything at and after the topmost markobject is popped, + /// leaving the mutated set at the top of the stack. + /// + /// Stack before: ... pyset markobject item_1 ... item_n + /// Stack after: ... pyset + /// + /// where pyset has been modified via pyset.add(item_i) = item_i for i in + /// 1, 2, ..., n, and in that order. + additems = 0x90, + + /// Build a frozenset out of the topmost slice, after markobject. + /// Introduced in protocol 4. + /// + /// All the stack entries following the topmost markobject are placed into + /// a single Python frozenset, which single frozenset object replaces all + /// of the stack from the topmost markobject onward. For example, + /// + /// Stack before: ... markobject 1 2 3 + /// Stack after: ... frozenset({1, 2, 3}) + frozenset = 0x91, + + pop = '0', + + dup = '2', + + /// Push markobject onto the stack. + /// + /// markobject is a unique object, used by other opcodes to identify a + /// region of the stack containing a variable number of objects for them + /// to work on. See markobject.doc for more detail. + mark = '(', + + /// Pop all the stack objects at and above the topmost markobject. + /// Introduced in protocol 1. + /// + /// When an opcode using a variable number of stack objects is done, + /// POP_MARK is used to remove those objects, and to remove the markobject + /// that delimited their starting position on the stack. + pop_mark = '1', + + /// Read an object from the memo and push it on the stack. + /// + /// The index of the memo object to push is given by the newline-terminated + /// decimal string following. BINGET and LONG_BINGET are space-optimized + /// versions. + get = 'g', + + /// Read an object from the memo and push it on the stack. + /// Introduced in protocol 1. + /// + /// The index of the memo object to push is given by the 1-byte unsigned + /// integer following. + binget = 'h', + + /// Read an object from the memo and push it on the stack. + /// Introduced in protocol 1. + /// + /// The index of the memo object to push is given by the 4-byte unsigned + /// little-endian integer following. + long_binget = 'j', + + /// Store the stack top into the memo. The stack is not popped. + /// + /// The index of the memo location to write into is given by the newline- + /// terminated decimal string following. BINPUT and LONG_BINPUT are + /// space-optimized versions. + put = 'p', + + /// Store the stack top into the memo. The stack is not popped. + /// Introduced in protocol 1. + /// + /// The index of the memo location to write into is given by the 1-byte + /// unsigned integer following. + binput = 'q', + + /// Store the stack top into the memo. The stack is not popped. + /// Introduced in protocol 1. + /// + /// The index of the memo location to write into is given by the 4-byte + /// unsigned little-endian integer following. + long_binput = 'r', + + /// Store the stack top into the memo. The stack is not popped. + /// Introduced in protocol 4. + /// + /// The index of the memo location to write is the number of + /// elements currently present in the memo. + memoize = 0x94, + + /// Extension code. + /// Introduced in protocol 2. + /// + /// This code and the similar EXT2 and EXT4 allow using a registry + /// of popular objects that are pickled by name, typically classes. + /// It is envisioned that through a global negotiation and + /// registration process, third parties can set up a mapping between + /// ints and object names. + /// + /// In order to guarantee pickle interchangeability, the extension + /// code registry ought to be global, although a range of codes may + /// be reserved for private use. + /// + /// EXT1 has a 1-byte integer argument. This is used to index into the + /// extension registry, and the object at that index is pushed on the stack. + ext1 = 0x82, + + /// Extension code. + /// Introduced in protocol 2. + /// + /// See EXT1. EXT2 has a two-byte integer argument. + ext2 = 0x83, + + /// Extension code. + /// Introduced in protocol 2. + /// + /// See EXT1. EXT4 has a four-byte integer argument. + ext4 = 0x84, + + /// Push a global object (module.attr) on the stack. + /// + /// Two newline-terminated strings follow the GLOBAL opcode. The first is + /// taken as a module name, and the second as a class name. The class + /// object module.class is pushed on the stack. More accurately, the + /// object returned by self.find_class(module, class) is pushed on the + /// stack, so unpickling subclasses can override this form of lookup. + global = 'c', + + /// Push a global object (module.attr) on the stack. + /// Introduced in protocol 4. + stack_global = 0x93, + + /// Push an object built from a callable and an argument tuple. + /// + /// The opcode is named to remind of the __reduce__() method. + /// + /// Stack before: ... callable pytuple + /// Stack after: ... callable(*pytuple) + /// + /// The callable and the argument tuple are the first two items returned + /// by a __reduce__ method. Applying the callable to the argtuple is + /// supposed to reproduce the original object, or at least get it started. + /// If the __reduce__ method returns a 3-tuple, the last component is an + /// argument to be passed to the object's __setstate__, and then the REDUCE + /// opcode is followed by code to create setstate's argument, and then a + /// BUILD opcode to apply __setstate__ to that argument. + /// + /// If not isinstance(callable, type), REDUCE complains unless the + /// callable has been registered with the copyreg module's + /// safe_constructors dict, or the callable has a magic + /// '__safe_for_unpickling__' attribute with a true value. I'm not sure + /// why it does this, but I've sure seen this complaint often enough when + /// I didn't want to . + reduce = 'R', + + /// Finish building an object, via __setstate__ or dict update. + /// + /// Stack before: ... anyobject argument + /// Stack after: ... anyobject + /// + /// where anyobject may have been mutated, as follows: + /// + /// If the object has a __setstate__ method, + /// + /// anyobject.__setstate__(argument) + /// + /// is called. + /// + /// Else the argument must be a dict, the object must have a __dict__, and + /// the object is updated via + /// + /// anyobject.__dict__.update(argument) + build = 'b', + + /// Build a class instance. + /// + /// This is the protocol 0 version of protocol 1's OBJ opcode. + /// INST is followed by two newline-terminated strings, giving a + /// module and class name, just as for the GLOBAL opcode (and see + /// GLOBAL for more details about that). self.find_class(module, name) + /// is used to get a class object. + /// + /// In addition, all the objects on the stack following the topmost + /// markobject are gathered into a tuple and popped (along with the + /// topmost markobject), just as for the TUPLE opcode. + /// + /// Now it gets complicated. If all of these are true: + /// + /// + The argtuple is empty (markobject was at the top of the stack + /// at the start). + /// + /// + The class object does not have a __getinitargs__ attribute. + /// + /// then we want to create an old-style class instance without invoking + /// its __init__() method (pickle has waffled on this over the years; not + /// calling __init__() is current wisdom). In this case, an instance of + /// an old-style dummy class is created, and then we try to rebind its + /// __class__ attribute to the desired class object. If this succeeds, + /// the new instance object is pushed on the stack, and we're done. + /// + /// Else (the argtuple is not empty, it's not an old-style class object, + /// or the class object does have a __getinitargs__ attribute), the code + /// first insists that the class object have a __safe_for_unpickling__ + /// attribute. Unlike as for the __safe_for_unpickling__ check in REDUCE, + /// it doesn't matter whether this attribute has a true or false value, it + /// only matters whether it exists (XXX this is a bug). If + /// __safe_for_unpickling__ doesn't exist, UnpicklingError is raised. + /// + /// Else (the class object does have a __safe_for_unpickling__ attr), + /// the class object obtained from INST's arguments is applied to the + /// argtuple obtained from the stack, and the resulting instance object + /// is pushed on the stack. + /// + /// NOTE: checks for __safe_for_unpickling__ went away in Python 2.3. + /// NOTE: the distinction between old-style and new-style classes does + /// not make sense in Python 3. + inst = 'i', + + /// Build a class instance. + /// Introduced in protocol 1. + /// + /// This is the protocol 1 version of protocol 0's INST opcode, and is + /// very much like it. The major difference is that the class object + /// is taken off the stack, allowing it to be retrieved from the memo + /// repeatedly if several instances of the same class are created. This + /// can be much more efficient (in both time and space) than repeatedly + /// embedding the module and class names in INST opcodes. + /// + /// Unlike INST, OBJ takes no arguments from the opcode stream. Instead + /// the class object is taken off the stack, immediately above the + /// topmost markobject: + /// + /// Stack before: ... markobject classobject stackslice + /// Stack after: ... new_instance_object + /// + /// As for INST, the remainder of the stack above the markobject is + /// gathered into an argument tuple, and then the logic seems identical, + /// except that no __safe_for_unpickling__ check is done (XXX this is + /// a bug). See INST for the gory details. + /// + /// NOTE: In Python 2.3, INST and OBJ are identical except for how they + /// get the class object. That was always the intent; the implementations + /// had diverged for accidental reasons. + obj = 'o', + + /// Build an object instance. + /// Introduced in protocol 2. + /// + /// The stack before should be thought of as containing a class + /// object followed by an argument tuple (the tuple being the stack + /// top). Call these cls and args. They are popped off the stack, + /// and the value returned by cls.__new__(cls, *args) is pushed back + /// onto the stack. + newobj = 0x81, + + /// Build an object instance. + /// Introduced in protocol 4. + /// + /// The stack before should be thought of as containing a class + /// object followed by an argument tuple and by a keyword argument dict + /// (the dict being the stack top). Call these cls and args. They are + /// popped off the stack, and the value returned by + /// cls.__new__(cls, *args, *kwargs) is pushed back onto the stack. + newobj_ex = 0x92, + + /// Protocol version indicator. + /// Introduced in protocol 2. + /// + /// For protocol 2 and above, a pickle must start with this opcode. + /// The argument is the protocol version, an int in range(2, 256). + proto = 0x80, + + /// Stop the unpickling machine. + /// + /// Every pickle ends with this opcode. The object at the top of the stack + /// is popped, and that's the result of unpickling. The stack should be + /// empty then. + stop = '.', + + /// Indicate the beginning of a new frame. + /// Introduced in protocol 4. + /// + /// The unpickler may use this opcode to safely prefetch data from its + /// underlying stream. + frame = 0x95, + + /// Push an object identified by a persistent ID. + /// + /// The pickle module doesn't define what a persistent ID means. PERSID's + /// argument is a newline-terminated str-style (no embedded escapes, no + /// bracketing quote characters) string, which *is* "the persistent ID". + /// The unpickler passes this string to self.persistent_load(). Whatever + /// object that returns is pushed on the stack. There is no implementation + /// of persistent_load() in Python's unpickler: it must be supplied by an + /// unpickler subclass. + persid = 'P', + + /// Push an object identified by a persistent ID. + /// Introduced in protocol 1. + /// + /// Like PERSID, except the persistent ID is popped off the stack (instead + /// of being a string embedded in the opcode bytestream). The persistent + /// ID is passed to self.persistent_load(), and whatever object that + /// returns is pushed on the stack. See PERSID for more detail. + binpersid = 'Q', + + _, +}; + +// The above enum was generated with the following Python code, +// run inside pickletools.py +// +// def generate_zig(): +// print("""/// All possible pickle operators. +// /// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py +// pub const OpCode = enum(u8) { +// """) +// for op in opcodes: +// lines = [_cleanup(l) for l in op.doc.split("\n")[:-1]] +// if op.proto > 0: +// lines.insert(1, _cleanup(f"Introduced in protocol {op.proto}.")) +// doc = "\n".join(lines) +// op_code = op.code.__repr__() +// if op_code.startswith("'\\x"): +// op_code = "0x" + op_code[3:-1] +// print(f"""{doc} +// {op.name.lower()} = {op_code}, +// """) +// +// print(" _,") +// print("};") +// +// def _cleanup(line: str) -> str: +// indent = " " +// if (line.startswith(indent)): +// line = line[len(indent):] +// line = line.replace(". ", ". ") +// if line: +// line = " " + line +// line = " ///" + line +// return line + /// A decoded Pickle operation in its natural state. -pub const Op = union(OpCode) { - mark, - stop, - pop, - pop_mark, - dup, - float: []const u8, +/// This is a bit different from Op enum, +/// because operators having same semantics, but different encoding have been merged. +/// ex: string, binstring, short_binstring -> string. +pub const Op = union(enum) { + // Initially numbers were represented by strings... int: []const u8, binint: i32, - binint1: u8, long: []const u8, - binint2: u16, - none, - persid: []const u8, - binpersid, - reduce, string: []const u8, - binstring: []const u8, - short_binstring: []const u8, + bytes: []const u8, + bytearray: []u8, + next_buffer, + readonly_buffer, + none, + bool: bool, unicode: []const u8, - binunicode: []const u8, - append, - build, - global: PyType, - dict, - empty_dict, - appends, - get: []const u8, - binget: u8, - inst: PyType, - long_binget: u32, - list, - empty_list, - obj, - put: []const u8, - binput: u8, - long_binput: u32, - setitem, - tuple, - empty_tuple, - setitems, + float: []const u8, binfloat: f64, - proto: u8, - newobj, - ext1: u8, - ext2: i16, - ext4: i32, + empty_list, + append, + appends, + list, + empty_tuple, + tuple, tuple1, tuple2, tuple3, - newtrue, - newfalse, - long1: []const u8, - long4: []const u8, - binbytes: []const u8, - short_binbytes: []const u8, - short_binunicode: []const u8, - binunicode8: []const u8, - binbytes8: []const u8, + empty_dict, + dict, + setitem, + setitems, empty_set, additems, frozenset, - newobj_ex, - stack_global, + pop, + dup, + mark, + pop_mark, + get: u32, + put: u32, memoize, - frame: u64, - bytearray8: []const u8, - next_buffer, - readonly_buffer, + ext1: u8, + ext2: i16, + ext4: i32, + global: PyType, + stack_global, + reduce, + build, + inst: PyType, + obj, + newobj, + newobj_ex, + proto: u8, + stop, + frame: u64, // new frame and its size + persid: []const u8, + binpersid, pub const PyType = struct { module: []const u8, class: []const u8 }; pub fn deinit(self: Op, allocator: std.mem.Allocator) void { switch (self) { - .float, - .int, - .long, - .persid, - .string, - .binstring, - .short_binstring, - .unicode, - .binunicode, - .get, - .put, - .long1, - .long4, - .binbytes, - .short_binbytes, - .short_binunicode, - .binunicode8, - .binbytes8, - .bytearray8, - => |v| allocator.free(v), - .global, .inst => |py_type| { - allocator.free(py_type.module); - allocator.free(py_type.class); + // Use a switch on the type of the stored data, + // this is easier than listing every opcode. + inline else => |v| switch (@TypeOf(v)) { + void, bool, u8, u16, u32, u64, i16, i32, f64 => {}, + []const u8, []u8 => allocator.free(v), + PyType => { + allocator.free(v.module); + allocator.free(v.class); + }, + else => @compileError("please explicit how to free this new opcode: " ++ @typeName(@TypeOf(v))), }, - else => {}, } } pub fn clone(self: Op, allocator: std.mem.Allocator) !Op { - var res = self; return switch (self) { - inline .float, - .int, - .long, - .persid, - .string, - .binstring, - .short_binstring, - .unicode, - .binunicode, - .get, - .put, - .long1, - .long4, - .binbytes, - .short_binbytes, - .short_binunicode, - .binunicode8, - .binbytes8, - .bytearray8, - => |v, tag| { - const cloned = try allocator.alloc(u8, v.len); - @memcpy(cloned, v); - @field(res, @tagName(tag)) = cloned; - return res; - }, - inline .global, .inst => |v, tag| { - @field(res, @tagName(tag)) = PyType{ + // Use a switch on the type of the stored data, + // this is easier than listing every opcode. + inline else => |v, tag| switch (@TypeOf(v)) { + void, bool, u8, u16, u32, u64, i16, i32, f64 => self, + []const u8, []u8 => @unionInit(Op, @tagName(tag), try allocator.dupe(u8, v)), + PyType => @unionInit(Op, @tagName(tag), .{ .module = try allocator.dupe(u8, v.module), .class = try allocator.dupe(u8, v.class), - }; - return res; + }), + else => @compileError("please explicit how to close this new opcode: " ++ @typeName(@TypeOf(v))), }, - else => self, }; } }; -/// The values for the possible opcodes are in this enum. -/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py -pub const OpCode = enum(u8) { - mark = '(', // push special markobject on stack - stop = '.', // every pickle ends with stop - pop = '0', // discard topmost stack item - pop_mark = '1', // discard stack top through topmost markobject - dup = '2', // duplicate top stack item - float = 'F', // push float object; decimal string argument - int = 'I', // push integer or bool; decimal string argument - binint = 'J', // push four-byte signed int - binint1 = 'K', // push 1-byte unsigned int - long = 'L', // push long; decimal string argument - binint2 = 'M', // push 2-byte unsigned int - none = 'N', // push None - persid = 'P', // push persistent object; id is taken from string arg - binpersid = 'Q', // " " " ; " " " " stack - reduce = 'R', // apply callable to argtuple, both on stack - string = 'S', // push string; NL-terminated string argument - binstring = 'T', // push string; counted binary string argument - short_binstring = 'U', // " " ; " " " " < 256 bytes - unicode = 'V', // push Unicode string; raw-unicode-escaped'd argument - binunicode = 'X', // " " " ; counted UTF-8 string argument - append = 'a', // append stack top to list below it - build = 'b', // call __setstate__ or __dict__.update() - global = 'c', // push self.find_class(modname, name); 2 string args - dict = 'd', // build a dict from stack items - empty_dict = '}', // push empty dict - appends = 'e', // extend list on stack by topmost stack slice - get = 'g', // push item from memo on stack; index is string arg - binget = 'h', // " " " " " " ; " " 1-byte arg - inst = 'i', // build & push class instance - long_binget = 'j', // push item from memo on stack; index is 4-byte arg - list = 'l', // build list from topmost stack items - empty_list = ']', // push empty list - obj = 'o', // build & push class instance - put = 'p', // store stack top in memo; index is string arg - binput = 'q', // " " " " " ; " " 1-byte arg - long_binput = 'r', // " " " " " ; " " 4-byte arg - setitem = 's', // add key+value pair to dict - tuple = 't', // build tuple from topmost stack items - empty_tuple = ')', // push empty tuple - setitems = 'u', // modify dict by adding topmost key+value pairs - binfloat = 'G', // push float; arg is 8-byte float encoding +/// Read a stream of bytes, and interpret it as a stream of Pickle operators. +pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) ![]const Op { + var results = std.ArrayList(Op).init(allocator); + errdefer results.deinit(); + const len = max_line_len; - // Protocol 2 - proto = '\x80', // identify pickle protocol - newobj = '\x81', // build object by applying cls.__new__ to argtuple - ext1 = '\x82', // push object from extension registry; 1-byte index - ext2 = '\x83', // ditto, but 2-byte index - ext4 = '\x84', // ditto, but 4-byte index - tuple1 = '\x85', // build 1-tuple from stack top - tuple2 = '\x86', // build 2-tuple from two topmost stack items - tuple3 = '\x87', // build 3-tuple from three topmost stack items - newtrue = '\x88', // push True - newfalse = '\x89', // push False - long1 = '\x8a', // push long from < 256 bytes - long4 = '\x8b', // push really big long + while (true) { + const b = try reader.readByte(); + const code: OpCode = @enumFromInt(b); + const op: Op = switch (code) { + .int => blk: { + const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + // Legacy hack, see OpCode.int documentation + // We do this parsing right away to simplify downstream code. + if (std.mem.eql(u8, "00", buf)) break :blk .{ .bool = false }; + if (std.mem.eql(u8, "01", buf)) break :blk .{ .bool = true }; + break :blk .{ .int = buf }; + }, + .binint => .{ .binint = try reader.readInt(i32, .little) }, + .binint1 => .{ .binint = try reader.readByte() }, + .binint2 => .{ .binint = try reader.readInt(u16, .little) }, + // TODO: long should handle the trailing 'L' -> add a test. + .long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .long1 => .{ .long = try _readSlice(reader, allocator, 1) }, + .long4 => .{ .long = try _readSlice(reader, allocator, 4) }, + .string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .binstring => .{ .string = try _readSlice(reader, allocator, 4) }, + .short_binstring => .{ .string = try _readSlice(reader, allocator, 1) }, + .binbytes => .{ .bytes = try _readSlice(reader, allocator, 4) }, + .binbytes8 => .{ .bytes = try _readSlice(reader, allocator, 8) }, + .short_binbytes => .{ .bytes = try _readSlice(reader, allocator, 1) }, + .bytearray8 => .{ .bytearray = try _readSlice(reader, allocator, 8) }, + .next_buffer => .next_buffer, + .readonly_buffer => .readonly_buffer, + .none => .none, + .newtrue => .{ .bool = true }, + .newfalse => .{ .bool = false }, + .unicode => .{ .unicode = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .short_binunicode => .{ .unicode = try _readSlice(reader, allocator, 1) }, + .binunicode => .{ .unicode = try _readSlice(reader, allocator, 4) }, + .binunicode8 => .{ .unicode = try _readSlice(reader, allocator, 8) }, + .float => .{ .float = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .binfloat => .{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }, + .empty_list => .empty_list, + .append => .append, + .appends => .appends, + .list => .list, + .empty_tuple => .empty_tuple, + .tuple => .tuple, + .tuple1 => .tuple1, + .tuple2 => .tuple2, + .tuple3 => .tuple3, + .empty_dict => .empty_dict, + .dict => .dict, + .setitem => .setitem, + .setitems => .setitems, + .empty_set => .empty_set, + .additems => .additems, + .frozenset => .frozenset, + .pop => .pop, + .dup => .dup, + .mark => .mark, + .pop_mark => .pop_mark, + .get => blk: { + const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + defer allocator.free(buf); + // If we fail to parse delay the error to the evaluation. + const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32); + break :blk .{ .get = n }; + }, + .binget => .{ .get = try reader.readByte() }, + .long_binget => .{ .get = try reader.readInt(u32, .little) }, + .put => blk: { + const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + defer allocator.free(buf); + const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32); + break :blk .{ .put = n }; + }, + .binput => .{ .put = try reader.readByte() }, + .long_binput => .{ .put = try reader.readInt(u32, .little) }, + .memoize => .memoize, + .ext1 => .{ .ext1 = try reader.readByte() }, + .ext2 => .{ .ext2 = try reader.readInt(i16, .little) }, + .ext4 => .{ .ext4 = try reader.readInt(i32, .little) }, + .global => .{ .global = .{ + .module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + .class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + } }, + .stack_global => .stack_global, + .reduce => .reduce, + .build => .build, + .inst => .{ .inst = .{ + .module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + .class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + } }, + .obj => .obj, + .newobj => .newobj, + .newobj_ex => .newobj_ex, + .proto => blk: { + const version = try reader.readByte(); + if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version}); + break :blk .{ .proto = version }; + }, + .stop => .stop, + // This is not documented in pickletools but in https://peps.python.org/pep-3154/ + // The frame size is stored right after the frame header. + // The loader is allowed to prefetch framesize from the underlying reader, + // and ops are not allowed to cross a frame boundary. + // We don't prefetch because we assume the reader is going to use some kind of buffered reader. + // We could try to enforce frame boundaries, but we would need to track + // how many bytes we are reading from the stream. + .frame => .{ .frame = try reader.readInt(u64, .little) }, + .persid => .{ .persid = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .binpersid => .binpersid, + _ => |unk_tag| { + log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag}); + return error.NotSupported; + }, + }; + try results.append(op); + if (op == .stop) break; + } + return results.toOwnedSlice(); +} - // Protocol 3 - binbytes = 'B', // push bytes; counted binary string argument - short_binbytes = 'C', // " " ; " " " " < 256 bytes +test parse { + const allocator = std.testing.allocator; + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); + var buffered_reader = std.io.bufferedReader(file.reader()); + const ops = try parse(allocator, buffered_reader.reader(), 4096); + defer { + // Test we are correctly freeing every allocation. + for (ops) |op| op.deinit(allocator); + allocator.free(ops); + } - // Protocol 4 - short_binunicode = '\x8c', // push short string; UTF-8 length < 256 bytes - binunicode8 = '\x8d', // push very long string - binbytes8 = '\x8e', // push very long bytes string - empty_set = '\x8f', // push empty set on the stack - additems = '\x90', // modify set by adding topmost stack items - frozenset = '\x91', // build frozenset from topmost stack items - newobj_ex = '\x92', // like newobj but work with keyword only arguments - stack_global = '\x93', // same as GLOBAL but using names on the stacks - memoize = '\x94', // store top of the stack in memo - frame = '\x95', // indicate the beginning of a new frame + try std.testing.expect(ops.len == 35); + // this can be obtained by running: `python -m pickletools simple_test.pickle` + const expected = [_]Op{ + .{ .proto = 4 }, + .{ .frame = 83 }, + .empty_dict, + .memoize, + .mark, + .{ .unicode = "hello" }, + .memoize, + .{ .unicode = "world" }, + .memoize, + .{ .unicode = "int" }, + .memoize, + .{ .binint = 1 }, + .{ .unicode = "float" }, + .memoize, + .{ .binfloat = 3.141592 }, + .{ .unicode = "list" }, + .memoize, + .empty_list, + .memoize, + .mark, + .{ .binint = 0 }, + .{ .binint = 1 }, + .{ .binint = 2 }, + .{ .binint = 3 }, + .{ .binint = 4 }, + .appends, + .{ .unicode = "tuple" }, + .memoize, + .{ .unicode = "a" }, + .memoize, + .{ .binint = 10 }, + .tuple2, + .memoize, + .setitems, + .stop, + }; + try std.testing.expectEqualDeep(&expected, ops); +} - // Protocol 5 - bytearray8 = '\x96', // push bytearray - next_buffer = '\x97', // push next out-of-band buffer - readonly_buffer = '\x98', // make top of stack readonly - _, -}; +fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { + const T = std.meta.Int(.unsigned, 8 * len_bytes); + const str_len: u64 = try reader.readInt(T, .little); + const buf = try allocator.alloc(u8, str_len); + errdefer allocator.free(buf); + _ = try reader.read(buf); + return buf; +} diff --git a/zml/aio/torch/value.zig b/zml/aio/torch/value.zig index b9cb4bb..17cd79f 100644 --- a/zml/aio/torch/value.zig +++ b/zml/aio/torch/value.zig @@ -109,7 +109,7 @@ pub const ValueType = enum { none, }; -/// A processed value. +/// A pickle operator that has been interpreted. pub const Value = union(ValueType) { /// Types that we can't handle or just had to give up on processing. raw: pickle.Op, @@ -283,10 +283,10 @@ pub const Value = union(ValueType) { return switch (self) { inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), - .seq => |seq| blk: { - const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) }; - for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator); - break :blk .{ .seq = new_val }; + .seq => |seq| { + const values = try allocator.alloc(Value, seq.values.len); + for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator); + return .{ .seq = .{ .type = seq.type, .values = values } }; }, inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), .bigint => |v| .{ .bigint = try v.clone() }, @@ -342,8 +342,9 @@ pub const Value = union(ValueType) { pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value { return switch (self) { .raw => |raw_val| switch (raw_val) { - .binint, .binint1, .binint2 => |val| .{ .int64 = val }, - .long1, .long4 => |b| if (b.len != 0) { + .binint => |val| .{ .int64 = val }, + .long => |b| if (b.len != 0) { + // TODO: handle trailing 'L' var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); var mutable = bint.toMutable(); mutable.readTwosComplement(b, b.len, .little, .signed); @@ -355,28 +356,17 @@ pub const Value = union(ValueType) { } else return .{ .bigint = bint }; } else .{ .raw_num = raw_val }, .binfloat => |val| .{ .float64 = val }, - .binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s }, - .binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b }, + .unicode => |s| .{ .string = s }, + .bytes => |b| .{ .bytes = b }, // This isn't how Pickle actually works but we just try to UTF8 decode the // string and if it fails, we make it a bytes value instead. If anyone // actually cares they can just fix values themselves or recover the raw bytes // from the UTF8 string (it's guaranteed to be reversible, as far as I know). - .binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, - .newtrue => .{ .boolval = true }, - .newfalse => .{ .boolval = false }, + .string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, + .bool => |b| .{ .boolval = b }, .none => .{ .none = {} }, - inline .int, - .float, - .long, - => |v, tag| { - if (tag == .int and std.mem.eql(u8, v, "01")) { - return .{ .boolval = true }; - } else if (tag == .int and std.mem.eql(u8, v, "00")) { - return .{ .boolval = false }; - } else { - return .{ .raw_num = raw_val }; - } - }, + // TODO .int should be handled like .long + .int, .float => .{ .raw_num = raw_val }, else => self, }, .app, .object, .global => |v| blk: { diff --git a/zml/aio/yaml.zig b/zml/aio/yaml.zig index dc1e553..2fc4690 100644 --- a/zml/aio/yaml.zig +++ b/zml/aio/yaml.zig @@ -25,8 +25,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void { switch (val) { - .int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), - .float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), + .int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int = v }), + .float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float = v }), .string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }), .list => |v| switch (validSlice(v)) { true => { @@ -36,13 +36,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str const values = try allocator.alloc(i64, v.len); errdefer allocator.free(values); for (v, 0..) |item, i| values[i] = item.int; - try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_int = values }); }, .float => { const values = try allocator.alloc(f64, v.len); errdefer allocator.free(values); for (v, 0..) |item, i| values[i] = item.float; - try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_float = values }); }, .string => { const values = try allocator.alloc([]const u8, v.len); @@ -50,7 +50,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str for (v, 0..) |item, i| { values[i] = try allocator.dupe(u8, item.string); } - try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_string = values }); }, .list => unreachable, else => {}, diff --git a/zml/meta.zig b/zml/meta.zig index 532afc7..826e159 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -145,6 +145,8 @@ pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool } pub fn DeclEnum(comptime T: type) type { + const field_infos = std.meta.declarations(T); + if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T}); return std.meta.DeclEnum(UnwrapPtr(T)); }