aio: correct refAllDecls handling for yaml and nemo modules
This commit is contained in:
parent
7dcd8b516c
commit
897786e440
@ -21,14 +21,12 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
|||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(gguf);
|
std.testing.refAllDecls(gguf);
|
||||||
// TODO(@cryptodeal)
|
std.testing.refAllDecls(nemo);
|
||||||
// std.testing.refAllDecls(nemo);
|
|
||||||
std.testing.refAllDecls(safetensors);
|
std.testing.refAllDecls(safetensors);
|
||||||
std.testing.refAllDecls(sentencepiece);
|
std.testing.refAllDecls(sentencepiece);
|
||||||
std.testing.refAllDecls(tinyllama);
|
std.testing.refAllDecls(tinyllama);
|
||||||
std.testing.refAllDecls(torch);
|
std.testing.refAllDecls(torch);
|
||||||
// TODO(@cryptodeal)
|
std.testing.refAllDecls(yaml);
|
||||||
// std.testing.refAllDecls(yaml);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Detects the format of the model file (base on filename) and open it.
|
/// Detects the format of the model file (base on filename) and open it.
|
||||||
|
|||||||
@ -34,7 +34,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
||||||
|
|
||||||
var prefix_buf: [1024]u8 = undefined;
|
var prefix_buf: [1024]u8 = undefined;
|
||||||
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0].map);
|
try zml.aio.yaml.parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
|
||||||
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
|
} else if (std.mem.endsWith(u8, file.name, ".ckpt") or std.mem.endsWith(u8, file.name, ".pt")) {
|
||||||
const start = try mapped_file.file.getPos();
|
const start = try mapped_file.file.getPos();
|
||||||
var tmp: zml.aio.torch.PickleData = .{
|
var tmp: zml.aio.torch.PickleData = .{
|
||||||
|
|||||||
@ -148,7 +148,7 @@ pub const PickleData = struct {
|
|||||||
if (local_header.last_modification_date != entry.last_modification_date)
|
if (local_header.last_modification_date != entry.last_modification_date)
|
||||||
return error.ZipMismatchModDate;
|
return error.ZipMismatchModDate;
|
||||||
|
|
||||||
if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags)))
|
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
|
||||||
return error.ZipMismatchFlags;
|
return error.ZipMismatchFlags;
|
||||||
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
||||||
return error.ZipMismatchCrc32;
|
return error.ZipMismatchCrc32;
|
||||||
|
|||||||
@ -15,12 +15,42 @@ test {
|
|||||||
|
|
||||||
pub const Decoder = struct {
|
pub const Decoder = struct {
|
||||||
buffer_file: zml.aio.MemoryMappedFile,
|
buffer_file: zml.aio.MemoryMappedFile,
|
||||||
file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{},
|
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
|
||||||
tar_file: ?TarStream = null,
|
tar_file: ?TarStream = null,
|
||||||
ops: []PickleOp,
|
ops: []PickleOp,
|
||||||
is_zip_file: bool,
|
is_zip_file: bool,
|
||||||
zip_prefix: []const u8 = &[_]u8{},
|
zip_prefix: []const u8 = &[_]u8{},
|
||||||
|
|
||||||
|
pub const FileEntry = struct {
|
||||||
|
version_needed_to_extract: u16,
|
||||||
|
flags: u16,
|
||||||
|
compression_method: std.zip.CompressionMethod,
|
||||||
|
last_modification_time: u16,
|
||||||
|
last_modification_date: u16,
|
||||||
|
header_zip_offset: u64,
|
||||||
|
crc32: u32,
|
||||||
|
filename_len: u32,
|
||||||
|
compressed_size: u64,
|
||||||
|
uncompressed_size: u64,
|
||||||
|
file_offset: u64,
|
||||||
|
|
||||||
|
pub fn init(entry: anytype) FileEntry {
|
||||||
|
return .{
|
||||||
|
.version_needed_to_extract = entry.version_needed_to_extract,
|
||||||
|
.flags = @as(u16, @bitCast(entry.flags)),
|
||||||
|
.compression_method = entry.compression_method,
|
||||||
|
.last_modification_time = entry.last_modification_time,
|
||||||
|
.last_modification_date = entry.last_modification_date,
|
||||||
|
.header_zip_offset = entry.header_zip_offset,
|
||||||
|
.crc32 = entry.crc32,
|
||||||
|
.filename_len = entry.filename_len,
|
||||||
|
.compressed_size = entry.compressed_size,
|
||||||
|
.uncompressed_size = entry.uncompressed_size,
|
||||||
|
.file_offset = entry.file_offset,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const magic = "PK\x03\x04";
|
const magic = "PK\x03\x04";
|
||||||
|
|
||||||
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder {
|
pub fn fromTarFile(allocator: Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !Decoder {
|
||||||
@ -64,12 +94,12 @@ pub const Decoder = struct {
|
|||||||
self.* = undefined;
|
self.* = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: asynk.File.SeekableStream) ![]PickleOp {
|
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
|
||||||
// TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor`
|
// TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor`
|
||||||
// TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?)
|
// TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?)
|
||||||
// not sure if that'd actually be beneficial in any way
|
// not sure if that'd actually be beneficial in any way
|
||||||
|
|
||||||
var iter = try std.zip.Iterator(asynk.File.SeekableStream).init(seekable_stream);
|
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
||||||
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||||
while (try iter.next()) |entry| {
|
while (try iter.next()) |entry| {
|
||||||
const filename = filename_buf[0..entry.filename_len];
|
const filename = filename_buf[0..entry.filename_len];
|
||||||
@ -78,7 +108,7 @@ pub const Decoder = struct {
|
|||||||
if (len != filename.len) return error.ZipBadFileOffset;
|
if (len != filename.len) return error.ZipBadFileOffset;
|
||||||
if (isBadFilename(filename)) return error.ZipBadFilename;
|
if (isBadFilename(filename)) return error.ZipBadFilename;
|
||||||
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
|
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
|
||||||
try self.file_map.put(allocator, try allocator.dupe(u8, filename), entry);
|
try self.file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
|
||||||
}
|
}
|
||||||
|
|
||||||
var file_iter = self.file_map.iterator();
|
var file_iter = self.file_map.iterator();
|
||||||
@ -101,7 +131,7 @@ pub const Decoder = struct {
|
|||||||
if (local_header.last_modification_date != entry.last_modification_date)
|
if (local_header.last_modification_date != entry.last_modification_date)
|
||||||
return error.ZipMismatchModDate;
|
return error.ZipMismatchModDate;
|
||||||
|
|
||||||
if (@as(u16, @bitCast(local_header.flags)) != @as(u16, @bitCast(entry.flags)))
|
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
|
||||||
return error.ZipMismatchFlags;
|
return error.ZipMismatchFlags;
|
||||||
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
||||||
return error.ZipMismatchCrc32;
|
return error.ZipMismatchCrc32;
|
||||||
|
|||||||
@ -19,12 +19,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
|||||||
const yaml_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
|
const yaml_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
|
||||||
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
const parsed = try yaml.Yaml.load(arena, yaml_data);
|
||||||
|
|
||||||
const map = parsed.docs.items[0].map;
|
var prefix_buf: [1024]u8 = undefined;
|
||||||
var map_iter = map.iterator();
|
try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), parsed.docs.items[0]);
|
||||||
while (map_iter.next()) |entry| {
|
|
||||||
var prefix_buf: [1024]u8 = undefined;
|
|
||||||
try parseMetadata(arena, &res, StringBuilder.initBuffer(&prefix_buf), entry.key, entry.value);
|
|
||||||
}
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user