aio: correct refAllDecls handling for yaml and nemo modules

This commit is contained in:
Tarry Singh 2023-01-31 11:58:58 +00:00
parent 7dcd8b516c
commit 897786e440
5 changed files with 41 additions and 17 deletions

View File

@ -21,14 +21,12 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(gguf);
// TODO(@cryptodeal)
// std.testing.refAllDecls(nemo);
std.testing.refAllDecls(nemo);
std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(sentencepiece);
std.testing.refAllDecls(tinyllama);
std.testing.refAllDecls(torch);
// TODO(@cryptodeal)
// std.testing.refAllDecls(yaml);
std.testing.refAllDecls(yaml);
}
/// Detects the format of the model file (base on filename) and open it.

View File

@ -34,7 +34,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const parsed = try yaml.Yaml.load(arena, yaml_data);
var prefix_buf: [1024]u8 = undefined;
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0].map);
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
const start = try mapped_file.file.getPos();
var tmp: zml.aio.torch.PickleData = .{

View File

@ -148,7 +148,7 @@ pub const PickleData = struct {
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags)))
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;

View File

@ -15,12 +15,42 @@ test {
pub const Decoder = struct {
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{},
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
ops: []PickleOp,
is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{},
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder {
@ -64,12 +94,12 @@ pub const Decoder = struct {
self.* = undefined;
}
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: asynk.File.SeekableStream) ![]PickleOp {
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
// TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor`
// TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?)
// not sure if that'd actually be beneficial in any way
var iter = try std.zip.Iterator(asynk.File.SeekableStream).init(seekable_stream);
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
@ -78,7 +108,7 @@ pub const Decoder = struct {
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try self.file_map.put(allocator, try allocator.dupe(u8, filename), entry);
try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
var file_iter = self.file_map.iterator();
@ -101,7 +131,7 @@ pub const Decoder = struct {
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags)))
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;

View File

@ -19,12 +19,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
const yaml_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
const parsed = try yaml.Yaml.load(arena, yaml_data);
const map = parsed.docs.items[0].map;
var map_iter = map.iterator();
while (map_iter.next()) |entry| {
var prefix_buf: [1024]u8 = undefined;
try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), entry.key, entry.value);
}
var prefix_buf: [1024]u8 = undefined;
try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
return res;
}