diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index 8d5bdde..1997182 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -62,7 +62,7 @@ zig_cc_test( name = "test", data = [ "aio/torch/simple.pt", - "aio/torch/simple_test.pickle", + "aio/torch/simple_test_4.pickle", ], test_runner = ":test_runner", deps = [":zml"], diff --git a/zml/aio.zig b/zml/aio.zig index 50ccc22..81631a0 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -104,6 +104,15 @@ pub const BufferStore = struct { buffers: Buffers = .{}, _metadata: Metadatas = .{}, + /// Create an empty BufferStore. Takes owneship of the given files. + pub fn init(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore { + var self: zml.aio.BufferStore = .{ + .arena = std.heap.ArenaAllocator.init(allocator), + }; + self.files = try self.arena.allocator().dupe(MemoryMappedFile, files); + return self; + } + pub fn deinit(self: BufferStore) void { for (self.files) |*file| file.deinit(); self.arena.deinit(); @@ -255,7 +264,7 @@ pub const MemoryMappedFile = struct { }; } - pub fn mappedSlice(self: *MemoryMappedFile, start: usize, len: usize) []const u8 { + pub fn mappedSlice(self: MemoryMappedFile, start: usize, len: usize) []const u8 { return self.data[self.data_offset + start ..][0..len]; } @@ -578,7 +587,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi return if (buffer_store.get(prefix)) |host_buffer| { // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. var buf_with_metadata = host_buffer; - log.warn("loading {s} ({})", .{ prefix, obj._shape }); + log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); buf_with_metadata._shape = obj._shape; obj.* = try zml.Buffer.from(platform, buf_with_metadata); diff --git a/zml/aio/nemo.zig b/zml/aio/nemo.zig index 3e7039d..d90de12 100644 --- a/zml/aio/nemo.zig +++ b/zml/aio/nemo.zig @@ -1,11 +1,12 @@ -const asynk = @import("async"); -const eval = @import("torch/eval.zig"); const std = @import("std"); +const log = std.log.scoped(.zml_aio); + +const asynk = @import("async"); const yaml = @import("zig-yaml"); + +const eval = @import("torch/eval.zig"); const zml = @import("../zml.zig"); - -const parser = @import("torch/parser.zig"); - +const File = @import("torch/file.zig").File; const StringBuilder = std.ArrayListUnmanaged(u8); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { @@ -14,8 +15,11 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore }; errdefer res.arena.deinit(); + // TODO(cryptodeal): this is incorrect, you should use a temporary arena for all intermediary allocations. const arena = res.arena.allocator(); + // TODO(cryptodeal): mapped_file will never be close in case of success. + // You need to store it inside the result. var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{})); errdefer mapped_file.deinit(); @@ -37,13 +41,11 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]); } else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) { const start = try mapped_file.file.getPos(); - var tmp: zml.aio.torch.PickleData = .{ - .data = try parser.Parser.fromTarFile(arena, mapped_file, file), - .stack = undefined, - }; - tmp.stack = try eval.evaluate(arena, tmp.data.ops, true); + var torch_file = try File.fromTarFile(arena, mapped_file, file); + const ops = try torch_file.parsePickle(arena); + const values = try eval.evaluate(arena, ops, true); - try tmp.parseModel(arena, &res); + try torch_file.parseModel(values, &res); // Since we directly manipulate the file handle pointer, // reset to the end of file so iterator does not error // and avoid `skipBytes`. diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 232513f..cde998b 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -2,24 +2,18 @@ const asynk = @import("async"); const std = @import("std"); const zml = @import("../zml.zig"); -const HostBuffer = @import("../hostbuffer.zig").HostBuffer; - const eval = @import("torch/eval.zig"); -const value = @import("torch/value.zig"); -const parser = @import("torch/parser.zig"); -const PersId = value.PersId; -const Sequence = value.Sequence; -const Value = value.Value; -const ValueType = value.ValueType; +const py = @import("torch/py.zig"); +const File = @import("torch/file.zig").File; const StringBuilder = std.ArrayListUnmanaged(u8); -const log = std.log.scoped(.zml_io); +const log = std.log.scoped(.zml_aio); test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(eval); - std.testing.refAllDecls(value); - std.testing.refAllDecls(parser); + std.testing.refAllDecls(py); + std.testing.refAllDecls(File); } /// Opens and loads a BufferStore from the torch file at the given path. @@ -35,392 +29,14 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore defer arena.deinit(); const tmp_alloc = arena.allocator(); - const _parser = try parser.Parser.init(tmp_alloc, file); - const stack = try eval.evaluate(tmp_alloc, _parser.ops, true); + const mmap_file = try zml.aio.MemoryMappedFile.init(file); + var torch_file = try File.init(tmp_alloc, mmap_file); - // But we create the HostBuffer objects inside the result BufferStore arena. - var res: zml.aio.BufferStore = .{ - .arena = std.heap.ArenaAllocator.init(allocator), - }; - res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file}); - var tmp: PickleData = .{ .data = _parser, .stack = stack }; - try tmp.parseModel(res.arena.allocator(), &res); + const ops = try torch_file.parsePickle(tmp_alloc); + const py_values = try eval.evaluate(tmp_alloc, ops, true); + + // file ownership is transferred to the BufferStore + var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file}); + try torch_file.parseModel(py_values, &res); return res; } - -// TODO: rename me to PytorchFile -pub const PickleData = struct { - stack: []const Value, - data: parser.Parser, - - fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool { - return switch (object.member) { - .raw => |raw| return (object.args[0] == .seq and - std.mem.eql(u8, module, raw.global.module) and - std.mem.eql(u8, class, raw.global.class)), - else => false, - }; - } - - pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void { - for (self.stack) |item| { - var prefix_buf: [1024]u8 = undefined; - try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); - } - } - - pub fn parseValue(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void { - switch (v) { - .app, .object, .global => |object| { - if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) { - try self.parseValue(allocator, store, prefix, object.member); - for (object.args) |item| { - // if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix) - if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) { - try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } }); - } else try self.parseValue(allocator, store, prefix, item); - } - } - }, - .build => |build| { - // `build` contains info about python struct being constructed - switch (build.member) { - .object => |obj| switch (obj.member) { - .raw => |raw| switch (raw) { - .global => |global| { - // in this case, we can capture the name of the python type - // which can be used for codegen (e.g. `torch.nn.modules.conv.Conv2d`) - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.appendSliceAssumeCapacity("_gen_type_helper"); - const key = try allocator.dupe(u8, new_prefix.items); - const d = try store._metadata.getOrPut(allocator, key); - if (d.found_existing) { - log.err("Duplicate key: {s}", .{new_prefix.items}); - allocator.free(key); - } else { - const val = try std.mem.join(allocator, ".", &.{ global.module, global.class }); - d.value_ptr.* = .{ .string = val }; - } - }, - else => try self.parseValue(allocator, store, prefix, build.member), // parse normally - }, - else => try self.parseValue(allocator, store, prefix, build.member), // parse normally - }, - else => try self.parseValue(allocator, store, prefix, build.member), // parse normally - } - try self.parseValue(allocator, store, prefix, build.args); - }, - .pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref), - .seq => |seq| { - switch (seq.type) { - .list, .tuple, .set, .frozen_set => { - if (seq.values.len == 0) return; - var valid_slice = true; - switch (seq.values[0]) { - inline .int64, .float64, .boolval => |val0, tag| { - const ItemType = switch (tag) { - .int64 => i64, - .float64 => f64, - .boolval => bool, - else => unreachable, - }; - var values: std.ArrayListUnmanaged(ItemType) = .{}; - try values.append(allocator, val0); - for (seq.values[1..], 1..) |val, i| { - if (std.meta.activeTag(val) != tag) valid_slice = false; - if (valid_slice) { - try values.append(allocator, @field(val, @tagName(tag))); - } else { - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - try self.parseValue(allocator, store, new_prefix, val); - } - } - - if (valid_slice) { - try store._metadata.put( - allocator, - try allocator.dupe(u8, prefix.items), - try zml.aio.Metadata.copySlice(allocator, values.items), - ); - } else { - for (values.items, 0..) |val, i| { - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - const new_tag = switch (tag) { - .int64 => "int", - .float64 => "float", - .boolval => "bool", - else => unreachable, // we are already inside a switch - }; - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val)); - } - } - }, - else => { - for (seq.values, 0..) |item, i| { - var new_prefix = prefix; - if (v.isPrimitive()) { - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - } - try self.parseValue(allocator, store, new_prefix, item); - } - }, - } - }, - .dict => for (seq.values) |item| { - try self.parseValue(allocator, store, prefix, item); - }, - .kv_tuple => { - const key, const val = seq.values[0..2].*; - switch (key) { - .string => |s| { - // Handle Pytorch specific fields - if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { - try self.parseValue(allocator, store, prefix, val); - } else { - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.appendSliceAssumeCapacity(s); - try self.parseValue(allocator, store, new_prefix, val); - } - }, - .int64 => |int| { - var new_prefix = prefix; - if (prefix.items.len > 0) { - new_prefix.appendAssumeCapacity('.'); - } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); - try self.parseValue(allocator, store, new_prefix, val); - }, - inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}), - } - }, - } - }, - .bytes => |val| { - const key = try allocator.dupe(u8, prefix.items); - const d = try store._metadata.getOrPut(allocator, key); - if (d.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } else d.value_ptr.* = .{ .string = val }; - }, - inline .float64, .int64, .boolval, .bigint, .string => |val| { - const key = try allocator.dupe(u8, prefix.items); - const d = try store._metadata.getOrPut(allocator, key); - if (d.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } else { - d.value_ptr.* = zml.aio.Metadata.wrap(val); - } - }, - else => {}, - } - } - - fn parseTorchGlobal(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool { - return switch (v) { - .global => |object| { - if (try self.parseTensor(allocator, object)) |host_buffer| { - const key = try allocator.dupe(u8, prefix.items); - const entry = try store.buffers.getOrPut(allocator, key); - if (entry.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } - entry.value_ptr.* = host_buffer; - return true; - } else if (basicTypeCheck(object, "torch", "Size")) { - const size = object.args[0].seq.values[0].seq.values; - const key = try allocator.dupe(u8, prefix.items); - const entry = try store._metadata.getOrPut(allocator, key); - if (entry.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } - const d = try allocator.alloc(i64, size.len); - for (d, 0..) |*di, i| di.* = size[i].int64; - entry.value_ptr.* = .{ .array_int = d }; - return true; - } else if (basicTypeCheck(object, "fractions", "Fraction")) { - const fraction_str = object.args[0].seq.values[0].string; - if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { - { - var new_prefix = prefix; - new_prefix.appendSliceAssumeCapacity(".numerator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); - } - { - var new_prefix = prefix; - new_prefix.appendSliceAssumeCapacity(".denominator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); - } - return true; - } - } - return false; - }, - else => false, - }; - } - - fn parseTensor(self: *PickleData, tmp_allocator: std.mem.Allocator, object: *value.Object) !?zml.HostBuffer { - if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) { - return null; - } - - const args = object.args[0].seq.values; - if (args.len < 4 or - args[0] != .pers_id or - args[1] != .int64 or - args[2] != .seq or args[2].seq.type != .tuple or - args[3] != .seq or args[3].seq.type != .tuple) - { - log.err("Unexpected value in call to torch._utils._rebuild_tensor_v2", .{}); - return error.InvalidInput; - } - - const pid: *PersId = args[0].pers_id; - var offset: u64 = @intCast(args[1].int64); - const raw_dims: Sequence = args[2].seq; - const raw_strides: Sequence = args[3].seq; - const dims = try parseDims(raw_dims.values); - var strides = try parseDims(raw_strides.values); - - const dtype, const storage_file = try parseStorage(pid.ref); - // Pytorch store "item" strides, while ZML uses byte strides. - for (strides.slice()) |*s| s.* *= dtype.sizeOf(); - // Same thing for the offset. - offset = offset * dtype.sizeOf(); - - const filename = try std.mem.join(tmp_allocator, "", &.{ self.data.zip_prefix, "data/", storage_file }); - defer tmp_allocator.free(filename); - - // The offset in the pickle is the offset inside the storage_file. - // But .pt are made of several files, so we need to append the file offset. - const storage = try self.getStorage(filename); - return HostBuffer.fromStridedSlice( - zml.Shape.init(dims.constSlice(), dtype), - storage[offset..], - strides.constSlice(), - ); - } - - fn parseStorage(val: value.Value) !struct { zml.DataType, []const u8 } { - if (val != .seq) return error.InvalidInput; - const sargs = val.seq.values; - if (val.seq.type == .tuple and - sargs.len >= 5 and - sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and - sargs[1] == .raw and sargs[1].raw == .global and - sargs[2] == .string and - sargs[3] == .string) - { - const op = sargs[1].raw.global; - const storage_file = sargs[2].string; - // const sdev = sargs[3].string; - if (!std.mem.eql(u8, "torch", op.module) or - !std.mem.endsWith(u8, op.class, "Storage")) - return error.InvalidInput; - - return .{ - try storageToDtype(op.class), - storage_file, - }; - } else { - return error.InvalidInput; - } - } - - /// Given the name of one of the files in the .pt tarball, - /// return the slice of the memory-mapped .pt corresponding to it. - fn getStorage(self: *PickleData, filename: []const u8) ![]const u8 { - const maybe_entry = self.data.file_map.get(filename); - if (maybe_entry == null) { - std.log.err("Could not find file ending in `{s}` in archive", .{filename}); - return error.TensorNotFound; - } - const entry = maybe_entry.?; - const base_offset: u64 = if (self.data.tar_file) |t| t.start else 0; - const file_offset: u64 = base_offset + entry.file_offset; - const file = self.data.buffer_file.file; - try file.seekTo(entry.file_offset); - const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little); - - if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) - return error.ZipBadFileOffset; - if (local_header.compressed_size != 0 and - local_header.compressed_size != entry.compressed_size) - return error.ZipMismatchCompLen; - if (local_header.uncompressed_size != 0 and - local_header.uncompressed_size != entry.uncompressed_size) - return error.ZipMismatchUncompLen; - if (local_header.filename_len != entry.filename_len) - return error.ZipMismatchFilenameLen; - - const start = file_offset + - @sizeOf(std.zip.LocalFileHeader) + - @as(u64, local_header.filename_len) + - @as(u64, local_header.extra_len); - return self.data.buffer_file.mappedSlice(start, entry.uncompressed_size); - } - - fn parseDims(values: []Value) error{InvalidInput}!zml.Shape.DimsArray { - zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); - var result: zml.Shape.DimsArray = .{}; - for (values) |val| { - switch (val) { - .int64 => |d| result.appendAssumeCapacity(d), - else => return error.InvalidInput, - } - } - return result; - } -}; - -/// Convert from a torch.Storage to a `zml.DataType`. -/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage -/// See https://pytorch.org/docs/stable/storage.html -fn storageToDtype(storage_type: []const u8) !zml.DataType { - const torch_type = storage_type[0 .. storage_type.len - "Storage".len]; - const map = std.StaticStringMap(zml.DataType).initComptime(.{ - .{ "Double", .f64 }, - .{ "Float", .f32 }, - .{ "Half", .f16 }, - .{ "Long", .i64 }, - .{ "Int", .i32 }, - .{ "Short", .i16 }, - .{ "Char", .i8 }, - .{ "Byte", .u8 }, - .{ "Bool", .bool }, - .{ "BFloat16", .bf16 }, - .{ "ComplexDouble", .c128 }, - .{ "ComplexFloat", .c64 }, - // QUInt8Storage - // QInt8Storage - // QInt32Storage - // QUInt4x2Storage - // QUInt2x4Storage - }); - - return map.get(torch_type) orelse { - log.err("Unsupported torch storage type: {s}", .{storage_type}); - return error.UnsupportedDataType; - }; -} diff --git a/zml/aio/torch/b_tree_map.zig b/zml/aio/torch/b_tree_map.zig deleted file mode 100644 index 19cc8bf..0000000 --- a/zml/aio/torch/b_tree_map.zig +++ /dev/null @@ -1,652 +0,0 @@ -const std = @import("std"); - -/// BTreeMap Node implementation. -pub fn NodeType(comptime K: type, comptime V: type, comptime B: u32) type { - return struct { - const Self = @This(); - keys: [2 * B - 1]K = [_]K{undefined} ** (2 * B - 1), - values: [2 * B - 1]V = [_]V{undefined} ** (2 * B - 1), - len: usize = 0, - edges: [2 * B]?*Self = [_]?*Self{null} ** (2 * B), - - pub const KV = struct { key: K, value: V }; - const KVE = struct { key: K, value: V, edge: ?*Self }; - const Entry = struct { key_ptr: *K, value_ptr: *V }; - - /// Initializes an empty Node. - pub fn initEmpty(allocator: std.mem.Allocator) !*Self { - const res: *Self = try allocator.create(Self); - res.* = .{}; - return res; - } - - /// Initializes a Node with a single Entry. - pub fn initKeyValue(allocator: std.mem.Allocator, entry: struct { K, V }) !*Self { - const key, const value = entry; - var res = try Self.initEmpty(allocator); - res.keys[0] = key; - res.values[0] = value; - res.len = 1; - return res; - } - - fn initFromSplit(allocator: std.mem.Allocator, keys: []K, values: []V, edges: []?*Self) !*Self { - var out = try Self.initEmpty(allocator); - std.mem.copyBackwards(K, out.keys[0..], keys); - std.mem.copyBackwards(V, out.values[0..], values); - std.mem.copyBackwards(?*Self, out.edges[0..], edges); - out.len = keys.len; - return out; - } - - pub fn count(self: Self) usize { - var len: usize = self.len; - for (0..self.len + 1) |i| { - if (!self.isLeaf()) { - len += self.edges[i].?.count(); - } - } - return len; - } - - // Searches the Node for a key. - pub fn search(self: Self, key: K) std.meta.Tuple(&.{ bool, usize }) { - var i: usize = 0; - while (i < self.len) : (i += 1) { - if (eql(key, self.keys[i])) { - return .{ true, i }; - } else if (lt(key, self.keys[i])) { - return .{ false, i }; - } - } - return .{ false, self.len }; - } - - pub fn insertOrSplit( - self: *Self, - allocator: std.mem.Allocator, - index: usize, - key: K, - value: V, - edge: ?*Self, - ) !?KVE { - if (self.isFull()) { - var split_result = try self.split(allocator); - switch (index < B) { - true => self.insert(index, key, value, edge), - false => split_result.edge.?.insert(index - B, key, value, edge), - } - return split_result; - } - self.insert(index, key, value, edge); - return null; - } - - pub fn swapValue(self: *Self, index: usize, value: V) V { - const out = self.values[index]; - self.values[index] = value; - return out; - } - - pub fn swapKeyValue(self: *Self, index: usize, key: K, value: V) KV { - const out = .{ .key = self.keys[index], .value = self.values[index] }; - self.values[index] = value; - self.keys[index] = key; - return out; - } - - pub fn orderedRemove(self: *Self, index: usize) KVE { - const out: KVE = .{ - .key = self.keys[index], - .value = self.values[index], - .edge = self.edges[index + 1], - }; - std.mem.copyForwards(K, self.keys[index..], self.keys[index + 1 .. self.len]); - std.mem.copyForwards(V, self.values[index..], self.values[index + 1 .. self.len]); - self.keys[self.len - 1] = undefined; - self.values[self.len - 1] = undefined; - if (!self.isLeaf()) { - std.mem.copyForwards(?*Self, self.edges[index + 1 ..], self.edges[index + 2 .. self.len + 1]); - self.edges[self.len] = null; - } - self.len -= 1; - return out; - } - - fn pop(self: *Self) KVE { - return self.orderedRemove(self.len - 1); - } - - fn shift(self: *Self) KVE { - const out: KVE = .{ - .key = self.keys[0], - .value = self.values[0], - .edge = self.edges[0], - }; - std.mem.copyForwards(K, self.keys[0..], self.keys[1..self.len]); - std.mem.copyForwards(V, self.values[0..], self.values[1..self.len]); - self.keys[self.len - 1] = undefined; - self.values[self.len - 1] = undefined; - if (!self.isLeaf()) { - std.mem.copyForwards( - ?*Self, - self.edges[0..], - self.edges[1 .. self.len + 1], - ); - self.edges[self.len] = null; - } - self.len -= 1; - return out; - } - - fn insert(self: *Self, index: usize, key: K, value: V, edge: ?*Self) void { - std.mem.copyBackwards( - K, - self.keys[index + 1 .. self.len + 1], - self.keys[index..self.len], - ); - self.keys[index] = key; - std.mem.copyBackwards(V, self.values[index + 1 .. self.len + 1], self.values[index..self.len]); - self.values[index] = value; - if (!self.isLeaf()) { - std.mem.copyBackwards(?*Self, self.edges[index + 2 .. self.len + 2], self.edges[index + 1 .. self.len + 1]); - self.edges[index + 1] = edge; - } - self.len += 1; - } - - fn append(self: *Self, key: K, value: V, edge: ?*Self) void { - self.keys[self.len] = key; - self.values[self.len] = value; - self.edges[self.len + 1] = edge; - self.len += 1; - } - - fn unshift(self: *Self, key: K, value: V, edge: ?*Self) void { - std.mem.copyBackwards(K, self.keys[1 .. self.len + 1], self.keys[0..self.len]); - self.keys[0] = key; - std.mem.copyBackwards(V, self.values[1 .. self.len + 1], self.values[0..self.len]); - self.values[0] = value; - if (!self.isLeaf()) { - std.mem.copyBackwards(?*Self, self.edges[1 .. self.len + 2], self.edges[0 .. self.len + 1]); - self.edges[0] = edge; - } - self.len += 1; - } - - pub fn borrowRight(self: *Self, index: usize) bool { - if (index == self.len) return false; - var from = self.edges[index + 1].?; - if (from.len > B - 1) { - var to = self.edges[index].?; - const borrowed = from.shift(); - to.append(self.keys[index], self.values[index], borrowed.edge); - _ = self.swapKeyValue(index, borrowed.key, borrowed.value); - return true; - } - return false; - } - - pub fn borrowLeft(self: *Self, index: usize) bool { - if (index == 0) return false; - var from = self.edges[index - 1].?; - if (from.len > B - 1) { - var to = self.edges[index].?; - const borrowed = from.pop(); - to.unshift(self.keys[index - 1], self.values[index - 1], borrowed.edge); - _ = self.swapKeyValue(index - 1, borrowed.key, borrowed.value); - return true; - } - return false; - } - - pub fn mergeEdges(self: *Self, allocator: std.mem.Allocator, left_edge_index: usize) void { - var left = self.edges[left_edge_index].?; - const removed = self.orderedRemove(left_edge_index); - left.append(removed.key, removed.value, null); - std.mem.copyBackwards(K, left.keys[left.len..], removed.edge.?.keys[0..removed.edge.?.len]); - std.mem.copyBackwards(V, left.values[left.len..], removed.edge.?.values[0..removed.edge.?.len]); - std.mem.copyBackwards(?*Self, left.edges[left.len..], removed.edge.?.edges[0 .. removed.edge.?.len + 1]); - left.len += removed.edge.?.len; - allocator.destroy(removed.edge.?); - } - - fn split(self: *Self, allocator: std.mem.Allocator) !KVE { - const median = B - 1; - const new_key = self.keys[median]; - const new_value = self.values[median]; - const new_node = try Self.initFromSplit( - allocator, - self.keys[median + 1 .. self.len], - self.values[median + 1 .. self.len], - self.edges[median + 1 .. self.len + 1], - ); - @memset(self.keys[median..], undefined); - @memset(self.values[median..], undefined); - @memset(self.edges[median + 1 ..], null); - self.len = median; - return .{ .key = new_key, .value = new_value, .edge = new_node }; - } - - pub fn isLeaf(self: Self) bool { - return self.edges[0] == null; - } - - pub fn isFull(self: Self) bool { - return self.len == 2 * B - 1; - } - - pub fn isLacking(self: Self) bool { - return self.len < B - 1; - } - }; -} - -pub fn BTreeMap(comptime K: type, comptime V: type) type { - return struct { - const Self = @This(); - - const B = 6; - const Node = NodeType(K, V, B); - const KV = Node.KV; - const SearchResult = std.meta.Tuple(&.{ bool, usize }); - const StackEntry = struct { node: *Node, index: usize }; - - allocator: std.mem.Allocator, - root: ?*Node = null, - - pub fn init(allocator: std.mem.Allocator) Self { - return .{ .allocator = allocator }; - } - - pub fn deinit(self: Self) !void { - if (self.root == null) return; - var stack = std.ArrayList(*Node).init(self.allocator); - defer stack.deinit(); - if (self.root) |root| { - try stack.append(root); - } - while (stack.popOrNull()) |node| { - if (!node.isLeaf()) { - for (0..node.len + 1) |i| { - try stack.append(node.edges[i].?); - } - } - self.allocator.destroy(node); - } - } - - pub fn count(self: Self) usize { - if (self.root == null) return 0; - var len: usize = 0; - if (self.root) |node| { - len += node.count(); - } - return len; - } - - pub fn isEmpty(self: *const Self) bool { - if (self.root == null) return true; - return self.root.?.len == 0; - } - - pub fn get(self: Self, key: K) ?V { - var current = self.root; - while (current) |node| { - const found, const index = node.search(key); - switch (found) { - true => return node.values[index], - false => current = node.edges[index], - } - } - return null; - } - - pub fn getPtr(self: Self, key: K) ?*V { - var current = self.root; - while (current) |node| { - const found, const index = node.search(key); - switch (found) { - true => return &node.values[index], - false => current = node.edges[index], - } - } - return null; - } - - pub fn fetchPut(self: *Self, key: K, value: V) !?KV { - if (self.root == null) { - self.root = try Node.initKeyValue(self.allocator, .{ key, value }); - return null; - } - var stack = std.ArrayList(StackEntry).init(self.allocator); - defer stack.deinit(); - var current = self.root; - var search_result: SearchResult = undefined; - while (current) |node| { - search_result = node.search(key); - if (search_result[0]) { - return .{ .key = key, .value = node.swapValue(search_result[1], value) }; - } - current = node.edges[search_result[1]]; - try stack.append(.{ .node = node, .index = search_result[1] }); - } - var stack_next: ?StackEntry = stack.pop(); - var split_result = try stack_next.?.node.insertOrSplit( - self.allocator, - stack_next.?.index, - key, - value, - null, - ); - if (split_result == null) { - return null; - } - stack_next = stack.popOrNull(); - while (split_result) |split_result_unwrapped| { - if (stack_next) |stack_next_unwrapped| { - split_result = try stack_next_unwrapped.node.insertOrSplit( - self.allocator, - stack_next_unwrapped.index, - split_result_unwrapped.key, - split_result_unwrapped.value, - split_result_unwrapped.edge, - ); - stack_next = stack.popOrNull(); - } else { - var new_root = try Node.initKeyValue( - self.allocator, - .{ split_result_unwrapped.key, split_result_unwrapped.value }, - ); - new_root.edges[0] = self.root; - new_root.edges[1] = split_result_unwrapped.edge; - self.root = new_root; - return null; - } - } else return null; - } - - pub fn fetchRemove(self: *Self, key: K) !?KV { - var stack = std.ArrayList(StackEntry).init(self.allocator); - defer stack.deinit(); - var current = self.root; - var search_result: SearchResult = undefined; - var found_key_ptr: ?*K = null; - var found_value_ptr: ?*V = null; - while (current) |node| { - search_result = node.search(key); - if (search_result[0]) { - found_key_ptr = &node.keys[search_result[1]]; - found_value_ptr = &node.values[search_result[1]]; - if (!node.isLeaf()) search_result[1] += 1; - } - try stack.append(.{ - .node = node, - .index = search_result[1], - }); - current = node.edges[search_result[1]]; - if (search_result[0]) break; - } else return null; - while (current) |node| { - try stack.append(.{ .node = node, .index = 0 }); - current = node.edges[0]; - } - var current_stack = stack.pop(); - const out: KV = .{ .key = found_key_ptr.?.*, .value = found_value_ptr.?.* }; - found_key_ptr.?.* = current_stack.node.keys[current_stack.index]; - found_value_ptr.?.* = current_stack.node.values[current_stack.index]; - _ = current_stack.node.orderedRemove(current_stack.index); - if (current_stack.node == self.root) return out; - while (current_stack.node.isLacking()) { - current_stack = stack.pop(); - if (current_stack.node.borrowRight(current_stack.index)) return out; - if (current_stack.node.borrowLeft(current_stack.index)) return out; - if (current_stack.index == current_stack.node.len) { - current_stack.node.mergeEdges(self.allocator, current_stack.index - 1); - } else { - current_stack.node.mergeEdges(self.allocator, current_stack.index); - } - if (current_stack.node == self.root) { - if (self.root.?.len == 0) { - const new_root = current_stack.node.edges[0].?; - self.allocator.destroy(self.root.?); - self.root.? = new_root; - } - break; - } - } - return out; - } - - const Iterator = struct { - stack: std.ArrayList(StackEntry), - backwards: bool, - - pub fn deinit(it: Iterator) void { - it.stack.deinit(); - } - - pub fn next(it: *Iterator) ?Node.Entry { - while (it.topStackItem()) |item| { - if (!item.node.isLeaf() and !it.backwards) { - const child = item.node.edges[item.index].?; - it.stack.append(StackEntry{ .node = child, .index = 0 }) catch unreachable; - } else { - if (item.index < item.node.len) { - const out: Node.Entry = .{ .key_ptr = &item.node.keys[item.index], .value_ptr = &item.node.values[item.index] }; - item.index += 1; - it.backwards = false; - return out; - } else { - _ = it.stack.popOrNull(); - it.backwards = true; - } - } - } else return null; - } - - fn topStackItem(it: *Iterator) ?*StackEntry { - return switch (it.stack.items.len) { - 0 => null, - else => &it.stack.items[it.stack.items.len - 1], - }; - } - }; - - pub fn iterator(self: *const Self) Iterator { - var new_stack = std.ArrayList(StackEntry).init(self.allocator); - if (self.root) |root| { - new_stack.append(.{ .node = root, .index = 0 }) catch unreachable; - } - return Iterator{ - .stack = new_stack, - .backwards = false, - }; - } - }; -} - -/// Compares two of any type for equality. Containers are compared on a field-by-field basis, -/// where possible. Pointers are followed if the addresses are not equal. -fn eql(a: anytype, b: @TypeOf(a)) bool { - const T = @TypeOf(a); - switch (@typeInfo(T)) { - .Struct => |info| { - inline for (info.fields) |field_info| { - if (!eql(@field(a, field_info.name), @field(b, field_info.name))) return false; - } - return true; - }, - .ErrorUnion => { - if (a) |a_p| { - if (b) |b_p| return eql(a_p, b_p) else |_| return false; - } else |a_e| { - if (b) |_| return false else |b_e| return a_e == b_e; - } - }, - .Union => |info| { - if (info.tag_type) |UnionTag| { - const tag_a = std.meta.activeTag(a); - const tag_b = std.meta.activeTag(b); - if (tag_a != tag_b) return false; - - inline for (info.fields) |field_info| { - if (@field(UnionTag, field_info.name) == tag_a) { - return eql(@field(a, field_info.name), @field(b, field_info.name)); - } - } - return false; - } - - @compileError("Cannot compare untagged union type " ++ @typeName(T)); - }, - .Array => { - if (a.len != b.len) return false; - for (a, 0..) |e, i| - if (!eql(e, b[i])) return false; - return true; - }, - .Vector => |info| { - var i: usize = 0; - while (i < info.len) : (i += 1) { - if (!eql(a[i], b[i])) return false; - } - return true; - }, - .Pointer => |info| { - return switch (info.size) { - .One => if (a == b) true else eql(a.*, b.*), - .Many => if (a == b) true else { - if (info.sentinel) { - if (std.mem.len(a) != std.mem.len(b)) return false; - var i: usize = 0; - while (i < std.mem.len(a)) : (i += 1) - if (!eql(a[i], b[i])) return false; - return true; - } - @compileError("Cannot compare many-item Pointers without sentinel value"); - }, - .C => if (a == b) true else @compileError("Cannot compare C pointers"), - .Slice => if (a.ptr == b.ptr and a.len == b.len) true else { - if (a.len != b.len) return false; - for (a, 0..) |_, i| - if (!eql(a[i], b[i])) return false; - return true; - }, - }; - }, - .Optional => { - if (a == null and b == null) return true; - if (a == null or b == null) return false; - return eql(a.?, b.?); - }, - else => return a == b, - } -} - -fn lt(a: anytype, b: @TypeOf(a)) bool { - const T = @TypeOf(a); - - switch (@typeInfo(T)) { - .Int, .ComptimeInt, .Float, .ComptimeFloat => { - return a < b; - }, - .Struct => { - if (!@hasDecl(T, "lt")) { - @compileError("Type `" ++ @typeName(T) ++ "` must implement a `lt` comparison method."); - } - return T.lt(a, b); - }, - .Union => |info| { - if (info.tag_type) |UnionTag| { - const tag_a = std.meta.activeTag(a); - const tag_b = std.meta.activeTag(b); - // if tags are not equal, perform comparison based on tag - if (tag_a != tag_b) { - return std.ascii.lessThanIgnoreCase(@tagName(tag_a), @tagName(tag_b)); - } - // if tags are equal, compare based on the active field - inline for (info.fields) |field_info| { - if (@field(UnionTag, field_info.name) == tag_a) { - return lt(@field(a, field_info.name), @field(b, field_info.name)); - } - } - return false; - } - - @compileError("Cannot perform `lt` check on untagged union type " ++ @typeName(T)); - }, - .Array => { - for (a, 0..) |_, i| { - if (lt(a[i], b[i])) { - return true; - } else if (eql(a[i], b[i])) { - continue; - } else { - return false; - } - } - return false; - }, - .Vector => |info| { - var i: usize = 0; - while (i < info.len) : (i += 1) { - if (lt(a[i], b[i])) { - return true; - } else if (eql(a[i], b[i])) { - continue; - } else { - return false; - } - } - return false; - }, - .Pointer => |info| { - switch (info.size) { - .One => return lt(a.*, b.*), - .Slice => { - const n = @min(a.len, b.len); - for (a[0..n], 0..) |_, i| { - if (lt(a[i], b[i])) { - return true; - } else if (eql(a[i], b[i])) { - continue; - } else { - return false; - } - } - return lt(a.len, b.len); - }, - .Many => { - if (info.sentinel) { - const n = @min(std.mem.len(a), std.mem.len(b)); - var i: usize = 0; - while (i < n) : (i += 1) { - if (lt(a[i], b[i])) { - return true; - } else if (eql(a[i], b[i])) { - continue; - } else { - return false; - } - } - return lt(std.mem.len(a), std.mem.len(b)); - } - @compileError("Cannot compare many-item pointer to unknown number of items without sentinel value"); - }, - .C => @compileError("Cannot compare C pointers"), - } - }, - .Optional => { - if (a == null or b == null) return false; - return lt(a.?, b.?); - }, - else => { - @compileError("Cannot compare type '" ++ @typeName(T) ++ "'"); - }, - } -} - -pub fn gt(a: anytype, b: @TypeOf(a)) bool { - return !lt(a, b) and !eql(a, b); -} diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index 175febe..d62fe3c 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -2,42 +2,33 @@ const std = @import("std"); const zml = @import("../../zml.zig"); const meta = zml.meta; -const value = @import("value.zig"); +const py = @import("py.zig"); const pickle = @import("pickle.zig"); -const BTreeMap = @import("b_tree_map.zig").BTreeMap; - -const Build = value.Build; -const Object = value.Object; -const PersId = value.PersId; -const Sequence = value.Sequence; -const SequenceType = value.SequenceType; -const Value = value.Value; const MAX_DEPTH: usize = 250; const MAX_PROTOCOL: u8 = 5; pub const PickleMemo = struct { - allocator: std.mem.Allocator, - map: BTreeMap(u32, Value), + map: std.AutoHashMap(u32, py.Any), pub fn init(allocator: std.mem.Allocator) PickleMemo { return .{ - .allocator = allocator, - .map = BTreeMap(u32, Value).init(allocator), + .map = std.AutoHashMap(u32, py.Any).init(allocator), }; } pub fn deinit(self: *PickleMemo) void { + const allocator = self.map.allocator; var iterator = self.map.iterator(); defer iterator.deinit(); while (iterator.next()) |entry| { - entry.value_ptr.deinit(self.allocator); + entry.value_ptr.deinit(allocator); } self.map.deinit() catch unreachable; self.* = undefined; } - pub fn resolve(self: *PickleMemo, allocator: std.mem.Allocator, op: Value, recursive: bool) !Value { + pub fn resolve(self: *PickleMemo, allocator: std.mem.Allocator, op: py.Any, recursive: bool) !py.Any { var used_op = op; while (used_op == .ref) { var count: usize = 0; @@ -67,12 +58,12 @@ pub const PickleMemo = struct { } } }, - .build => |v| { - if (v.member.containsRef()) { - v.member = try self.resolve(allocator, v.member, recursive); + .set_state => |v| { + if (v.obj.containsRef()) { + v.obj = try self.resolve(allocator, v.obj, recursive); } - if (v.args.containsRef()) { - v.args = try self.resolve(allocator, v.args, recursive); + if (v.state.containsRef()) { + v.state = try self.resolve(allocator, v.state, recursive); } }, .pers_id => |v| { @@ -93,11 +84,11 @@ pub const PickleMemo = struct { return used_op; } - pub fn insert(self: *PickleMemo, mid: u32, val: Value) !void { + pub fn insert(self: *PickleMemo, mid: u32, val: py.Any) !void { _ = try self.map.fetchPut(mid, val); } - pub fn resolveMut(self: *PickleMemo, op: *Value, recursive: bool) !*Value { + pub fn resolveMut(self: *PickleMemo, op: *py.Any, recursive: bool) !*py.Any { if (op.* != .ref) return op; var lastmid = op.ref; var count: usize = 0; @@ -122,34 +113,35 @@ pub const PickleMemo = struct { }); } - const MemoError = std.math.big.int.Managed.ConvertError || std.mem.Allocator.Error || error{BadMemoRef}; + const MemoError = py.Any.UnpickleError || error{BadMemoRef}; - pub fn resolveAllRefsIter(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, vals: []Value, fix_values: bool) MemoError![]Value { + pub fn resolveAllRefsIter(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, vals: []py.Any, fix_values: bool) MemoError![]py.Any { if (depth >= MAX_DEPTH) { return vals; } - const res = try allocator.alloc(Value, vals.len); + const res = try allocator.alloc(py.Any, vals.len); for (vals, 0..) |v, i| { res[i] = try self.resolveAllRefs(allocator, depth + 1, v, fix_values); } return res; } - pub fn resolveAllRefs(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, val: Value, fix_values: bool) !Value { - var output: Value = switch (val) { + pub fn resolveAllRefs(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, val: py.Any, fix_values: bool) !py.Any { + var output: py.Any = switch (val) { .ref => try self.resolve(allocator, val, true), - inline .app, .object, .global => |v, tag| @unionInit(Value, @tagName(tag), try Object.init( + inline .app, .object, .global => |v, tag| @unionInit(py.Any, @tagName(tag), try py.Object.init( allocator, try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), try self.resolveAllRefsIter(allocator, depth + 1, v.args, fix_values), + try self.resolveAllRefsIter(allocator, depth + 1, v.kwargs, fix_values), )), - .build => |v| .{ .build = try Build.init( + .set_state => |v| .{ .set_state = try py.SetState.init( allocator, - try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), - try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values), + try self.resolveAllRefs(allocator, depth + 1, v.obj, fix_values), + try self.resolveAllRefs(allocator, depth + 1, v.state, fix_values), ) }, .seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } }, - .pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) }, + .pers_id => |v| .{ .pers_id = try py.PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) }, else => try val.clone(allocator), }; if (fix_values) { @@ -159,29 +151,9 @@ pub const PickleMemo = struct { } }; -pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value { - var stack = std.ArrayList(Value).init(arena); +pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any { + var stack = std.ArrayList(py.Any).init(arena); var memo = PickleMemo.init(arena); - errdefer memo.deinit(); - - const makeKVList = (struct { - pub fn call(alloc: std.mem.Allocator, items: []const Value) ![]Value { - meta.assert(items.len & 1 == 0, "Bad value for setitems", .{}); - var kv_items = try std.ArrayList(Value).initCapacity(alloc, items.len); - errdefer kv_items.deinit(); - var idx: usize = 0; - while (idx < items.len) : (idx += 2) { - if (idx + 1 >= items.len) { - return error.MissingValueItem; - } - const kv = try alloc.alloc(Value, 2); - kv[0] = items[idx]; - kv[1] = items[idx + 1]; - kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } }); - } - return kv_items.toOwnedSlice(); - } - }).call; for (x) |op| { switch (op) { @@ -189,47 +161,50 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo .frame => {}, .stop => break, .pop => _ = try pop(&stack), - .pop_mark => try popMarkDiscard(&stack), + .pop_mark => _ = try popMark(&stack), .dup => if (stack.getLastOrNull()) |item| try stack.append(try item.clone(arena)) else return error.CannotDupEmptyStack, - .persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), - .binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }), + .persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), + .binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }), .reduce => try stack.append(.{ .global = blk: { - const values = try arena.alloc(Value, 1); - values[0] = try memo.resolve(arena, try pop(&stack), true); - break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values); + var args = try pop(&stack); + args = try memo.resolve(arena, args, true); + if (args != .seq) return error.InvalidInput; + var func = try pop(&stack); + func = try memo.resolve(arena, func, true); + break :blk try py.Object.init(arena, func, args.seq.values, &.{}); } }), .build => try stack.append(blk: { const args = try memo.resolve(arena, try pop(&stack), true); const member = try memo.resolve(arena, try pop(&stack), true); - break :blk .{ .build = try Build.init(arena, member, args) }; + break :blk .{ .set_state = try py.SetState.init(arena, member, args) }; }), - .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), + .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }), .get => |v| try stack.append(.{ .ref = v }), - .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), + .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }), .put => |v| { try memo.insert(v, try pop(&stack)); try stack.append(.{ .ref = v }); }, .tuple => try stack.append(blk: { - const popped = try popMark(&stack, arena); - break :blk .{ .seq = .{ .type = .tuple, .values = popped } }; + const popped = try popMark(&stack); + break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } }; }), - .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), + .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }), .setitem => { - const v, const k = .{ try pop(&stack), try pop(&stack) }; + const v = try memo.resolve(arena, try pop(&stack), true); + const k = try memo.resolve(arena, try pop(&stack), true); const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } }; + try append(arena, &obj.kwargs, &.{ k, v }); }, - .seq => |*tup| { - tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); - tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } }; + .seq => |*dict| { + if (dict.type != .dict) return error.BadStackTopForSetItem; + try append(arena, &dict.values, &.{ k, v }); }, else => { return error.BadStackTopForSetItem; @@ -237,39 +212,35 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo } }, .setitems => { - const popped = try popMark(&stack, arena); - defer arena.free(popped); - const kv_items = try makeKVList(arena, popped); + const popped = try memo.resolveAllRefsIter(arena, 0, try popMark(&stack), true); const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; + try append(arena, &obj.kwargs, popped); }, - .seq => |*tup| { - tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); - tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; + .seq => |*dict| { + if (dict.type != .dict) return error.BadStackTopForSetItems; + try append(arena, &dict.values, popped); }, else => { - defer arena.free(kv_items); return error.BadStackTopForSetItems; }, } }, .proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .tuple1 => try stack.append(blk: { - const tup_values = try arena.alloc(Value, 1); + const tup_values = try arena.alloc(py.Any, 1); tup_values[0] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .tuple2 => try stack.append(blk: { - const tup_values = try arena.alloc(Value, 2); + const tup_values = try arena.alloc(py.Any, 2); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), .tuple3 => try stack.append(blk: { - const tup_values = try arena.alloc(Value, 3); + const tup_values = try arena.alloc(py.Any, 3); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), @@ -279,12 +250,12 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo const rtop = try memo.resolveMut(top, true); switch (rtop.*) { .global => |obj| { - obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); - obj.args[obj.args.len - 1] = v; + // can this happen ? + try append(arena, &obj.args, &.{v}); }, - .seq => |*tup| { - tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); - tup.values[tup.values.len - 1] = v; + .seq => |*seq| { + if (seq.type != .list) return error.BadStackTopForAppend; + try append(arena, &seq.values, &.{v}); }, else => { return error.BadStackTopForAppend; @@ -292,83 +263,75 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo } }, .appends => { - const postmark = try popMark(&stack, arena); - defer arena.free(postmark); + const postmark = try popMark(&stack); const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { - .global => |obj| { - const obj_len = obj.args.len; - obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); - @memcpy(obj.args[obj_len..], postmark); - }, - .seq => |*tup| { - const tup_len = tup.values.len; - tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len); - @memcpy(tup.values[tup_len..], postmark); + .global => try append(arena, &rtop.global.args, postmark), + .seq => |*seq| { + if (seq.type != .list) return error.BadStackTopForAppend; + try append(arena, &seq.values, postmark); }, else => { return error.BadStackTopForAppends; }, } }, - .dict => try stack.append(blk: { - const popped = try popMark(&stack, arena); - defer arena.free(popped); - const kv_items = try makeKVList(arena, popped); - break :blk .{ .seq = .{ .type = .dict, .values = kv_items } }; - }), - .list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }), - .inst => |v| try stack.append(blk: { - const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); - break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) }; - }), + .dict => try stack.append(.{ .seq = .{ + .type = .dict, + .values = try arena.dupe(py.Any, try popMark(&stack)), + } }), + .list => try stack.append(.{ .seq = .{ + .type = .list, + .values = try arena.dupe(py.Any, try popMark(&stack)), + } }), + .inst => |v| try stack.append(.{ .object = try py.Object.init( + arena, + try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena), + try arena.dupe(py.Any, try popMark(&stack)), + &.{}, + ) }), .obj => try stack.append(blk: { const mark = try findMark(&stack); - const args = try arena.dupe(Value, stack.items[mark + 2 ..]); + const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]); const member = stack.items[mark + 1]; - break :blk .{ .object = try Object.init(arena, member, args) }; + break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) }; }), .newobj => try stack.append(blk: { - const args = try arena.alloc(Value, 1); + const args = try arena.alloc(py.Any, 1); args[0] = try pop(&stack); - break :blk .{ .object = try Object.init(arena, try pop(&stack), args) }; + break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) }; }), - .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), + .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }), .additems => { - const postmark = try popMark(&stack, arena); - defer arena.free(postmark); + const postmark = try popMark(&stack); const top = try lastMut(&stack); const rtop = try memo.resolveMut(top, true); switch (rtop.*) { - .global => |obj| { - const obj_len = obj.args.len; - obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); - @memcpy(obj.args[obj_len..], postmark); - }, - .seq => |*tup| { - const tup_len = tup.values.len; - tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len); - @memcpy(tup.values[tup_len..], postmark); + .seq => |*seq| { + if (seq.type != .set) return error.BadStackTopForAppend; + try append(arena, &seq.values, postmark); }, else => { - return error.BadStackTopForSetItem; + return error.BadStackTopForAppends; }, } }, - .frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }), + .frozenset => try stack.append(.{ .seq = .{ + .type = .frozen_set, + .values = try arena.dupe(py.Any, try popMark(&stack)), + } }), .newobj_ex => try stack.append(blk: { const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) }; - const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) }; - break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) }; + break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) }; }), .stack_global => try stack.append(blk: { const gn, const mn = .{ try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true), }; - const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) }; - break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) }; + const new_seq: py.Sequence = .{ .type = .tuple, .values = try arena.dupe(py.Any, &.{ gn, mn }) }; + break :blk .{ .object = try py.Object.init(arena, .{ .seq = new_seq }, &.{}, &.{}) }; }), .memoize => { const item = stack.getLastOrNull() orelse { @@ -385,23 +348,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo return stack.toOwnedSlice(); } -// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead -fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_length: usize) ![]T { - if (allocator.resize(old, new_length)) { - return old; - } else { - defer allocator.free(old); - const new = try allocator.alloc(T, new_length); - @memcpy(new[0..old.len], old); - return new; - } +fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void { + var array_list = std.ArrayListUnmanaged(py.Any).fromOwnedSlice(current.*); + try array_list.appendSlice(allocator, values); + current.* = array_list.items; } test evaluate { var arena = std.heap.ArenaAllocator.init(std.testing.allocator); defer arena.deinit(); const allocator = arena.allocator(); - const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); var buffered_reader = std.io.bufferedReader(file.reader()); const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096); @@ -411,64 +368,62 @@ test evaluate { try std.testing.expect(vals.len == 1); try std.testing.expect(vals[0] == .seq); try std.testing.expect(vals[0].seq.type == .dict); - const entries = vals[0].seq.values[0].seq.values; - try std.testing.expect(entries.len == 5); - const expected: []const Value = &.{ - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ - .{ .string = "list" }, - .{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{ - .{ .int64 = 0 }, - .{ .int64 = 1 }, - .{ .int64 = 2 }, - .{ .int64 = 3 }, - .{ .int64 = 4 }, - }) } }, - }) } }, - .{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ - .{ .string = "tuple" }, - .{ .seq = .{ - .type = .tuple, - .values = @constCast(&[_]Value{ - .{ .string = "a" }, - .{ .int64 = 10 }, + const entries = vals[0].seq.values; + const expected: []const py.Any = &.{ + // Key, followed by its value + .{ .string = "hello" }, .{ .string = "world" }, + .{ .string = "int" }, .{ .int64 = 1 }, + .{ .string = "float" }, .{ .float64 = 3.141592 }, + .{ .string = "list" }, + .{ + .seq = .{ + .type = .list, + .values = @constCast(&[_]py.Any{ + .{ .int64 = 255 }, + .{ .int64 = 1234 }, + .{ .int64 = -123 }, + .{ .int64 = 1_000_000_000 }, + .{ .int64 = 999_000_000_000 }, + .{ .bigint = (try std.math.big.int.Managed.initSet(allocator, 999_000_000_000_000_000_000_000_000_000)).toConst() }, }), - } }, - }) } }, + }, + }, + .{ .string = "bool" }, .{ .boolval = false }, + .{ .string = "tuple" }, + .{ .seq = .{ + .type = .tuple, + .values = @constCast(&[_]py.Any{ + .{ .string = "a" }, + .{ .int64 = 10 }, + }), + } }, }; try std.testing.expectEqualDeep(expected, entries); } -pub fn pop(values: *std.ArrayList(Value)) !Value { +pub fn pop(values: *std.ArrayList(py.Any)) !py.Any { if (values.items.len == 0) { return error.StackUnderrun; } return values.pop(); } -fn popMarkDiscard(values: *std.ArrayList(Value)) !void { - const mark = try findMark(values); - values.shrinkRetainingCapacity(mark); -} - -fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value { +fn popMark(values: *std.ArrayList(py.Any)) ![]py.Any { const mark = try findMark(values); const popping = values.items[mark + 1 ..]; values.shrinkRetainingCapacity(mark); - return try allocator.dupe(Value, popping); + return popping; } -fn lastMut(values: *std.ArrayList(Value)) !*Value { +fn lastMut(values: *std.ArrayList(py.Any)) !*py.Any { if (values.items.len == 0) { return error.UnexpectedEmptyStack; } return &values.items[values.items.len - 1]; } -fn findMark(values: *std.ArrayList(Value)) !usize { +fn findMark(values: *std.ArrayList(py.Any)) !usize { const len = values.items.len; for (0..len) |i| { const idx = (len - 1) - i; diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig new file mode 100644 index 0000000..afe2331 --- /dev/null +++ b/zml/aio/torch/file.zig @@ -0,0 +1,680 @@ +const std = @import("std"); +const testing = std.testing; +const log = std.log.scoped(.zml_aio); + +const asynk = @import("async"); + +const zml = @import("../../zml.zig"); +const pickle = @import("pickle.zig"); +const py = @import("py.zig"); +const eval = @import("eval.zig"); +const HostBuffer = zml.HostBuffer; + +// TODO(cryptodeal): use zml.aio.PrefixBuilder instead +const StringBuilder = std.ArrayListUnmanaged(u8); + +test { + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(File); +} + +pub const File = struct { + buffer_file: zml.aio.MemoryMappedFile, + /// Map names to sub file + file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, + tar_file: ?TarStream = null, + is_zip_file: bool, + zip_prefix: []const u8 = &.{}, + pickle_subfile: struct { start: u64 = 0, len: usize }, + + pub const FileEntry = struct { + version_needed_to_extract: u16, + flags: u16, + compression_method: std.zip.CompressionMethod, + last_modification_time: u16, + last_modification_date: u16, + header_zip_offset: u64, + crc32: u32, + filename_len: u32, + compressed_size: u64, + uncompressed_size: u64, + file_offset: u64, + + pub fn init(entry: anytype) FileEntry { + return .{ + .version_needed_to_extract = entry.version_needed_to_extract, + .flags = @as(u16, @bitCast(entry.flags)), + .compression_method = entry.compression_method, + .last_modification_time = entry.last_modification_time, + .last_modification_date = entry.last_modification_date, + .header_zip_offset = entry.header_zip_offset, + .crc32 = entry.crc32, + .filename_len = entry.filename_len, + .compressed_size = entry.compressed_size, + .uncompressed_size = entry.uncompressed_size, + .file_offset = entry.file_offset, + }; + } + }; + + const magic = "PK\x03\x04"; + + pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File { + const tar_file = try TarStream.init(file); + const file_magic = try tar_file.reader().readBytesNoEof(magic.len); + try tar_file.seekTo(0); + var res: File = .{ + .buffer_file = mapped, + .tar_file = tar_file, + .is_zip_file = std.mem.eql(u8, &file_magic, magic), + .pickle_subfile = .{ .len = try tar_file.getEndPos() }, + }; + if (res.is_zip_file) { + try res.parseZipHeaders(allocator, tar_file.seekableStream()); + } + return res; + } + + pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File { + const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len); + try mmap_file.file.seekTo(0); + var res: File = .{ + .buffer_file = mmap_file, + .is_zip_file = std.mem.eql(u8, &file_magic, magic), + .pickle_subfile = .{ .len = mmap_file.data.len }, + }; + + if (res.is_zip_file) { + try res.parseZipHeaders(allocator, mmap_file.file.seekableStream()); + } + return res; + } + + pub fn close(self: *File) void { + self.buffer_file.deinit(); + } + + pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op { + return if (self.tar_file) |tar_file| { + try tar_file.seekTo(self.pickle_subfile.start); + var buffered = std.io.bufferedReader(tar_file.reader()); + return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len); + } else { + const file = self.buffer_file.file; + try file.seekTo(self.pickle_subfile.start); + var buffered = std.io.bufferedReader(file.reader()); + return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len); + }; + } + + fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void { + var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}; + + var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); + var filename_buf: [std.fs.max_path_bytes]u8 = undefined; + while (try iter.next()) |entry| { + const filename = filename_buf[0..entry.filename_len]; + try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader)); + const len = try seekable_stream.context.reader().readAll(filename); + if (len != filename.len) return error.ZipBadFileOffset; + if (isBadFilename(filename)) return error.ZipBadFilename; + std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators + try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry)); + } + + self.file_map = file_map; + var file_iter = file_map.iterator(); + while (file_iter.next()) |e| { + const entry = e.value_ptr.*; + const filename = e.key_ptr.*; + if (!std.mem.endsWith(u8, filename, "data.pkl")) continue; + + self.zip_prefix = filename[0 .. filename.len - "data.pkl".len]; + + const local_data_header_offset: u64 = local_data_header_offset: { + switch (entry.compression_method) { + .store => {}, + .deflate => { + // TODO(cryptodeal): handle decompress + @panic("TODO support use of `deflate`"); + }, + else => @panic("TODO support other modes of compression"), + } + const local_header = blk: { + try seekable_stream.seekTo(entry.file_offset); + break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little); + }; + if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) + return error.ZipBadFileOffset; + if (local_header.version_needed_to_extract != entry.version_needed_to_extract) + return error.ZipMismatchVersionNeeded; + if (local_header.last_modification_time != entry.last_modification_time) + return error.ZipMismatchModTime; + if (local_header.last_modification_date != entry.last_modification_date) + return error.ZipMismatchModDate; + + if (@as(u16, @bitCast(local_header.flags)) != entry.flags) + return error.ZipMismatchFlags; + if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) + return error.ZipMismatchCrc32; + if (local_header.compressed_size != 0 and + local_header.compressed_size != entry.compressed_size) + return error.ZipMismatchCompLen; + if (local_header.uncompressed_size != 0 and + local_header.uncompressed_size != entry.uncompressed_size) + return error.ZipMismatchUncompLen; + if (local_header.filename_len != entry.filename_len) + return error.ZipMismatchFilenameLen; + + break :local_data_header_offset @as(u64, local_header.filename_len) + + @as(u64, local_header.extra_len); + }; + + const local_data_file_offset: u64 = + @as(u64, entry.file_offset) + + @as(u64, @sizeOf(std.zip.LocalFileHeader)) + + local_data_header_offset; + self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size }; + return; + } + + log.err("Could not find file ending in `data.pkl` in archive", .{}); + return error.PickleNotFound; + } + + fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool { + return switch (object.member) { + .raw => |raw| return (std.mem.eql(u8, module, raw.global.module) and + std.mem.eql(u8, class, raw.global.class)), + else => false, + }; + } + + pub fn parseModel(self: File, values: []const py.Any, store: *zml.aio.BufferStore) !void { + var prefix_buf: [1024]u8 = undefined; + const allocator = store.arena.allocator(); + for (values) |item| { + try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); + } + } + + pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void { + // log.warn("Parsing {}", .{v}); + switch (v) { + .app, .object, .global => |object| { + if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) { + try self.parseValue(allocator, store, prefix, object.member); + for (object.args) |item| { + try self.parseValue(allocator, store, prefix, item); + } + if (object.kwargs.len % 2 != 0) return error.InvalidInput; + const n_kwargs = @divExact(object.kwargs.len, 2); + + for (0..n_kwargs) |i| { + const key, const val = object.kwargs[2 * i ..][0..2].*; + // kwargs can only be keyed by string. + if (key != .string) return error.InvalidInput; + // Handle Pytorch specific fields + const s = key.string; + if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { + try self.parseValue(allocator, store, prefix, val); + } else { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.appendSliceAssumeCapacity(s); + try self.parseValue(allocator, store, new_prefix, val); + } + } + } + }, + .set_state => |set_state| { + // `set_state` contains info about python struct being constructed + switch (set_state.obj) { + .object => |obj| switch (obj.member) { + .raw => |raw| switch (raw) { + .global => |global| { + // in this case, we can capture the name of the python type + // which can be used for codegen (e.g. `torch.nn.modules.conv.Conv2d`) + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.appendSliceAssumeCapacity("_gen_type_helper"); + const key = try allocator.dupe(u8, new_prefix.items); + const d = try store._metadata.getOrPut(allocator, key); + if (d.found_existing) { + log.err("Duplicate key: {s}", .{new_prefix.items}); + allocator.free(key); + } else { + const val = try std.mem.join(allocator, ".", &.{ global.module, global.class }); + d.value_ptr.* = .{ .string = val }; + } + }, + else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally + }, + else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally + }, + else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally + } + try self.parseValue(allocator, store, prefix, set_state.state); + }, + .pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref), + .seq => |seq| { + switch (seq.type) { + .list, .tuple, .set, .frozen_set => { + if (seq.values.len == 0) return; + var valid_slice = true; + switch (seq.values[0]) { + inline .int64, .float64, .boolval => |val0, tag| { + const ItemType = switch (tag) { + .int64 => i64, + .float64 => f64, + .boolval => bool, + else => unreachable, + }; + var values: std.ArrayListUnmanaged(ItemType) = .{}; + try values.append(allocator, val0); + for (seq.values[1..], 1..) |val, i| { + if (std.meta.activeTag(val) != tag) valid_slice = false; + if (valid_slice) { + try values.append(allocator, @field(val, @tagName(tag))); + } else { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + try self.parseValue(allocator, store, new_prefix, val); + } + } + + if (valid_slice) { + try store._metadata.put( + allocator, + try allocator.dupe(u8, prefix.items), + try zml.aio.Metadata.copySlice(allocator, values.items), + ); + } else { + for (values.items, 0..) |val, i| { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + const new_tag = switch (tag) { + .int64 => "int", + .float64 => "float", + .boolval => "bool", + else => unreachable, // we are already inside a switch + }; + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val)); + } + } + }, + else => { + for (seq.values, 0..) |item, i| { + var new_prefix = prefix; + if (v.isPrimitive()) { + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + } + try self.parseValue(allocator, store, new_prefix, item); + } + }, + } + }, + .dict => { + const n = @divExact(seq.values.len, 2); + log.info("found dict with {} entries", .{n}); + for (0..n) |i| { + const key, const val = seq.values[2 * i ..][0..2].*; + switch (key) { + .string => |s| { + // Handle Pytorch specific fields. + if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { + try self.parseValue(allocator, store, prefix, val); + } else { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.appendSliceAssumeCapacity(s); + + try self.parseValue(allocator, store, new_prefix, val); + } + }, + .int64 => |int| { + var new_prefix = prefix; + if (prefix.items.len > 0) { + new_prefix.appendAssumeCapacity('.'); + } + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); + try self.parseValue(allocator, store, new_prefix, val); + }, + inline else => |_, tag| { + log.debug("Ignoring unsupported key type found in torch file: {s}", .{@tagName(tag)}); + continue; + }, + } + } + }, + } + }, + .bytes => |val| { + const key = try allocator.dupe(u8, prefix.items); + const d = try store._metadata.getOrPut(allocator, key); + if (d.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } else d.value_ptr.* = .{ .string = val }; + }, + inline .float64, .int64, .boolval, .bigint, .string => |val| { + const key = try allocator.dupe(u8, prefix.items); + const d = try store._metadata.getOrPut(allocator, key); + if (d.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } else { + d.value_ptr.* = zml.aio.Metadata.wrap(val); + } + }, + else => {}, + } + } + + fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool { + return switch (v) { + .global => |object| { + if (try self.parseTensor(allocator, object)) |host_buffer| { + const key = try allocator.dupe(u8, prefix.items); + const entry = try store.buffers.getOrPut(allocator, key); + if (entry.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } + entry.value_ptr.* = host_buffer; + return true; + } else if (basicTypeCheck(object, "torch", "Size")) { + const size = object.args; + const key = try allocator.dupe(u8, prefix.items); + const entry = try store._metadata.getOrPut(allocator, key); + if (entry.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } + const d = try allocator.alloc(i64, size.len); + for (d, 0..) |*di, i| di.* = size[i].int64; + entry.value_ptr.* = .{ .array_int = d }; + return true; + } else if (basicTypeCheck(object, "fractions", "Fraction")) { + const fraction_str = object.args[0].string; + if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { + { + var new_prefix = prefix; + new_prefix.appendSliceAssumeCapacity(".numerator"); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); + } + { + var new_prefix = prefix; + new_prefix.appendSliceAssumeCapacity(".denominator"); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); + } + return true; + } + } + return false; + }, + else => false, + }; + } + + fn parseTensor(self: File, tmp_allocator: std.mem.Allocator, object: *py.Object) !?zml.HostBuffer { + if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) { + return null; + } + + const args = object.args; + if (args.len < 4 or + args[0] != .pers_id or + args[1] != .int64 or + args[2] != .seq or args[2].seq.type != .tuple or + args[3] != .seq or args[3].seq.type != .tuple) + { + log.err("Unexpected py.Any in call to torch._utils._rebuild_tensor_v2: {}", .{object.*}); + return error.InvalidInput; + } + + const pid: *py.PersId = args[0].pers_id; + var offset: u64 = @intCast(args[1].int64); + const raw_dims: py.Sequence = args[2].seq; + const raw_strides: py.Sequence = args[3].seq; + const dims = try parseDims(raw_dims.values); + var strides = try parseDims(raw_strides.values); + + const dtype, const storage_file = try parseStorage(pid.ref); + // Pytorch store "item" strides, while ZML uses byte strides. + for (strides.slice()) |*s| s.* *= dtype.sizeOf(); + // Same thing for the offset. + offset = offset * dtype.sizeOf(); + + const filename = try std.mem.join(tmp_allocator, "", &.{ self.zip_prefix, "data/", storage_file }); + defer tmp_allocator.free(filename); + + // The offset in the pickle is the offset inside the storage_file. + // But .pt are made of several files, so we need to append the file offset. + const storage = try self.getStorage(filename); + return HostBuffer.fromStridedSlice( + zml.Shape.init(dims.constSlice(), dtype), + storage[offset..], + strides.constSlice(), + ); + } + + fn parseStorage(val: py.Any) !struct { zml.DataType, []const u8 } { + if (val != .seq) return error.InvalidInput; + const sargs = val.seq.values; + if (val.seq.type == .tuple and + sargs.len >= 5 and + sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and + sargs[1] == .raw and sargs[1].raw == .global and + sargs[2] == .string and + sargs[3] == .string) + { + const op = sargs[1].raw.global; + const storage_file = sargs[2].string; + // const sdev = sargs[3].string; + if (!std.mem.eql(u8, "torch", op.module) or + !std.mem.endsWith(u8, op.class, "Storage")) + return error.InvalidInput; + + return .{ + try storageToDtype(op.class), + storage_file, + }; + } else { + return error.InvalidInput; + } + } + + /// Given the name of one of the files in the .pt tarball, + /// return the slice of the memory-mapped .pt corresponding to it. + fn getStorage(self: File, filename: []const u8) ![]const u8 { + const maybe_entry = self.file_map.get(filename); + if (maybe_entry == null) { + std.log.err("Could not find file ending in `{s}` in archive", .{filename}); + return error.TensorNotFound; + } + const entry = maybe_entry.?; + const base_offset: u64 = if (self.tar_file) |t| t.start else 0; + const file_offset: u64 = base_offset + entry.file_offset; + const file = self.buffer_file.file; + try file.seekTo(entry.file_offset); + const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little); + + if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) + return error.ZipBadFileOffset; + if (local_header.compressed_size != 0 and + local_header.compressed_size != entry.compressed_size) + return error.ZipMismatchCompLen; + if (local_header.uncompressed_size != 0 and + local_header.uncompressed_size != entry.uncompressed_size) + return error.ZipMismatchUncompLen; + if (local_header.filename_len != entry.filename_len) + return error.ZipMismatchFilenameLen; + + const start = file_offset + + @sizeOf(std.zip.LocalFileHeader) + + @as(u64, local_header.filename_len) + + @as(u64, local_header.extra_len); + return self.buffer_file.mappedSlice(start, entry.uncompressed_size); + } + + fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray { + zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); + var result: zml.Shape.DimsArray = .{}; + for (values) |val| { + switch (val) { + .int64 => |d| result.appendAssumeCapacity(d), + else => return error.InvalidInput, + } + } + return result; + } +}; + +/// Convert from a torch.Storage to a `zml.DataType`. +/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage +/// See https://pytorch.org/docs/stable/storage.html +fn storageToDtype(storage_type: []const u8) !zml.DataType { + const torch_type = storage_type[0 .. storage_type.len - "Storage".len]; + const map = std.StaticStringMap(zml.DataType).initComptime(.{ + .{ "Double", .f64 }, + .{ "Float", .f32 }, + .{ "Half", .f16 }, + .{ "Long", .i64 }, + .{ "Int", .i32 }, + .{ "Short", .i16 }, + .{ "Char", .i8 }, + .{ "Byte", .u8 }, + .{ "Bool", .bool }, + .{ "BFloat16", .bf16 }, + .{ "ComplexDouble", .c128 }, + .{ "ComplexFloat", .c64 }, + // QUInt8Storage + // QInt8Storage + // QInt32Storage + // QUInt4x2Storage + // QUInt2x4Storage + }); + + return map.get(torch_type) orelse { + log.err("Unsupported torch storage type: {s}", .{storage_type}); + return error.UnsupportedDataType; + }; +} + +const TarStream = struct { + pub const SeekableStream = std.io.SeekableStream( + TarStream, + asynk.File.SeekError, + asynk.File.GetSeekPosError, + TarStream.seekTo, + TarStream.seekBy, + TarStream.getPos, + TarStream.getEndPos, + ); + + file: std.tar.Iterator(asynk.File.Reader).File, + start: usize, + + pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream { + return .{ + .file = file, + .start = try file.parent_reader.context.getPos(), + }; + } + + pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader { + return file.file.reader(); + } + + pub fn seekTo(self: TarStream, offset: u64) !void { + return self.file.parent_reader.context.seekTo(self.start + offset); + } + + pub fn seekBy(self: TarStream, offset: i64) !void { + return self.file.parent_reader.context.seekBy(offset); + } + + pub fn getPos(self: TarStream) !u64 { + return try self.file.parent_reader.context.getPos() - self.start; + } + + pub fn getEndPos(self: TarStream) !u64 { + return self.file.size; + } + + pub fn seekableStream(self: TarStream) TarStream.SeekableStream { + return .{ .context = self }; + } +}; + +test "Read pickle (zipped)" { + // test file created with following python snippet: + // + // import torch + // torch.manual_seed(0) + // model = torch.nn.Conv2d(2, 2, 3, stride=2, padding=[2, 4], dtype=torch.float16) + // tensor = torch.tensor([[2, 4, 3, 2]], dtype=torch.uint8) + // torch.save({ "model": model, "tensor": tensor}, "simple.pt") + const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only }); + const mmap_file = try zml.aio.MemoryMappedFile.init(file); + var store = try zml.aio.BufferStore.init(testing.allocator, &.{mmap_file}); + defer store.deinit(); + + { + var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator); + defer tmp_arena.deinit(); + const tmp_alloc = tmp_arena.allocator(); + var torch_file = try File.init(tmp_alloc, mmap_file); + // We don't close the file directly, it will be closed by the store. + + const ops = try torch_file.parsePickle(tmp_alloc); + try std.testing.expectEqual(302, ops.len); + + const py_values = try eval.evaluate(tmp_alloc, ops, true); + try torch_file.parseModel(py_values, &store); + } + + // now we have freed the tmp_arena. + // all data needed should have been copied into the store arena. + try zml.testing.expectEqualShapes( + zml.Shape.init(.{ 1, 4 }, .u8), + store.get("tensor").?.shape(), + ); + try zml.testing.expectEqualShapes( + zml.Shape.init(.{ 2, 2, 3, 3 }, .f16), + store.get("model.weight").?.shape(), + ); + try zml.testing.expectEqualShapes( + zml.Shape.init(.{2}, .f16), + store.get("model.bias").?.shape(), + ); +} + +fn isBadFilename(filename: []const u8) bool { + if (filename.len == 0 or filename[0] == '/') + return true; + + var it = std.mem.splitScalar(u8, filename, '/'); + while (it.next()) |part| { + if (std.mem.eql(u8, part, "..")) + return true; + } + + return false; +} diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig deleted file mode 100644 index 6d9485a..0000000 --- a/zml/aio/torch/parser.zig +++ /dev/null @@ -1,237 +0,0 @@ -const asynk = @import("async"); -const std = @import("std"); -const testing = std.testing; -const Allocator = std.mem.Allocator; - -const zml = @import("../../zml.zig"); -const pickle = @import("pickle.zig"); - -test { - std.testing.refAllDecls(@This()); - std.testing.refAllDecls(Parser); -} - -pub const Parser = struct { - // TODO: move the file logic to torch.PytorchFile - // the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch - buffer_file: zml.aio.MemoryMappedFile, - file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, - tar_file: ?TarStream = null, - ops: []const pickle.Op, - is_zip_file: bool, - zip_prefix: []const u8 = &[_]u8{}, - - pub const FileEntry = struct { - version_needed_to_extract: u16, - flags: u16, - compression_method: std.zip.CompressionMethod, - last_modification_time: u16, - last_modification_date: u16, - header_zip_offset: u64, - crc32: u32, - filename_len: u32, - compressed_size: u64, - uncompressed_size: u64, - file_offset: u64, - - pub fn init(entry: anytype) FileEntry { - return .{ - .version_needed_to_extract = entry.version_needed_to_extract, - .flags = @as(u16, @bitCast(entry.flags)), - .compression_method = entry.compression_method, - .last_modification_time = entry.last_modification_time, - .last_modification_date = entry.last_modification_date, - .header_zip_offset = entry.header_zip_offset, - .crc32 = entry.crc32, - .filename_len = entry.filename_len, - .compressed_size = entry.compressed_size, - .uncompressed_size = entry.uncompressed_size, - .file_offset = entry.file_offset, - }; - } - }; - - const magic = "PK\x03\x04"; - - pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser { - const tar_stream = try TarStream.init(file); - const file_magic = try tar_stream.reader().readBytesNoEof(magic.len); - try tar_stream.seekTo(0); - var self: Parser = .{ - .buffer_file = mapped, - .tar_file = tar_stream, - .ops = undefined, - .is_zip_file = std.mem.eql(u8, &file_magic, magic), - }; - if (!self.is_zip_file) { - const reader = tar_stream.reader(); - self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos()); - } else { - self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream()); - } - return self; - } - - pub fn init(allocator: Allocator, file: asynk.File) !Parser { - const file_magic = try file.reader().readBytesNoEof(magic.len); - try file.seekTo(0); - var self: Parser = .{ - .buffer_file = try zml.aio.MemoryMappedFile.init(file), - .is_zip_file = std.mem.eql(u8, &file_magic, magic), - .ops = undefined, - }; - if (!self.is_zip_file) { - const reader = self.buffer_file.file.reader(); - self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos()); - } else { - self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream()); - } - return self; - } - - pub fn deinit(self: *Parser) void { - self.buffer_file.deinit(); - self.* = undefined; - } - - fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op { - var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); - var filename_buf: [std.fs.max_path_bytes]u8 = undefined; - while (try iter.next()) |entry| { - const filename = filename_buf[0..entry.filename_len]; - try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader)); - const len = try seekable_stream.context.reader().readAll(filename); - if (len != filename.len) return error.ZipBadFileOffset; - if (isBadFilename(filename)) return error.ZipBadFilename; - std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators - try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry)); - } - - var file_iter = self.file_map.iterator(); - while (file_iter.next()) |e| { - const entry = e.value_ptr.*; - const filename = e.key_ptr.*; - if (std.mem.indexOf(u8, filename, "data.pkl")) |idx| { - self.zip_prefix = filename[0..idx]; - const local_data_header_offset: u64 = local_data_header_offset: { - const local_header = blk: { - try seekable_stream.seekTo(entry.file_offset); - break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little); - }; - if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) - return error.ZipBadFileOffset; - if (local_header.version_needed_to_extract != entry.version_needed_to_extract) - return error.ZipMismatchVersionNeeded; - if (local_header.last_modification_time != entry.last_modification_time) - return error.ZipMismatchModTime; - if (local_header.last_modification_date != entry.last_modification_date) - return error.ZipMismatchModDate; - - if (@as(u16, @bitCast(local_header.flags)) != entry.flags) - return error.ZipMismatchFlags; - if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) - return error.ZipMismatchCrc32; - if (local_header.compressed_size != 0 and - local_header.compressed_size != entry.compressed_size) - return error.ZipMismatchCompLen; - if (local_header.uncompressed_size != 0 and - local_header.uncompressed_size != entry.uncompressed_size) - return error.ZipMismatchUncompLen; - if (local_header.filename_len != entry.filename_len) - return error.ZipMismatchFilenameLen; - - break :local_data_header_offset @as(u64, local_header.filename_len) + - @as(u64, local_header.extra_len); - }; - - const local_data_file_offset: u64 = - @as(u64, entry.file_offset) + - @as(u64, @sizeOf(std.zip.LocalFileHeader)) + - local_data_header_offset; - try seekable_stream.seekTo(local_data_file_offset); - - switch (entry.compression_method) { - .store => { - return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size); - }, - .deflate => { - // TODO(cryptodeal): handle decompress - @panic("TODO support use of `deflate`"); - }, - else => @panic("TODO support other modes of compression"), - } - } - } - - std.log.err("Could not find file ending in `data.pkl` in archive", .{}); - return error.PickleNotFound; - } -}; - -const TarStream = struct { - pub const SeekableStream = std.io.SeekableStream( - TarStream, - asynk.File.SeekError, - asynk.File.GetSeekPosError, - TarStream.seekTo, - TarStream.seekBy, - TarStream.getPos, - TarStream.getEndPos, - ); - - file: std.tar.Iterator(asynk.File.Reader).File, - start: usize, - - pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream { - return .{ - .file = file, - .start = try file.parent_reader.context.getPos(), - }; - } - - pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader { - return file.file.reader(); - } - - pub fn seekTo(self: TarStream, offset: u64) !void { - return self.file.parent_reader.context.seekTo(self.start + offset); - } - - pub fn seekBy(self: TarStream, offset: i64) !void { - return self.file.parent_reader.context.seekBy(offset); - } - - pub fn getPos(self: TarStream) !u64 { - return try self.file.parent_reader.context.getPos() - self.start; - } - - pub fn getEndPos(self: TarStream) !u64 { - return self.file.size; - } - - pub fn seekableStream(self: TarStream) TarStream.SeekableStream { - return .{ .context = self }; - } -}; - -test "Read pickle (zipped)" { - var arena = std.heap.ArenaAllocator.init(testing.allocator); - defer arena.deinit(); - const allocator = arena.allocator(); - const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only }); - var data = try Parser.init(allocator, file); - defer data.deinit(); -} - -fn isBadFilename(filename: []const u8) bool { - if (filename.len == 0 or filename[0] == '/') - return true; - - var it = std.mem.splitScalar(u8, filename, '/'); - while (it.next()) |part| { - if (std.mem.eql(u8, part, "..")) - return true; - } - - return false; -} diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 22d80a8..5aeb529 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -673,10 +673,10 @@ pub const OpCode = enum(u8) { /// because operators having same semantics, but different encoding have been merged. /// ex: string, binstring, short_binstring -> string. pub const Op = union(enum) { - // Initially numbers were represented by strings... - int: []const u8, - binint: i32, + int: i32, + // Python can represent arbitrary long integers long: []const u8, + binlong: []const u8, string: []const u8, bytes: []const u8, bytearray: []u8, @@ -767,26 +767,32 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) var results = std.ArrayList(Op).init(allocator); errdefer results.deinit(); const len = max_line_len; + var _buf: std.BoundedArray(u8, 12) = .{}; while (true) { const b = try reader.readByte(); const code: OpCode = @enumFromInt(b); const op: Op = switch (code) { .int => blk: { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); + _buf.len = 0; + try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1); + const buf = _buf.constSlice(); // Legacy hack, see OpCode.int documentation // We do this parsing right away to simplify downstream code. - if (std.mem.eql(u8, "00", buf)) break :blk .{ .bool = false }; - if (std.mem.eql(u8, "01", buf)) break :blk .{ .bool = true }; - break :blk .{ .int = buf }; + break :blk if (std.mem.eql(u8, "00", buf)) + .{ .bool = false } + else if (std.mem.eql(u8, "01", buf)) + .{ .bool = true } + else + .{ .int = try std.fmt.parseInt(i32, buf, 10) }; }, - .binint => .{ .binint = try reader.readInt(i32, .little) }, - .binint1 => .{ .binint = try reader.readByte() }, - .binint2 => .{ .binint = try reader.readInt(u16, .little) }, + .binint => .{ .int = try reader.readInt(i32, .little) }, + .binint1 => .{ .int = try reader.readByte() }, + .binint2 => .{ .int = try reader.readInt(u16, .little) }, // TODO: long should handle the trailing 'L' -> add a test. .long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, - .long1 => .{ .long = try _readSlice(reader, allocator, 1) }, - .long4 => .{ .long = try _readSlice(reader, allocator, 4) }, + .long1 => .{ .binlong = try _readSlice(reader, allocator, 1) }, + .long4 => .{ .binlong = try _readSlice(reader, allocator, 4) }, .string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .binstring => .{ .string = try _readSlice(reader, allocator, 4) }, .short_binstring => .{ .string = try _readSlice(reader, allocator, 1) }, @@ -825,12 +831,9 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) .dup => .dup, .mark => .mark, .pop_mark => .pop_mark, - .get => blk: { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - defer allocator.free(buf); - // If we fail to parse delay the error to the evaluation. - const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32); - break :blk .{ .get = n }; + // If we fail to parse delay the error to the evaluation. + .get => .{ + .get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32), }, .binget => .{ .get = try reader.readByte() }, .long_binget => .{ .get = try reader.readInt(u32, .little) }, @@ -887,9 +890,9 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) return results.toOwnedSlice(); } -test parse { +test "parse protocol 4" { const allocator = std.testing.allocator; - const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); var buffered_reader = std.io.bufferedReader(file.reader()); const ops = try parse(allocator, buffered_reader.reader(), 4096); defer { @@ -898,11 +901,10 @@ test parse { allocator.free(ops); } - try std.testing.expect(ops.len == 35); - // this can be obtained by running: `python -m pickletools simple_test.pickle` - const expected = [_]Op{ + // this can be obtained by running: `python -m pickletools simple_test_4.pickle` + var expected = [_]Op{ .{ .proto = 4 }, - .{ .frame = 83 }, + .{ .frame = 119 }, .empty_dict, .memoize, .mark, @@ -912,7 +914,7 @@ test parse { .memoize, .{ .unicode = "int" }, .memoize, - .{ .binint = 1 }, + .{ .int = 1 }, .{ .unicode = "float" }, .memoize, .{ .binfloat = 3.141592 }, @@ -921,17 +923,21 @@ test parse { .empty_list, .memoize, .mark, - .{ .binint = 0 }, - .{ .binint = 1 }, - .{ .binint = 2 }, - .{ .binint = 3 }, - .{ .binint = 4 }, + .{ .int = 255 }, + .{ .int = 1234 }, + .{ .int = -123 }, + .{ .int = 1_000_000_000 }, + .{ .binlong = &writeIntBuff(u48, 999_000_000_000) }, + .{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) }, .appends, + .{ .unicode = "bool" }, + .memoize, + .{ .bool = false }, .{ .unicode = "tuple" }, .memoize, .{ .unicode = "a" }, .memoize, - .{ .binint = 10 }, + .{ .int = 10 }, .tuple2, .memoize, .setitems, @@ -940,6 +946,109 @@ test parse { try std.testing.expectEqualDeep(&expected, ops); } +test "parse protocol 0" { + // We also test protocol 0, cause it's more text oriented. + const allocator = std.testing.allocator; + const pickle_0 = + \\(dp0 + \\Vhello + \\p1 + \\Vworld + \\p2 + \\sVint + \\p3 + \\I1 + \\sVfloat + \\p4 + \\F3.141592 + \\sVlist + \\p5 + \\(lp6 + \\I255 + \\aI1234 + \\aI-123 + \\aI1000000000 + \\aL999000000000L + \\aL999000000000000000000000000000L + \\asVbool + \\p7 + \\I00 + \\sVtuple + \\p8 + \\(Va + \\p9 + \\I10 + \\tp10 + \\s. + ; + + var stream = std.io.fixedBufferStream(pickle_0); + const ops = try parse(allocator, stream.reader(), 4096); + defer { + // Test we are correctly freeing every allocation. + for (ops) |op| op.deinit(allocator); + allocator.free(ops); + } + + var expected = [_]Op{ + .mark, + .dict, + .{ .put = 0 }, + .{ .unicode = "hello" }, + .{ .put = 1 }, + .{ .unicode = "world" }, + .{ .put = 2 }, + .setitem, + .{ .unicode = "int" }, + .{ .put = 3 }, + .{ .int = 1 }, + .setitem, + .{ .unicode = "float" }, + .{ .put = 4 }, + .{ .float = "3.141592" }, + .setitem, + .{ .unicode = "list" }, + .{ .put = 5 }, + .mark, + .list, + .{ .put = 6 }, + .{ .int = 255 }, + .append, + .{ .int = 1234 }, + .append, + .{ .int = -123 }, + .append, + .{ .int = 1_000_000_000 }, + .append, + .{ .long = "999000000000L" }, + .append, + .{ .long = "999000000000000000000000000000L" }, + .append, + .setitem, + .{ .unicode = "bool" }, + .{ .put = 7 }, + .{ .bool = false }, + .setitem, + .{ .unicode = "tuple" }, + .{ .put = 8 }, + .mark, + .{ .unicode = "a" }, + .{ .put = 9 }, + .{ .int = 10 }, + .tuple, + .{ .put = 10 }, + .setitem, + .stop, + }; + try std.testing.expectEqualDeep(&expected, ops); +} + +fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T { + buffer.len = 0; + try reader.streamUntilDelimiter(buffer.writer(), '\n', 13); + return std.fmt.parseInt(T, buffer.constSlice(), 10); +} + fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { const T = std.meta.Int(.unsigned, 8 * len_bytes); const str_len: u64 = try reader.readInt(T, .little); @@ -948,3 +1057,9 @@ fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: _ = try reader.read(buf); return buf; } + +fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).Int.bits, 8)]u8 { + var res: [@divExact(@typeInfo(T).Int.bits, 8)]u8 = undefined; + std.mem.writeInt(T, &res, value, .little); + return res; +} diff --git a/zml/aio/torch/value.zig b/zml/aio/torch/py.zig similarity index 60% rename from zml/aio/torch/value.zig rename to zml/aio/torch/py.zig index 17cd79f..452e4d3 100644 --- a/zml/aio/torch/value.zig +++ b/zml/aio/torch/py.zig @@ -1,101 +1,111 @@ const std = @import("std"); -const big_int = std.math.big.int; +const math = std.math; +const log = std.log.scoped(.zml_aio); const pickle = @import("pickle.zig"); -/// The types of sequences that exist. -pub const SequenceType = enum { - list, - dict, - kv_tuple, - tuple, - set, - frozen_set, -}; - +/// Correspond to a function/constructor call pub const Object = struct { - allocator: std.mem.Allocator, - member: Value, - args: []Value, + member: Any, + args: []Any, + kwargs: []Any, - pub fn init(allocator: std.mem.Allocator, member: Value, args: []Value) !*Object { + pub fn init(allocator: std.mem.Allocator, member: Any, args: []Any, kwargs: []Any) !*Object { const self = try allocator.create(Object); - self.* = .{ .allocator = allocator, .member = member, .args = args }; + self.* = .{ .member = member, .args = args, .kwargs = kwargs }; return self; } pub fn clone(self: *Object, allocator: std.mem.Allocator) std.mem.Allocator.Error!*Object { const res = try allocator.create(Object); - res.* = .{ .allocator = allocator, .member = try self.member.clone(allocator), .args = try allocator.alloc(Value, self.args.len) }; + res.* = .{ + .member = try self.member.clone(allocator), + .args = try allocator.alloc(Any, self.args.len), + .kwargs = try allocator.alloc(Any, self.kwargs.len), + }; for (self.args, 0..) |v, i| res.args[i] = try v.clone(allocator); + for (self.kwargs, 0..) |v, i| res.kwargs[i] = try v.clone(allocator); return res; } - pub fn deinit(self: *Object) void { - self.member.deinit(self.allocator); - for (self.args) |*v| v.deinit(self.allocator); - self.allocator.free(self.args); + pub fn deinit(self: *Object, allocator: std.mem.Allocator) void { + self.member.deinit(allocator); + for (self.args) |*v| v.deinit(allocator); + allocator.free(self.args); + allocator.destroy(self); + } +}; + +/// Correspond to the __set_state__ call when pickle finishes building an object. +pub const SetState = struct { + obj: Any, + state: Any, + + pub fn init(allocator: std.mem.Allocator, obj: Any, state: Any) !*SetState { + const res = try allocator.create(SetState); + res.* = .{ .obj = obj, .state = state }; + return res; + } + + pub fn clone(self: *SetState, allocator: std.mem.Allocator) std.mem.Allocator.Error!*SetState { + const res = try allocator.create(SetState); + res.* = .{ .obj = try self.obj.clone(allocator), .state = try self.state.clone(allocator) }; + return res; + } + + pub fn deinit(self: *SetState, allocator: std.mem.Allocator) void { + self.obj.deinit(allocator); + self.state.deinit(allocator); self.allocator.destroy(self); } }; -pub const Build = struct { - allocator: std.mem.Allocator, - member: Value, - args: Value, - - pub fn init(allocator: std.mem.Allocator, member: Value, args: Value) !*Build { - const self = try allocator.create(Build); - self.* = .{ .allocator = allocator, .member = member, .args = args }; - return self; - } - - pub fn clone(self: *Build, allocator: std.mem.Allocator) std.mem.Allocator.Error!*Build { - const res = try allocator.create(Build); - res.* = .{ .allocator = allocator, .member = try self.member.clone(allocator), .args = try self.args.clone(allocator) }; - return res; - } - - pub fn deinit(self: *Build) void { - self.member.deinit(self.allocator); - self.args.deinit(self.allocator); - self.allocator.destroy(self); - } +/// The types of sequences that exist. +pub const SequenceType = enum { + list, + dict, + tuple, + set, + frozen_set, }; pub const Sequence = struct { type: SequenceType, - values: []Value, + values: []Any, }; -pub const PersId = struct { - allocator: std.mem.Allocator, - ref: Value, +pub fn tuple(values: []const Any) Any { + // tuple are readonly, but sequence in general aren't + return .{ .seq = .{ .type = .tuple, .values = @constCast(values) } }; +} - pub fn init(allocator: std.mem.Allocator, ref: Value) !*PersId { +pub const PersId = struct { + ref: Any, + + pub fn init(allocator: std.mem.Allocator, ref: Any) !*PersId { const self = try allocator.create(PersId); - self.* = .{ .allocator = allocator, .ref = ref }; + self.* = .{ .ref = ref }; return self; } pub fn clone(self: *PersId, allocator: std.mem.Allocator) std.mem.Allocator.Error!*PersId { const res = try allocator.create(PersId); - res.* = .{ .allocator = allocator, .ref = try self.ref.clone(allocator) }; + res.* = .{ .ref = try self.ref.clone(allocator) }; return res; } - pub fn deinit(self: *PersId) void { - self.ref.deinit(self.allocator); - self.allocator.destroy(self); + pub fn deinit(self: *PersId, allocator: std.mem.Allocator) void { + self.ref.deinit(allocator); + allocator.destroy(self); } }; -pub const ValueType = enum { +pub const Kind = enum { raw, ref, app, object, - build, + set_state, pers_id, global, seq, @@ -110,7 +120,7 @@ pub const ValueType = enum { }; /// A pickle operator that has been interpreted. -pub const Value = union(ValueType) { +pub const Any = union(Kind) { /// Types that we can't handle or just had to give up on processing. raw: pickle.Op, @@ -128,9 +138,10 @@ pub const Value = union(ValueType) { /// thing, the second one is the arguments it got applied to. object: *Object, - /// Something we tried to build. The first tuple member is the - /// thing, the second one is the arguments it got applied to. - build: *Build, + /// Correspond to the __set_state__ call when pickle finishes building an object. + /// The first tuple member is the target object, + /// the second one is the "state" argument + set_state: *SetState, /// References to persistant storage. They basically could be anything. /// You kind of have to know what the thing you're trying to @@ -164,7 +175,7 @@ pub const Value = union(ValueType) { int64: i64, /// An integer that can't fit in i64. - bigint: big_int.Managed, + bigint: math.big.int.Const, /// An float, but not the crazy kind that comes as a string /// that has to be parsed. You can look in `Value.raw_num` for @@ -180,7 +191,7 @@ pub const Value = union(ValueType) { /// Python `None`. none: void, - pub fn deinit(self: *Value, allocator: std.mem.Allocator) void { + pub fn deinit(self: *Any, allocator: std.mem.Allocator) void { switch (self.*) { .raw, .raw_num => |v| v.deinit(allocator), inline .app, .object, .global, .build, .pers_id => |v| v.deinit(), @@ -189,7 +200,7 @@ pub const Value = union(ValueType) { allocator.free(v.values); }, .string, .bytes => |v| allocator.free(v), - .bigint => self.bigint.deinit(), + .bigint => |big| allocator.free(big.limbs), else => {}, } self.* = undefined; @@ -200,7 +211,7 @@ pub const Value = union(ValueType) { // try writer.writeByteNTimes('\t'); } - fn internalFormat(value: Value, indents: usize, writer: anytype) !void { + fn internalFormat(value: Any, indents: usize, writer: anytype) !void { try writeIndents(indents, writer); try writer.writeAll(".{\n"); try writeIndents(indents + 1, writer); @@ -209,9 +220,12 @@ pub const Value = union(ValueType) { inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}), .app, .object, .global => |v| { try writer.writeAll(".{\n"); + try writeIndents(indents + 2, writer); + try writer.writeAll(".fn ="); try internalFormat(v.member, indents + 2, writer); try writer.writeAll(",\n"); try writeIndents(indents + 2, writer); + try writer.writeAll(".args = "); if (v.args.len > 0) { try writer.writeAll(".{\n"); for (v.args, 0..) |arg, i| { @@ -220,6 +234,20 @@ pub const Value = union(ValueType) { try writer.writeByte('\n'); } try writeIndents(indents + 2, writer); + try writer.writeAll("},\n"); + } else { + try writer.writeAll(".{},\n"); + } + try writeIndents(indents + 2, writer); + try writer.writeAll(".kwargs ="); + if (v.kwargs.len > 0) { + try writer.writeAll(".{\n"); + for (v.kwargs, 0..) |arg, i| { + try internalFormat(arg, indents + 3, writer); + if (i < v.kwargs.len - 1) try writer.writeAll(","); + try writer.writeByte('\n'); + } + try writeIndents(indents + 2, writer); try writer.writeAll("}\n"); } else { try writer.writeAll(".{}\n"); @@ -227,16 +255,16 @@ pub const Value = union(ValueType) { try writeIndents(indents + 1, writer); try writer.writeAll("}"); }, - .build => |v| { + .set_state => |v| { try writer.writeAll(".{\n"); - try internalFormat(v.member, indents + 2, writer); + try internalFormat(v.obj, indents + 2, writer); try writer.writeAll(",\n"); - try internalFormat(v.args, indents + 2, writer); + try internalFormat(v.state, indents + 2, writer); try writer.writeAll(",\n"); try writeIndents(indents + 1, writer); try writer.writeAll("}"); }, - inline .pers_id => |v| { + .pers_id => |v| { try writer.writeByte('\n'); try internalFormat(v.ref, indents + 2, writer); }, @@ -275,26 +303,26 @@ pub const Value = union(ValueType) { try writer.writeByte('}'); } - pub fn format(self: Value, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + pub fn format(self: Any, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { return internalFormat(self, 0, writer); } - pub fn clone(self: Value, allocator: std.mem.Allocator) !Value { + pub fn clone(self: Any, allocator: std.mem.Allocator) !Any { return switch (self) { - inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), - inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), + inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)), + inline .app, .object, .global, .set_state, .pers_id => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)), .seq => |seq| { - const values = try allocator.alloc(Value, seq.values.len); + const values = try allocator.alloc(Any, seq.values.len); for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator); return .{ .seq = .{ .type = seq.type, .values = values } }; }, - inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), - .bigint => |v| .{ .bigint = try v.clone() }, + inline .string, .bytes => |v, tag| @unionInit(Any, @tagName(tag), try allocator.dupe(u8, v)), + .bigint => |v| .{ .bigint = (try v.toManaged(allocator)).toConst() }, else => self, }; } - pub fn isPrimitive(self: Value) bool { + pub fn isPrimitive(self: Any) bool { return switch (self) { .int64, .bigint, .float64, .string, .bytes, .boolval, .none => true, .seq => |seq| { @@ -307,7 +335,7 @@ pub const Value = union(ValueType) { }; } - pub fn containsRef(self: Value) bool { + pub fn containsRef(self: Any) bool { switch (self) { .ref => return true, .app, .object, .global => |v| { @@ -315,9 +343,9 @@ pub const Value = union(ValueType) { for (v.args) |arg| if (arg.containsRef()) return true; return false; }, - .build => |v| { - if (v.member.containsRef()) return true; - if (v.args.containsRef()) return true; + .set_state => |v| { + if (v.obj.containsRef()) return true; + if (v.state.containsRef()) return true; return false; }, .pers_id => |v| return v.ref.containsRef(), @@ -329,44 +357,49 @@ pub const Value = union(ValueType) { } } - const BI64MIN = big_int.Const{ - .limbs = &.{@intCast(@abs(std.math.minInt(i64)))}, - .positive = false, - }; + pub const UnpickleError = error{ InvalidCharacter, OutOfMemory }; - const BI64MAX = big_int.Const{ - .limbs = &.{@intCast(std.math.maxInt(i64))}, - .positive = true, - }; - - pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value { + pub fn coerceFromRaw(self: Any, allocator: std.mem.Allocator) UnpickleError!Any { return switch (self) { .raw => |raw_val| switch (raw_val) { - .binint => |val| .{ .int64 = val }, - .long => |b| if (b.len != 0) { - // TODO: handle trailing 'L' - var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); - var mutable = bint.toMutable(); - mutable.readTwosComplement(b, b.len, .little, .signed); - const min_comp = bint.toConst().order(BI64MIN); - const max_comp = bint.toConst().order(BI64MAX); - if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) { - defer bint.deinit(); - return .{ .int64 = try bint.to(i64) }; - } else return .{ .bigint = bint }; - } else .{ .raw_num = raw_val }, + .none => .none, + .bool => |b| .{ .boolval = b }, + .float => |b| .{ .float64 = std.fmt.parseFloat(f64, b) catch std.math.nan(f64) }, + .int => |val| .{ .int64 = val }, + .long => |digits| { + const n = std.fmt.parseInt(i64, digits[0 .. digits.len - 1], 10) catch |err| { + switch (err) { + error.Overflow => { + log.warn("Not parsing long integer: {s}", .{digits}); + return self; + }, + error.InvalidCharacter => return error.InvalidCharacter, + } + }; + return .{ .int64 = n }; + }, + .binlong => |bytes| if (bytes.len <= 8) + .{ .int64 = std.mem.readVarInt(i64, bytes, .little) } + else { + // Note: we need to copy here, because Zig big int limbs are usize aligned, + // whereas pickle big int are byte aligned. + const n_limbs = std.math.divCeil(usize, bytes.len, @sizeOf(math.big.Limb)) catch unreachable; + var big = (try math.big.int.Managed.initCapacity(allocator, n_limbs)).toMutable(); + big.readTwosComplement(bytes, bytes.len * 8, .little, .signed); + + return .{ .bigint = big.toConst() }; + }, .binfloat => |val| .{ .float64 = val }, .unicode => |s| .{ .string = s }, - .bytes => |b| .{ .bytes = b }, + inline .bytes, .bytearray => |b| .{ .bytes = b }, // This isn't how Pickle actually works but we just try to UTF8 decode the // string and if it fails, we make it a bytes value instead. If anyone // actually cares they can just fix values themselves or recover the raw bytes // from the UTF8 string (it's guaranteed to be reversible, as far as I know). - .string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, - .bool => |b| .{ .boolval = b }, - .none => .{ .none = {} }, - // TODO .int should be handled like .long - .int, .float => .{ .raw_num = raw_val }, + .string => |b| if (std.unicode.utf8ValidateSlice(b)) + .{ .string = b } + else + .{ .bytes = b }, else => self, }, .app, .object, .global => |v| blk: { @@ -376,9 +409,9 @@ pub const Value = union(ValueType) { } break :blk self; }, - .build => |v| blk: { - v.member = try v.member.coerceFromRaw(allocator); - v.args = try v.args.coerceFromRaw(allocator); + .set_state => |v| blk: { + v.obj = try v.obj.coerceFromRaw(allocator); + v.state = try v.state.coerceFromRaw(allocator); break :blk self; }, .pers_id => |v| blk: { diff --git a/zml/aio/torch/simple.pt b/zml/aio/torch/simple.pt index beeb39d..f6978e6 100644 Binary files a/zml/aio/torch/simple.pt and b/zml/aio/torch/simple.pt differ diff --git a/zml/aio/torch/simple_test.pickle b/zml/aio/torch/simple_test.pickle deleted file mode 100644 index 19d7260..0000000 Binary files a/zml/aio/torch/simple_test.pickle and /dev/null differ diff --git a/zml/aio/torch/simple_test_4.pickle b/zml/aio/torch/simple_test_4.pickle new file mode 100644 index 0000000..4a51765 Binary files /dev/null and b/zml/aio/torch/simple_test_4.pickle differ