From 897786e4408afbc26c534612aa8c3a2d4d846bd2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 31 Jan 2023 11:58:58 +0000 Subject: [PATCH] aio: correct refAllDecls handling for yaml and nemo modules --- zml/aio.zig | 6 ++---- zml/aio/nemo.zig | 2 +- zml/aio/torch.zig | 2 +- zml/aio/torch/parser.zig | 40 +++++++++++++++++++++++++++++++++++----- zml/aio/yaml.zig | 8 ++------ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/zml/aio.zig b/zml/aio.zig index ce1c47b..1d185f9 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -21,14 +21,12 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer; test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(gguf); - // TODO(@cryptodeal) - // std.testing.refAllDecls(nemo); + std.testing.refAllDecls(nemo); std.testing.refAllDecls(safetensors); std.testing.refAllDecls(sentencepiece); std.testing.refAllDecls(tinyllama); std.testing.refAllDecls(torch); - // TODO(@cryptodeal) - // std.testing.refAllDecls(yaml); + std.testing.refAllDecls(yaml); } /// Detects the format of the model file (base on filename) and open it. diff --git a/zml/aio/nemo.zig b/zml/aio/nemo.zig index 7fb5f30..d52854e 100644 --- a/zml/aio/nemo.zig +++ b/zml/aio/nemo.zig @@ -34,7 +34,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore const parsed = try yaml.Yaml.load(arena, yaml_data); var prefix_buf: [1024]u8 = undefined; - try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0].map); + try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]); } else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) { const start = try mapped_file.file.getPos(); var tmp: zml.aio.torch.PickleData = .{ diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 20434d4..b5e2404 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -148,7 +148,7 @@ pub const PickleData = struct { if (local_header.last_modification_date != entry.last_modification_date) return error.ZipMismatchModDate; - if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags))) + if (@as(u16, @bitCast(local_header.flags)) != entry.flags) return error.ZipMismatchFlags; if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) return error.ZipMismatchCrc32; diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index e7a574e..9d3feec 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -15,12 +15,42 @@ test { pub const Decoder = struct { buffer_file: zml.aio.MemoryMappedFile, - file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{}, + file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, tar_file: ?TarStream = null, ops: []PickleOp, is_zip_file: bool, zip_prefix: []const u8 = &[_]u8{}, + pub const FileEntry = struct { + version_needed_to_extract: u16, + flags: u16, + compression_method: std.zip.CompressionMethod, + last_modification_time: u16, + last_modification_date: u16, + header_zip_offset: u64, + crc32: u32, + filename_len: u32, + compressed_size: u64, + uncompressed_size: u64, + file_offset: u64, + + pub fn init(entry: anytype) FileEntry { + return .{ + .version_needed_to_extract = entry.version_needed_to_extract, + .flags = @as(u16, @bitCast(entry.flags)), + .compression_method = entry.compression_method, + .last_modification_time = entry.last_modification_time, + .last_modification_date = entry.last_modification_date, + .header_zip_offset = entry.header_zip_offset, + .crc32 = entry.crc32, + .filename_len = entry.filename_len, + .compressed_size = entry.compressed_size, + .uncompressed_size = entry.uncompressed_size, + .file_offset = entry.file_offset, + }; + } + }; + const magic = "PK\x03\x04"; pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder { @@ -64,12 +94,12 @@ pub const Decoder = struct { self.* = undefined; } - fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: asynk.File.SeekableStream) ![]PickleOp { + fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp { // TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor` // TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?) // not sure if that'd actually be beneficial in any way - var iter = try std.zip.Iterator(asynk.File.SeekableStream).init(seekable_stream); + var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); var filename_buf: [std.fs.max_path_bytes]u8 = undefined; while (try iter.next()) |entry| { const filename = filename_buf[0..entry.filename_len]; @@ -78,7 +108,7 @@ pub const Decoder = struct { if (len != filename.len) return error.ZipBadFileOffset; if (isBadFilename(filename)) return error.ZipBadFilename; std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators - try self.file_map.put(allocator, try allocator.dupe(u8, filename), entry); + try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry)); } var file_iter = self.file_map.iterator(); @@ -101,7 +131,7 @@ pub const Decoder = struct { if (local_header.last_modification_date != entry.last_modification_date) return error.ZipMismatchModDate; - if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags))) + if (@as(u16, @bitCast(local_header.flags)) != entry.flags) return error.ZipMismatchFlags; if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) return error.ZipMismatchCrc32; diff --git a/zml/aio/yaml.zig b/zml/aio/yaml.zig index 8200144..8ceffe0 100644 --- a/zml/aio/yaml.zig +++ b/zml/aio/yaml.zig @@ -19,12 +19,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { const yaml_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size()); const parsed = try yaml.Yaml.load(arena, yaml_data); - const map = parsed.docs.items[0].map; - var map_iter = map.iterator(); - while (map_iter.next()) |entry| { - var prefix_buf: [1024]u8 = undefined; - try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), entry.key, entry.value); - } + var prefix_buf: [1024]u8 = undefined; + try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]); return res; }