Refactor torch module: merge PickleData into Parser as torch.File, rename value file to py_object.zig, use buffered reader for pickle and zip headers, adjust intermediate result handling, simplify Python dict representation, separate kwargs from args, and add extensive tests for long integers, protocol 0, zipped pickle, and a complex PyTorch Conv2d case; also streamline BufferStore initialization.

This commit is contained in:
Tarry Singh 2023-04-20 15:43:18 +00:00
parent 837f8fb111
commit 11006ca08d
13 changed files with 1142 additions and 1621 deletions

View File

@ -62,7 +62,7 @@ zig_cc_test(
name = "test", name = "test",
data = [ data = [
"aio/torch/simple.pt", "aio/torch/simple.pt",
"aio/torch/simple_test.pickle", "aio/torch/simple_test_4.pickle",
], ],
test_runner = ":test_runner", test_runner = ":test_runner",
deps = [":zml"], deps = [":zml"],

View File

@ -104,6 +104,15 @@ pub const BufferStore = struct {
buffers: Buffers = .{}, buffers: Buffers = .{},
_metadata: Metadatas = .{}, _metadata: Metadatas = .{},
/// Create an empty BufferStore. Takes owneship of the given files.
pub fn init(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore {
var self: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
self.files = try self.arena.allocator().dupe(MemoryMappedFile, files);
return self;
}
pub fn deinit(self: BufferStore) void { pub fn deinit(self: BufferStore) void {
for (self.files) |*file| file.deinit(); for (self.files) |*file| file.deinit();
self.arena.deinit(); self.arena.deinit();
@ -255,7 +264,7 @@ pub const MemoryMappedFile = struct {
}; };
} }
pub fn mappedSlice(self: *MemoryMappedFile, start: usize, len: usize) []const u8 { pub fn mappedSlice(self: MemoryMappedFile, start: usize, len: usize) []const u8 {
return self.data[self.data_offset + start ..][0..len]; return self.data[self.data_offset + start ..][0..len];
} }
@ -578,7 +587,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
return if (buffer_store.get(prefix)) |host_buffer| { return if (buffer_store.get(prefix)) |host_buffer| {
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer; var buf_with_metadata = host_buffer;
log.warn("loading {s} ({})", .{ prefix, obj._shape }); log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
buf_with_metadata._shape = obj._shape; buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata); obj.* = try zml.Buffer.from(platform, buf_with_metadata);

View File

@ -1,11 +1,12 @@
const asynk = @import("async");
const eval = @import("torch/eval.zig");
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.zml_aio);
const asynk = @import("async");
const yaml = @import("zig-yaml"); const yaml = @import("zig-yaml");
const eval = @import("torch/eval.zig");
const zml = @import("../zml.zig"); const zml = @import("../zml.zig");
const File = @import("torch/file.zig").File;
const parser = @import("torch/parser.zig");
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
@ -14,8 +15,11 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
}; };
errdefer res.arena.deinit(); errdefer res.arena.deinit();
// TODO(cryptodeal): this is incorrect, you should use a temporary arena for all intermediary allocations.
const arena = res.arena.allocator(); const arena = res.arena.allocator();
// TODO(cryptodeal): mapped_file will never be close in case of success.
// You need to store it inside the result.
var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{})); var mapped_file = try zml.aio.MemoryMappedFile.init(try asynk.File.open(path, .{}));
errdefer mapped_file.deinit(); errdefer mapped_file.deinit();
@ -37,13 +41,11 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]); try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) { } else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
const start = try mapped_file.file.getPos(); const start = try mapped_file.file.getPos();
var tmp: zml.aio.torch.PickleData = .{ var torch_file = try File.fromTarFile(arena, mapped_file, file);
.data = try parser.Parser.fromTarFile(arena, mapped_file, file), const ops = try torch_file.parsePickle(arena);
.stack = undefined, const values = try eval.evaluate(arena, ops, true);
};
tmp.stack = try eval.evaluate(arena, tmp.data.ops, true);
try tmp.parseModel(arena, &res); try torch_file.parseModel(values, &res);
// Since we directly manipulate the file handle pointer, // Since we directly manipulate the file handle pointer,
// reset to the end of file so iterator does not error // reset to the end of file so iterator does not error
// and avoid `skipBytes`. // and avoid `skipBytes`.

View File

@ -2,24 +2,18 @@ const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const zml = @import("../zml.zig"); const zml = @import("../zml.zig");
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const eval = @import("torch/eval.zig"); const eval = @import("torch/eval.zig");
const value = @import("torch/value.zig"); const py = @import("torch/py.zig");
const parser = @import("torch/parser.zig"); const File = @import("torch/file.zig").File;
const PersId = value.PersId;
const Sequence = value.Sequence;
const Value = value.Value;
const ValueType = value.ValueType;
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.zml_aio);
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
std.testing.refAllDecls(eval); std.testing.refAllDecls(eval);
std.testing.refAllDecls(value); std.testing.refAllDecls(py);
std.testing.refAllDecls(parser); std.testing.refAllDecls(File);
} }
/// Opens and loads a BufferStore from the torch file at the given path. /// Opens and loads a BufferStore from the torch file at the given path.
@ -35,392 +29,14 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
defer arena.deinit(); defer arena.deinit();
const tmp_alloc = arena.allocator(); const tmp_alloc = arena.allocator();
const _parser = try parser.Parser.init(tmp_alloc, file); const mmap_file = try zml.aio.MemoryMappedFile.init(file);
const stack = try eval.evaluate(tmp_alloc, _parser.ops, true); var torch_file = try File.init(tmp_alloc, mmap_file);
// But we create the HostBuffer objects inside the result BufferStore arena. const ops = try torch_file.parsePickle(tmp_alloc);
var res: zml.aio.BufferStore = .{ const py_values = try eval.evaluate(tmp_alloc, ops, true);
.arena = std.heap.ArenaAllocator.init(allocator),
}; // file ownership is transferred to the BufferStore
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file}); var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file});
var tmp: PickleData = .{ .data = _parser, .stack = stack }; try torch_file.parseModel(py_values, &res);
try tmp.parseModel(res.arena.allocator(), &res);
return res; return res;
} }
// TODO: rename me to PytorchFile
pub const PickleData = struct {
stack: []const Value,
data: parser.Parser,
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
return switch (object.member) {
.raw => |raw| return (object.args[0] == .seq and
std.mem.eql(u8, module, raw.global.module) and
std.mem.eql(u8, class, raw.global.class)),
else => false,
};
}
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
for (self.stack) |item| {
var prefix_buf: [1024]u8 = undefined;
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
}
}
pub fn parseValue(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void {
switch (v) {
.app, .object, .global => |object| {
if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) {
try self.parseValue(allocator, store, prefix, object.member);
for (object.args) |item| {
// if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix)
if (item == .seq and item.seq.type == .tuple and item.seq.values.len == 2 and item.seq.values[0] == .string) {
try self.parseValue(allocator, store, prefix, .{ .seq = .{ .type = .kv_tuple, .values = item.seq.values } });
} else try self.parseValue(allocator, store, prefix, item);
}
}
},
.build => |build| {
// `build` contains info about python struct being constructed
switch (build.member) {
.object => |obj| switch (obj.member) {
.raw => |raw| switch (raw) {
.global => |global| {
// in this case, we can capture the name of the python type
// which can be used for codegen (e.g. `torch.nn.modules.conv.Conv2d`)
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity("_gen_type_helper");
const key = try allocator.dupe(u8, new_prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.err("Duplicate key: {s}", .{new_prefix.items});
allocator.free(key);
} else {
const val = try std.mem.join(allocator, ".", &.{ global.module, global.class });
d.value_ptr.* = .{ .string = val };
}
},
else => try self.parseValue(allocator, store, prefix, build.member), // parse normally
},
else => try self.parseValue(allocator, store, prefix, build.member), // parse normally
},
else => try self.parseValue(allocator, store, prefix, build.member), // parse normally
}
try self.parseValue(allocator, store, prefix, build.args);
},
.pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref),
.seq => |seq| {
switch (seq.type) {
.list, .tuple, .set, .frozen_set => {
if (seq.values.len == 0) return;
var valid_slice = true;
switch (seq.values[0]) {
inline .int64, .float64, .boolval => |val0, tag| {
const ItemType = switch (tag) {
.int64 => i64,
.float64 => f64,
.boolval => bool,
else => unreachable,
};
var values: std.ArrayListUnmanaged(ItemType) = .{};
try values.append(allocator, val0);
for (seq.values[1..], 1..) |val, i| {
if (std.meta.activeTag(val) != tag) valid_slice = false;
if (valid_slice) {
try values.append(allocator, @field(val, @tagName(tag)));
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
}
if (valid_slice) {
try store._metadata.put(
allocator,
try allocator.dupe(u8, prefix.items),
try zml.aio.Metadata.copySlice(allocator, values.items),
);
} else {
for (values.items, 0..) |val, i| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
const new_tag = switch (tag) {
.int64 => "int",
.float64 => "float",
.boolval => "bool",
else => unreachable, // we are already inside a switch
};
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val));
}
}
},
else => {
for (seq.values, 0..) |item, i| {
var new_prefix = prefix;
if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
}
},
}
},
.dict => for (seq.values) |item| {
try self.parseValue(allocator, store, prefix, item);
},
.kv_tuple => {
const key, const val = seq.values[0..2].*;
switch (key) {
.string => |s| {
// Handle Pytorch specific fields
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
try self.parseValue(allocator, store, prefix, val);
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
try self.parseValue(allocator, store, new_prefix, val);
}
},
.int64 => |int| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
},
inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}),
}
},
}
},
.bytes => |val| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else d.value_ptr.* = .{ .string = val };
},
inline .float64, .int64, .boolval, .bigint, .string => |val| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else {
d.value_ptr.* = zml.aio.Metadata.wrap(val);
}
},
else => {},
}
}
fn parseTorchGlobal(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool {
return switch (v) {
.global => |object| {
if (try self.parseTensor(allocator, object)) |host_buffer| {
const key = try allocator.dupe(u8, prefix.items);
const entry = try store.buffers.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
entry.value_ptr.* = host_buffer;
return true;
} else if (basicTypeCheck(object, "torch", "Size")) {
const size = object.args[0].seq.values[0].seq.values;
const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array_int = d };
return true;
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq.values[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
}
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
}
return true;
}
}
return false;
},
else => false,
};
}
fn parseTensor(self: *PickleData, tmp_allocator: std.mem.Allocator, object: *value.Object) !?zml.HostBuffer {
if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) {
return null;
}
const args = object.args[0].seq.values;
if (args.len < 4 or
args[0] != .pers_id or
args[1] != .int64 or
args[2] != .seq or args[2].seq.type != .tuple or
args[3] != .seq or args[3].seq.type != .tuple)
{
log.err("Unexpected value in call to torch._utils._rebuild_tensor_v2", .{});
return error.InvalidInput;
}
const pid: *PersId = args[0].pers_id;
var offset: u64 = @intCast(args[1].int64);
const raw_dims: Sequence = args[2].seq;
const raw_strides: Sequence = args[3].seq;
const dims = try parseDims(raw_dims.values);
var strides = try parseDims(raw_strides.values);
const dtype, const storage_file = try parseStorage(pid.ref);
// Pytorch store "item" strides, while ZML uses byte strides.
for (strides.slice()) |*s| s.* *= dtype.sizeOf();
// Same thing for the offset.
offset = offset * dtype.sizeOf();
const filename = try std.mem.join(tmp_allocator, "", &.{ self.data.zip_prefix, "data/", storage_file });
defer tmp_allocator.free(filename);
// The offset in the pickle is the offset inside the storage_file.
// But .pt are made of several files, so we need to append the file offset.
const storage = try self.getStorage(filename);
return HostBuffer.fromStridedSlice(
zml.Shape.init(dims.constSlice(), dtype),
storage[offset..],
strides.constSlice(),
);
}
fn parseStorage(val: value.Value) !struct { zml.DataType, []const u8 } {
if (val != .seq) return error.InvalidInput;
const sargs = val.seq.values;
if (val.seq.type == .tuple and
sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and
sargs[2] == .string and
sargs[3] == .string)
{
const op = sargs[1].raw.global;
const storage_file = sargs[2].string;
// const sdev = sargs[3].string;
if (!std.mem.eql(u8, "torch", op.module) or
!std.mem.endsWith(u8, op.class, "Storage"))
return error.InvalidInput;
return .{
try storageToDtype(op.class),
storage_file,
};
} else {
return error.InvalidInput;
}
}
/// Given the name of one of the files in the .pt tarball,
/// return the slice of the memory-mapped .pt corresponding to it.
fn getStorage(self: *PickleData, filename: []const u8) ![]const u8 {
const maybe_entry = self.data.file_map.get(filename);
if (maybe_entry == null) {
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
return error.TensorNotFound;
}
const entry = maybe_entry.?;
const base_offset: u64 = if (self.data.tar_file) |t| t.start else 0;
const file_offset: u64 = base_offset + entry.file_offset;
const file = self.data.buffer_file.file;
try file.seekTo(entry.file_offset);
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
const start = file_offset +
@sizeOf(std.zip.LocalFileHeader) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
return self.data.buffer_file.mappedSlice(start, entry.uncompressed_size);
}
fn parseDims(values: []Value) error{InvalidInput}!zml.Shape.DimsArray {
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
var result: zml.Shape.DimsArray = .{};
for (values) |val| {
switch (val) {
.int64 => |d| result.appendAssumeCapacity(d),
else => return error.InvalidInput,
}
}
return result;
}
};
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}

