Radix/zml/aio/torch/parser.zig

238 lines
9.3 KiB
Zig
Raw Normal View History

const asynk = @import("async");
const std = @import("std");
const testing = std.testing;
const Allocator = std.mem.Allocator;
const zml = @import("../../zml.zig");
const pickle = @import("pickle.zig");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(Parser);
}
pub const Parser = struct {
// TODO: move the file logic to torch.PytorchFile
// the Pickle parser shouldn't have to deal with the zip archive stuff used by Pytorch
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
ops: []const pickle.Op,
is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{},
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Parser {
const tar_stream = try TarStream.init(file);
const file_magic = try tar_stream.reader().readBytesNoEof(magic.len);
try tar_stream.seekTo(0);
var self: Parser = .{
.buffer_file = mapped,
.tar_file = tar_stream,
.ops = undefined,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
};
if (!self.is_zip_file) {
const reader = tar_stream.reader();
self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream());
}
return self;
}
pub fn init(allocator: Allocator, file: asynk.File) !Parser {
const file_magic = try file.reader().readBytesNoEof(magic.len);
try file.seekTo(0);
var self: Parser = .{
.buffer_file = try zml.aio.MemoryMappedFile.init(file),
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.ops = undefined,
};
if (!self.is_zip_file) {
const reader = self.buffer_file.file.reader();
self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream());
}
return self;
}
pub fn deinit(self: *Parser) void {
self.buffer_file.deinit();
self.* = undefined;
}
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op {
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
const len = try seekable_stream.context.reader().readAll(filename);
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
var file_iter = self.file_map.iterator();
while (file_iter.next()) |e| {
const entry = e.value_ptr.*;
const filename = e.key_ptr.*;
if (std.mem.indexOf(u8, filename, "data.pkl")) |idx| {
self.zip_prefix = filename[0..idx];
const local_data_header_offset: u64 = local_data_header_offset: {
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
break :local_data_header_offset @as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
};
const local_data_file_offset: u64 =
@as(u64, entry.file_offset) +
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
local_data_header_offset;
try seekable_stream.seekTo(local_data_file_offset);
switch (entry.compression_method) {
.store => {
return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
},
.deflate => {
// TODO(cryptodeal): handle decompress
@panic("TODO support use of `deflate`");
},
else => @panic("TODO support other modes of compression"),
}
}
}
std.log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
};
const TarStream = struct {
pub const SeekableStream = std.io.SeekableStream(
TarStream,
asynk.File.SeekError,
asynk.File.GetSeekPosError,
TarStream.seekTo,
TarStream.seekBy,
TarStream.getPos,
TarStream.getEndPos,
);
file: std.tar.Iterator(asynk.File.Reader).File,
start: usize,
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
return .{
.file = file,
.start = try file.parent_reader.context.getPos(),
};
}
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
return file.file.reader();
}
pub fn seekTo(self: TarStream, offset: u64) !void {
return self.file.parent_reader.context.seekTo(self.start + offset);
}
pub fn seekBy(self: TarStream, offset: i64) !void {
return self.file.parent_reader.context.seekBy(offset);
}
pub fn getPos(self: TarStream) !u64 {
return try self.file.parent_reader.context.getPos() - self.start;
}
pub fn getEndPos(self: TarStream) !u64 {
return self.file.size;
}
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
return .{ .context = self };
}
};
test "Read pickle (zipped)" {
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
var data = try Parser.init(allocator, file);
defer data.deinit();
}
fn isBadFilename(filename: []const u8) bool {
if (filename.len == 0 or filename[0] == '/')
return true;
var it = std.mem.splitScalar(u8, filename, '/');
while (it.next()) |part| {
if (std.mem.eql(u8, part, ".."))
return true;
}
return false;
}