From e25f70d923142348c4e24ae355692d8118a2f2f3 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 4 Apr 2023 17:20:53 +0000 Subject: [PATCH] =?UTF-8?q?Rename=20and=20simplify=20modules=20in=20`zml/a?= =?UTF-8?q?io/torch`:=20replace=20redundant=20qualified=20names,=20remove?= =?UTF-8?q?=20generic=20utilities,=20inline=20code,=20reorder=20functions?= =?UTF-8?q?=20for=20top=E2=80=91to=E2=80=91bottom=20readability,=20and=20e?= =?UTF-8?q?xtract=20parsing=20logic=20into=20`parseTensor`=20and=20`parseS?= =?UTF-8?q?torage`=20functions.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/aio/gguf/core.zig | 1 - zml/aio/json.zig | 1 - zml/aio/nemo.zig | 4 +- zml/aio/safetensors.zig | 4 +- zml/aio/torch.zig | 442 +++++++++++++------------- zml/aio/torch/eval.zig | 4 +- zml/aio/torch/parser.zig | 44 +-- zml/aio/torch/{ops.zig => pickle.zig} | 9 +- zml/aio/torch/utils.zig | 23 -- zml/aio/torch/value.zig | 17 +- zml/aio/utils.zig | 7 - zml/aio/yaml.zig | 1 - 12 files changed, 261 insertions(+), 296 deletions(-) rename zml/aio/torch/{ops.zig => pickle.zig} (96%) delete mode 100644 zml/aio/torch/utils.zig delete mode 100644 zml/aio/utils.zig diff --git a/zml/aio/gguf/core.zig b/zml/aio/gguf/core.zig index 61b5137..a7feadf 100644 --- a/zml/aio/gguf/core.zig +++ b/zml/aio/gguf/core.zig @@ -1,6 +1,5 @@ const asynk = @import("async"); const std = @import("std"); -const utils = @import("../utils.zig"); const zml = @import("../../zml.zig"); const assert = std.debug.assert; diff --git a/zml/aio/json.zig b/zml/aio/json.zig index c8ba494..0c1634a 100644 --- a/zml/aio/json.zig +++ b/zml/aio/json.zig @@ -1,6 +1,5 @@ const asynk = @import("async"); const std = @import("std"); -const utils = @import("utils.zig"); const zml = @import("../zml.zig"); const StringBuilder = std.ArrayListUnmanaged(u8); diff --git a/zml/aio/nemo.zig b/zml/aio/nemo.zig index d52854e..feacc63 100644 --- a/zml/aio/nemo.zig +++ b/zml/aio/nemo.zig @@ -4,7 +4,7 @@ const std = @import("std"); const yaml = @import("zig-yaml"); const zml = @import("../zml.zig"); -const Decoder = @import("torch/parser.zig").Decoder; +const parser = @import("torch/parser.zig"); const StringBuilder = std.ArrayListUnmanaged(u8); @@ -38,7 +38,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore } else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) { const start = try mapped_file.file.getPos(); var tmp: zml.aio.torch.PickleData = .{ - .data = try Decoder.fromTarFile(arena, mapped_file, file), + .data = try parser.Parser.fromTarFile(arena, mapped_file, file), .memo = undefined, .stack = undefined, }; diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index 519be29..c87eb0f 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -1,10 +1,8 @@ const asynk = @import("async"); const std = @import("std"); const zml = @import("../zml.zig"); -const helpers = @import("../helpers.zig"); -const utils = @import("utils.zig"); const json = @import("json.zig"); -const HostBuffer = @import("../hostbuffer.zig").HostBuffer; +const HostBuffer = zml.HostBuffer; const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; const StringBuilder = std.ArrayListUnmanaged(u8); diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 5620a69..395ae60 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -4,262 +4,72 @@ const zml = @import("../zml.zig"); const HostBuffer = @import("../hostbuffer.zig").HostBuffer; -const toVoidSlice = @import("utils.zig").toVoidSlice; const eval = @import("torch/eval.zig"); -const utils = @import("torch/utils.zig"); const value = @import("torch/value.zig"); -const Decoder = @import("torch/parser.zig").Decoder; +const parser = @import("torch/parser.zig"); const PersId = value.PersId; -const PickleMemo = eval.PickleMemo; -const PickleStack = eval.PickleStack; const Sequence = value.Sequence; const Value = value.Value; const ValueType = value.ValueType; const StringBuilder = std.ArrayListUnmanaged(u8); -const Allocator = std.mem.Allocator; const log = std.log.scoped(.zml_io); +test { + std.testing.refAllDecls(eval); + std.testing.refAllDecls(value); + std.testing.refAllDecls(parser); +} + /// Opens and loads a BufferStore from the torch file at the given path. -pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { +pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { const file = asynk.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; }; errdefer file.close() catch unreachable; + // Temporary memory needed to parse the pytorch file. + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + const tmp_alloc = arena.allocator(); + + const _parser = try parser.Parser.init(tmp_alloc, file); + const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true); + + // But we create the HostBuffer objects inside the result BufferStore arena. var res: zml.aio.BufferStore = .{ .arena = std.heap.ArenaAllocator.init(allocator), }; - - const arena = res.arena.allocator(); - - var tmp: PickleData = .{ - .data = try Decoder.init(arena, file), - .memo = undefined, - .stack = undefined, - }; - tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true); - res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{tmp.data.buffer_file}); - try tmp.parseModel(arena, &res); + res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file}); + var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack }; + try tmp.parseModel(res.arena.allocator(), &res); return res; } -/// Convert from a torch.Storage to a `zml.DataType`. -/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage -/// See https://pytorch.org/docs/stable/storage.html -fn storageToDtype(storage_type: []const u8) !zml.DataType { - const torch_type = storage_type[0 .. storage_type.len - "Storage".len]; - const map = std.StaticStringMap(zml.DataType).initComptime(.{ - .{ "Double", .f64 }, - .{ "Float", .f32 }, - .{ "Half", .f16 }, - .{ "Long", .i64 }, - .{ "Int", .i32 }, - .{ "Short", .i16 }, - .{ "Char", .i8 }, - .{ "Byte", .u8 }, - .{ "Bool", .bool }, - .{ "BFloat16", .bf16 }, - .{ "ComplexDouble", .c128 }, - .{ "ComplexFloat", .c64 }, - // QUInt8Storage - // QInt8Storage - // QInt32Storage - // QUInt4x2Storage - // QUInt2x4Storage - }); - - return map.get(torch_type) orelse { - log.err("Unsupported torch storage type: {s}", .{storage_type}); - return error.UnsupportedDataType; - }; -} - +// TODO: rename me to PytorchFile pub const PickleData = struct { - stack: PickleStack, - memo: PickleMemo, - data: Decoder, + stack: eval.PickleStack, + memo: eval.PickleMemo, + data: parser.Parser, - fn basicTypeCheck(v: Value, ns: []const u8, name: []const u8) bool { - return switch (v) { - .global => |object| switch (object.member) { - .raw => |raw| { - if (std.mem.eql(u8, ns, raw.global.module) and std.mem.eql(u8, name, raw.global.class) and object.args[0] == .seq) { - return true; - } else return false; - }, - else => false, - }, + fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool { + return switch (object.member) { + .raw => |raw| return (object.args[0] == .seq and + std.mem.eql(u8, module, raw.global.module) and + std.mem.eql(u8, class, raw.global.class)), else => false, }; } - fn isTensor(v: Value) bool { - if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) { - const args = v.global.args[0].seq.values; - if (args.len >= 5 and - args[0] == .pers_id and - args[1] == .int64 and - args[2] == .seq and args[2].seq.type == .tuple and - args[3] == .seq and args[3].seq.type == .tuple) - { - return true; - } else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2"); - } - return false; - } - - fn dimsFromValues(values: []Value) [zml.Tensor.MAX_RANK]i64 { - std.debug.assert(values.len <= zml.Tensor.MAX_RANK); - var result: [zml.Tensor.MAX_RANK]i64 = undefined; - for (values, result[0..values.len]) |val, *elem| { - switch (val) { - .int64 => |int| elem.* = int, - else => @panic("Bad value for shape item"), - } - } - return result; - } - - pub fn parseModel(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore) !void { + pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void { for (self.stack.stack) |item| { var prefix_buf: [1024]u8 = undefined; try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); } } - fn tensorOffset(self: *PickleData, seekable_stream: anytype, sfile: []const u8) !u64 { - if (self.data.file_map.get(sfile)) |entry| { - const local_header = blk: { - try seekable_stream.seekTo(entry.file_offset); - break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little); - }; - if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) - return error.ZipBadFileOffset; - if (local_header.version_needed_to_extract != entry.version_needed_to_extract) - return error.ZipMismatchVersionNeeded; - if (local_header.last_modification_time != entry.last_modification_time) - return error.ZipMismatchModTime; - if (local_header.last_modification_date != entry.last_modification_date) - return error.ZipMismatchModDate; - - if (@as(u16, @bitCast(local_header.flags)) != entry.flags) - return error.ZipMismatchFlags; - if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) - return error.ZipMismatchCrc32; - if (local_header.compressed_size != 0 and - local_header.compressed_size != entry.compressed_size) - return error.ZipMismatchCompLen; - if (local_header.uncompressed_size != 0 and - local_header.uncompressed_size != entry.uncompressed_size) - return error.ZipMismatchUncompLen; - if (local_header.filename_len != entry.filename_len) - return error.ZipMismatchFilenameLen; - - return (try seekable_stream.context.getPos()) + - @as(u64, local_header.filename_len) + - @as(u64, local_header.extra_len); - } - - std.log.err("Could not find file ending in `{s}` in archive", .{sfile}); - return error.TensorNotFound; - } - - fn parseTorchGlobal(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool { - return switch (v) { - .global => |object| { - if (isTensor(v)) { - const args = object.args[0].seq.values; - const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int64), args[2].seq, args[3].seq }; - const rank = raw_shape.values.len; - const shape = dimsFromValues(raw_shape.values); - var strides = dimsFromValues(raw_strides.values); - const storage_type, const sfile = switch (pidval.ref) { - .seq => |seq| blk: { - const sargs = seq.values; - if (seq.type == .tuple and - sargs.len >= 5 and - sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and - sargs[1] == .raw and sargs[1].raw == .global and - sargs[2] == .string and - sargs[3] == .string) - { - const op = sargs[1].raw.global; - const sfile = sargs[2].string; - // const sdev = sargs[3].string; - if (std.mem.eql(u8, "torch", op.module) and std.mem.endsWith(u8, op.class, "Storage")) { - break :blk .{ op.class, sfile }; - } else @panic("Unexpected storage type part of persistant ID"); - } else @panic("Unexpected value for persistant ID"); - }, - else => @panic("Unexpected value for persistant ID"), - }; - - const data_type = try storageToDtype(storage_type); - for (strides[0..rank]) |*s| s.* *= data_type.sizeOf(); - - var sfile_buf = std.ArrayList(u8).init(allocator); - defer sfile_buf.deinit(); - try sfile_buf.writer().print("{s}data/{s}", .{ self.data.zip_prefix, sfile }); - - // find offsets for tensor zip file - const absolute_offset = blk: { - if (self.data.tar_file) |t| { - break :blk try self.tensorOffset(t.seekableStream(), sfile_buf.items); - } else { - break :blk try self.tensorOffset(self.data.buffer_file.file.seekableStream(), sfile_buf.items); - } - }; - offs = offs * data_type.sizeOf(); - const key = try allocator.dupe(u8, prefix.items); - const entry = try store.buffers.getOrPut(allocator, key); - if (entry.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } - const out_shape = zml.Shape.init(shape[0..rank], data_type); - entry.value_ptr.* = HostBuffer.fromStridedSlice( - out_shape, - self.data.buffer_file.mappedSlice((if (self.data.tar_file) |t| t.start else 0) + absolute_offset + offs, out_shape.byteSize()), - strides[0..rank], - ); - return true; - } else if (basicTypeCheck(v, "torch", "Size")) { - const size = object.args[0].seq.values[0].seq.values; - const key = try allocator.dupe(u8, prefix.items); - const entry = try store._metadata.getOrPut(allocator, key); - if (entry.found_existing) { - log.warn("Duplicate key: {s}", .{prefix.items}); - allocator.free(key); - } - const d = try allocator.alloc(i64, size.len); - for (d, 0..) |*di, i| di.* = size[i].int64; - entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; - return true; - } else if (basicTypeCheck(v, "fractions", "Fraction")) { - const fraction_str = object.args[0].seq.values[0].string; - if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { - { - var new_prefix = prefix; - new_prefix.appendSliceAssumeCapacity(".numerator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); - } - { - var new_prefix = prefix; - new_prefix.appendSliceAssumeCapacity(".denominator"); - try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); - } - return true; - } - } - return false; - }, - else => false, - }; - } - - pub fn parseValue(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void { + pub fn parseValue(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void { switch (v) { .app, .object, .global => |object| { if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) { @@ -415,8 +225,194 @@ pub const PickleData = struct { else => {}, } } + + fn parseTorchGlobal(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool { + return switch (v) { + .global => |object| { + if (try self.parseTensor(allocator, object)) |host_buffer| { + const key = try allocator.dupe(u8, prefix.items); + const entry = try store.buffers.getOrPut(allocator, key); + if (entry.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } + entry.value_ptr.* = host_buffer; + return true; + } else if (basicTypeCheck(object, "torch", "Size")) { + const size = object.args[0].seq.values[0].seq.values; + const key = try allocator.dupe(u8, prefix.items); + const entry = try store._metadata.getOrPut(allocator, key); + if (entry.found_existing) { + log.warn("Duplicate key: {s}", .{prefix.items}); + allocator.free(key); + } + const d = try allocator.alloc(i64, size.len); + for (d, 0..) |*di, i| di.* = size[i].int64; + entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; + return true; + } else if (basicTypeCheck(object, "fractions", "Fraction")) { + const fraction_str = object.args[0].seq.values[0].string; + if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { + { + var new_prefix = prefix; + new_prefix.appendSliceAssumeCapacity(".numerator"); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); + } + { + var new_prefix = prefix; + new_prefix.appendSliceAssumeCapacity(".denominator"); + try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); + } + return true; + } + } + return false; + }, + else => false, + }; + } + + fn parseTensor(self: *PickleData, tmp_allocator: std.mem.Allocator, object: *value.Object) !?zml.HostBuffer { + if (!basicTypeCheck(object, "torch._utils", "_rebuild_tensor_v2")) { + return null; + } + + const args = object.args[0].seq.values; + if (args.len < 4 or + args[0] != .pers_id or + args[1] != .int64 or + args[2] != .seq or args[2].seq.type != .tuple or + args[3] != .seq or args[3].seq.type != .tuple) + { + log.err("Unexpected value in call to torch._utils._rebuild_tensor_v2", .{}); + return error.InvalidInput; + } + + const pid: *PersId = args[0].pers_id; + var offset: u64 = @intCast(args[1].int64); + const raw_dims: Sequence = args[2].seq; + const raw_strides: Sequence = args[3].seq; + const dims = try parseDims(raw_dims.values); + var strides = try parseDims(raw_strides.values); + + const dtype, const storage_file = try parseStorage(pid.ref); + // Pytorch store "item" strides, while ZML uses byte strides. + for (strides.slice()) |*s| s.* *= dtype.sizeOf(); + // Same thing for the offset. + offset = offset * dtype.sizeOf(); + + const filename = try std.mem.join(tmp_allocator, "", &.{ self.data.zip_prefix, "data/", storage_file }); + defer tmp_allocator.free(filename); + + // The offset in the pickle is the offset inside the storage_file. + // But .pt are made of several files, so we need to append the file offset. + const storage = try self.getStorage(filename); + return HostBuffer.fromStridedSlice( + zml.Shape.init(dims.constSlice(), dtype), + storage[offset..], + strides.constSlice(), + ); + } + + fn parseStorage(val: value.Value) !struct { zml.DataType, []const u8 } { + if (val != .seq) return error.InvalidInput; + const sargs = val.seq.values; + if (val.seq.type == .tuple and + sargs.len >= 5 and + sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and + sargs[1] == .raw and sargs[1].raw == .global and + sargs[2] == .string and + sargs[3] == .string) + { + const op = sargs[1].raw.global; + const storage_file = sargs[2].string; + // const sdev = sargs[3].string; + if (!std.mem.eql(u8, "torch", op.module) or + !std.mem.endsWith(u8, op.class, "Storage")) + return error.InvalidInput; + + return .{ + try storageToDtype(op.class), + storage_file, + }; + } else { + return error.InvalidInput; + } + } + + /// Given the name of one of the files in the .pt tarball, + /// return the slice of the memory-mapped .pt corresponding to it. + fn getStorage(self: *PickleData, filename: []const u8) ![]const u8 { + const maybe_entry = self.data.file_map.get(filename); + if (maybe_entry == null) { + std.log.err("Could not find file ending in `{s}` in archive", .{filename}); + return error.TensorNotFound; + } + const entry = maybe_entry.?; + const base_offset: u64 = if (self.data.tar_file) |t| t.start else 0; + const file_offset: u64 = base_offset + entry.file_offset; + const file = self.data.buffer_file.file; + try file.seekTo(entry.file_offset); + const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little); + + if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) + return error.ZipBadFileOffset; + if (local_header.compressed_size != 0 and + local_header.compressed_size != entry.compressed_size) + return error.ZipMismatchCompLen; + if (local_header.uncompressed_size != 0 and + local_header.uncompressed_size != entry.uncompressed_size) + return error.ZipMismatchUncompLen; + if (local_header.filename_len != entry.filename_len) + return error.ZipMismatchFilenameLen; + + const start = file_offset + + @sizeOf(std.zip.LocalFileHeader) + + @as(u64, local_header.filename_len) + + @as(u64, local_header.extra_len); + return self.data.buffer_file.mappedSlice(start, entry.uncompressed_size); + } + + fn parseDims(values: []Value) error{InvalidInput}!zml.Shape.DimsArray { + zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); + var result: zml.Shape.DimsArray = .{}; + for (values) |val| { + switch (val) { + .int64 => |d| result.appendAssumeCapacity(d), + else => return error.InvalidInput, + } + } + return result; + } }; -test { - std.testing.refAllDecls(@This()); +/// Convert from a torch.Storage to a `zml.DataType`. +/// TODO: make this future proof, storage type are going to get replaced with torch.UntypedStorage +/// See https://pytorch.org/docs/stable/storage.html +fn storageToDtype(storage_type: []const u8) !zml.DataType { + const torch_type = storage_type[0 .. storage_type.len - "Storage".len]; + const map = std.StaticStringMap(zml.DataType).initComptime(.{ + .{ "Double", .f64 }, + .{ "Float", .f32 }, + .{ "Half", .f16 }, + .{ "Long", .i64 }, + .{ "Int", .i32 }, + .{ "Short", .i16 }, + .{ "Char", .i8 }, + .{ "Byte", .u8 }, + .{ "Bool", .bool }, + .{ "BFloat16", .bf16 }, + .{ "ComplexDouble", .c128 }, + .{ "ComplexFloat", .c64 }, + // QUInt8Storage + // QInt8Storage + // QInt32Storage + // QUInt4x2Storage + // QUInt2x4Storage + }); + + return map.get(torch_type) orelse { + log.err("Unsupported torch storage type: {s}", .{storage_type}); + return error.UnsupportedDataType; + }; } diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index d54f02d..d5070c8 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -3,8 +3,8 @@ const zml = @import("../../zml.zig"); const meta = zml.meta; const value = @import("value.zig"); +const pickle = @import("pickle.zig"); const BTreeMap = @import("b_tree_map.zig").BTreeMap; -const PickleOp = @import("ops.zig").PickleOp; const Build = value.Build; const Object = value.Object; @@ -233,7 +233,7 @@ pub const PickleStack = struct { } }; -pub fn evaluate(allocator: std.mem.Allocator, x: []const PickleOp, resolve_refs: bool) !struct { PickleStack, PickleMemo } { +pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } { var stack = InternalStack.init(allocator); defer stack.deinit(); var memo = PickleMemo.init(allocator); diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index 754b055..95341bd 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -1,23 +1,23 @@ const asynk = @import("async"); const std = @import("std"); -const zml = @import("../../zml.zig"); - -const utils = @import("utils.zig"); -const PickleOp = @import("ops.zig").PickleOp; -const RawPickleOp = @import("ops.zig").RawPickleOp; - -const Allocator = std.mem.Allocator; const testing = std.testing; +const Allocator = std.mem.Allocator; + +const zml = @import("../../zml.zig"); +const pickle = @import("pickle.zig"); test { std.testing.refAllDecls(@This()); + std.testing.refAllDecls(Parser); } -pub const Decoder = struct { +pub const Parser = struct { + // TODO: move the file logic to torch.PytorchFile + // the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch buffer_file: zml.aio.MemoryMappedFile, file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, tar_file: ?TarStream = null, - ops: []PickleOp, + ops: []pickle.Op, is_zip_file: bool, zip_prefix: []const u8 = &[_]u8{}, @@ -53,11 +53,11 @@ pub const Decoder = struct { const magic = "PK\x03\x04"; - pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder { + pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser { const tar_stream = try TarStream.init(file); const file_magic = try tar_stream.reader().readBytesNoEof(magic.len); try tar_stream.seekTo(0); - var self: Decoder = .{ + var self: Parser = .{ .buffer_file = mapped, .tar_file = tar_stream, .ops = undefined, @@ -72,10 +72,10 @@ pub const Decoder = struct { return self; } - pub fn init(allocator: Allocator, file: asynk.File) !Decoder { + pub fn init(allocator: Allocator, file: asynk.File) !Parser { const file_magic = try file.reader().readBytesNoEof(magic.len); try file.seekTo(0); - var self: Decoder = .{ + var self: Parser = .{ .buffer_file = try zml.aio.MemoryMappedFile.init(file), .is_zip_file = std.mem.eql(u8, &file_magic, magic), .ops = undefined, @@ -89,12 +89,12 @@ pub const Decoder = struct { return self; } - pub fn deinit(self: *Decoder) void { + pub fn deinit(self: *Parser) void { self.buffer_file.deinit(); self.* = undefined; } - fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp { + fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op { var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); var filename_buf: [std.fs.max_path_bytes]u8 = undefined; while (try iter.next()) |entry| { @@ -167,12 +167,12 @@ pub const Decoder = struct { return error.PickleNotFound; } - fn parse(allocator: Allocator, reader: anytype, len: usize) ![]PickleOp { - var results = std.ArrayList(PickleOp).init(allocator); + fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op { + var results = std.ArrayList(pickle.Op).init(allocator); errdefer results.deinit(); outer: while (true) { const b = try reader.readByte(); - switch (@as(RawPickleOp, @enumFromInt(b))) { + switch (@as(pickle.OpCode, @enumFromInt(b))) { .mark => try results.append(.{ .mark = {} }), .stop => { try results.append(.{ .stop = {} }); @@ -351,7 +351,7 @@ pub const Decoder = struct { }, .next_buffer => try results.append(.{ .next_buffer = {} }), .readonly_buffer => try results.append(.{ .readonly_buffer = {} }), - else => {}, + _ => {}, } } return results.toOwnedSlice(); @@ -411,7 +411,7 @@ test "Read pickle (simple)" { const allocator = arena.allocator(); const eval = @import("eval.zig"); const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only }); - var data = try Decoder.init(allocator, file); + var data = try Parser.init(allocator, file); defer data.deinit(); var vals, var memo = try eval.evaluate(allocator, data.ops, true); defer vals.deinit(); @@ -457,11 +457,11 @@ test "Read pickle (zipped)" { defer arena.deinit(); const allocator = arena.allocator(); const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only }); - var data = try Decoder.init(allocator, file); + var data = try Parser.init(allocator, file); defer data.deinit(); } -pub fn isBadFilename(filename: []const u8) bool { +fn isBadFilename(filename: []const u8) bool { if (filename.len == 0 or filename[0] == '/') return true; diff --git a/zml/aio/torch/ops.zig b/zml/aio/torch/pickle.zig similarity index 96% rename from zml/aio/torch/ops.zig rename to zml/aio/torch/pickle.zig index 1a00f5d..3d77508 100644 --- a/zml/aio/torch/ops.zig +++ b/zml/aio/torch/pickle.zig @@ -1,7 +1,7 @@ const std = @import("std"); /// A decoded Pickle operation in its natural state. -pub const PickleOp = union(RawPickleOp) { +pub const Op = union(OpCode) { mark, stop, pop, @@ -73,7 +73,7 @@ pub const PickleOp = union(RawPickleOp) { pub const PyType = struct { module: []const u8, class: []const u8 }; - pub fn deinit(self: PickleOp, allocator: std.mem.Allocator) void { + pub fn deinit(self: Op, allocator: std.mem.Allocator) void { switch (self) { .float, .int, @@ -103,7 +103,7 @@ pub const PickleOp = union(RawPickleOp) { } } - pub fn clone(self: PickleOp, allocator: std.mem.Allocator) !PickleOp { + pub fn clone(self: Op, allocator: std.mem.Allocator) !Op { var res = self; return switch (self) { inline .float, @@ -144,7 +144,8 @@ pub const PickleOp = union(RawPickleOp) { }; /// The values for the possible opcodes are in this enum. -pub const RawPickleOp = enum(u8) { +/// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py +pub const OpCode = enum(u8) { mark = '(', // push special markobject on stack stop = '.', // every pickle ends with stop pop = '0', // discard topmost stack item diff --git a/zml/aio/torch/utils.zig b/zml/aio/torch/utils.zig deleted file mode 100644 index 2b573dd..0000000 --- a/zml/aio/torch/utils.zig +++ /dev/null @@ -1,23 +0,0 @@ -const std = @import("std"); - -const Value = @import("value.zig").Value; - -pub fn allTrue(values: []const Value, func: fn (v: Value) bool) bool { - for (values) |v| { - if (!func(v)) return false; - } - return true; -} - -pub fn isBadFilename(filename: []const u8) bool { - if (filename.len == 0 or filename[0] == '/') - return true; - - var it = std.mem.splitScalar(u8, filename, '/'); - while (it.next()) |part| { - if (std.mem.eql(u8, part, "..")) - return true; - } - - return false; -} diff --git a/zml/aio/torch/value.zig b/zml/aio/torch/value.zig index 9c6acc4..b9cb4bb 100644 --- a/zml/aio/torch/value.zig +++ b/zml/aio/torch/value.zig @@ -1,10 +1,8 @@ const std = @import("std"); -const utils = @import("utils.zig"); - -const PickleOp = @import("ops.zig").PickleOp; - const big_int = std.math.big.int; +const pickle = @import("pickle.zig"); + /// The types of sequences that exist. pub const SequenceType = enum { list, @@ -114,7 +112,7 @@ pub const ValueType = enum { /// A processed value. pub const Value = union(ValueType) { /// Types that we can't handle or just had to give up on processing. - raw: PickleOp, + raw: pickle.Op, /// A reference. You might be able to look it up in the memo map /// unless there's something weird going on like recursive references. @@ -174,7 +172,7 @@ pub const Value = union(ValueType) { float64: f64, /// Some kind of weird number we can't handle. - raw_num: PickleOp, + raw_num: pickle.Op, /// A boolean value. boolval: bool, @@ -299,7 +297,12 @@ pub const Value = union(ValueType) { pub fn isPrimitive(self: Value) bool { return switch (self) { .int64, .bigint, .float64, .string, .bytes, .boolval, .none => true, - .seq => |seq| utils.allTrue(seq.values, Value.isPrimitive), + .seq => |seq| { + for (seq.values) |v| { + if (!v.isPrimitive()) return false; + } + return true; + }, else => false, }; } diff --git a/zml/aio/utils.zig b/zml/aio/utils.zig deleted file mode 100644 index 0bf1043..0000000 --- a/zml/aio/utils.zig +++ /dev/null @@ -1,7 +0,0 @@ -pub fn toVoidSlice(data: anytype) []void { - const info = @typeInfo(@TypeOf(data)); - if (info != .Pointer or info.Pointer.size != .Slice) { - @compileError("toVoidSlice expects a slice"); - } - return @as([*]void, @ptrCast(@alignCast(data.ptr)))[0..data.len]; -} diff --git a/zml/aio/yaml.zig b/zml/aio/yaml.zig index 8ceffe0..dc1e553 100644 --- a/zml/aio/yaml.zig +++ b/zml/aio/yaml.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const utils = @import("utils.zig"); const yaml = @import("zig-yaml"); const zml = @import("../zml.zig");