const asynk = @import("async"); const std = @import("std"); const zml = @import("../zml.zig"); const HostBuffer = @import("../hostbuffer.zig").HostBuffer; const toVoidSlice = @import("utils.zig").toVoidSlice; const eval = @import("torch/eval.zig"); const utils = @import("torch/utils.zig"); const value = @import("torch/value.zig"); const Decoder = @import("torch/parser.zig").Decoder; const PersId = value.PersId; const PickleMemo = eval.PickleMemo; const PickleStack = eval.PickleStack; const Sequence = value.Sequence; const Value = value.Value; const ValueType = value.ValueType; const StringBuilder = std.ArrayListUnmanaged(u8); const Allocator = std.mem.Allocator; const log = std.log.scoped(.zml_io); const TorchType = enum { float64, double, float32, float, float16, half, bfloat16, int64, long, int32, int, int16, short, int8, char, uint8, byte, }; fn dtypeFromStr(str: []const u8) !zml.DataType { const case = std.meta.stringToEnum(TorchType, str) orelse return error.UnknownTensorType; return switch (case) { .float64, .double => .f64, .float32, .float => .f32, .float16, .half => .f16, .bfloat16 => .bf16, .int64, .long => .i64, .int32, .int => .i32, .int16, .short => .i16, .int8, .char => .i8, .uint8, .byte => .u8, }; } /// Opens and loads a BufferStore from the torch file at the given path. pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { const file = asynk.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; }; errdefer file.close() catch unreachable; var res: zml.aio.BufferStore = .{ .arena = std.heap.ArenaAllocator.init(allocator), }; const arena = res.arena.allocator(); var tmp: PickleData = .{ .data = try Decoder.init(arena, file), .memo = undefined, .stack = undefined, }; tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true); res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{tmp.data.buffer_file}); try tmp.parseModel(arena, &res); return res; } pub const PickleData = struct { stack: PickleStack, memo: PickleMemo, data: Decoder, fn basicTypeCheck(v: Value, ns: []const u8, name: []const u8) bool { return switch (v) { .global => |object| switch (object.member) { .raw => |raw| { if (std.mem.eql(u8, ns, raw.global[0]) and std.mem.eql(u8, name, raw.global[1]) and object.args[0] == .seq) { return true; } else return false; }, else => false, }, else => false, }; } fn isTensor(v: Value) bool { if (basicTypeCheck(v, "torch._utils", "_rebuild_tensor_v2")) { const args = v.global.args[0].seq[1]; if (args.len >= 5 and args[0] == .pers_id and args[1] == .int and args[2] == .seq and args[2].seq[0] == .tuple and args[3] == .seq and args[3].seq[0] == .tuple) { return true; } else @panic("Unexpected value in call to torch._utils._rebuild_tensor_v2"); } return false; } fn dimsFromValues(values: []Value) [zml.Tensor.MAX_RANK]i64 { std.debug.assert(values.len <= zml.Tensor.MAX_RANK); var result: [zml.Tensor.MAX_RANK]i64 = undefined; for (values, result[0..values.len]) |val, *elem| { switch (val) { .int => |int| elem.* = int, else => @panic("Bad value for shape item"), } } return result; } pub fn parseModel(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore) !void { for (self.stack.stack) |item| { var prefix_buf: [1024]u8 = undefined; try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); } } fn tensorOffset(self: *PickleData, seekable_stream: anytype, sfile: []const u8) !u64 { if (self.data.file_map.get(sfile)) |entry| { const local_header = blk: { try seekable_stream.seekTo(entry.file_offset); break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little); }; if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) return error.ZipBadFileOffset; if (local_header.version_needed_to_extract != entry.version_needed_to_extract) return error.ZipMismatchVersionNeeded; if (local_header.last_modification_time != entry.last_modification_time) return error.ZipMismatchModTime; if (local_header.last_modification_date != entry.last_modification_date) return error.ZipMismatchModDate; if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags))) return error.ZipMismatchFlags; if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) return error.ZipMismatchCrc32; if (local_header.compressed_size != 0 and local_header.compressed_size != entry.compressed_size) return error.ZipMismatchCompLen; if (local_header.uncompressed_size != 0 and local_header.uncompressed_size != entry.uncompressed_size) return error.ZipMismatchUncompLen; if (local_header.filename_len != entry.filename_len) return error.ZipMismatchFilenameLen; return (try seekable_stream.context.getPos()) + @as(u64, local_header.filename_len) + @as(u64, local_header.extra_len); } std.log.err("Could not find file ending in `{s}` in archive", .{sfile}); return error.TensorNotFound; } fn parseTorchGlobal(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !bool { return switch (v) { .global => |object| { if (isTensor(v)) { const args = object.args[0].seq[1]; const pidval: *PersId, var offs: u64, const raw_shape: Sequence, const raw_strides: Sequence = .{ args[0].pers_id, @intCast(args[1].int), args[2].seq, args[3].seq }; const rank = raw_shape[1].len; const shape = dimsFromValues(raw_shape[1]); var strides = dimsFromValues(raw_strides[1]); const stype: []const u8, const sfile: []const u8, const sdev: []const u8 = switch (pidval.ref) { .seq => |seq| blk: { const sargs = seq[1]; if (seq[0] == .tuple and sargs.len >= 5 and sargs[0] == .string and std.mem.eql(u8, sargs[0].string, "storage") and sargs[1] == .raw and sargs[1].raw == .global and sargs[2] == .string and sargs[3] == .string) { const op = sargs[1].raw.global; const sfile = sargs[2].string; const sdev = sargs[3].string; const styp = op[1]; if (std.mem.eql(u8, "torch", op[0]) and std.mem.endsWith(u8, styp, "Storage")) { break :blk .{ std.ascii.lowerString(styp[0 .. styp.len - 7], styp[0 .. styp.len - 7]), sfile, sdev }; } else @panic("Unexpected storage type part of persistant ID"); } else @panic("Unexpected value for persistant ID"); }, else => @panic("Unexpected value for persistant ID"), }; _ = sdev; const data_type = try dtypeFromStr(stype); for (strides[0..rank]) |*s| s.* *= data_type.sizeOf(); var sfile_buf = std.ArrayList(u8).init(allocator); defer sfile_buf.deinit(); try sfile_buf.writer().print("{s}data/{s}", .{ self.data.zip_prefix, sfile }); // find offsets for tensor zip file const absolute_offset = blk: { if (self.data.tar_file) |t| { break :blk try self.tensorOffset(t.seekableStream(), sfile_buf.items); } else { break :blk try self.tensorOffset(self.data.buffer_file.file.seekableStream(), sfile_buf.items); } }; offs = offs * data_type.sizeOf(); const key = try allocator.dupe(u8, prefix.items); const entry = try store.buffers.getOrPut(allocator, key); if (entry.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); } const out_shape = zml.Shape.init(shape[0..rank], data_type); entry.value_ptr.* = HostBuffer.fromStridedSlice( out_shape, self.data.buffer_file.mappedSlice((if (self.data.tar_file) |t| t.start else 0) + absolute_offset + offs, out_shape.byteSize()), strides[0..rank], ); return true; } else if (basicTypeCheck(v, "torch", "Size")) { const size = object.args[0].seq[1][0].seq[1]; const key = try allocator.dupe(u8, prefix.items); const entry = try store._metadata.getOrPut(allocator, key); if (entry.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); } const d = try allocator.alloc(i64, size.len); for (d, 0..) |*di, i| di.* = size[i].int; entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; return true; } else if (basicTypeCheck(v, "fractions", "Fraction")) { const fraction_str = object.args[0].seq[1][0].string; if (std.mem.indexOfScalar(u8, fraction_str, '/')) |split_idx| { { var new_prefix = prefix; new_prefix.appendSliceAssumeCapacity(".numerator"); try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); } { var new_prefix = prefix; new_prefix.appendSliceAssumeCapacity(".denominator"); try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); } return true; } } return false; }, else => false, }; } pub fn parseValue(self: *PickleData, allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: Value) !void { switch (v) { .app, .object, .global => |object| { if (!(try self.parseTorchGlobal(allocator, store, prefix, v))) { try self.parseValue(allocator, store, prefix, object.member); for (object.args) |item| { // if possible, coerce to `kv_tuple` (only if key val doesn't match root of prefix) if (item == .seq and item.seq[0] == .tuple and item.seq[1].len == 2 and item.seq[1][0] == .string) { try self.parseValue(allocator, store, prefix, .{ .seq = .{ .kv_tuple, item.seq[1] } }); } else try self.parseValue(allocator, store, prefix, item); } } }, .build => |build| { // `build` contains info about python struct being constructed switch (build.member) { .object => |obj| switch (obj.member) { .raw => |raw| switch (raw) { .global => |global| { // in this case, we can capture the name of the python type // which can be used for codegen (e.g. `torch.nn.modules.conv.Conv2d`) var new_prefix = prefix; if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } new_prefix.appendSliceAssumeCapacity("_gen_type_helper"); const key = try allocator.dupe(u8, new_prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.err("Duplicate key: {s}", .{new_prefix.items}); allocator.free(key); } else { const val = try allocator.alloc(u8, global[0].len + 1 + global[1].len); @memcpy(val[0..global[0].len], global[0]); val[global[0].len] = '.'; @memcpy(val[global[0].len + 1 ..], global[1]); d.value_ptr.* = .{ .string = val }; } }, else => try self.parseValue(allocator, store, prefix, build.member), // parse normally }, else => try self.parseValue(allocator, store, prefix, build.member), // parse normally }, else => try self.parseValue(allocator, store, prefix, build.member), // parse normally } try self.parseValue(allocator, store, prefix, build.args); }, .pers_id => |pers_id| try self.parseValue(allocator, store, prefix, pers_id.ref), .seq => |*seq| switch (seq[0]) { .list, .tuple, .set, .frozen_set => { const elemCheck = struct { fn call(comptime T: ValueType) fn (v: Value) bool { return struct { fn call(val: Value) bool { return val == T; } }.call; } }.call; if (seq[1].len > 0 and switch (seq[1][0]) { inline .int, .bool, .float => |_, tag| utils.allTrue(seq[1][1..], elemCheck(tag)), else => false, }) { const out: []u8 = switch (seq[1][0]) { .int => blk: { const d = try allocator.alloc(i64, seq[1].len); for (seq[1], 0..) |item, i| { d[i] = item.int; } break :blk std.mem.sliceAsBytes(d); }, .float => blk: { const d = try allocator.alloc(f64, seq[1].len); for (seq[1], 0..) |item, i| { d[i] = item.float; } break :blk std.mem.sliceAsBytes(d); }, else => blk: { const d = try allocator.alloc(bool, seq[1].len); for (seq[1], 0..) |item, i| { d[i] = item.bool; } break :blk std.mem.sliceAsBytes(d); }, }; const key = try allocator.dupe(u8, prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); allocator.free(out); } else d.value_ptr.* = @unionInit(zml.aio.Value, "array", .{ .item_type = switch (seq[1][0]) { .int => .int64, .float => .float64, .string => .string, else => .boolval, }, .data = out }); } else { for (seq[1], 0..) |item, i| { var new_prefix = prefix; if (v.isPrimitive()) { if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); } try self.parseValue(allocator, store, new_prefix, item); } } }, .dict => { for (seq[1]) |item| { try self.parseValue(allocator, store, prefix, item); } }, .kv_tuple => { const key = seq[1][0]; const val = seq[1][1]; switch (key) { .string => |s| { if (std.mem.eql(u8, s, "_modules") or std.mem.eql(u8, s, "_parameters") or std.mem.eql(u8, s, "_buffers")) { try self.parseValue(allocator, store, prefix, val); } else { var new_prefix = prefix; if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } new_prefix.appendSliceAssumeCapacity(s); try self.parseValue(allocator, store, new_prefix, val); } }, .int => |int| { var new_prefix = prefix; if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); try self.parseValue(allocator, store, new_prefix, val); }, inline else => |_, tag| std.debug.panic("Unexpected key type: {s}", .{@tagName(tag)}), } }, }, .bytes => |val| { const key = try allocator.dupe(u8, prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); } else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } }; }, inline .float, .int, .bool, .bigint, .string => |val, tag| { const key = try allocator.dupe(u8, prefix.items); const d = try store._metadata.getOrPut(allocator, key); if (d.found_existing) { log.warn("Duplicate key: {s}", .{prefix.items}); allocator.free(key); } else d.value_ptr.* = @unionInit(zml.aio.Value, switch (tag) { .int => "int64", .float => "float64", .bool => "boolval", else => @tagName(tag), }, val); }, else => {}, } } }; test { std.testing.refAllDecls(@This()); }