View File

@ -1,652 +0,0 @@
const std = @import("std");
/// BTreeMap Node implementation.
pub fn NodeType(comptime K: type, comptime V: type, comptime B: u32) type {
return struct {
const Self = @This();
keys: [2 * B - 1]K = [_]K{undefined} ** (2 * B - 1),
values: [2 * B - 1]V = [_]V{undefined} ** (2 * B - 1),
len: usize = 0,
edges: [2 * B]?*Self = [_]?*Self{null} ** (2 * B),
pub const KV = struct { key: K, value: V };
const KVE = struct { key: K, value: V, edge: ?*Self };
const Entry = struct { key_ptr: *K, value_ptr: *V };
/// Initializes an empty Node.
pub fn initEmpty(allocator: std.mem.Allocator) !*Self {
const res: *Self = try allocator.create(Self);
res.* = .{};
return res;
}
/// Initializes a Node with a single Entry.
pub fn initKeyValue(allocator: std.mem.Allocator, entry: struct { K, V }) !*Self {
const key, const value = entry;
var res = try Self.initEmpty(allocator);
res.keys[0] = key;
res.values[0] = value;
res.len = 1;
return res;
}
fn initFromSplit(allocator: std.mem.Allocator, keys: []K, values: []V, edges: []?*Self) !*Self {
var out = try Self.initEmpty(allocator);
std.mem.copyBackwards(K, out.keys[0..], keys);
std.mem.copyBackwards(V, out.values[0..], values);
std.mem.copyBackwards(?*Self, out.edges[0..], edges);
out.len = keys.len;
return out;
}
pub fn count(self: Self) usize {
var len: usize = self.len;
for (0..self.len + 1) |i| {
if (!self.isLeaf()) {
len += self.edges[i].?.count();
}
}
return len;
}
// Searches the Node for a key.
pub fn search(self: Self, key: K) std.meta.Tuple(&.{ bool, usize }) {
var i: usize = 0;
while (i < self.len) : (i += 1) {
if (eql(key, self.keys[i])) {
return .{ true, i };
} else if (lt(key, self.keys[i])) {
return .{ false, i };
}
}
return .{ false, self.len };
}
pub fn insertOrSplit(
self: *Self,
allocator: std.mem.Allocator,
index: usize,
key: K,
value: V,
edge: ?*Self,
) !?KVE {
if (self.isFull()) {
var split_result = try self.split(allocator);
switch (index < B) {
true => self.insert(index, key, value, edge),
false => split_result.edge.?.insert(index - B, key, value, edge),
}
return split_result;
}
self.insert(index, key, value, edge);
return null;
}
pub fn swapValue(self: *Self, index: usize, value: V) V {
const out = self.values[index];
self.values[index] = value;
return out;
}
pub fn swapKeyValue(self: *Self, index: usize, key: K, value: V) KV {
const out = .{ .key = self.keys[index], .value = self.values[index] };
self.values[index] = value;
self.keys[index] = key;
return out;
}
pub fn orderedRemove(self: *Self, index: usize) KVE {
const out: KVE = .{
.key = self.keys[index],
.value = self.values[index],
.edge = self.edges[index + 1],
};
std.mem.copyForwards(K, self.keys[index..], self.keys[index + 1 .. self.len]);
std.mem.copyForwards(V, self.values[index..], self.values[index + 1 .. self.len]);
self.keys[self.len - 1] = undefined;
self.values[self.len - 1] = undefined;
if (!self.isLeaf()) {
std.mem.copyForwards(?*Self, self.edges[index + 1 ..], self.edges[index + 2 .. self.len + 1]);
self.edges[self.len] = null;
}
self.len -= 1;
return out;
}
fn pop(self: *Self) KVE {
return self.orderedRemove(self.len - 1);
}
fn shift(self: *Self) KVE {
const out: KVE = .{
.key = self.keys[0],
.value = self.values[0],
.edge = self.edges[0],
};
std.mem.copyForwards(K, self.keys[0..], self.keys[1..self.len]);
std.mem.copyForwards(V, self.values[0..], self.values[1..self.len]);
self.keys[self.len - 1] = undefined;
self.values[self.len - 1] = undefined;
if (!self.isLeaf()) {
std.mem.copyForwards(
?*Self,
self.edges[0..],
self.edges[1 .. self.len + 1],
);
self.edges[self.len] = null;
}
self.len -= 1;
return out;
}
fn insert(self: *Self, index: usize, key: K, value: V, edge: ?*Self) void {
std.mem.copyBackwards(
K,
self.keys[index + 1 .. self.len + 1],
self.keys[index..self.len],
);
self.keys[index] = key;
std.mem.copyBackwards(V, self.values[index + 1 .. self.len + 1], self.values[index..self.len]);
self.values[index] = value;
if (!self.isLeaf()) {
std.mem.copyBackwards(?*Self, self.edges[index + 2 .. self.len + 2], self.edges[index + 1 .. self.len + 1]);
self.edges[index + 1] = edge;
}
self.len += 1;
}
fn append(self: *Self, key: K, value: V, edge: ?*Self) void {
self.keys[self.len] = key;
self.values[self.len] = value;
self.edges[self.len + 1] = edge;
self.len += 1;
}
fn unshift(self: *Self, key: K, value: V, edge: ?*Self) void {
std.mem.copyBackwards(K, self.keys[1 .. self.len + 1], self.keys[0..self.len]);
self.keys[0] = key;
std.mem.copyBackwards(V, self.values[1 .. self.len + 1], self.values[0..self.len]);
self.values[0] = value;
if (!self.isLeaf()) {
std.mem.copyBackwards(?*Self, self.edges[1 .. self.len + 2], self.edges[0 .. self.len + 1]);
self.edges[0] = edge;
}
self.len += 1;
}
pub fn borrowRight(self: *Self, index: usize) bool {
if (index == self.len) return false;
var from = self.edges[index + 1].?;
if (from.len > B - 1) {
var to = self.edges[index].?;
const borrowed = from.shift();
to.append(self.keys[index], self.values[index], borrowed.edge);
_ = self.swapKeyValue(index, borrowed.key, borrowed.value);
return true;
}
return false;
}
pub fn borrowLeft(self: *Self, index: usize) bool {
if (index == 0) return false;
var from = self.edges[index - 1].?;
if (from.len > B - 1) {
var to = self.edges[index].?;
const borrowed = from.pop();
to.unshift(self.keys[index - 1], self.values[index - 1], borrowed.edge);
_ = self.swapKeyValue(index - 1, borrowed.key, borrowed.value);
return true;
}
return false;
}
pub fn mergeEdges(self: *Self, allocator: std.mem.Allocator, left_edge_index: usize) void {
var left = self.edges[left_edge_index].?;
const removed = self.orderedRemove(left_edge_index);
left.append(removed.key, removed.value, null);
std.mem.copyBackwards(K, left.keys[left.len..], removed.edge.?.keys[0..removed.edge.?.len]);
std.mem.copyBackwards(V, left.values[left.len..], removed.edge.?.values[0..removed.edge.?.len]);
std.mem.copyBackwards(?*Self, left.edges[left.len..], removed.edge.?.edges[0 .. removed.edge.?.len + 1]);
left.len += removed.edge.?.len;
allocator.destroy(removed.edge.?);
}
fn split(self: *Self, allocator: std.mem.Allocator) !KVE {
const median = B - 1;
const new_key = self.keys[median];
const new_value = self.values[median];
const new_node = try Self.initFromSplit(
allocator,
self.keys[median + 1 .. self.len],
self.values[median + 1 .. self.len],
self.edges[median + 1 .. self.len + 1],
);
@memset(self.keys[median..], undefined);
@memset(self.values[median..], undefined);
@memset(self.edges[median + 1 ..], null);
self.len = median;
return .{ .key = new_key, .value = new_value, .edge = new_node };
}
pub fn isLeaf(self: Self) bool {
return self.edges[0] == null;
}
pub fn isFull(self: Self) bool {
return self.len == 2 * B - 1;
}
pub fn isLacking(self: Self) bool {
return self.len < B - 1;
}
};
}
pub fn BTreeMap(comptime K: type, comptime V: type) type {
return struct {
const Self = @This();
const B = 6;
const Node = NodeType(K, V, B);
const KV = Node.KV;
const SearchResult = std.meta.Tuple(&.{ bool, usize });
const StackEntry = struct { node: *Node, index: usize };
allocator: std.mem.Allocator,
root: ?*Node = null,
pub fn init(allocator: std.mem.Allocator) Self {
return .{ .allocator = allocator };
}
pub fn deinit(self: Self) !void {
if (self.root == null) return;
var stack = std.ArrayList(*Node).init(self.allocator);
defer stack.deinit();
if (self.root) |root| {
try stack.append(root);
}
while (stack.popOrNull()) |node| {
if (!node.isLeaf()) {
for (0..node.len + 1) |i| {
try stack.append(node.edges[i].?);
}
}
self.allocator.destroy(node);
}
}
pub fn count(self: Self) usize {
if (self.root == null) return 0;
var len: usize = 0;
if (self.root) |node| {
len += node.count();
}
return len;
}
pub fn isEmpty(self: *const Self) bool {
if (self.root == null) return true;
return self.root.?.len == 0;
}
pub fn get(self: Self, key: K) ?V {
var current = self.root;
while (current) |node| {
const found, const index = node.search(key);
switch (found) {
true => return node.values[index],
false => current = node.edges[index],
}
}
return null;
}
pub fn getPtr(self: Self, key: K) ?*V {
var current = self.root;
while (current) |node| {
const found, const index = node.search(key);
switch (found) {
true => return &node.values[index],
false => current = node.edges[index],
}
}
return null;
}
pub fn fetchPut(self: *Self, key: K, value: V) !?KV {
if (self.root == null) {
self.root = try Node.initKeyValue(self.allocator, .{ key, value });
return null;
}
var stack = std.ArrayList(StackEntry).init(self.allocator);
defer stack.deinit();
var current = self.root;
var search_result: SearchResult = undefined;
while (current) |node| {
search_result = node.search(key);
if (search_result[0]) {
return .{ .key = key, .value = node.swapValue(search_result[1], value) };
}
current = node.edges[search_result[1]];
try stack.append(.{ .node = node, .index = search_result[1] });
}
var stack_next: ?StackEntry = stack.pop();
var split_result = try stack_next.?.node.insertOrSplit(
self.allocator,
stack_next.?.index,
key,
value,
null,
);
if (split_result == null) {
return null;
}
stack_next = stack.popOrNull();
while (split_result) |split_result_unwrapped| {
if (stack_next) |stack_next_unwrapped| {
split_result = try stack_next_unwrapped.node.insertOrSplit(
self.allocator,
stack_next_unwrapped.index,
split_result_unwrapped.key,
split_result_unwrapped.value,
split_result_unwrapped.edge,
);
stack_next = stack.popOrNull();
} else {
var new_root = try Node.initKeyValue(
self.allocator,
.{ split_result_unwrapped.key, split_result_unwrapped.value },
);
new_root.edges[0] = self.root;
new_root.edges[1] = split_result_unwrapped.edge;
self.root = new_root;
return null;
}
} else return null;
}
pub fn fetchRemove(self: *Self, key: K) !?KV {
var stack = std.ArrayList(StackEntry).init(self.allocator);
defer stack.deinit();
var current = self.root;
var search_result: SearchResult = undefined;
var found_key_ptr: ?*K = null;
var found_value_ptr: ?*V = null;
while (current) |node| {
search_result = node.search(key);
if (search_result[0]) {
found_key_ptr = &node.keys[search_result[1]];
found_value_ptr = &node.values[search_result[1]];
if (!node.isLeaf()) search_result[1] += 1;
}
try stack.append(.{
.node = node,
.index = search_result[1],
});
current = node.edges[search_result[1]];
if (search_result[0]) break;
} else return null;
while (current) |node| {
try stack.append(.{ .node = node, .index = 0 });
current = node.edges[0];
}
var current_stack = stack.pop();
const out: KV = .{ .key = found_key_ptr.?.*, .value = found_value_ptr.?.* };
found_key_ptr.?.* = current_stack.node.keys[current_stack.index];
found_value_ptr.?.* = current_stack.node.values[current_stack.index];
_ = current_stack.node.orderedRemove(current_stack.index);
if (current_stack.node == self.root) return out;
while (current_stack.node.isLacking()) {
current_stack = stack.pop();
if (current_stack.node.borrowRight(current_stack.index)) return out;
if (current_stack.node.borrowLeft(current_stack.index)) return out;
if (current_stack.index == current_stack.node.len) {
current_stack.node.mergeEdges(self.allocator, current_stack.index - 1);
} else {
current_stack.node.mergeEdges(self.allocator, current_stack.index);
}
if (current_stack.node == self.root) {
if (self.root.?.len == 0) {
const new_root = current_stack.node.edges[0].?;
self.allocator.destroy(self.root.?);
self.root.? = new_root;
}
break;
}
}
return out;
}
const Iterator = struct {
stack: std.ArrayList(StackEntry),
backwards: bool,
pub fn deinit(it: Iterator) void {
it.stack.deinit();
}
pub fn next(it: *Iterator) ?Node.Entry {
while (it.topStackItem()) |item| {
if (!item.node.isLeaf() and !it.backwards) {
const child = item.node.edges[item.index].?;
it.stack.append(StackEntry{ .node = child, .index = 0 }) catch unreachable;
} else {
if (item.index < item.node.len) {
const out: Node.Entry = .{ .key_ptr = &item.node.keys[item.index], .value_ptr = &item.node.values[item.index] };
item.index += 1;
it.backwards = false;
return out;
} else {
_ = it.stack.popOrNull();
it.backwards = true;
}
}
} else return null;
}
fn topStackItem(it: *Iterator) ?*StackEntry {
return switch (it.stack.items.len) {
0 => null,
else => &it.stack.items[it.stack.items.len - 1],
};
}
};
pub fn iterator(self: *const Self) Iterator {
var new_stack = std.ArrayList(StackEntry).init(self.allocator);
if (self.root) |root| {
new_stack.append(.{ .node = root, .index = 0 }) catch unreachable;
}
return Iterator{
.stack = new_stack,
.backwards = false,
};
}
};
}
/// Compares two of any type for equality. Containers are compared on a field-by-field basis,
/// where possible. Pointers are followed if the addresses are not equal.
fn eql(a: anytype, b: @TypeOf(a)) bool {
const T = @TypeOf(a);
switch (@typeInfo(T)) {
.Struct => |info| {
inline for (info.fields) |field_info| {
if (!eql(@field(a, field_info.name), @field(b, field_info.name))) return false;
}
return true;
},
.ErrorUnion => {
if (a) |a_p| {
if (b) |b_p| return eql(a_p, b_p) else |_| return false;
} else |a_e| {
if (b) |_| return false else |b_e| return a_e == b_e;
}
},
.Union => |info| {
if (info.tag_type) |UnionTag| {
const tag_a = std.meta.activeTag(a);
const tag_b = std.meta.activeTag(b);
if (tag_a != tag_b) return false;
inline for (info.fields) |field_info| {
if (@field(UnionTag, field_info.name) == tag_a) {
return eql(@field(a, field_info.name), @field(b, field_info.name));
}
}
return false;
}
@compileError("Cannot compare untagged union type " ++ @typeName(T));
},
.Array => {
if (a.len != b.len) return false;
for (a, 0..) |e, i|
if (!eql(e, b[i])) return false;
return true;
},
.Vector => |info| {
var i: usize = 0;
while (i < info.len) : (i += 1) {
if (!eql(a[i], b[i])) return false;
}
return true;
},
.Pointer => |info| {
return switch (info.size) {
.One => if (a == b) true else eql(a.*, b.*),
.Many => if (a == b) true else {
if (info.sentinel) {
if (std.mem.len(a) != std.mem.len(b)) return false;
var i: usize = 0;
while (i < std.mem.len(a)) : (i += 1)
if (!eql(a[i], b[i])) return false;
return true;
}
@compileError("Cannot compare many-item Pointers without sentinel value");
},
.C => if (a == b) true else @compileError("Cannot compare C pointers"),
.Slice => if (a.ptr == b.ptr and a.len == b.len) true else {
if (a.len != b.len) return false;
for (a, 0..) |_, i|
if (!eql(a[i], b[i])) return false;
return true;
},
};
},
.Optional => {
if (a == null and b == null) return true;
if (a == null or b == null) return false;
return eql(a.?, b.?);
},
else => return a == b,
}
}
fn lt(a: anytype, b: @TypeOf(a)) bool {
const T = @TypeOf(a);
switch (@typeInfo(T)) {
.Int, .ComptimeInt, .Float, .ComptimeFloat => {
return a < b;
},
.Struct => {
if (!@hasDecl(T, "lt")) {
@compileError("Type `" ++ @typeName(T) ++ "` must implement a `lt` comparison method.");
}
return T.lt(a, b);
},
.Union => |info| {
if (info.tag_type) |UnionTag| {
const tag_a = std.meta.activeTag(a);
const tag_b = std.meta.activeTag(b);
// if tags are not equal, perform comparison based on tag
if (tag_a != tag_b) {
return std.ascii.lessThanIgnoreCase(@tagName(tag_a), @tagName(tag_b));
}
// if tags are equal, compare based on the active field
inline for (info.fields) |field_info| {
if (@field(UnionTag, field_info.name) == tag_a) {
return lt(@field(a, field_info.name), @field(b, field_info.name));
}
}
return false;
}
@compileError("Cannot perform `lt` check on untagged union type " ++ @typeName(T));
},
.Array => {
for (a, 0..) |_, i| {
if (lt(a[i], b[i])) {
return true;
} else if (eql(a[i], b[i])) {
continue;
} else {
return false;
}
}
return false;
},
.Vector => |info| {
var i: usize = 0;
while (i < info.len) : (i += 1) {
if (lt(a[i], b[i])) {
return true;
} else if (eql(a[i], b[i])) {
continue;
} else {
return false;
}
}
return false;
},
.Pointer => |info| {
switch (info.size) {
.One => return lt(a.*, b.*),
.Slice => {
const n = @min(a.len, b.len);
for (a[0..n], 0..) |_, i| {
if (lt(a[i], b[i])) {
return true;
} else if (eql(a[i], b[i])) {
continue;
} else {
return false;
}
}
return lt(a.len, b.len);
},
.Many => {
if (info.sentinel) {
const n = @min(std.mem.len(a), std.mem.len(b));
var i: usize = 0;
while (i < n) : (i += 1) {
if (lt(a[i], b[i])) {
return true;
} else if (eql(a[i], b[i])) {
continue;
} else {
return false;
}
}
return lt(std.mem.len(a), std.mem.len(b));
}
@compileError("Cannot compare many-item pointer to unknown number of items without sentinel value");
},
.C => @compileError("Cannot compare C pointers"),
}
},
.Optional => {
if (a == null or b == null) return false;
return lt(a.?, b.?);
},
else => {
@compileError("Cannot compare type '" ++ @typeName(T) ++ "'");
},
}
}
pub fn gt(a: anytype, b: @TypeOf(a)) bool {
return !lt(a, b) and !eql(a, b);
}

