Rename and simplify modules in zml/aio/torch: replace redundant qualified names, remove generic utilities, inline code, reorder functions for top‑to‑bottom readability, and extract parsing logic into parseTensor and parseStorage functions.

This commit is contained in:
Tarry Singh 2023-04-04 17:20:53 +00:00
parent 66881899ca
commit e25f70d923
12 changed files with 261 additions and 296 deletions

View File

@ -1,6 +1,5 @@
const asynk = @import("async");
const std = @import("std");
const utils = @import("../utils.zig");
const zml = @import("../../zml.zig");
const assert = std.debug.assert;

View File

@ -1,6 +1,5 @@
const asynk = @import("async");
const std = @import("std");
const utils = @import("utils.zig");
const zml = @import("../zml.zig");
const StringBuilder = std.ArrayListUnmanaged(u8);

View File

@ -4,7 +4,7 @@ const std = @import("std");
const yaml = @import("zig-yaml");
const zml = @import("../zml.zig");
const Decoder = @import("torch/parser.zig").Decoder;
const parser = @import("torch/parser.zig");
const StringBuilder = std.ArrayListUnmanaged(u8);
@ -38,7 +38,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
const start = try mapped_file.file.getPos();
var tmp: zml.aio.torch.PickleData = .{
.data = try Decoder.fromTarFile(arena, mapped_file, file),
.data = try parser.Parser.fromTarFile(arena, mapped_file, file),
.memo = undefined,
.stack = undefined,
};

View File

@ -1,10 +1,8 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../zml.zig");
const helpers = @import("../helpers.zig");
const utils = @import("utils.zig");
const json = @import("json.zig");
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const HostBuffer = zml.HostBuffer;
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
const StringBuilder = std.ArrayListUnmanaged(u8);

View File

