Rename and simplify modules in zml/aio/torch: replace redundant qualified names, remove generic utilities, inline code, reorder functions for top‑to‑bottom readability, and extract parsing logic into parseTensor and parseStorage functions.
This commit is contained in:
parent
66881899ca
commit
e25f70d923
@ -1,6 +1,5 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const utils = @import("../utils.zig");
|
||||
const zml = @import("../../zml.zig");
|
||||
|
||||
const assert = std.debug.assert;
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const utils = @import("utils.zig");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
|
||||
@ -4,7 +4,7 @@ const std = @import("std");
|
||||
const yaml = @import("zig-yaml");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
const Decoder = @import("torch/parser.zig").Decoder;
|
||||
const parser = @import("torch/parser.zig");
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
|
||||
@ -38,7 +38,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
||||
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
|
||||
const start = try mapped_file.file.getPos();
|
||||
var tmp: zml.aio.torch.PickleData = .{
|
||||
.data = try Decoder.fromTarFile(arena, mapped_file, file),
|
||||
.data = try parser.Parser.fromTarFile(arena, mapped_file, file),
|
||||
.memo = undefined,
|
||||
.stack = undefined,
|
||||
};
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const zml = @import("../zml.zig");
|
||||
const helpers = @import("../helpers.zig");
|
||||
const utils = @import("utils.zig");
|
||||
const json = @import("json.zig");
|
||||
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
|
||||
const HostBuffer = zml.HostBuffer;
|
||||
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
|
||||
@ -4,262 +4,72 @@ const zml = @import("../zml.zig");
|
||||
|
||||
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
|
||||
|
||||
const toVoidSlice = @import("utils.zig").toVoidSlice;
|
||||
const eval = @import("torch/eval.zig");
|
||||
const utils = @import("torch/utils.zig");
|
||||
const value = @import("torch/value.zig");
|
||||
const Decoder = @import("torch/parser.zig").Decoder;
|
||||
const parser = @import("torch/parser.zig");
|
||||
const PersId = value.PersId;
|
||||
const PickleMemo = eval.PickleMemo;
|
||||
const PickleStack = eval.PickleStack;
|
||||
const Sequence = value.Sequence;
|
||||
const Value = value.Value;
|
||||
const ValueType = value.ValueType;
|
||||
|
||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
const Allocator = std.mem.Allocator;
|
||||
const log = std.log.scoped(.zml_io);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(eval);
|
||||
std.testing.refAllDecls(value);
|
||||
std.testing.refAllDecls(parser);
|
||||
}
|
||||
|
||||
/// Opens and loads a BufferStore from the torch file at the given path.
|
||||
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
const file = asynk.File.open(path, .{}) catch |err| {
|
||||
log.err("Failed to open {s}: {}", .{ path, err });
|
||||
return err;
|
||||
};
|
||||
errdefer file.close() catch unreachable;
|
||||
|
||||
// Temporary memory needed to parse the pytorch file.
|
||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena.deinit();
|
||||
const tmp_alloc = arena.allocator();
|
||||
|
||||
const _parser = try parser.Parser.init(tmp_alloc, file);
|
||||
const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true);
|
||||
|
||||
// But we create the HostBuffer objects inside the result BufferStore arena.
|
||||
var res: zml.aio.BufferStore = .{
|
||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||
};
|
||||
|
||||
const arena = res.arena.allocator();
|
||||
|
||||
var tmp: PickleData = .{
|
||||
.data = try Decoder.init(arena, file),
|
||||
.memo = undefined,
|
||||
.stack = undefined,
|
||||
};
|
||||
tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true);
|
||||
res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{tmp.data.buffer_file});
|
||||
try tmp.parseModel(arena, &res);
|
||||
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file});
|
||||
var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack };
|
||||
try tmp.parseModel(res.arena.allocator(), &res);
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Convert from a torch.<type>Storage to a `zml.DataType`.
|
||||
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
|
||||
/// See https://pytorch.org/docs/stable/storage.html
|
||||
fn storageToDtype(storage_type: []const u8) !zml.DataType {
|
||||
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
|
||||
const map = std.StaticStringMap(zml.DataType).initComptime(.{
|
||||
.{ "Double", .f64 },
|
||||
.{ "Float", .f32 },
|
||||
.{ "Half", .f16 },
|
||||
.{ "Long", .i64 },
|
||||
.{ "Int", .i32 },
|
||||
.{ "Short", .i16 },
|
||||
.{ "Char", .i8 },
|
||||
.{ "Byte", .u8 },
|
||||
.{ "Bool", .bool },
|
||||
.{ "BFloat16", .bf16 },
|
||||
.{ "ComplexDouble", .c128 },
|
||||
.{ "ComplexFloat", .c64 },
|
||||
// QUInt8Storage
|
||||
// QInt8Storage
|
||||
// QInt32Storage
|
||||
// QUInt4x2Storage
|
||||
// QUInt2x4Storage
|
||||
});
|
||||
|
||||
return map.get(torch_type) orelse {
|
||||
log.err("Unsupported torch storage type: {s}", .{storage_type});
|
||||
return error.UnsupportedDataType;
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: rename me to PytorchFile
|
||||
pub const PickleData = struct {
|
||||
stack: PickleStack,
|
||||
memo: PickleMemo,
|
||||
data: Decoder,
|
||||
stack: eval.PickleStack,
|
||||
memo: eval.PickleMemo,
|
||||
data: parser.Parser,
|
||||
|
||||
fn basicTypeCheck(v: Value, ns: []const u8, name: []const u8) bool {
|
||||
return switch (v) {
|
||||
.global => |object| switch (object.member) {
|
||||
.raw => |raw| {
|
||||
if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) {
|
||||
return true;
|
||||
} else return false;
|
||||
},
|
||||
else => false,
|
||||
},
|
||||
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
|
||||
return switch (object.member) {
|
||||
.raw => |raw| return (object.args[0] == .seq and
|
||||
std.mem.eql(u8, module, raw.global.module) and
|
||||
std.mem.eql(u8, class, raw.global.class)),
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
fn isTensor(v: Value) bool {
|
||||
if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) {
|
||||
const args = v.global.args[0].seq.values;
|
||||
if (args.len >= 5 and
|
||||
args[0] == .pers_id and
|
||||
args[1] == .int64 and
|
||||
args[2] == .seq and args[2].seq.type == .tuple and
|
||||
args[3] == .seq and args[3].seq.type == .tuple)
|
||||
{
|
||||
return true;
|
||||
} else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fn dimsFromValues(values: []Value) [zml.Tensor.MAX_RANK]i64 {
|
||||
std.debug.assert(values.len <= zml.Tensor.MAX_RANK);
|
||||
var result: [zml.Tensor.MAX_RANK]i64 = undefined;
|
||||
for (values, result[0..values.len]) |val, *elem| {
|
||||
switch (val) {
|
||||
.int64 => |int| elem.* = int,
|
||||
else => @panic("Bad value for shape item"),
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn parseModel(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore) !void {
|
||||
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
|
||||
for (self.stack.stack) |item| {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
|
||||
}
|
||||
}
|
||||
|
||||
fn tensorOffset(self: *PickleData, seekable_stream: anytype, sfile: []const u8) !u64 {
|
||||
if (self.data.file_map.get(sfile)) |entry| {
|
||||
const local_header = blk: {
|
||||
try seekable_stream.seekTo(entry.file_offset);
|
||||
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
|
||||
};
|
||||
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
|
||||
return error.ZipBadFileOffset;
|
||||
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
|
||||
return error.ZipMismatchVersionNeeded;
|
||||
if (local_header.last_modification_time != entry.last_modification_time)
|
||||
return error.ZipMismatchModTime;
|
||||
if (local_header.last_modification_date != entry.last_modification_date)
|
||||
return error.ZipMismatchModDate;
|
||||
|
||||
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
|
||||
return error.ZipMismatchFlags;
|
||||
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
||||
return error.ZipMismatchCrc32;
|
||||
if (local_header.compressed_size != 0 and
|
||||
local_header.compressed_size != entry.compressed_size)
|
||||
return error.ZipMismatchCompLen;
|
||||
if (local_header.uncompressed_size != 0 and
|
||||
local_header.uncompressed_size != entry.uncompressed_size)
|
||||
return error.ZipMismatchUncompLen;
|
||||
if (local_header.filename_len != entry.filename_len)
|
||||
return error.ZipMismatchFilenameLen;
|
||||
|
||||
return (try seekable_stream.context.getPos()) +
|
||||
@as(u64, local_header.filename_len) +
|
||||
@as(u64, local_header.extra_len);
|
||||
}
|
||||
|
||||
std.log.err("Could not find file ending in `{s}` in archive", .{sfile});
|
||||
return error.TensorNotFound;
|
||||
}
|
||||
|
||||
fn parseTorchGlobal(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool {
|
||||
return switch (v) {
|
||||
.global => |object| {
|
||||
if (isTensor(v)) {
|
||||
const args = object.args[0].seq.values;
|
||||
const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq };
|
||||
const rank = raw_shape.values.len;
|
||||
const shape = dimsFromValues(raw_shape.values);
|
||||
var strides = dimsFromValues(raw_strides.values);
|
||||
const storage_type, const sfile = switch (pidval.ref) {
|
||||
.seq => |seq| blk: {
|
||||
const sargs = seq.values;
|
||||
if (seq.type == .tuple and
|
||||
sargs.len >= 5 and
|
||||
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
|
||||
sargs[1] == .raw and sargs[1].raw == .global and
|
||||
sargs[2] == .string and
|
||||
sargs[3] == .string)
|
||||
{
|
||||
const op = sargs[1].raw.global;
|
||||
const sfile = sargs[2].string;
|
||||
// const sdev = sargs[3].string;
|
||||
if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) {
|
||||
break :blk .{ op.class, sfile };
|
||||
} else @panic("Unexpected storage type part of persistant ID");
|
||||
} else @panic("Unexpected value for persistant ID");
|
||||
},
|
||||
else => @panic("Unexpected value for persistant ID"),
|
||||
};
|
||||
|
||||
const data_type = try storageToDtype(storage_type);
|
||||
for (strides[0..rank]) |*s| s.* *= data_type.sizeOf();
|
||||
|
||||
var sfile_buf = std.ArrayList(u8).init(allocator);
|
||||
defer sfile_buf.deinit();
|
||||
try sfile_buf.writer().print("{s}data/{s}", .{ self.data.zip_prefix, sfile });
|
||||
|
||||
// find offsets for tensor zip file
|
||||
const absolute_offset = blk: {
|
||||
if (self.data.tar_file) |t| {
|
||||
break :blk try self.tensorOffset(t.seekableStream(), sfile_buf.items);
|
||||
} else {
|
||||
break :blk try self.tensorOffset(self.data.buffer_file.file.seekableStream(), sfile_buf.items);
|
||||
}
|
||||
};
|
||||
offs = offs * data_type.sizeOf();
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const entry = try store.buffers.getOrPut(allocator, key);
|
||||
if (entry.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
}
|
||||
const out_shape = zml.Shape.init(shape[0..rank], data_type);
|
||||
entry.value_ptr.* = HostBuffer.fromStridedSlice(
|
||||
out_shape,
|
||||
self.data.buffer_file.mappedSlice((if (self.data.tar_file) |t| t.start else 0) + absolute_offset + offs, out_shape.byteSize()),
|
||||
strides[0..rank],
|
||||
);
|
||||
return true;
|
||||
} else if (basicTypeCheck(v, "torch", "Size")) {
|
||||
const size = object.args[0].seq.values[0].seq.values;
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const entry = try store._metadata.getOrPut(allocator, key);
|
||||
if (entry.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
}
|
||||
const d = try allocator.alloc(i64, size.len);
|
||||
for (d, 0..) |*di, i| di.* = size[i].int64;
|
||||
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
|
||||
return true;
|
||||
} else if (basicTypeCheck(v, "fractions", "Fraction")) {
|
||||
const fraction_str = object.args[0].seq.values[0].string;
|
||||
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".numerator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
|
||||
}
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".denominator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn parseValue(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void {
|
||||
pub fn parseValue(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void {
|
||||
switch (v) {
|
||||
.app, .object, .global => |object| {
|
||||
if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) {
|
||||
@ -415,8 +225,194 @@ pub const PickleData = struct {
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
|
||||
fn parseTorchGlobal(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool {
|
||||
return switch (v) {
|
||||
.global => |object| {
|
||||
if (try self.parseTensor(allocator, object)) |host_buffer| {
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const entry = try store.buffers.getOrPut(allocator, key);
|
||||
if (entry.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
}
|
||||
entry.value_ptr.* = host_buffer;
|
||||
return true;
|
||||
} else if (basicTypeCheck(object, "torch", "Size")) {
|
||||
const size = object.args[0].seq.values[0].seq.values;
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const entry = try store._metadata.getOrPut(allocator, key);
|
||||
if (entry.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
}
|
||||
const d = try allocator.alloc(i64, size.len);
|
||||
for (d, 0..) |*di, i| di.* = size[i].int64;
|
||||
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
|
||||
return true;
|
||||
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
|
||||
const fraction_str = object.args[0].seq.values[0].string;
|
||||
if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| {
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".numerator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
|
||||
}
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".denominator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
fn parseTensor(self: *PickleData, tmp_allocator: std.mem.Allocator, object: *value.Object) !?zml.HostBuffer {
|
||||
if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const args = object.args[0].seq.values;
|
||||
if (args.len < 4 or
|
||||
args[0] != .pers_id or
|
||||
args[1] != .int64 or
|
||||
args[2] != .seq or args[2].seq.type != .tuple or
|
||||
args[3] != .seq or args[3].seq.type != .tuple)
|
||||
{
|
||||
log.err("Unexpected value in call to torch._utils._rebuild_tensor_v2", .{});
|
||||
return error.InvalidInput;
|
||||
}
|
||||
|
||||
const pid: *PersId = args[0].pers_id;
|
||||
var offset: u64 = @intCast(args[1].int64);
|
||||
const raw_dims: Sequence = args[2].seq;
|
||||
const raw_strides: Sequence = args[3].seq;
|
||||
const dims = try parseDims(raw_dims.values);
|
||||
var strides = try parseDims(raw_strides.values);
|
||||
|
||||
const dtype, const storage_file = try parseStorage(pid.ref);
|
||||
// Pytorch store "item" strides, while ZML uses byte strides.
|
||||
for (strides.slice()) |*s| s.* *= dtype.sizeOf();
|
||||
// Same thing for the offset.
|
||||
offset = offset * dtype.sizeOf();
|
||||
|
||||
const filename = try std.mem.join(tmp_allocator, "", &.{ self.data.zip_prefix, "data/", storage_file });
|
||||
defer tmp_allocator.free(filename);
|
||||
|
||||
// The offset in the pickle is the offset inside the storage_file.
|
||||
// But .pt are made of several files, so we need to append the file offset.
|
||||
const storage = try self.getStorage(filename);
|
||||
return HostBuffer.fromStridedSlice(
|
||||
zml.Shape.init(dims.constSlice(), dtype),
|
||||
storage[offset..],
|
||||
strides.constSlice(),
|
||||
);
|
||||
}
|
||||
|
||||
fn parseStorage(val: value.Value) !struct { zml.DataType, []const u8 } {
|
||||
if (val != .seq) return error.InvalidInput;
|
||||
const sargs = val.seq.values;
|
||||
if (val.seq.type == .tuple and
|
||||
sargs.len >= 5 and
|
||||
sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and
|
||||
sargs[1] == .raw and sargs[1].raw == .global and
|
||||
sargs[2] == .string and
|
||||
sargs[3] == .string)
|
||||
{
|
||||
const op = sargs[1].raw.global;
|
||||
const storage_file = sargs[2].string;
|
||||
// const sdev = sargs[3].string;
|
||||
if (!std.mem.eql(u8, "torch", op.module) or
|
||||
!std.mem.endsWith(u8, op.class, "Storage"))
|
||||
return error.InvalidInput;
|
||||
|
||||
return .{
|
||||
try storageToDtype(op.class),
|
||||
storage_file,
|
||||
};
|
||||
} else {
|
||||
return error.InvalidInput;
|
||||
}
|
||||
}
|
||||
|
||||
/// Given the name of one of the files in the .pt tarball,
|
||||
/// return the slice of the memory-mapped .pt corresponding to it.
|
||||
fn getStorage(self: *PickleData, filename: []const u8) ![]const u8 {
|
||||
const maybe_entry = self.data.file_map.get(filename);
|
||||
if (maybe_entry == null) {
|
||||
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
|
||||
return error.TensorNotFound;
|
||||
}
|
||||
const entry = maybe_entry.?;
|
||||
const base_offset: u64 = if (self.data.tar_file) |t| t.start else 0;
|
||||
const file_offset: u64 = base_offset + entry.file_offset;
|
||||
const file = self.data.buffer_file.file;
|
||||
try file.seekTo(entry.file_offset);
|
||||
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
|
||||
|
||||
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
|
||||
return error.ZipBadFileOffset;
|
||||
if (local_header.compressed_size != 0 and
|
||||
local_header.compressed_size != entry.compressed_size)
|
||||
return error.ZipMismatchCompLen;
|
||||
if (local_header.uncompressed_size != 0 and
|
||||
local_header.uncompressed_size != entry.uncompressed_size)
|
||||
return error.ZipMismatchUncompLen;
|
||||
if (local_header.filename_len != entry.filename_len)
|
||||
return error.ZipMismatchFilenameLen;
|
||||
|
||||
const start = file_offset +
|
||||
@sizeOf(std.zip.LocalFileHeader) +
|
||||
@as(u64, local_header.filename_len) +
|
||||
@as(u64, local_header.extra_len);
|
||||
return self.data.buffer_file.mappedSlice(start, entry.uncompressed_size);
|
||||
}
|
||||
|
||||
fn parseDims(values: []Value) error{InvalidInput}!zml.Shape.DimsArray {
|
||||
zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len});
|
||||
var result: zml.Shape.DimsArray = .{};
|
||||
for (values) |val| {
|
||||
switch (val) {
|
||||
.int64 => |d| result.appendAssumeCapacity(d),
|
||||
else => return error.InvalidInput,
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
/// Convert from a torch.<type>Storage to a `zml.DataType`.
|
||||
/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage
|
||||
/// See https://pytorch.org/docs/stable/storage.html
|
||||
fn storageToDtype(storage_type: []const u8) !zml.DataType {
|
||||
const torch_type = storage_type[0 .. storage_type.len - "Storage".len];
|
||||
const map = std.StaticStringMap(zml.DataType).initComptime(.{
|
||||
.{ "Double", .f64 },
|
||||
.{ "Float", .f32 },
|
||||
.{ "Half", .f16 },
|
||||
.{ "Long", .i64 },
|
||||
.{ "Int", .i32 },
|
||||
.{ "Short", .i16 },
|
||||
.{ "Char", .i8 },
|
||||
.{ "Byte", .u8 },
|
||||
.{ "Bool", .bool },
|
||||
.{ "BFloat16", .bf16 },
|
||||
.{ "ComplexDouble", .c128 },
|
||||
.{ "ComplexFloat", .c64 },
|
||||
// QUInt8Storage
|
||||
// QInt8Storage
|
||||
// QInt32Storage
|
||||
// QUInt4x2Storage
|
||||
// QUInt2x4Storage
|
||||
});
|
||||
|
||||
return map.get(torch_type) orelse {
|
||||
log.err("Unsupported torch storage type: {s}", .{storage_type});
|
||||
return error.UnsupportedDataType;
|
||||
};
|
||||
}
|
||||
|
||||
@ -3,8 +3,8 @@ const zml = @import("../../zml.zig");
|
||||
const meta = zml.meta;
|
||||
|
||||
const value = @import("value.zig");
|
||||
const pickle = @import("pickle.zig");
|
||||
const BTreeMap = @import("b_tree_map.zig").BTreeMap;
|
||||
const PickleOp = @import("ops.zig").PickleOp;
|
||||
|
||||
const Build = value.Build;
|
||||
const Object = value.Object;
|
||||
@ -233,7 +233,7 @@ pub const PickleStack = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
|
||||
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
|
||||
var stack = InternalStack.init(allocator);
|
||||
defer stack.deinit();
|
||||
var memo = PickleMemo.init(allocator);
|
||||
|
||||
@ -1,23 +1,23 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const zml = @import("../../zml.zig");
|
||||
|
||||
const utils = @import("utils.zig");
|
||||
const PickleOp = @import("ops.zig").PickleOp;
|
||||
const RawPickleOp = @import("ops.zig").RawPickleOp;
|
||||
|
||||
const Allocator = std.mem.Allocator;
|
||||
const testing = std.testing;
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
const zml = @import("../../zml.zig");
|
||||
const pickle = @import("pickle.zig");
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(Parser);
|
||||
}
|
||||
|
||||
pub const Decoder = struct {
|
||||
pub const Parser = struct {
|
||||
// TODO: move the file logic to torch.PytorchFile
|
||||
// the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch
|
||||
buffer_file: zml.aio.MemoryMappedFile,
|
||||
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
|
||||
tar_file: ?TarStream = null,
|
||||
ops: []PickleOp,
|
||||
ops: []pickle.Op,
|
||||
is_zip_file: bool,
|
||||
zip_prefix: []const u8 = &[_]u8{},
|
||||
|
||||
@ -53,11 +53,11 @@ pub const Decoder = struct {
|
||||
|
||||
const magic = "PK\x03\x04";
|
||||
|
||||
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder {
|
||||
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser {
|
||||
const tar_stream = try TarStream.init(file);
|
||||
const file_magic = try tar_stream.reader().readBytesNoEof(magic.len);
|
||||
try tar_stream.seekTo(0);
|
||||
var self: Decoder = .{
|
||||
var self: Parser = .{
|
||||
.buffer_file = mapped,
|
||||
.tar_file = tar_stream,
|
||||
.ops = undefined,
|
||||
@ -72,10 +72,10 @@ pub const Decoder = struct {
|
||||
return self;
|
||||
}
|
||||
|
||||
pub fn init(allocator: Allocator, file: asynk.File) !Decoder {
|
||||
pub fn init(allocator: Allocator, file: asynk.File) !Parser {
|
||||
const file_magic = try file.reader().readBytesNoEof(magic.len);
|
||||
try file.seekTo(0);
|
||||
var self: Decoder = .{
|
||||
var self: Parser = .{
|
||||
.buffer_file = try zml.aio.MemoryMappedFile.init(file),
|
||||
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
|
||||
.ops = undefined,
|
||||
@ -89,12 +89,12 @@ pub const Decoder = struct {
|
||||
return self;
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Decoder) void {
|
||||
pub fn deinit(self: *Parser) void {
|
||||
self.buffer_file.deinit();
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
|
||||
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op {
|
||||
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
||||
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
while (try iter.next()) |entry| {
|
||||
@ -167,12 +167,12 @@ pub const Decoder = struct {
|
||||
return error.PickleNotFound;
|
||||
}
|
||||
|
||||
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]PickleOp {
|
||||
var results = std.ArrayList(PickleOp).init(allocator);
|
||||
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op {
|
||||
var results = std.ArrayList(pickle.Op).init(allocator);
|
||||
errdefer results.deinit();
|
||||
outer: while (true) {
|
||||
const b = try reader.readByte();
|
||||
switch (@as(RawPickleOp, @enumFromInt(b))) {
|
||||
switch (@as(pickle.OpCode, @enumFromInt(b))) {
|
||||
.mark => try results.append(.{ .mark = {} }),
|
||||
.stop => {
|
||||
try results.append(.{ .stop = {} });
|
||||
@ -351,7 +351,7 @@ pub const Decoder = struct {
|
||||
},
|
||||
.next_buffer => try results.append(.{ .next_buffer = {} }),
|
||||
.readonly_buffer => try results.append(.{ .readonly_buffer = {} }),
|
||||
else => {},
|
||||
_ => {},
|
||||
}
|
||||
}
|
||||
return results.toOwnedSlice();
|
||||
@ -411,7 +411,7 @@ test "Read pickle (simple)" {
|
||||
const allocator = arena.allocator();
|
||||
const eval = @import("eval.zig");
|
||||
const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
|
||||
var data = try Decoder.init(allocator, file);
|
||||
var data = try Parser.init(allocator, file);
|
||||
defer data.deinit();
|
||||
var vals, var memo = try eval.evaluate(allocator, data.ops, true);
|
||||
defer vals.deinit();
|
||||
@ -457,11 +457,11 @@ test "Read pickle (zipped)" {
|
||||
defer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
|
||||
var data = try Decoder.init(allocator, file);
|
||||
var data = try Parser.init(allocator, file);
|
||||
defer data.deinit();
|
||||
}
|
||||
|
||||
pub fn isBadFilename(filename: []const u8) bool {
|
||||
fn isBadFilename(filename: []const u8) bool {
|
||||
if (filename.len == 0 or filename[0] == '/')
|
||||
return true;
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
const std = @import("std");
|
||||
|
||||
/// A decoded Pickle operation in its natural state.
|
||||
pub const PickleOp = union(RawPickleOp) {
|
||||
pub const Op = union(OpCode) {
|
||||
mark,
|
||||
stop,
|
||||
pop,
|
||||
@ -73,7 +73,7 @@ pub const PickleOp = union(RawPickleOp) {
|
||||
|
||||
pub const PyType = struct { module: []const u8, class: []const u8 };
|
||||
|
||||
pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void {
|
||||
pub fn deinit(self: Op, allocator: std.mem.Allocator) void {
|
||||
switch (self) {
|
||||
.float,
|
||||
.int,
|
||||
@ -103,7 +103,7 @@ pub const PickleOp = union(RawPickleOp) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone(self: PickleOp, allocator: std.mem.Allocator) !PickleOp {
|
||||
pub fn clone(self: Op, allocator: std.mem.Allocator) !Op {
|
||||
var res = self;
|
||||
return switch (self) {
|
||||
inline .float,
|
||||
@ -144,7 +144,8 @@ pub const PickleOp = union(RawPickleOp) {
|
||||
};
|
||||
|
||||
/// The values for the possible opcodes are in this enum.
|
||||
pub const RawPickleOp = enum(u8) {
|
||||
/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py
|
||||
pub const OpCode = enum(u8) {
|
||||
mark = '(', // push special markobject on stack
|
||||
stop = '.', // every pickle ends with stop
|
||||
pop = '0', // discard topmost stack item
|
||||
@ -1,23 +0,0 @@
|
||||
const std = @import("std");
|
||||
|
||||
const Value = @import("value.zig").Value;
|
||||
|
||||
pub fn allTrue(values: []const Value, func: fn (v: Value) bool) bool {
|
||||
for (values) |v| {
|
||||
if (!func(v)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
pub fn isBadFilename(filename: []const u8) bool {
|
||||
if (filename.len == 0 or filename[0] == '/')
|
||||
return true;
|
||||
|
||||
var it = std.mem.splitScalar(u8, filename, '/');
|
||||
while (it.next()) |part| {
|
||||
if (std.mem.eql(u8, part, ".."))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -1,10 +1,8 @@
|
||||
const std = @import("std");
|
||||
const utils = @import("utils.zig");
|
||||
|
||||
const PickleOp = @import("ops.zig").PickleOp;
|
||||
|
||||
const big_int = std.math.big.int;
|
||||
|
||||
const pickle = @import("pickle.zig");
|
||||
|
||||
/// The types of sequences that exist.
|
||||
pub const SequenceType = enum {
|
||||
list,
|
||||
@ -114,7 +112,7 @@ pub const ValueType = enum {
|
||||
/// A processed value.
|
||||
pub const Value = union(ValueType) {
|
||||
/// Types that we can't handle or just had to give up on processing.
|
||||
raw: PickleOp,
|
||||
raw: pickle.Op,
|
||||
|
||||
/// A reference. You might be able to look it up in the memo map
|
||||
/// unless there's something weird going on like recursive references.
|
||||
@ -174,7 +172,7 @@ pub const Value = union(ValueType) {
|
||||
float64: f64,
|
||||
|
||||
/// Some kind of weird number we can't handle.
|
||||
raw_num: PickleOp,
|
||||
raw_num: pickle.Op,
|
||||
|
||||
/// A boolean value.
|
||||
boolval: bool,
|
||||
@ -299,7 +297,12 @@ pub const Value = union(ValueType) {
|
||||
pub fn isPrimitive(self: Value) bool {
|
||||
return switch (self) {
|
||||
.int64, .bigint, .float64, .string, .bytes, .boolval, .none => true,
|
||||
.seq => |seq| utils.allTrue(seq.values, Value.isPrimitive),
|
||||
.seq => |seq| {
|
||||
for (seq.values) |v| {
|
||||
if (!v.isPrimitive()) return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
pub fn toVoidSlice(data: anytype) []void {
|
||||
const info = @typeInfo(@TypeOf(data));
|
||||
if (info != .Pointer or info.Pointer.size != .Slice) {
|
||||
@compileError("toVoidSlice expects a slice");
|
||||
}
|
||||
return @as([*]void, @ptrCast(@alignCast(data.ptr)))[0..data.len];
|
||||
}
|
||||
@ -1,5 +1,4 @@
|
||||
const std = @import("std");
|
||||
const utils = @import("utils.zig");
|
||||
const yaml = @import("zig-yaml");
|
||||
const zml = @import("../zml.zig");
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user