View File

@ -2,42 +2,33 @@ const std = @import("std");
const zml = @import("../../zml.zig"); const zml = @import("../../zml.zig");
const meta = zml.meta; const meta = zml.meta;
const value = @import("value.zig"); const py = @import("py.zig");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
const BTreeMap = @import("b_tree_map.zig").BTreeMap;
const Build = value.Build;
const Object = value.Object;
const PersId = value.PersId;
const Sequence = value.Sequence;
const SequenceType = value.SequenceType;
const Value = value.Value;
const MAX_DEPTH: usize = 250; const MAX_DEPTH: usize = 250;
const MAX_PROTOCOL: u8 = 5; const MAX_PROTOCOL: u8 = 5;
pub const PickleMemo = struct { pub const PickleMemo = struct {
allocator: std.mem.Allocator, map: std.AutoHashMap(u32, py.Any),
map: BTreeMap(u32, Value),
pub fn init(allocator: std.mem.Allocator) PickleMemo { pub fn init(allocator: std.mem.Allocator) PickleMemo {
return .{ return .{
.allocator = allocator, .map = std.AutoHashMap(u32, py.Any).init(allocator),
.map = BTreeMap(u32, Value).init(allocator),
}; };
} }
pub fn deinit(self: *PickleMemo) void { pub fn deinit(self: *PickleMemo) void {
const allocator = self.map.allocator;
var iterator = self.map.iterator(); var iterator = self.map.iterator();
defer iterator.deinit(); defer iterator.deinit();
while (iterator.next()) |entry| { while (iterator.next()) |entry| {
entry.value_ptr.deinit(self.allocator); entry.value_ptr.deinit(allocator);
} }
self.map.deinit() catch unreachable; self.map.deinit() catch unreachable;
self.* = undefined; self.* = undefined;
} }
pub fn resolve(self: *PickleMemo, allocator: std.mem.Allocator, op: Value, recursive: bool) !Value { pub fn resolve(self: *PickleMemo, allocator: std.mem.Allocator, op: py.Any, recursive: bool) !py.Any {
var used_op = op; var used_op = op;
while (used_op == .ref) { while (used_op == .ref) {
var count: usize = 0; var count: usize = 0;
@ -67,12 +58,12 @@ pub const PickleMemo = struct {
} }
} }
}, },
.build => |v| { .set_state => |v| {
if (v.member.containsRef()) { if (v.obj.containsRef()) {
v.member = try self.resolve(allocator, v.member, recursive); v.obj = try self.resolve(allocator, v.obj, recursive);
} }
if (v.args.containsRef()) { if (v.state.containsRef()) {
v.args = try self.resolve(allocator, v.args, recursive); v.state = try self.resolve(allocator, v.state, recursive);
} }
}, },
.pers_id => |v| { .pers_id => |v| {
@ -93,11 +84,11 @@ pub const PickleMemo = struct {
return used_op; return used_op;
} }
pub fn insert(self: *PickleMemo, mid: u32, val: Value) !void { pub fn insert(self: *PickleMemo, mid: u32, val: py.Any) !void {
_ = try self.map.fetchPut(mid, val); _ = try self.map.fetchPut(mid, val);
} }
pub fn resolveMut(self: *PickleMemo, op: *Value, recursive: bool) !*Value { pub fn resolveMut(self: *PickleMemo, op: *py.Any, recursive: bool) !*py.Any {
if (op.* != .ref) return op; if (op.* != .ref) return op;
var lastmid = op.ref; var lastmid = op.ref;
var count: usize = 0; var count: usize = 0;
@ -122,34 +113,35 @@ pub const PickleMemo = struct {
}); });
} }
const MemoError = std.math.big.int.Managed.ConvertError || std.mem.Allocator.Error || error{BadMemoRef}; const MemoError = py.Any.UnpickleError || error{BadMemoRef};
pub fn resolveAllRefsIter(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, vals: []Value, fix_values: bool) MemoError![]Value { pub fn resolveAllRefsIter(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, vals: []py.Any, fix_values: bool) MemoError![]py.Any {
if (depth >= MAX_DEPTH) { if (depth >= MAX_DEPTH) {
return vals; return vals;
} }
const res = try allocator.alloc(Value, vals.len); const res = try allocator.alloc(py.Any, vals.len);
for (vals, 0..) |v, i| { for (vals, 0..) |v, i| {
res[i] = try self.resolveAllRefs(allocator, depth + 1, v, fix_values); res[i] = try self.resolveAllRefs(allocator, depth + 1, v, fix_values);
} }
return res; return res;
} }
pub fn resolveAllRefs(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, val: Value, fix_values: bool) !Value { pub fn resolveAllRefs(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, val: py.Any, fix_values: bool) !py.Any {
var output: Value = switch (val) { var output: py.Any = switch (val) {
.ref => try self.resolve(allocator, val, true), .ref => try self.resolve(allocator, val, true),
inline .app, .object, .global => |v, tag| @unionInit(Value, @tagName(tag), try Object.init( inline .app, .object, .global => |v, tag| @unionInit(py.Any, @tagName(tag), try py.Object.init(
allocator, allocator,
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
try self.resolveAllRefsIter(allocator, depth + 1, v.args, fix_values), try self.resolveAllRefsIter(allocator, depth + 1, v.args, fix_values),
try self.resolveAllRefsIter(allocator, depth + 1, v.kwargs, fix_values),
)), )),
.build => |v| .{ .build = try Build.init( .set_state => |v| .{ .set_state = try py.SetState.init(
allocator, allocator,
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.obj, fix_values),
try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values), try self.resolveAllRefs(allocator, depth + 1, v.state, fix_values),
) }, ) },
.seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } }, .seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } },
.pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) }, .pers_id => |v| .{ .pers_id = try py.PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) },
else => try val.clone(allocator), else => try val.clone(allocator),
}; };
if (fix_values) { if (fix_values) {
@ -159,29 +151,9 @@ pub const PickleMemo = struct {
} }
}; };
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value { pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any {
var stack = std.ArrayList(Value).init(arena); var stack = std.ArrayList(py.Any).init(arena);
var memo = PickleMemo.init(arena); var memo = PickleMemo.init(arena);
errdefer memo.deinit();
const makeKVList = (struct {
pub fn call(alloc: std.mem.Allocator, items: []const Value) ![]Value {
meta.assert(items.len & 1 == 0, "Bad value for setitems", .{});
var kv_items = try std.ArrayList(Value).initCapacity(alloc, items.len);
errdefer kv_items.deinit();
var idx: usize = 0;
while (idx < items.len) : (idx += 2) {
if (idx + 1 >= items.len) {
return error.MissingValueItem;
}
const kv = try alloc.alloc(Value, 2);
kv[0] = items[idx];
kv[1] = items[idx + 1];
kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } });
}
return kv_items.toOwnedSlice();
}
}).call;
for (x) |op| { for (x) |op| {
switch (op) { switch (op) {
@ -189,47 +161,50 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
.frame => {}, .frame => {},
.stop => break, .stop => break,
.pop => _ = try pop(&stack), .pop => _ = try pop(&stack),
.pop_mark => try popMarkDiscard(&stack), .pop_mark => _ = try popMark(&stack),
.dup => if (stack.getLastOrNull()) |item| .dup => if (stack.getLastOrNull()) |item|
try stack.append(try item.clone(arena)) try stack.append(try item.clone(arena))
else else
return error.CannotDupEmptyStack, return error.CannotDupEmptyStack,
.persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), .persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
.binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }), .binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
.reduce => try stack.append(.{ .global = blk: { .reduce => try stack.append(.{ .global = blk: {
const values = try arena.alloc(Value, 1); var args = try pop(&stack);
values[0] = try memo.resolve(arena, try pop(&stack), true); args = try memo.resolve(arena, args, true);
break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values); if (args != .seq) return error.InvalidInput;
var func = try pop(&stack);
func = try memo.resolve(arena, func, true);
break :blk try py.Object.init(arena, func, args.seq.values, &.{});
} }), } }),
.build => try stack.append(blk: { .build => try stack.append(blk: {
const args = try memo.resolve(arena, try pop(&stack), true); const args = try memo.resolve(arena, try pop(&stack), true);
const member = try memo.resolve(arena, try pop(&stack), true); const member = try memo.resolve(arena, try pop(&stack), true);
break :blk .{ .build = try Build.init(arena, member, args) }; break :blk .{ .set_state = try py.SetState.init(arena, member, args) };
}), }),
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
.get => |v| try stack.append(.{ .ref = v }), .get => |v| try stack.append(.{ .ref = v }),
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
.put => |v| { .put => |v| {
try memo.insert(v, try pop(&stack)); try memo.insert(v, try pop(&stack));
try stack.append(.{ .ref = v }); try stack.append(.{ .ref = v });
}, },
.tuple => try stack.append(blk: { .tuple => try stack.append(blk: {
const popped = try popMark(&stack, arena); const popped = try popMark(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = popped } }; break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } };
}), }),
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
.setitem => { .setitem => {
const v, const k = .{ try pop(&stack), try pop(&stack) }; const v = try memo.resolve(arena, try pop(&stack), true);
const k = try memo.resolve(arena, try pop(&stack), true);
const top = try lastMut(&stack); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); try append(arena, &obj.kwargs, &.{ k, v });
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
}, },
.seq => |*tup| { .seq => |*dict| {
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); if (dict.type != .dict) return error.BadStackTopForSetItem;
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } }; try append(arena, &dict.values, &.{ k, v });
}, },
else => { else => {
return error.BadStackTopForSetItem; return error.BadStackTopForSetItem;
@ -237,39 +212,35 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
} }
}, },
.setitems => { .setitems => {
const popped = try popMark(&stack, arena); const popped = try memo.resolveAllRefsIter(arena, 0, try popMark(&stack), true);
defer arena.free(popped);
const kv_items = try makeKVList(arena, popped);
const top = try lastMut(&stack); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); try append(arena, &obj.kwargs, popped);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
}, },
.seq => |*tup| { .seq => |*dict| {
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); if (dict.type != .dict) return error.BadStackTopForSetItems;
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; try append(arena, &dict.values, popped);
}, },
else => { else => {
defer arena.free(kv_items);
return error.BadStackTopForSetItems; return error.BadStackTopForSetItems;
}, },
} }
}, },
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.append(blk: { .tuple1 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 1); const tup_values = try arena.alloc(py.Any, 1);
tup_values[0] = try pop(&stack); tup_values[0] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple2 => try stack.append(blk: { .tuple2 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 2); const tup_values = try arena.alloc(py.Any, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple3 => try stack.append(blk: { .tuple3 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 3); const tup_values = try arena.alloc(py.Any, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
@ -279,12 +250,12 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1); // can this happen ?
obj.args[obj.args.len - 1] = v; try append(arena, &obj.args, &.{v});
}, },
.seq => |*tup| { .seq => |*seq| {
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1); if (seq.type != .list) return error.BadStackTopForAppend;
tup.values[tup.values.len - 1] = v; try append(arena, &seq.values, &.{v});
}, },
else => { else => {
return error.BadStackTopForAppend; return error.BadStackTopForAppend;
@ -292,83 +263,75 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
} }
}, },
.appends => { .appends => {
const postmark = try popMark(&stack, arena); const postmark = try popMark(&stack);
defer arena.free(postmark);
const top = try lastMut(&stack); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => try append(arena, &rtop.global.args, postmark),
const obj_len = obj.args.len; .seq => |*seq| {
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); if (seq.type != .list) return error.BadStackTopForAppend;
@memcpy(obj.args[obj_len..], postmark); try append(arena, &seq.values, postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
return error.BadStackTopForAppends; return error.BadStackTopForAppends;
}, },
} }
}, },
.dict => try stack.append(blk: { .dict => try stack.append(.{ .seq = .{
const popped = try popMark(&stack, arena); .type = .dict,
defer arena.free(popped); .values = try arena.dupe(py.Any, try popMark(&stack)),
const kv_items = try makeKVList(arena, popped); } }),
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } }; .list => try stack.append(.{ .seq = .{
}), .type = .list,
.list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }), .values = try arena.dupe(py.Any, try popMark(&stack)),
.inst => |v| try stack.append(blk: { } }),
const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); .inst => |v| try stack.append(.{ .object = try py.Object.init(
break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) }; arena,
}), try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena),
try arena.dupe(py.Any, try popMark(&stack)),
&.{},
) }),
.obj => try stack.append(blk: { .obj => try stack.append(blk: {
const mark = try findMark(&stack); const mark = try findMark(&stack);
const args = try arena.dupe(Value, stack.items[mark + 2 ..]); const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]);
const member = stack.items[mark + 1]; const member = stack.items[mark + 1];
break :blk .{ .object = try Object.init(arena, member, args) }; break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) };
}), }),
.newobj => try stack.append(blk: { .newobj => try stack.append(blk: {
const args = try arena.alloc(Value, 1); const args = try arena.alloc(py.Any, 1);
args[0] = try pop(&stack); args[0] = try pop(&stack);
break :blk .{ .object = try Object.init(arena, try pop(&stack), args) }; break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) };
}), }),
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
.additems => { .additems => {
const postmark = try popMark(&stack, arena); const postmark = try popMark(&stack);
defer arena.free(postmark);
const top = try lastMut(&stack); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .seq => |*seq| {
const obj_len = obj.args.len; if (seq.type != .set) return error.BadStackTopForAppend;
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len); try append(arena, &seq.values, postmark);
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
return error.BadStackTopForSetItem; return error.BadStackTopForAppends;
}, },
} }
}, },
.frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }), .frozenset => try stack.append(.{ .seq = .{
.type = .frozen_set,
.values = try arena.dupe(py.Any, try popMark(&stack)),
} }),
.newobj_ex => try stack.append(blk: { .newobj_ex => try stack.append(blk: {
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) }; const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) }; break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) };
break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) };
}), }),
.stack_global => try stack.append(blk: { .stack_global => try stack.append(blk: {
const gn, const mn = .{ const gn, const mn = .{
try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true),
try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true),
}; };
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) }; const new_seq: py.Sequence = .{ .type = .tuple, .values = try arena.dupe(py.Any, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) }; break :blk .{ .object = try py.Object.init(arena, .{ .seq = new_seq }, &.{}, &.{}) };
}), }),
.memoize => { .memoize => {
const item = stack.getLastOrNull() orelse { const item = stack.getLastOrNull() orelse {
@ -385,23 +348,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
return stack.toOwnedSlice(); return stack.toOwnedSlice();
} }
// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_length: usize) ![]T { var array_list = std.ArrayListUnmanaged(py.Any).fromOwnedSlice(current.*);
if (allocator.resize(old, new_length)) { try array_list.appendSlice(allocator, values);
return old; current.* = array_list.items;
} else {
defer allocator.free(old);
const new = try allocator.alloc(T, new_length);
@memcpy(new[0..old.len], old);
return new;
}
} }
test evaluate { test evaluate {
var arena = std.heap.ArenaAllocator.init(std.testing.allocator); var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit(); defer arena.deinit();
const allocator = arena.allocator(); const allocator = arena.allocator();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader()); var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096); const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
@ -411,64 +368,62 @@ test evaluate {
try std.testing.expect(vals.len == 1); try std.testing.expect(vals.len == 1);
try std.testing.expect(vals[0] == .seq); try std.testing.expect(vals[0] == .seq);
try std.testing.expect(vals[0].seq.type == .dict); try std.testing.expect(vals[0].seq.type == .dict);
const entries = vals[0].seq.values[0].seq.values; const entries = vals[0].seq.values;
try std.testing.expect(entries.len == 5); const expected: []const py.Any = &.{
const expected: []const Value = &.{ // Key, followed by its value
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } }, .{ .string = "hello" }, .{ .string = "world" },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } }, .{ .string = "int" }, .{ .int64 = 1 },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } }, .{ .string = "float" }, .{ .float64 = 3.141592 },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
.{ .string = "list" }, .{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{ .{
.{ .int64 = 0 }, .seq = .{
.{ .int64 = 1 }, .type = .list,
.{ .int64 = 2 }, .values = @constCast(&[_]py.Any{
.{ .int64 = 3 }, .{ .int64 = 255 },
.{ .int64 = 4 }, .{ .int64 = 1234 },
}) } }, .{ .int64 = -123 },
}) } }, .{ .int64 = 1_000_000_000 },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .int64 = 999_000_000_000 },
.{ .bigint = (try std.math.big.int.Managed.initSet(allocator, 999_000_000_000_000_000_000_000_000_000)).toConst() },
}),
},
},
.{ .string = "bool" }, .{ .boolval = false },
.{ .string = "tuple" }, .{ .string = "tuple" },
.{ .seq = .{ .{ .seq = .{
.type = .tuple, .type = .tuple,
.values = @constCast(&[_]Value{ .values = @constCast(&[_]py.Any{
.{ .string = "a" }, .{ .string = "a" },
.{ .int64 = 10 }, .{ .int64 = 10 },
}), }),
} }, } },
}) } },
}; };
try std.testing.expectEqualDeep(expected, entries); try std.testing.expectEqualDeep(expected, entries);
} }
pub fn pop(values: *std.ArrayList(Value)) !Value { pub fn pop(values: *std.ArrayList(py.Any)) !py.Any {
if (values.items.len == 0) { if (values.items.len == 0) {
return error.StackUnderrun; return error.StackUnderrun;
} }
return values.pop(); return values.pop();
} }
fn popMarkDiscard(values: *std.ArrayList(Value)) !void { fn popMark(values: *std.ArrayList(py.Any)) ![]py.Any {
const mark = try findMark(values);
values.shrinkRetainingCapacity(mark);
}
fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value {
const mark = try findMark(values); const mark = try findMark(values);
const popping = values.items[mark + 1 ..]; const popping = values.items[mark + 1 ..];
values.shrinkRetainingCapacity(mark); values.shrinkRetainingCapacity(mark);
return try allocator.dupe(Value, popping); return popping;
} }
fn lastMut(values: *std.ArrayList(Value)) !*Value { fn lastMut(values: *std.ArrayList(py.Any)) !*py.Any {
if (values.items.len == 0) { if (values.items.len == 0) {
return error.UnexpectedEmptyStack; return error.UnexpectedEmptyStack;
} }
return &values.items[values.items.len - 1]; return &values.items[values.items.len - 1];
} }
fn findMark(values: *std.ArrayList(Value)) !usize { fn findMark(values: *std.ArrayList(py.Any)) !usize {
const len = values.items.len; const len = values.items.len;
for (0..len) |i| { for (0..len) |i| {
const idx = (len - 1) - i; const idx = (len - 1) - i;

680
zml/aio/torch/file.zig Normal file
View File

@ -0,0 +1,680 @@
const std = @import("std");
const testing = std.testing;
const log = std.log.scoped(.zml_aio);
const asynk = @import("async");
const zml = @import("../../zml.zig");
const pickle = @import("pickle.zig");
const py = @import("py.zig");
const eval = @import("eval.zig");
const HostBuffer = zml.HostBuffer;
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
const StringBuilder = std.ArrayListUnmanaged(u8);
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(File);
}
pub const File = struct {
buffer_file: zml.aio.MemoryMappedFile,
/// Map names to sub file
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
is_zip_file: bool,
zip_prefix: []const u8 = &.{},
pickle_subfile: struct { start: u64 = 0, len: usize },
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File {
const tar_file = try TarStream.init(file);
const file_magic = try tar_file.reader().readBytesNoEof(magic.len);
try tar_file.seekTo(0);
var res: File = .{
.buffer_file = mapped,
.tar_file = tar_file,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.pickle_subfile = .{ .len = try tar_file.getEndPos() },
};
if (res.is_zip_file) {
try res.parseZipHeaders(allocator, tar_file.seekableStream());
}
return res;
}
pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File {
const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len);
try mmap_file.file.seekTo(0);
var res: File = .{
.buffer_file = mmap_file,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.pickle_subfile = .{ .len = mmap_file.data.len },
};
if (res.is_zip_file) {
try res.parseZipHeaders(allocator, mmap_file.file.seekableStream());
}
return res;
}
pub fn close(self: *File) void {
self.buffer_file.deinit();
}
pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op {
return if (self.tar_file) |tar_file| {
try tar_file.seekTo(self.pickle_subfile.start);
var buffered = std.io.bufferedReader(tar_file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
} else {
const file = self.buffer_file.file;
try file.seekTo(self.pickle_subfile.start);
var buffered = std.io.bufferedReader(file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
};
}
fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void {
var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{};
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
const len = try seekable_stream.context.reader().readAll(filename);
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
self.file_map = file_map;
var file_iter = file_map.iterator();
while (file_iter.next()) |e| {
const entry = e.value_ptr.*;
const filename = e.key_ptr.*;
if (!std.mem.endsWith(u8, filename, "data.pkl")) continue;
self.zip_prefix = filename[0 .. filename.len - "data.pkl".len];
const local_data_header_offset: u64 = local_data_header_offset: {
switch (entry.compression_method) {
.store => {},
.deflate => {
// TODO(cryptodeal): handle decompress
@panic("TODO support use of `deflate`");
},
else => @panic("TODO support other modes of compression"),
}
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
break :local_data_header_offset @as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
};
const local_data_file_offset: u64 =
@as(u64, entry.file_offset) +
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
local_data_header_offset;
self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size };
return;
}
log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool {
return switch (object.member) {
.raw => |raw| return (std.mem.eql(u8, module, raw.global.module) and
std.mem.eql(u8, class, raw.global.class)),
else => false,
};
}
pub fn parseModel(self: File, values: []const py.Any, store: *zml.aio.BufferStore) !void {
var prefix_buf: [1024]u8 = undefined;
const allocator = store.arena.allocator();
for (values) |item| {
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
}
}
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void {
// log.warn("Parsing {}", .{v});
switch (v) {
.app, .object, .global => |object| {
if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) {
try self.parseValue(allocator, store, prefix, object.member);
for (object.args) |item| {
try self.parseValue(allocator, store, prefix, item);
}
if (object.kwargs.len % 2 != 0) return error.InvalidInput;
const n_kwargs = @divExact(object.kwargs.len, 2);
for (0..n_kwargs) |i| {
const key, const val = object.kwargs[2 * i ..][0..2].*;
// kwargs can only be keyed by string.
if (key != .string) return error.InvalidInput;
// Handle Pytorch specific fields
const s = key.string;
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
try self.parseValue(allocator, store, prefix, val);
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
try self.parseValue(allocator, store, new_prefix, val);
}
}
}
},
.set_state => |set_state| {
// `set_state` contains info about python struct being constructed
switch (set_state.obj) {
.object => |obj| switch (obj.member) {
.raw => |raw| switch (raw) {
.global => |global| {
// in this case, we can capture the name of the python type
// which can be used for codegen (e.g. `torch.nn.modules.conv.Conv2d`)
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity("_gen_type_helper");
const key = try allocator.dupe(u8, new_prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.err("Duplicate key: {s}", .{new_prefix.items});
allocator.free(key);
} else {
const val = try std.mem.join(allocator, ".", &.{ global.module, global.class });
d.value_ptr.* = .{ .string = val };
}
},
else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally
},
else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally
},
else => try self.parseValue(allocator, store, prefix, set_state.obj), // parse normally
}
try self.parseValue(allocator, store, prefix, set_state.state);
},
.pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref),
.seq => |seq| {
switch (seq.type) {
.list, .tuple, .set, .frozen_set => {
if (seq.values.len == 0) return;
var valid_slice = true;
switch (seq.values[0]) {
inline .int64, .float64, .boolval => |val0, tag| {
const ItemType = switch (tag) {
.int64 => i64,
.float64 => f64,
.boolval => bool,
else => unreachable,
};
var values: std.ArrayListUnmanaged(ItemType) = .{};
try values.append(allocator, val0);
for (seq.values[1..], 1..) |val, i| {
if (std.meta.activeTag(val) != tag) valid_slice = false;
if (valid_slice) {
try values.append(allocator, @field(val, @tagName(tag)));
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
}
if (valid_slice) {
try store._metadata.put(
allocator,
try allocator.dupe(u8, prefix.items),
try zml.aio.Metadata.copySlice(allocator, values.items),
);
} else {
for (values.items, 0..) |val, i| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
const new_tag = switch (tag) {
.int64 => "int",
.float64 => "float",
.boolval => "bool",
else => unreachable, // we are already inside a switch
};
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val));
}
}
},
else => {
for (seq.values, 0..) |item, i| {
var new_prefix = prefix;
if (v.isPrimitive()) {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
}
},
}
},
.dict => {
const n = @divExact(seq.values.len, 2);
log.info("found dict with {} entries", .{n});
for (0..n) |i| {
const key, const val = seq.values[2 * i ..][0..2].*;
switch (key) {
.string => |s| {
// Handle Pytorch specific fields.
if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) {
try self.parseValue(allocator, store, prefix, val);
} else {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.appendSliceAssumeCapacity(s);
try self.parseValue(allocator, store, new_prefix, val);
}
},
.int64 => |int| {
var new_prefix = prefix;
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
},
inline else => |_, tag| {
log.debug("Ignoring unsupported key type found in torch file: {s}", .{@tagName(tag)});
continue;
},
}
}
},
}
},
.bytes => |val| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else d.value_ptr.* = .{ .string = val };
},
inline .float64, .int64, .boolval, .bigint, .string => |val| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else {
d.value_ptr.* = zml.aio.Metadata.wrap(val);
}
},
else => {},
}
}
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool {
return switch (v) {
.global => |object| {
if (try self.parseTensor(allocator, object)) |host_buffer| {
const key = try allocator.dupe(u8, prefix.items);
const entry = try store.buffers.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
entry.value_ptr.* = host_buffer;
return true;
} else if (basicTypeCheck(object, "torch", "Size")) {
const size = object.args;
const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array_int = d };
return true;
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
const fraction_str = object.args[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
}
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
}
return true;
}
}
return false;
},
else => false,
};
}
fn parseTensor(self: File, tmp_allocator: std.mem.Allocator, object: *py.Object) !?zml.HostBuffer {
if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) {
return null;
}
const args = object.args;
if (args.len < 4 or
args[0] != .pers_id or
args[1] != .int64 or
args[2] != .seq or args[2].seq.type != .tuple or
args[3] != .seq or args[3].seq.type != .tuple)
{
log.err("Unexpected py.Any in call to torch._utils._rebuild_tensor_v2: {}", .{object.*});
return error.InvalidInput;
}
const pid: *py.PersId = args[0].pers_id;
var offset: u64 = @intCast(args[1].int64);
const raw_dims: py.Sequence = args[2].seq;
const raw_strides: py.Sequence = args[3].seq;
const dims = try parseDims(raw_dims.values);
var strides = try parseDims(raw_strides.values);
const dtype, const storage_file = try parseStorage(pid.ref);
// Pytorch store "item" strides, while ZML uses byte strides.
for (strides.slice()) |*s| s.* *= dtype.sizeOf();
// Same thing for the offset.
offset = offset * dtype.sizeOf();
const filename = try std.mem.join(tmp_allocator, "", &.{ self.zip_prefix, "data/", storage_file });
defer tmp_allocator.free(filename);
// The offset in the pickle is the offset inside the storage_file.
// But .pt are made of several files, so we need to append the file offset.
const storage = try self.getStorage(filename);
return HostBuffer.fromStridedSlice(
zml.Shape.init(dims.constSlice(), dtype),
storage[offset..],
strides.constSlice(),
);
}
fn parseStorage(val: py.Any) !struct { zml.DataType, []const u8 } {
if (val != .seq) return error.InvalidInput;
const sargs = val.seq.values;
if (val.seq.type == .tuple and
sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and
sargs[2] == .string and
sargs[3] == .string)
{
const op = sargs[1].raw.global;
const storage_file = sargs[2].string;
// const sdev = sargs[3].string;
if (!std.mem.eql(u8, "torch", op.module) or
!std.mem.endsWith(u8, op.class, "Storage"))
return error.InvalidInput;
return .{
try storageToDtype(op.class),
storage_file,
};
} else {
return error.InvalidInput;
}
}
/// Given the name of one of the files in the .pt tarball,
/// return the slice of the memory-mapped .pt corresponding to it.
fn getStorage(self: File, filename: []const u8) ![]const u8 {
const maybe_entry = self.file_map.get(filename);
if (maybe_entry == null) {
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
return error.TensorNotFound;
}
const entry = maybe_entry.?;
const base_offset: u64 = if (self.tar_file) |t| t.start else 0;
const file_offset: u64 = base_offset + entry.file_offset;
const file = self.buffer_file.file;
try file.seekTo(entry.file_offset);
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
const start = file_offset +
@sizeOf(std.zip.LocalFileHeader) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
return self.buffer_file.mappedSlice(start, entry.uncompressed_size);
}
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
var result: zml.Shape.DimsArray = .{};
for (values) |val| {
switch (val) {
.int64 => |d| result.appendAssumeCapacity(d),
else => return error.InvalidInput,
}
}
return result;
}
};
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}
const TarStream = struct {
pub const SeekableStream = std.io.SeekableStream(
TarStream,
asynk.File.SeekError,
asynk.File.GetSeekPosError,
TarStream.seekTo,
TarStream.seekBy,
TarStream.getPos,
TarStream.getEndPos,
);
file: std.tar.Iterator(asynk.File.Reader).File,
start: usize,
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
return .{
.file = file,
.start = try file.parent_reader.context.getPos(),
};
}
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
return file.file.reader();
}
pub fn seekTo(self: TarStream, offset: u64) !void {
return self.file.parent_reader.context.seekTo(self.start + offset);
}
pub fn seekBy(self: TarStream, offset: i64) !void {
return self.file.parent_reader.context.seekBy(offset);
}
pub fn getPos(self: TarStream) !u64 {
return try self.file.parent_reader.context.getPos() - self.start;
}
pub fn getEndPos(self: TarStream) !u64 {
return self.file.size;
}
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
return .{ .context = self };
}
};
test "Read pickle (zipped)" {
// test file created with following python snippet:
//
// import torch
// torch.manual_seed(0)
// model = torch.nn.Conv2d(2, 2, 3, stride=2, padding=[2, 4], dtype=torch.float16)
// tensor = torch.tensor([[2, 4, 3, 2]], dtype=torch.uint8)
// torch.save({ "model": model, "tensor": tensor}, "simple.pt")
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
var store = try zml.aio.BufferStore.init(testing.allocator, &.{mmap_file});
defer store.deinit();
{
var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator);
defer tmp_arena.deinit();
const tmp_alloc = tmp_arena.allocator();
var torch_file = try File.init(tmp_alloc, mmap_file);
// We don't close the file directly, it will be closed by the store.
const ops = try torch_file.parsePickle(tmp_alloc);
try std.testing.expectEqual(302, ops.len);
const py_values = try eval.evaluate(tmp_alloc, ops, true);
try torch_file.parseModel(py_values, &store);
}
// now we have freed the tmp_arena.
// all data needed should have been copied into the store arena.
try zml.testing.expectEqualShapes(
zml.Shape.init(.{ 1, 4 }, .u8),
store.get("tensor").?.shape(),
);
try zml.testing.expectEqualShapes(
zml.Shape.init(.{ 2, 2, 3, 3 }, .f16),
store.get("model.weight").?.shape(),
);
try zml.testing.expectEqualShapes(
zml.Shape.init(.{2}, .f16),
store.get("model.bias").?.shape(),
);
}
fn isBadFilename(filename: []const u8) bool {
if (filename.len == 0 or filename[0] == '/')
return true;
var it = std.mem.splitScalar(u8, filename, '/');
while (it.next()) |part| {
if (std.mem.eql(u8, part, ".."))
return true;
}
return false;
}