@ -4,262 +4,72 @@ const zml = @import("../zml.zig");
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const toVoidSlice = @import("utils.zig").toVoidSlice;
const eval = @import("torch/eval.zig");
const utils = @import("torch/utils.zig");
const value = @import("torch/value.zig");
const Decoder = @import("torch/parser.zig").Decoder;
const parser = @import("torch/parser.zig");
const PersId = value.PersId;
const PickleMemo = eval.PickleMemo;
const PickleStack = eval.PickleStack;
const Sequence = value.Sequence;
const Value = value.Value;
const ValueType = value.ValueType;
const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.zml_io);
test {
std.testing.refAllDecls(eval);
std.testing.refAllDecls(value);
std.testing.refAllDecls(parser);
}
/// Opens and loads a BufferStore from the torch file at the given path.
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
const file = asynk.File.open(path, .{}) catch |err| {
log.err("Failed to open {s}: {}", .{ path, err });
return err;
};
errdefer file.close() catch unreachable;
// Temporary memory needed to parse the pytorch file.
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();
const tmp_alloc = arena.allocator();
const _parser = try parser.Parser.init(tmp_alloc, file);
const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true);
// But we create the HostBuffer objects inside the result BufferStore arena.
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
const arena = res.arena.allocator();
var tmp: PickleData = .{
.data = try Decoder.init(arena, file),
.memo = undefined,
.stack = undefined,
};
tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true);
res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{tmp.data.buffer_file});
try tmp.parseModel(arena, &res);
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file});
var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack };
try tmp.parseModel(res.arena.allocator(), &res);
return res;
}
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}
// TODO: rename me to PytorchFile
pub const PickleData = struct {
stack: PickleStack,
memo: PickleMemo,
data: Decoder,
stack: eval.PickleStack,
memo: eval.PickleMemo,
data: parser.Parser,
fn basicTypeCheck(v: Value, ns: []const u8, name: []const u8) bool {
return switch (v) {
.global => |object| switch (object.member) {
.raw => |raw| {
if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) {
return true;
} else return false;
},
else => false,
},
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
return switch (object.member) {
.raw => |raw| return (object.args[0] == .seq and
std.mem.eql(u8, module, raw.global.module) and
std.mem.eql(u8, class, raw.global.class)),
else => false,
};
}
fn isTensor(v: Value) bool {
if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) {
const args = v.global.args[0].seq.values;
if (args.len >= 5 and
args[0] == .pers_id and
args[1] == .int64 and
args[2] == .seq and args[2].seq.type == .tuple and
args[3] == .seq and args[3].seq.type == .tuple)
{
return true;
} else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2");
}
return false;
}
fn dimsFromValues(values: []Value) [zml.Tensor.MAX_RANK]i64 {
std.debug.assert(values.len <= zml.Tensor.MAX_RANK);
var result: [zml.Tensor.MAX_RANK]i64 = undefined;
for (values, result[0..values.len]) |val, *elem| {
switch (val) {
.int64 => |int| elem.* = int,
else => @panic("Bad value for shape item"),
}
}
return result;
}
pub fn parseModel(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore) !void {
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
for (self.stack.stack) |item| {
var prefix_buf: [1024]u8 = undefined;
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
}
}
fn tensorOffset(self: *PickleData, seekable_stream: anytype, sfile: []const u8) !u64 {
if (self.data.file_map.get(sfile)) |entry| {
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
return (try seekable_stream.context.getPos()) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
}
std.log.err("Could not find file ending in `{s}` in archive", .{sfile});
return error.TensorNotFound;
}
fn parseTorchGlobal(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool {
return switch (v) {
.global => |object| {
if (isTensor(v)) {
const args = object.args[0].seq.values;
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq };
const rank = raw_shape.values.len;
const shape = dimsFromValues(raw_shape.values);
var strides = dimsFromValues(raw_strides.values);
const storage_type, const sfile = switch (pidval.ref) {
.seq => |seq| blk: {
const sargs = seq.values;
if (seq.type == .tuple and
sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and
sargs[2] == .string and
sargs[3] == .string)
{
const op = sargs[1].raw.global;
const sfile = sargs[2].string;
// const sdev = sargs[3].string;
if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) {
break :blk .{ op.class, sfile };
} else @panic("Unexpected storage type part of persistant ID");
} else @panic("Unexpected value for persistant ID");
},
else => @panic("Unexpected value for persistant ID"),
};
const data_type = try storageToDtype(storage_type);
for (strides[0..rank]) |*s| s.* *= data_type.sizeOf();
var sfile_buf = std.ArrayList(u8).init(allocator);
defer sfile_buf.deinit();
try sfile_buf.writer().print("{s}data/{s}", .{ self.data.zip_prefix, sfile });
// find offsets for tensor zip file
const absolute_offset = blk: {
if (self.data.tar_file) |t| {
break :blk try self.tensorOffset(t.seekableStream(), sfile_buf.items);
} else {
break :blk try self.tensorOffset(self.data.buffer_file.file.seekableStream(), sfile_buf.items);
}
};
offs = offs * data_type.sizeOf();
const key = try allocator.dupe(u8, prefix.items);
const entry = try store.buffers.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
const out_shape = zml.Shape.init(shape[0..rank], data_type);
entry.value_ptr.* = HostBuffer.fromStridedSlice(
out_shape,
self.data.buffer_file.mappedSlice((if (self.data.tar_file) |t| t.start else 0) + absolute_offset + offs, out_shape.byteSize()),
strides[0..rank],
);
return true;
} else if (basicTypeCheck(v, "torch", "Size")) {
const size = object.args[0].seq.values[0].seq.values;
const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
return true;
} else if (basicTypeCheck(v, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq.values[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
}
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
}
return true;
}
}
return false;
},
else => false,
};
}
pub fn parseValue(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void {
pub fn parseValue(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void {
switch (v) {
.app, .object, .global => |object| {
if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) {
@ -415,8 +225,194 @@ pub const PickleData = struct {
else => {},
}
}
fn parseTorchGlobal(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool {
return switch (v) {
.global => |object| {
if (try self.parseTensor(allocator, object)) |host_buffer| {
const key = try allocator.dupe(u8, prefix.items);
const entry = try store.buffers.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
entry.value_ptr.* = host_buffer;
return true;
} else if (basicTypeCheck(object, "torch", "Size")) {
const size = object.args[0].seq.values[0].seq.values;
const key = try allocator.dupe(u8, prefix.items);
const entry = try store._metadata.getOrPut(allocator, key);
if (entry.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
return true;
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq.values[0].string;
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
}
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
}
return true;
}
}
return false;
},
else => false,
};
}
fn parseTensor(self: *PickleData, tmp_allocator: std.mem.Allocator, object: *value.Object) !?zml.HostBuffer {
if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) {
return null;
}
const args = object.args[0].seq.values;
if (args.len < 4 or
args[0] != .pers_id or
args[1] != .int64 or
args[2] != .seq or args[2].seq.type != .tuple or
args[3] != .seq or args[3].seq.type != .tuple)
{
log.err("Unexpected value in call to torch._utils._rebuild_tensor_v2", .{});
return error.InvalidInput;
}
const pid: *PersId = args[0].pers_id;
var offset: u64 = @intCast(args[1].int64);
const raw_dims: Sequence = args[2].seq;
const raw_strides: Sequence = args[3].seq;
const dims = try parseDims(raw_dims.values);
var strides = try parseDims(raw_strides.values);
const dtype, const storage_file = try parseStorage(pid.ref);
// Pytorch store "item" strides, while ZML uses byte strides.
for (strides.slice()) |*s| s.* *= dtype.sizeOf();
// Same thing for the offset.
offset = offset * dtype.sizeOf();
const filename = try std.mem.join(tmp_allocator, "", &.{ self.data.zip_prefix, "data/", storage_file });
defer tmp_allocator.free(filename);
// The offset in the pickle is the offset inside the storage_file.
// But .pt are made of several files, so we need to append the file offset.
const storage = try self.getStorage(filename);
return HostBuffer.fromStridedSlice(
zml.Shape.init(dims.constSlice(), dtype),
storage[offset..],
strides.constSlice(),
);
}
fn parseStorage(val: value.Value) !struct { zml.DataType, []const u8 } {
if (val != .seq) return error.InvalidInput;
const sargs = val.seq.values;
if (val.seq.type == .tuple and
sargs.len >= 5 and
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
sargs[1] == .raw and sargs[1].raw == .global and
sargs[2] == .string and
sargs[3] == .string)
{
const op = sargs[1].raw.global;
const storage_file = sargs[2].string;
// const sdev = sargs[3].string;
if (!std.mem.eql(u8, "torch", op.module) or
!std.mem.endsWith(u8, op.class, "Storage"))
return error.InvalidInput;
return .{
try storageToDtype(op.class),
storage_file,
};
} else {
return error.InvalidInput;
}
}
/// Given the name of one of the files in the .pt tarball,
/// return the slice of the memory-mapped .pt corresponding to it.
fn getStorage(self: *PickleData, filename: []const u8) ![]const u8 {
const maybe_entry = self.data.file_map.get(filename);
if (maybe_entry == null) {
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
return error.TensorNotFound;
}
const entry = maybe_entry.?;
const base_offset: u64 = if (self.data.tar_file) |t| t.start else 0;
const file_offset: u64 = base_offset + entry.file_offset;
const file = self.data.buffer_file.file;
try file.seekTo(entry.file_offset);
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
const start = file_offset +
@sizeOf(std.zip.LocalFileHeader) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
return self.data.buffer_file.mappedSlice(start, entry.uncompressed_size);
}
fn parseDims(values: []Value) error{InvalidInput}!zml.Shape.DimsArray {
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
var result: zml.Shape.DimsArray = .{};
for (values) |val| {
switch (val) {
.int64 => |d| result.appendAssumeCapacity(d),
else => return error.InvalidInput,
}
}
return result;
}
};
test {
std.testing.refAllDecls(@This());
/// Convert from a torch.<type>Storage to a `zml.DataType`.
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
/// See https://pytorch.org/docs/stable/storage.html
fn storageToDtype(storage_type: []const u8) !zml.DataType {
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
const map = std.StaticStringMap(zml.DataType).initComptime(.{
.{ "Double", .f64 },
.{ "Float", .f32 },
.{ "Half", .f16 },
.{ "Long", .i64 },
.{ "Int", .i32 },
.{ "Short", .i16 },
.{ "Char", .i8 },
.{ "Byte", .u8 },
.{ "Bool", .bool },
.{ "BFloat16", .bf16 },
.{ "ComplexDouble", .c128 },
.{ "ComplexFloat", .c64 },
// QUInt8Storage
// QInt8Storage
// QInt32Storage
// QUInt4x2Storage
// QUInt2x4Storage
});
return map.get(torch_type) orelse {
log.err("Unsupported torch storage type: {s}", .{storage_type});
return error.UnsupportedDataType;
};
}

View File

@ -3,8 +3,8 @@ const zml = @import("../../zml.zig");
const meta = zml.meta;
const value = @import("value.zig");
const pickle = @import("pickle.zig");
const BTreeMap = @import("b_tree_map.zig").BTreeMap;
const PickleOp = @import("ops.zig").PickleOp;
const Build = value.Build;
const Object = value.Object;
@ -233,7 +233,7 @@ pub const PickleStack = struct {
}
};
pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
var stack = InternalStack.init(allocator);
defer stack.deinit();
var memo = PickleMemo.init(allocator);

View File

@ -1,23 +1,23 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../../zml.zig");
const utils = @import("utils.zig");
const PickleOp = @import("ops.zig").PickleOp;
const RawPickleOp = @import("ops.zig").RawPickleOp;
const Allocator = std.mem.Allocator;
const testing = std.testing;
const Allocator = std.mem.Allocator;
const zml = @import("../../zml.zig");
const pickle = @import("pickle.zig");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(Parser);
}
pub const Decoder = struct {
pub const Parser = struct {
// TODO: move the file logic to torch.PytorchFile
// the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
ops: []PickleOp,
ops: []pickle.Op,
is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{},
@ -53,11 +53,11 @@ pub const Decoder = struct {
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder {
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser {
const tar_stream = try TarStream.init(file);
const file_magic = try tar_stream.reader().readBytesNoEof(magic.len);
try tar_stream.seekTo(0);
var self: Decoder = .{
var self: Parser = .{
.buffer_file = mapped,
.tar_file = tar_stream,
.ops = undefined,
@ -72,10 +72,10 @@ pub const Decoder = struct {
return self;
}
pub fn init(allocator: Allocator, file: asynk.File) !Decoder {
pub fn init(allocator: Allocator, file: asynk.File) !Parser {
const file_magic = try file.reader().readBytesNoEof(magic.len);
try file.seekTo(0);
var self: Decoder = .{
var self: Parser = .{
.buffer_file = try zml.aio.MemoryMappedFile.init(file),
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.ops = undefined,
@ -89,12 +89,12 @@ pub const Decoder = struct {
return self;
}
pub fn deinit(self: *Decoder) void {
pub fn deinit(self: *Parser) void {
self.buffer_file.deinit();
self.* = undefined;
}
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op {
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
@ -167,12 +167,12 @@ pub const Decoder = struct {
return error.PickleNotFound;
}
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]PickleOp {
var results = std.ArrayList(PickleOp).init(allocator);
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op {
var results = std.ArrayList(pickle.Op).init(allocator);
errdefer results.deinit();
outer: while (true) {
const b = try reader.readByte();
switch (@as(RawPickleOp, @enumFromInt(b))) {
switch (@as(pickle.OpCode, @enumFromInt(b))) {
.mark => try results.append(.{ .mark = {} }),
.stop => {
try results.append(.{ .stop = {} });
@ -351,7 +351,7 @@ pub const Decoder = struct {
},
.next_buffer => try results.append(.{ .next_buffer = {} }),
.readonly_buffer => try results.append(.{ .readonly_buffer = {} }),
else => {},
_ => {},
}
}
return results.toOwnedSlice();
@ -411,7 +411,7 @@ test "Read pickle (simple)" {
const allocator = arena.allocator();
const eval = @import("eval.zig");
const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
var data = try Decoder.init(allocator, file);
var data = try Parser.init(allocator, file);
defer data.deinit();
var vals, var memo = try eval.evaluate(allocator, data.ops, true);
defer vals.deinit();
@ -457,11 +457,11 @@ test "Read pickle (zipped)" {
defer arena.deinit();
const allocator = arena.allocator();
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
var data = try Decoder.init(allocator, file);
var data = try Parser.init(allocator, file);
defer data.deinit();
}
pub fn isBadFilename(filename: []const u8) bool {
fn isBadFilename(filename: []const u8) bool {
if (filename.len == 0 or filename[0] == '/')
return true;

View File

@ -1,7 +1,7 @@
const std = @import("std");
/// A decoded Pickle operation in its natural state.
pub const PickleOp = union(RawPickleOp) {
pub const Op = union(OpCode) {
mark,
stop,
pop,
@ -73,7 +73,7 @@ pub const PickleOp = union(RawPickleOp) {
pub const PyType = struct { module: []const u8, class: []const u8 };
pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void {
pub fn deinit(self: Op, allocator: std.mem.Allocator) void {
switch (self) {
.float,
.int,
@ -103,7 +103,7 @@ pub const PickleOp = union(RawPickleOp) {
}
}
pub fn clone(self: PickleOp, allocator: std.mem.Allocator) !PickleOp {
pub fn clone(self: Op, allocator: std.mem.Allocator) !Op {
var res = self;
return switch (self) {
inline .float,
@ -144,7 +144,8 @@ pub const PickleOp = union(RawPickleOp) {
};
/// The values for the possible opcodes are in this enum.
pub const RawPickleOp = enum(u8) {
/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py
pub const OpCode = enum(u8) {
mark = '(', // push special markobject on stack
stop = '.', // every pickle ends with stop
pop = '0', // discard topmost stack item

View File

@ -1,23 +0,0 @@
const std = @import("std");
const Value = @import("value.zig").Value;
pub fn allTrue(values: []const Value, func: fn (v: Value) bool) bool {
for (values) |v| {
if (!func(v)) return false;
}
return true;
}
pub fn isBadFilename(filename: []const u8) bool {
if (filename.len == 0 or filename[0] == '/')
return true;
var it = std.mem.splitScalar(u8, filename, '/');
while (it.next()) |part| {
if (std.mem.eql(u8, part, ".."))
return true;
}
return false;
}

View File

@ -1,10 +1,8 @@
const std = @import("std");
const utils = @import("utils.zig");
const PickleOp = @import("ops.zig").PickleOp;
const big_int = std.math.big.int;
const pickle = @import("pickle.zig");
/// The types of sequences that exist.
pub const SequenceType = enum {
list,
@ -114,7 +112,7 @@ pub const ValueType = enum {
/// A processed value.
pub const Value = union(ValueType) {
/// Types that we can't handle or just had to give up on processing.
raw: PickleOp,
raw: pickle.Op,
/// A reference. You might be able to look it up in the memo map
/// unless there's something weird going on like recursive references.
@ -174,7 +172,7 @@ pub const Value = union(ValueType) {
float64: f64,
/// Some kind of weird number we can't handle.
raw_num: PickleOp,
raw_num: pickle.Op,
/// A boolean value.
boolval: bool,
@ -299,7 +297,12 @@ pub const Value = union(ValueType) {
pub fn isPrimitive(self: Value) bool {
return switch (self) {
.int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
.seq => |seq| utils.allTrue(seq.values, Value.isPrimitive),
.seq => |seq| {
for (seq.values) |v| {
if (!v.isPrimitive()) return false;
}
return true;
},
else => false,
};
}

View File

@ -1,7 +0,0 @@
pub fn toVoidSlice(data: anytype) []void {
const info = @typeInfo(@TypeOf(data));
if (info != .Pointer or info.Pointer.size != .Slice) {
@compileError("toVoidSlice expects a slice");
}
return @as([*]void, @ptrCast(@alignCast(data.ptr)))[0..data.len];
}

View File

@ -1,5 +1,4 @@
const std = @import("std");
const utils = @import("utils.zig");
const yaml = @import("zig-yaml");
const zml = @import("../zml.zig");