View File

@ -1,237 +0,0 @@
const asynk = @import("async");
const std = @import("std");
const testing = std.testing;
const Allocator = std.mem.Allocator;
const zml = @import("../../zml.zig");
const pickle = @import("pickle.zig");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(Parser);
}
pub const Parser = struct {
// TODO: move the file logic to torch.PytorchFile
// the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
ops: []const pickle.Op,
is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{},
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser {
const tar_stream = try TarStream.init(file);
const file_magic = try tar_stream.reader().readBytesNoEof(magic.len);
try tar_stream.seekTo(0);
var self: Parser = .{
.buffer_file = mapped,
.tar_file = tar_stream,
.ops = undefined,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
};
if (!self.is_zip_file) {
const reader = tar_stream.reader();
self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream());
}
return self;
}
pub fn init(allocator: Allocator, file: asynk.File) !Parser {
const file_magic = try file.reader().readBytesNoEof(magic.len);
try file.seekTo(0);
var self: Parser = .{
.buffer_file = try zml.aio.MemoryMappedFile.init(file),
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.ops = undefined,
};
if (!self.is_zip_file) {
const reader = self.buffer_file.file.reader();
self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream());
}
return self;
}
pub fn deinit(self: *Parser) void {
self.buffer_file.deinit();
self.* = undefined;
}
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op {
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
const len = try seekable_stream.context.reader().readAll(filename);
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
var file_iter = self.file_map.iterator();
while (file_iter.next()) |e| {
const entry = e.value_ptr.*;
const filename = e.key_ptr.*;
if (std.mem.indexOf(u8, filename, "data.pkl")) |idx| {
self.zip_prefix = filename[0..idx];
const local_data_header_offset: u64 = local_data_header_offset: {
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
break :local_data_header_offset @as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
};
const local_data_file_offset: u64 =
@as(u64, entry.file_offset) +
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
local_data_header_offset;
try seekable_stream.seekTo(local_data_file_offset);
switch (entry.compression_method) {
.store => {
return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
},
.deflate => {
// TODO(cryptodeal): handle decompress
@panic("TODO support use of `deflate`");
},
else => @panic("TODO support other modes of compression"),
}
}
}
std.log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
};
const TarStream = struct {
pub const SeekableStream = std.io.SeekableStream(
TarStream,
asynk.File.SeekError,
asynk.File.GetSeekPosError,
TarStream.seekTo,
TarStream.seekBy,
TarStream.getPos,
TarStream.getEndPos,
);
file: std.tar.Iterator(asynk.File.Reader).File,
start: usize,
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
return .{
.file = file,
.start = try file.parent_reader.context.getPos(),
};
}
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
return file.file.reader();
}
pub fn seekTo(self: TarStream, offset: u64) !void {
return self.file.parent_reader.context.seekTo(self.start + offset);
}
pub fn seekBy(self: TarStream, offset: i64) !void {
return self.file.parent_reader.context.seekBy(offset);
}
pub fn getPos(self: TarStream) !u64 {
return try self.file.parent_reader.context.getPos() - self.start;
}
pub fn getEndPos(self: TarStream) !u64 {
return self.file.size;
}
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
return .{ .context = self };
}
};
test "Read pickle (zipped)" {
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
var data = try Parser.init(allocator, file);
defer data.deinit();
}
fn isBadFilename(filename: []const u8) bool {
if (filename.len == 0 or filename[0] == '/')
return true;
var it = std.mem.splitScalar(u8, filename, '/');
while (it.next()) |part| {
if (std.mem.eql(u8, part, ".."))
return true;
}
return false;
}

View File

@ -673,10 +673,10 @@ pub const OpCode = enum(u8) {
/// because operators having same semantics, but different encoding have been merged. /// because operators having same semantics, but different encoding have been merged.
/// ex: string, binstring, short_binstring -> string. /// ex: string, binstring, short_binstring -> string.
pub const Op = union(enum) { pub const Op = union(enum) {
// Initially numbers were represented by strings... int: i32,
int: []const u8, // Python can represent arbitrary long integers
binint: i32,
long: []const u8, long: []const u8,
binlong: []const u8,
string: []const u8, string: []const u8,
bytes: []const u8, bytes: []const u8,
bytearray: []u8, bytearray: []u8,
@ -767,26 +767,32 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
var results = std.ArrayList(Op).init(allocator); var results = std.ArrayList(Op).init(allocator);
errdefer results.deinit(); errdefer results.deinit();
const len = max_line_len; const len = max_line_len;
var _buf: std.BoundedArray(u8, 12) = .{};
while (true) { while (true) {
const b = try reader.readByte(); const b = try reader.readByte();
const code: OpCode = @enumFromInt(b); const code: OpCode = @enumFromInt(b);
const op: Op = switch (code) { const op: Op = switch (code) {
.int => blk: { .int => blk: {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); _buf.len = 0;
try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1);
const buf = _buf.constSlice();
// Legacy hack, see OpCode.int documentation // Legacy hack, see OpCode.int documentation
// We do this parsing right away to simplify downstream code. // We do this parsing right away to simplify downstream code.
if (std.mem.eql(u8, "00", buf)) break :blk .{ .bool = false }; break :blk if (std.mem.eql(u8, "00", buf))
if (std.mem.eql(u8, "01", buf)) break :blk .{ .bool = true }; .{ .bool = false }
break :blk .{ .int = buf }; else if (std.mem.eql(u8, "01", buf))
.{ .bool = true }
else
.{ .int = try std.fmt.parseInt(i32, buf, 10) };
}, },
.binint => .{ .binint = try reader.readInt(i32, .little) }, .binint => .{ .int = try reader.readInt(i32, .little) },
.binint1 => .{ .binint = try reader.readByte() }, .binint1 => .{ .int = try reader.readByte() },
.binint2 => .{ .binint = try reader.readInt(u16, .little) }, .binint2 => .{ .int = try reader.readInt(u16, .little) },
// TODO: long should handle the trailing 'L' -> add a test. // TODO: long should handle the trailing 'L' -> add a test.
.long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.long1 => .{ .long = try _readSlice(reader, allocator, 1) }, .long1 => .{ .binlong = try _readSlice(reader, allocator, 1) },
.long4 => .{ .long = try _readSlice(reader, allocator, 4) }, .long4 => .{ .binlong = try _readSlice(reader, allocator, 4) },
.string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.binstring => .{ .string = try _readSlice(reader, allocator, 4) }, .binstring => .{ .string = try _readSlice(reader, allocator, 4) },
.short_binstring => .{ .string = try _readSlice(reader, allocator, 1) }, .short_binstring => .{ .string = try _readSlice(reader, allocator, 1) },
@ -825,12 +831,9 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
.dup => .dup, .dup => .dup,
.mark => .mark, .mark => .mark,
.pop_mark => .pop_mark, .pop_mark => .pop_mark,
.get => blk: {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
defer allocator.free(buf);
// If we fail to parse delay the error to the evaluation. // If we fail to parse delay the error to the evaluation.
const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32); .get => .{
break :blk .{ .get = n }; .get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32),
}, },
.binget => .{ .get = try reader.readByte() }, .binget => .{ .get = try reader.readByte() },
.long_binget => .{ .get = try reader.readInt(u32, .little) }, .long_binget => .{ .get = try reader.readInt(u32, .little) },
@ -887,9 +890,9 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
return results.toOwnedSlice(); return results.toOwnedSlice();
} }
test parse { test "parse protocol 4" {
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader()); var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try parse(allocator, buffered_reader.reader(), 4096); const ops = try parse(allocator, buffered_reader.reader(), 4096);
defer { defer {
@ -898,11 +901,10 @@ test parse {
allocator.free(ops); allocator.free(ops);
} }
try std.testing.expect(ops.len == 35); // this can be obtained by running: `python -m pickletools simple_test_4.pickle`
// this can be obtained by running: `python -m pickletools simple_test.pickle` var expected = [_]Op{
const expected = [_]Op{
.{ .proto = 4 }, .{ .proto = 4 },
.{ .frame = 83 }, .{ .frame = 119 },
.empty_dict, .empty_dict,
.memoize, .memoize,
.mark, .mark,
@ -912,7 +914,7 @@ test parse {
.memoize, .memoize,
.{ .unicode = "int" }, .{ .unicode = "int" },
.memoize, .memoize,
.{ .binint = 1 }, .{ .int = 1 },
.{ .unicode = "float" }, .{ .unicode = "float" },
.memoize, .memoize,
.{ .binfloat = 3.141592 }, .{ .binfloat = 3.141592 },
@ -921,17 +923,21 @@ test parse {
.empty_list, .empty_list,
.memoize, .memoize,
.mark, .mark,
.{ .binint = 0 }, .{ .int = 255 },
.{ .binint = 1 }, .{ .int = 1234 },
.{ .binint = 2 }, .{ .int = -123 },
.{ .binint = 3 }, .{ .int = 1_000_000_000 },
.{ .binint = 4 }, .{ .binlong = &writeIntBuff(u48, 999_000_000_000) },
.{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) },
.appends, .appends,
.{ .unicode = "bool" },
.memoize,
.{ .bool = false },
.{ .unicode = "tuple" }, .{ .unicode = "tuple" },
.memoize, .memoize,
.{ .unicode = "a" }, .{ .unicode = "a" },
.memoize, .memoize,
.{ .binint = 10 }, .{ .int = 10 },
.tuple2, .tuple2,
.memoize, .memoize,
.setitems, .setitems,
@ -940,6 +946,109 @@ test parse {
try std.testing.expectEqualDeep(&expected, ops); try std.testing.expectEqualDeep(&expected, ops);
} }
test "parse protocol 0" {
// We also test protocol 0, cause it's more text oriented.
const allocator = std.testing.allocator;
const pickle_0 =
\\(dp0
\\Vhello
\\p1
\\Vworld
\\p2
\\sVint
\\p3
\\I1
\\sVfloat
\\p4
\\F3.141592
\\sVlist
\\p5
\\(lp6
\\I255
\\aI1234
\\aI-123
\\aI1000000000
\\aL999000000000L
\\aL999000000000000000000000000000L
\\asVbool
\\p7
\\I00
\\sVtuple
\\p8
\\(Va
\\p9
\\I10
\\tp10
\\s.
;
var stream = std.io.fixedBufferStream(pickle_0);
const ops = try parse(allocator, stream.reader(), 4096);
defer {
// Test we are correctly freeing every allocation.
for (ops) |op| op.deinit(allocator);
allocator.free(ops);
}
var expected = [_]Op{
.mark,
.dict,
.{ .put = 0 },
.{ .unicode = "hello" },
.{ .put = 1 },
.{ .unicode = "world" },
.{ .put = 2 },
.setitem,
.{ .unicode = "int" },
.{ .put = 3 },
.{ .int = 1 },
.setitem,
.{ .unicode = "float" },
.{ .put = 4 },
.{ .float = "3.141592" },
.setitem,
.{ .unicode = "list" },
.{ .put = 5 },
.mark,
.list,
.{ .put = 6 },
.{ .int = 255 },
.append,
.{ .int = 1234 },
.append,
.{ .int = -123 },
.append,
.{ .int = 1_000_000_000 },
.append,
.{ .long = "999000000000L" },
.append,
.{ .long = "999000000000000000000000000000L" },
.append,
.setitem,
.{ .unicode = "bool" },
.{ .put = 7 },
.{ .bool = false },
.setitem,
.{ .unicode = "tuple" },
.{ .put = 8 },
.mark,
.{ .unicode = "a" },
.{ .put = 9 },
.{ .int = 10 },
.tuple,
.{ .put = 10 },
.setitem,
.stop,
};
try std.testing.expectEqualDeep(&expected, ops);
}
fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T {
buffer.len = 0;
try reader.streamUntilDelimiter(buffer.writer(), '\n', 13);
return std.fmt.parseInt(T, buffer.constSlice(), 10);
}
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
const T = std.meta.Int(.unsigned, 8 * len_bytes); const T = std.meta.Int(.unsigned, 8 * len_bytes);
const str_len: u64 = try reader.readInt(T, .little); const str_len: u64 = try reader.readInt(T, .little);
@ -948,3 +1057,9 @@ fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes:
_ = try reader.read(buf); _ = try reader.read(buf);
return buf; return buf;
} }
fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).Int.bits, 8)]u8 {
var res: [@divExact(@typeInfo(T).Int.bits, 8)]u8 = undefined;
std.mem.writeInt(T, &res, value, .little);
return res;
}

View File

@ -1,101 +1,111 @@
const std = @import("std"); const std = @import("std");
const big_int = std.math.big.int; const math = std.math;
const log = std.log.scoped(.zml_aio);
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
/// The types of sequences that exist. /// Correspond to a function/constructor call
pub const SequenceType = enum {
list,
dict,
kv_tuple,
tuple,
set,
frozen_set,
};
pub const Object = struct { pub const Object = struct {
allocator: std.mem.Allocator, member: Any,
member: Value, args: []Any,
args: []Value, kwargs: []Any,
pub fn init(allocator: std.mem.Allocator, member: Value, args: []Value) !*Object { pub fn init(allocator: std.mem.Allocator, member: Any, args: []Any, kwargs: []Any) !*Object {
const self = try allocator.create(Object); const self = try allocator.create(Object);
self.* = .{ .allocator = allocator, .member = member, .args = args }; self.* = .{ .member = member, .args = args, .kwargs = kwargs };
return self; return self;
} }
pub fn clone(self: *Object, allocator: std.mem.Allocator) std.mem.Allocator.Error!*Object { pub fn clone(self: *Object, allocator: std.mem.Allocator) std.mem.Allocator.Error!*Object {
const res = try allocator.create(Object); const res = try allocator.create(Object);
res.* = .{ .allocator = allocator, .member = try self.member.clone(allocator), .args = try allocator.alloc(Value, self.args.len) }; res.* = .{
.member = try self.member.clone(allocator),
.args = try allocator.alloc(Any, self.args.len),
.kwargs = try allocator.alloc(Any, self.kwargs.len),
};
for (self.args, 0..) |v, i| res.args[i] = try v.clone(allocator); for (self.args, 0..) |v, i| res.args[i] = try v.clone(allocator);
for (self.kwargs, 0..) |v, i| res.kwargs[i] = try v.clone(allocator);
return res; return res;
} }
pub fn deinit(self: *Object) void { pub fn deinit(self: *Object, allocator: std.mem.Allocator) void {
self.member.deinit(self.allocator); self.member.deinit(allocator);
for (self.args) |*v| v.deinit(self.allocator); for (self.args) |*v| v.deinit(allocator);
self.allocator.free(self.args); allocator.free(self.args);
allocator.destroy(self);
}
};
/// Correspond to the __set_state__ call when pickle finishes building an object.
pub const SetState = struct {
obj: Any,
state: Any,
pub fn init(allocator: std.mem.Allocator, obj: Any, state: Any) !*SetState {
const res = try allocator.create(SetState);
res.* = .{ .obj = obj, .state = state };
return res;
}
pub fn clone(self: *SetState, allocator: std.mem.Allocator) std.mem.Allocator.Error!*SetState {
const res = try allocator.create(SetState);
res.* = .{ .obj = try self.obj.clone(allocator), .state = try self.state.clone(allocator) };
return res;
}
pub fn deinit(self: *SetState, allocator: std.mem.Allocator) void {
self.obj.deinit(allocator);
self.state.deinit(allocator);
self.allocator.destroy(self); self.allocator.destroy(self);
} }
}; };
pub const Build = struct { /// The types of sequences that exist.
allocator: std.mem.Allocator, pub const SequenceType = enum {
member: Value, list,
args: Value, dict,
tuple,
pub fn init(allocator: std.mem.Allocator, member: Value, args: Value) !*Build { set,
const self = try allocator.create(Build); frozen_set,
self.* = .{ .allocator = allocator, .member = member, .args = args };
return self;
}
pub fn clone(self: *Build, allocator: std.mem.Allocator) std.mem.Allocator.Error!*Build {
const res = try allocator.create(Build);
res.* = .{ .allocator = allocator, .member = try self.member.clone(allocator), .args = try self.args.clone(allocator) };
return res;
}
pub fn deinit(self: *Build) void {
self.member.deinit(self.allocator);
self.args.deinit(self.allocator);
self.allocator.destroy(self);
}
}; };
pub const Sequence = struct { pub const Sequence = struct {
type: SequenceType, type: SequenceType,
values: []Value, values: []Any,
}; };
pub const PersId = struct { pub fn tuple(values: []const Any) Any {
allocator: std.mem.Allocator, // tuple are readonly, but sequence in general aren't
ref: Value, return .{ .seq = .{ .type = .tuple, .values = @constCast(values) } };
}
pub fn init(allocator: std.mem.Allocator, ref: Value) !*PersId { pub const PersId = struct {
ref: Any,
pub fn init(allocator: std.mem.Allocator, ref: Any) !*PersId {
const self = try allocator.create(PersId); const self = try allocator.create(PersId);
self.* = .{ .allocator = allocator, .ref = ref }; self.* = .{ .ref = ref };
return self; return self;
} }
pub fn clone(self: *PersId, allocator: std.mem.Allocator) std.mem.Allocator.Error!*PersId { pub fn clone(self: *PersId, allocator: std.mem.Allocator) std.mem.Allocator.Error!*PersId {
const res = try allocator.create(PersId); const res = try allocator.create(PersId);
res.* = .{ .allocator = allocator, .ref = try self.ref.clone(allocator) }; res.* = .{ .ref = try self.ref.clone(allocator) };
return res; return res;
} }
pub fn deinit(self: *PersId) void { pub fn deinit(self: *PersId, allocator: std.mem.Allocator) void {
self.ref.deinit(self.allocator); self.ref.deinit(allocator);
self.allocator.destroy(self); allocator.destroy(self);
} }
}; };
pub const ValueType = enum { pub const Kind = enum {
raw, raw,
ref, ref,
app, app,
object, object,
build, set_state,
pers_id, pers_id,
global, global,
seq, seq,
@ -110,7 +120,7 @@ pub const ValueType = enum {
}; };
/// A pickle operator that has been interpreted. /// A pickle operator that has been interpreted.
pub const Value = union(ValueType) { pub const Any = union(Kind) {
/// Types that we can't handle or just had to give up on processing. /// Types that we can't handle or just had to give up on processing.
raw: pickle.Op, raw: pickle.Op,
@ -128,9 +138,10 @@ pub const Value = union(ValueType) {
/// thing, the second one is the arguments it got applied to. /// thing, the second one is the arguments it got applied to.
object: *Object, object: *Object,
/// Something we tried to build. The first tuple member is the /// Correspond to the __set_state__ call when pickle finishes building an object.
/// thing, the second one is the arguments it got applied to. /// The first tuple member is the target object,
build: *Build, /// the second one is the "state" argument
set_state: *SetState,
/// References to persistant storage. They basically could be anything. /// References to persistant storage. They basically could be anything.
/// You kind of have to know what the thing you're trying to /// You kind of have to know what the thing you're trying to
@ -164,7 +175,7 @@ pub const Value = union(ValueType) {
int64: i64, int64: i64,
/// An integer that can't fit in i64. /// An integer that can't fit in i64.
bigint: big_int.Managed, bigint: math.big.int.Const,
/// An float, but not the crazy kind that comes as a string /// An float, but not the crazy kind that comes as a string
/// that has to be parsed. You can look in `Value.raw_num` for /// that has to be parsed. You can look in `Value.raw_num` for
@ -180,7 +191,7 @@ pub const Value = union(ValueType) {
/// Python `None`. /// Python `None`.
none: void, none: void,
pub fn deinit(self: *Value, allocator: std.mem.Allocator) void { pub fn deinit(self: *Any, allocator: std.mem.Allocator) void {
switch (self.*) { switch (self.*) {
.raw, .raw_num => |v| v.deinit(allocator), .raw, .raw_num => |v| v.deinit(allocator),
inline .app, .object, .global, .build, .pers_id => |v| v.deinit(), inline .app, .object, .global, .build, .pers_id => |v| v.deinit(),
@ -189,7 +200,7 @@ pub const Value = union(ValueType) {
allocator.free(v.values); allocator.free(v.values);
}, },
.string, .bytes => |v| allocator.free(v), .string, .bytes => |v| allocator.free(v),
.bigint => self.bigint.deinit(), .bigint => |big| allocator.free(big.limbs),
else => {}, else => {},
} }
self.* = undefined; self.* = undefined;
@ -200,7 +211,7 @@ pub const Value = union(ValueType) {
// try writer.writeByteNTimes('\t'); // try writer.writeByteNTimes('\t');
} }
fn internalFormat(value: Value, indents: usize, writer: anytype) !void { fn internalFormat(value: Any, indents: usize, writer: anytype) !void {
try writeIndents(indents, writer); try writeIndents(indents, writer);
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try writeIndents(indents + 1, writer); try writeIndents(indents + 1, writer);
@ -209,9 +220,12 @@ pub const Value = union(ValueType) {
inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}), inline .ref, .int64, .float64 => |v| try writer.print("{d} ", .{v}),
.app, .object, .global => |v| { .app, .object, .global => |v| {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try writeIndents(indents + 2, writer);
try writer.writeAll(".fn =");
try internalFormat(v.member, indents + 2, writer); try internalFormat(v.member, indents + 2, writer);
try writer.writeAll(",\n"); try writer.writeAll(",\n");
try writeIndents(indents + 2, writer); try writeIndents(indents + 2, writer);
try writer.writeAll(".args = ");
if (v.args.len > 0) { if (v.args.len > 0) {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
for (v.args, 0..) |arg, i| { for (v.args, 0..) |arg, i| {
@ -220,6 +234,20 @@ pub const Value = union(ValueType) {
try writer.writeByte('\n'); try writer.writeByte('\n');
} }
try writeIndents(indents + 2, writer); try writeIndents(indents + 2, writer);
try writer.writeAll("},\n");
} else {
try writer.writeAll(".{},\n");
}
try writeIndents(indents + 2, writer);
try writer.writeAll(".kwargs =");
if (v.kwargs.len > 0) {
try writer.writeAll(".{\n");
for (v.kwargs, 0..) |arg, i| {
try internalFormat(arg, indents + 3, writer);
if (i < v.kwargs.len - 1) try writer.writeAll(",");
try writer.writeByte('\n');
}
try writeIndents(indents + 2, writer);
try writer.writeAll("}\n"); try writer.writeAll("}\n");
} else { } else {
try writer.writeAll(".{}\n"); try writer.writeAll(".{}\n");
@ -227,16 +255,16 @@ pub const Value = union(ValueType) {
try writeIndents(indents + 1, writer); try writeIndents(indents + 1, writer);
try writer.writeAll("}"); try writer.writeAll("}");
}, },
.build => |v| { .set_state => |v| {
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try internalFormat(v.member, indents + 2, writer); try internalFormat(v.obj, indents + 2, writer);
try writer.writeAll(",\n"); try writer.writeAll(",\n");
try internalFormat(v.args, indents + 2, writer); try internalFormat(v.state, indents + 2, writer);
try writer.writeAll(",\n"); try writer.writeAll(",\n");
try writeIndents(indents + 1, writer); try writeIndents(indents + 1, writer);
try writer.writeAll("}"); try writer.writeAll("}");
}, },
inline .pers_id => |v| { .pers_id => |v| {
try writer.writeByte('\n'); try writer.writeByte('\n');
try internalFormat(v.ref, indents + 2, writer); try internalFormat(v.ref, indents + 2, writer);
}, },
@ -275,26 +303,26 @@ pub const Value = union(ValueType) {
try writer.writeByte('}'); try writer.writeByte('}');
} }
pub fn format(self: Value, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: Any, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
return internalFormat(self, 0, writer); return internalFormat(self, 0, writer);
} }
pub fn clone(self: Value, allocator: std.mem.Allocator) !Value { pub fn clone(self: Any, allocator: std.mem.Allocator) !Any {
return switch (self) { return switch (self) {
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)),
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .app, .object, .global, .set_state, .pers_id => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)),
.seq => |seq| { .seq => |seq| {
const values = try allocator.alloc(Value, seq.values.len); const values = try allocator.alloc(Any, seq.values.len);
for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator); for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator);
return .{ .seq = .{ .type = seq.type, .values = values } }; return .{ .seq = .{ .type = seq.type, .values = values } };
}, },
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), inline .string, .bytes => |v, tag| @unionInit(Any, @tagName(tag), try allocator.dupe(u8, v)),
.bigint => |v| .{ .bigint = try v.clone() }, .bigint => |v| .{ .bigint = (try v.toManaged(allocator)).toConst() },
else => self, else => self,
}; };
} }
pub fn isPrimitive(self: Value) bool { pub fn isPrimitive(self: Any) bool {
return switch (self) { return switch (self) {
.int64, .bigint, .float64, .string, .bytes, .boolval, .none => true, .int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
.seq => |seq| { .seq => |seq| {
@ -307,7 +335,7 @@ pub const Value = union(ValueType) {
}; };
} }
pub fn containsRef(self: Value) bool { pub fn containsRef(self: Any) bool {
switch (self) { switch (self) {
.ref => return true, .ref => return true,
.app, .object, .global => |v| { .app, .object, .global => |v| {
@ -315,9 +343,9 @@ pub const Value = union(ValueType) {
for (v.args) |arg| if (arg.containsRef()) return true; for (v.args) |arg| if (arg.containsRef()) return true;
return false; return false;
}, },
.build => |v| { .set_state => |v| {
if (v.member.containsRef()) return true; if (v.obj.containsRef()) return true;
if (v.args.containsRef()) return true; if (v.state.containsRef()) return true;
return false; return false;
}, },
.pers_id => |v| return v.ref.containsRef(), .pers_id => |v| return v.ref.containsRef(),
@ -329,44 +357,49 @@ pub const Value = union(ValueType) {
} }
} }
const BI64MIN = big_int.Const{ pub const UnpickleError = error{ InvalidCharacter, OutOfMemory };
.limbs = &.{@intCast(@abs(std.math.minInt(i64)))},
.positive = false,
};
const BI64MAX = big_int.Const{ pub fn coerceFromRaw(self: Any, allocator: std.mem.Allocator) UnpickleError!Any {
.limbs = &.{@intCast(std.math.maxInt(i64))},
.positive = true,
};
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
return switch (self) { return switch (self) {
.raw => |raw_val| switch (raw_val) { .raw => |raw_val| switch (raw_val) {
.binint => |val| .{ .int64 = val }, .none => .none,
.long => |b| if (b.len != 0) { .bool => |b| .{ .boolval = b },
// TODO: handle trailing 'L' .float => |b| .{ .float64 = std.fmt.parseFloat(f64, b) catch std.math.nan(f64) },
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); .int => |val| .{ .int64 = val },
var mutable = bint.toMutable(); .long => |digits| {
mutable.readTwosComplement(b, b.len, .little, .signed); const n = std.fmt.parseInt(i64, digits[0 .. digits.len - 1], 10) catch |err| {
const min_comp = bint.toConst().order(BI64MIN); switch (err) {
const max_comp = bint.toConst().order(BI64MAX); error.Overflow => {
if ((min_comp == .gt or min_comp == .eq) and (max_comp == .lt or max_comp == .eq)) { log.warn("Not parsing long integer: {s}", .{digits});
defer bint.deinit(); return self;
return .{ .int64 = try bint.to(i64) }; },
} else return .{ .bigint = bint }; error.InvalidCharacter => return error.InvalidCharacter,
} else .{ .raw_num = raw_val }, }
};
return .{ .int64 = n };
},
.binlong => |bytes| if (bytes.len <= 8)
.{ .int64 = std.mem.readVarInt(i64, bytes, .little) }
else {
// Note: we need to copy here, because Zig big int limbs are usize aligned,
// whereas pickle big int are byte aligned.
const n_limbs = std.math.divCeil(usize, bytes.len, @sizeOf(math.big.Limb)) catch unreachable;
var big = (try math.big.int.Managed.initCapacity(allocator, n_limbs)).toMutable();
big.readTwosComplement(bytes, bytes.len * 8, .little, .signed);
return .{ .bigint = big.toConst() };
},
.binfloat => |val| .{ .float64 = val }, .binfloat => |val| .{ .float64 = val },
.unicode => |s| .{ .string = s }, .unicode => |s| .{ .string = s },
.bytes => |b| .{ .bytes = b }, inline .bytes, .bytearray => |b| .{ .bytes = b },
// This isn't how Pickle actually works but we just try to UTF8 decode the // This isn't how Pickle actually works but we just try to UTF8 decode the
// string and if it fails, we make it a bytes value instead. If anyone // string and if it fails, we make it a bytes value instead. If anyone
// actually cares they can just fix values themselves or recover the raw bytes // actually cares they can just fix values themselves or recover the raw bytes
// from the UTF8 string (it's guaranteed to be reversible, as far as I know). // from the UTF8 string (it's guaranteed to be reversible, as far as I know).
.string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, .string => |b| if (std.unicode.utf8ValidateSlice(b))
.bool => |b| .{ .boolval = b }, .{ .string = b }
.none => .{ .none = {} }, else
// TODO .int should be handled like .long .{ .bytes = b },
.int, .float => .{ .raw_num = raw_val },
else => self, else => self,
}, },
.app, .object, .global => |v| blk: { .app, .object, .global => |v| blk: {
@ -376,9 +409,9 @@ pub const Value = union(ValueType) {
} }
break :blk self; break :blk self;
}, },
.build => |v| blk: { .set_state => |v| blk: {
v.member = try v.member.coerceFromRaw(allocator); v.obj = try v.obj.coerceFromRaw(allocator);
v.args = try v.args.coerceFromRaw(allocator); v.state = try v.state.coerceFromRaw(allocator);
break :blk self; break :blk self;
}, },
.pers_id => |v| blk: { .pers_id => |v| blk: {

Binary file not shown.

Binary file not shown.

Binary file not shown.