bump runtimes/* code to Zig 0.15.1, restore PyTorch loader using std.fs.File, update CI zig fmt, remove stdx.io, note remaining issues with Neuron and CUDA debug builds
This commit is contained in:
parent
0ed7f5c907
commit
9e3cd6d616
@ -605,7 +605,7 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
|||||||
|
|
||||||
pub fn items(self: Attr) []const dt.ZigType() {
|
pub fn items(self: Attr) []const dt.ZigType() {
|
||||||
const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable;
|
const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable;
|
||||||
const ptr: [*]const dt.ZigType() = @alignCast(@ptrCast(raw_bytes));
|
const ptr: [*]const dt.ZigType() = @ptrCast(@alignCast(raw_bytes));
|
||||||
// Note the mlir API returns us the number of elements, not the number of bytes,
|
// Note the mlir API returns us the number of elements, not the number of bytes,
|
||||||
// that's why we track the element type at comptime to allow items to work.
|
// that's why we track the element type at comptime to allow items to work.
|
||||||
return ptr[0..self.len()];
|
return ptr[0..self.len()];
|
||||||
@ -1743,7 +1743,7 @@ pub const helpers = struct {
|
|||||||
writer: *std.Io.Writer,
|
writer: *std.Io.Writer,
|
||||||
err: ?std.Io.Writer.Error = null,
|
err: ?std.Io.Writer.Error = null,
|
||||||
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
||||||
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
|
var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
|
||||||
if (ctx.err) |_| return;
|
if (ctx.err) |_| return;
|
||||||
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
||||||
ctx.err = err;
|
ctx.err = err;
|
||||||
|
|||||||
@ -359,7 +359,7 @@ pub const Attrs = extern struct {
|
|||||||
value: *const anyopaque,
|
value: *const anyopaque,
|
||||||
|
|
||||||
pub fn get(self: Scalar, T: type) T {
|
pub fn get(self: Scalar, T: type) T {
|
||||||
const ptr: *const T = @alignCast(@ptrCast(self.value));
|
const ptr: *const T = @ptrCast(@alignCast(self.value));
|
||||||
return ptr.*;
|
return ptr.*;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -370,13 +370,13 @@ pub const Attrs = extern struct {
|
|||||||
data: [*]const u8,
|
data: [*]const u8,
|
||||||
|
|
||||||
pub fn slice(self: Array, T: type) []const T {
|
pub fn slice(self: Array, T: type) []const T {
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
|
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
|
||||||
return ptr[0..self.len];
|
return ptr[0..self.len];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn slice(self: Array, T: type) []const T {
|
pub fn slice(self: Array, T: type) []const T {
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
|
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
|
||||||
return ptr[0..self.len];
|
return ptr[0..self.len];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -58,7 +58,7 @@ pub const ApiError = error{
|
|||||||
fn InnerMixin(comptime innerT: type) type {
|
fn InnerMixin(comptime innerT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
fn inner(self: anytype) *innerT {
|
fn inner(self: anytype) *innerT {
|
||||||
return @ptrCast(@constCast(@alignCast(self)));
|
return @ptrCast(@alignCast(@constCast(self)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -125,10 +125,10 @@ pub const Api = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
|
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
|
||||||
var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start));
|
var cur: [*c]const c.PJRT_Extension_Base = @ptrCast(@alignCast(self.inner.extension_start));
|
||||||
while (cur != null) : (cur = cur.*.next) {
|
while (cur != null) : (cur = cur.*.next) {
|
||||||
if (cur.*.type == ext_id) {
|
if (cur.*.type == ext_id) {
|
||||||
return @alignCast(@ptrCast(cur));
|
return @ptrCast(@alignCast(cur));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,7 +432,7 @@ pub const Client = opaque {
|
|||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
}) catch unreachable;
|
}) catch unreachable;
|
||||||
if (ret.addressable_memories) |memories| {
|
if (ret.addressable_memories) |memories| {
|
||||||
return @constCast(@ptrCast(memories[0..ret.num_addressable_memories]));
|
return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
|
||||||
}
|
}
|
||||||
return &.{};
|
return &.{};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,7 +28,7 @@ fn hasCudaPathInLDPath() bool {
|
|||||||
|
|
||||||
fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
|
fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
|
||||||
const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
|
const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
|
||||||
const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox });
|
const new_xla_flagsZ = try std.fmt.allocPrintSentinel(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }, 0);
|
||||||
|
|
||||||
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,7 +38,7 @@ var module_def: c.PyModuleDef = .{
|
|||||||
.{},
|
.{},
|
||||||
}),
|
}),
|
||||||
.m_slots = @constCast(&[_]c.PyModuleDef_Slot{
|
.m_slots = @constCast(&[_]c.PyModuleDef_Slot{
|
||||||
.{ .slot = c.Py_mod_exec, .value = @constCast(@ptrCast(&module_exec)) },
|
.{ .slot = c.Py_mod_exec, .value = @ptrCast(@constCast(&module_exec)) },
|
||||||
.{},
|
.{},
|
||||||
}),
|
}),
|
||||||
.m_traverse = null,
|
.m_traverse = null,
|
||||||
|
|||||||
@ -25,10 +25,10 @@ fn isRunningOnEC2() !bool {
|
|||||||
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only });
|
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only });
|
||||||
defer f.close() catch {};
|
defer f.close() catch {};
|
||||||
|
|
||||||
var buf: [AmazonEC2.len]u8 = undefined;
|
var content: [AmazonEC2.len]u8 = undefined;
|
||||||
_ = try f.reader().readAll(&buf);
|
const n_read = try f.pread(&content, 0);
|
||||||
|
|
||||||
return std.mem.eql(u8, &buf, AmazonEC2);
|
return std.mem.eql(u8, content[0..n_read], AmazonEC2);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load() !*const pjrt.Api {
|
pub fn load() !*const pjrt.Api {
|
||||||
@ -45,7 +45,7 @@ pub fn load() !*const pjrt.Api {
|
|||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
|
|
||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
|
|||||||
@ -37,7 +37,7 @@ pub fn load() !*const pjrt.Api {
|
|||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
|
|
||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const pjrt = @import("pjrt");
|
|
||||||
const c = @import("c");
|
|
||||||
const stdx = @import("stdx");
|
|
||||||
const bazel_builtin = @import("bazel_builtin");
|
const bazel_builtin = @import("bazel_builtin");
|
||||||
|
const c = @import("c");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const log = std.log.scoped(.@"zml/runtime/tpu");
|
const log = std.log.scoped(.@"zml/runtime/tpu");
|
||||||
|
|
||||||
@ -25,10 +25,10 @@ fn isOnGCP() !bool {
|
|||||||
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
|
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
|
||||||
defer f.close() catch {};
|
defer f.close() catch {};
|
||||||
|
|
||||||
var buf = [_]u8{0} ** GoogleComputeEngine.len;
|
var content: [GoogleComputeEngine.len]u8 = undefined;
|
||||||
_ = try f.reader().readAll(&buf);
|
const n_read = try f.pread(&content, 0);
|
||||||
|
|
||||||
return std.mem.eql(u8, &buf, GoogleComputeEngine);
|
return std.mem.eql(u8, content[0..n_read], GoogleComputeEngine);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load() !*const pjrt.Api {
|
pub fn load() !*const pjrt.Api {
|
||||||
@ -42,7 +42,7 @@ pub fn load() !*const pjrt.Api {
|
|||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
|
|
||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
|
|||||||
@ -8,7 +8,6 @@ zig_library(
|
|||||||
"flags.zig",
|
"flags.zig",
|
||||||
"fmt.zig",
|
"fmt.zig",
|
||||||
"fs.zig",
|
"fs.zig",
|
||||||
"io.zig",
|
|
||||||
"json.zig",
|
"json.zig",
|
||||||
"math.zig",
|
"math.zig",
|
||||||
"meta.zig",
|
"meta.zig",
|
||||||
|
|||||||
@ -1,4 +0,0 @@
|
|||||||
const std = @import("std");
|
|
||||||
|
|
||||||
pub const BufferedAnyWriter = std.io.BufferedWriter(4096, std.io.AnyWriter);
|
|
||||||
pub const BufferedAnyReader = std.io.BufferedReader(4096, std.io.AnyReader);
|
|
||||||
@ -4,7 +4,6 @@ pub const debug = @import("debug.zig");
|
|||||||
pub const flags = @import("flags.zig");
|
pub const flags = @import("flags.zig");
|
||||||
pub const fmt = @import("fmt.zig");
|
pub const fmt = @import("fmt.zig");
|
||||||
pub const fs = @import("fs.zig");
|
pub const fs = @import("fs.zig");
|
||||||
pub const io = @import("io.zig");
|
|
||||||
pub const json = @import("json.zig");
|
pub const json = @import("json.zig");
|
||||||
pub const math = @import("math.zig");
|
pub const math = @import("math.zig");
|
||||||
pub const meta = @import("meta.zig");
|
pub const meta = @import("meta.zig");
|
||||||
|
|||||||
@ -95,7 +95,7 @@ pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 {
|
|||||||
|
|
||||||
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
|
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
|
||||||
const obj = try new(UpbType, arena);
|
const obj = try new(UpbType, arena);
|
||||||
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
|
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @ptrCast(@alignCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
|
||||||
c.kUpb_DecodeStatus_Ok => obj,
|
c.kUpb_DecodeStatus_Ok => obj,
|
||||||
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
|
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
|
||||||
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
|
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
|
||||||
|
|||||||
@ -24,6 +24,11 @@ zig_library(
|
|||||||
"aio/json.zig",
|
"aio/json.zig",
|
||||||
"aio/safetensors.zig",
|
"aio/safetensors.zig",
|
||||||
"aio/tinyllama.zig",
|
"aio/tinyllama.zig",
|
||||||
|
"aio/torch.zig",
|
||||||
|
"aio/torch/eval.zig",
|
||||||
|
"aio/torch/file.zig",
|
||||||
|
"aio/torch/pickle.zig",
|
||||||
|
"aio/torch/py.zig",
|
||||||
"buffer.zig",
|
"buffer.zig",
|
||||||
"context.zig",
|
"context.zig",
|
||||||
"dtype.zig",
|
"dtype.zig",
|
||||||
@ -72,6 +77,10 @@ zig_library(
|
|||||||
|
|
||||||
zig_test(
|
zig_test(
|
||||||
name = "test",
|
name = "test",
|
||||||
|
data = [
|
||||||
|
"aio/torch/simple.pt",
|
||||||
|
"aio/torch/simple_test_4.pickle",
|
||||||
|
],
|
||||||
test_runner = ":test_runner",
|
test_runner = ":test_runner",
|
||||||
deps = [":zml"],
|
deps = [":zml"],
|
||||||
)
|
)
|
||||||
|
|||||||
12
zml/aio.zig
12
zml/aio.zig
@ -5,16 +5,16 @@ const c = @import("c");
|
|||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub const safetensors = @import("aio/safetensors.zig");
|
pub const safetensors = @import("aio/safetensors.zig");
|
||||||
pub const tinyllama = @import("aio/tinyllama.zig");
|
pub const torch = @import("aio/torch.zig");
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
const posix = @import("posix.zig");
|
const posix = @import("posix.zig");
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
pub const log = std.log.scoped(.@"zml/aio");
|
pub const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(safetensors);
|
std.testing.refAllDecls(safetensors);
|
||||||
|
std.testing.refAllDecls(torch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO error set for weight loading
|
// TODO error set for weight loading
|
||||||
@ -25,6 +25,12 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
|
|||||||
try safetensors.open(allocator, model_path)
|
try safetensors.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
|
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
|
||||||
try safetensors.open(allocator, model_path)
|
try safetensors.open(allocator, model_path)
|
||||||
|
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
||||||
|
try torch.open(allocator, model_path)
|
||||||
|
// else if (std.mem.endsWith(u8, model_path, ".gguf"))
|
||||||
|
// try gguf.open(allocator, model_path)
|
||||||
|
// else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
|
||||||
|
// try tinyllama.open(allocator, model_path)
|
||||||
else {
|
else {
|
||||||
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
||||||
};
|
};
|
||||||
@ -384,7 +390,7 @@ fn _populateStruct(
|
|||||||
partial_struct = partial_struct or field_found;
|
partial_struct = partial_struct or field_found;
|
||||||
if (!field_found) {
|
if (!field_found) {
|
||||||
if (field.default_value_ptr) |v| {
|
if (field.default_value_ptr) |v| {
|
||||||
@field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*;
|
@field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
|
||||||
} else {
|
} else {
|
||||||
if (partial_struct) {
|
if (partial_struct) {
|
||||||
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });
|
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });
|
||||||
|
|||||||
@ -1,83 +0,0 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const core = @import("gguf/core.zig");
|
|
||||||
const std = @import("std");
|
|
||||||
const zml = @import("../zml.zig");
|
|
||||||
|
|
||||||
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
|
|
||||||
|
|
||||||
const Allocator = std.mem.Allocator;
|
|
||||||
const assert = std.debug.assert;
|
|
||||||
|
|
||||||
const log = std.log.scoped(.@"zml/io");
|
|
||||||
|
|
||||||
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
|
||||||
var file = try core.GgufFile.open(path);
|
|
||||||
errdefer file.close();
|
|
||||||
|
|
||||||
var res: zml.aio.BufferStore = .{
|
|
||||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
|
||||||
};
|
|
||||||
errdefer res.arena.deinit();
|
|
||||||
const arena = res.arena.allocator();
|
|
||||||
|
|
||||||
res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{file.file});
|
|
||||||
|
|
||||||
// metadata must be read in order to read tensors
|
|
||||||
try loadMetadata(arena, &res, &file);
|
|
||||||
try loadBuffers(arena, &res, &file);
|
|
||||||
if (res.buffers.count() != file.header.tensor_count) {
|
|
||||||
log.warn("Expected to find {d} tensors in {s}, only found {d}", .{ file.header.tensor_count, path, res.buffers.count() });
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
|
|
||||||
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));
|
|
||||||
|
|
||||||
while (file.readMetadata(allocator)) |entry| {
|
|
||||||
log.info("Loading MetaData: {s}", .{entry.name});
|
|
||||||
const res = store._metadata.getOrPutAssumeCapacity(entry.name);
|
|
||||||
if (res.found_existing) {
|
|
||||||
// This file seems invalid. Since most metadatas aren't required, continue ahead.
|
|
||||||
log.warn("Found duplicated metadata key: {s}", .{entry.name});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
res.value_ptr.* = switch (entry.val) {
|
|
||||||
.array => |arr| switch (arr.child) {
|
|
||||||
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
|
|
||||||
const T = @FieldType(core.GgufValue, @tagName(tag));
|
|
||||||
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
|
|
||||||
},
|
|
||||||
else => blk: {
|
|
||||||
log.warn("ignoring array metadata", .{});
|
|
||||||
break :blk .null;
|
|
||||||
},
|
|
||||||
},
|
|
||||||
inline else => |v| zml.aio.Metadata.wrap(v),
|
|
||||||
};
|
|
||||||
} else |err| switch (err) {
|
|
||||||
error.EndOfMetadata => {},
|
|
||||||
else => return err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn loadBuffers(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
|
|
||||||
try store.buffers.ensureTotalCapacity(allocator, @intCast(file.header.tensor_count));
|
|
||||||
while (file.readTensorInfo(allocator)) |info| {
|
|
||||||
const res = store.buffers.getOrPutAssumeCapacity(info.name);
|
|
||||||
if (res.found_existing) {
|
|
||||||
// This file seems invalid. Try to continue anyway.
|
|
||||||
log.warn("Found duplicated tensor: {s}", .{info.name});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: handle quantized types
|
|
||||||
const dtype: zml.DataType = info.t.toDtype() orelse return error.UnsupportedGgufType;
|
|
||||||
const buffer = HostBuffer.fromBytes(zml.Shape.init(info.shape(), dtype), file.file.mappedSlice(info.start, info.byte_len));
|
|
||||||
res.value_ptr.* = buffer;
|
|
||||||
// store the info index.
|
|
||||||
} else |err| switch (err) {
|
|
||||||
error.EndOfMetadata => {},
|
|
||||||
else => return err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,505 +0,0 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const std = @import("std");
|
|
||||||
const zml = @import("../../zml.zig");
|
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
|
||||||
const log = std.log.scoped(.@"zml/io");
|
|
||||||
|
|
||||||
pub const GgufErrors = error{
|
|
||||||
ValueTypeMismatch,
|
|
||||||
InvalidGguf,
|
|
||||||
UnsupportedGgufType,
|
|
||||||
EndOfMetadata,
|
|
||||||
OutOfMemory,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Enums and structures
|
|
||||||
pub const TensorType = enum(u32) {
|
|
||||||
f32 = 0,
|
|
||||||
f16 = 1,
|
|
||||||
q4_0 = 2,
|
|
||||||
q4_1 = 3,
|
|
||||||
deprecated_q4_2 = 4,
|
|
||||||
deprecated_q4_3 = 5,
|
|
||||||
q5_0 = 6,
|
|
||||||
q5_1 = 7,
|
|
||||||
q8_0 = 8,
|
|
||||||
q8_1 = 9,
|
|
||||||
// k-quantizations
|
|
||||||
q2_k = 10,
|
|
||||||
q3_k = 11,
|
|
||||||
q4_k = 12,
|
|
||||||
q5_k = 13,
|
|
||||||
q6_k = 14,
|
|
||||||
q8_k = 15,
|
|
||||||
i8 = 16,
|
|
||||||
i16 = 17,
|
|
||||||
i32 = 18,
|
|
||||||
|
|
||||||
const MAX_KNOWN_ENUM = 18;
|
|
||||||
|
|
||||||
pub fn canConvertQuant(self: TensorType) bool {
|
|
||||||
return switch (self) {
|
|
||||||
.q8_0, .q4_k, .q6_k, .q2_k, .q4_0, .q4_1 => true,
|
|
||||||
else => false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn toDtype(self: TensorType) ?zml.DataType {
|
|
||||||
return switch (self) {
|
|
||||||
.f32 => .f32,
|
|
||||||
.f16 => .f16,
|
|
||||||
.i8 => .i8,
|
|
||||||
.i16 => .i16,
|
|
||||||
.i32 => .i32,
|
|
||||||
else => null,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sizeOf(self: TensorType) usize {
|
|
||||||
return self.toDtype().?.sizeOf();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the tensor type features
|
|
||||||
pub fn getFeatures(t: TensorType) TensorTypeFeatures {
|
|
||||||
return switch (t) {
|
|
||||||
inline else => |val| @field(TENSOR_TYPE_FEATURES, @tagName(val)),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// GGUF tensor type to features lookup table.
|
|
||||||
pub const TensorTypeFeatures = struct {
|
|
||||||
items_per_block: u29,
|
|
||||||
bytes_per_block: u29,
|
|
||||||
|
|
||||||
pub fn alignment(features: TensorTypeFeatures) u8 {
|
|
||||||
return std.math.log2_int(u29, features.bytes_per_block);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const TENSOR_TYPE_FEATURES: std.enums.EnumFieldStruct(TensorType, TensorTypeFeatures, null) = .{
|
|
||||||
.f32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f32) },
|
|
||||||
.f16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f16) },
|
|
||||||
.q4_0 = .{ .items_per_block = 32, .bytes_per_block = 18 },
|
|
||||||
.q4_1 = .{ .items_per_block = 32, .bytes_per_block = 20 },
|
|
||||||
.deprecated_q4_2 = .{ .items_per_block = 0, .bytes_per_block = 0 },
|
|
||||||
.deprecated_q4_3 = .{ .items_per_block = 0, .bytes_per_block = 0 },
|
|
||||||
.q5_0 = .{ .items_per_block = 32, .bytes_per_block = 22 },
|
|
||||||
.q5_1 = .{ .items_per_block = 32, .bytes_per_block = 24 },
|
|
||||||
.q8_0 = .{ .items_per_block = 32, .bytes_per_block = 34 },
|
|
||||||
.q8_1 = .{ .items_per_block = 32, .bytes_per_block = 40 },
|
|
||||||
.q2_k = .{ .items_per_block = 256, .bytes_per_block = 82 },
|
|
||||||
.q3_k = .{ .items_per_block = 256, .bytes_per_block = 110 },
|
|
||||||
.q4_k = .{ .items_per_block = 256, .bytes_per_block = 144 },
|
|
||||||
.q5_k = .{ .items_per_block = 256, .bytes_per_block = 176 },
|
|
||||||
.q6_k = .{ .items_per_block = 256, .bytes_per_block = 210 },
|
|
||||||
.q8_k = .{ .items_per_block = 256, .bytes_per_block = 292 },
|
|
||||||
.i8 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i8) },
|
|
||||||
.i16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i16) },
|
|
||||||
.i32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i32) },
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const GgufValueType = enum(u32) {
|
|
||||||
// The value is a 8-bit unsigned integer.
|
|
||||||
uint8 = 0,
|
|
||||||
// The value is a 8-bit signed integer.
|
|
||||||
int8 = 1,
|
|
||||||
// The value is a 16-bit unsigned little-endian integer.
|
|
||||||
uint16 = 2,
|
|
||||||
// The value is a 16-bit signed little-endian integer.
|
|
||||||
int16 = 3,
|
|
||||||
// The value is a 32-bit unsigned little-endian integer.
|
|
||||||
uint32 = 4,
|
|
||||||
// The value is a 32-bit signed little-endian integer.
|
|
||||||
int32 = 5,
|
|
||||||
// The value is a 32-bit IEEE754 floating point number.
|
|
||||||
float32 = 6,
|
|
||||||
// The value is a boolean.
|
|
||||||
// 1-byte value where 0 is false and 1 is true.
|
|
||||||
// Anything else is invalid, and should be treated as either the model
|
|
||||||
// being invalid or the reader being buggy.
|
|
||||||
bool = 7,
|
|
||||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
|
||||||
string = 8,
|
|
||||||
// The value is an array of other values, with the length and type
|
|
||||||
// prepended. Arrays can be nested, and the length of the array is the
|
|
||||||
// number of elements in the array, not the number of bytes.
|
|
||||||
array = 9,
|
|
||||||
// The value is a 64-bit unsigned little-endian integer.
|
|
||||||
uint64 = 10,
|
|
||||||
// The value is a 64-bit signed little-endian integer.
|
|
||||||
int64 = 11,
|
|
||||||
// The value is a 64-bit IEEE754 floating point number.
|
|
||||||
float64 = 12,
|
|
||||||
// Special values used by the callbacks of gguf_do_with_value().
|
|
||||||
array_start = 100,
|
|
||||||
array_end = 101,
|
|
||||||
|
|
||||||
// Allow other values in case GGUF add more types without us noticing
|
|
||||||
_,
|
|
||||||
|
|
||||||
pub fn sizeOf(self: GgufValueType) usize {
|
|
||||||
return switch (self) {
|
|
||||||
.uint8 => @sizeOf(u8),
|
|
||||||
.int8 => @sizeOf(i8),
|
|
||||||
.uint16 => @sizeOf(u16),
|
|
||||||
.int16 => @sizeOf(i16),
|
|
||||||
.uint32 => @sizeOf(u32),
|
|
||||||
.int32 => @sizeOf(i32),
|
|
||||||
.float32 => @sizeOf(f32),
|
|
||||||
.bool => @sizeOf(bool),
|
|
||||||
.uint64 => @sizeOf(u64),
|
|
||||||
.int64 => @sizeOf(i64),
|
|
||||||
.float64 => @sizeOf(f64),
|
|
||||||
.string => @sizeOf([]u8),
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn arrayTypeCheck(self: GgufValueType, comptime T: type) !void {
|
|
||||||
switch (self) {
|
|
||||||
.string => if (T != []u8 and T != []const u8) return error.ValueTypeMismatch,
|
|
||||||
.uint8 => if (T != u8) return error.ValueTypeMismatch,
|
|
||||||
.int8 => if (T != i8) return error.ValueTypeMismatch,
|
|
||||||
.uint16 => if (T != u16) return error.ValueTypeMismatch,
|
|
||||||
.int16 => if (T != i16) return error.ValueTypeMismatch,
|
|
||||||
.uint32 => if (T != u32) return error.ValueTypeMismatch,
|
|
||||||
.int32 => if (T != i32) return error.ValueTypeMismatch,
|
|
||||||
.float32 => if (T != f32) return error.ValueTypeMismatch,
|
|
||||||
.bool => if (T != bool) return error.ValueTypeMismatch,
|
|
||||||
.uint64 => if (T != u64) return error.ValueTypeMismatch,
|
|
||||||
.int64 => if (T != i64) return error.ValueTypeMismatch,
|
|
||||||
.float64 => if (T != f64) return error.ValueTypeMismatch,
|
|
||||||
else => {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const ValueType = enum(u8) {
|
|
||||||
uint8 = 0,
|
|
||||||
int8 = 1,
|
|
||||||
uint16 = 2,
|
|
||||||
int16 = 3,
|
|
||||||
uint32 = 4,
|
|
||||||
int32 = 5,
|
|
||||||
float32 = 6,
|
|
||||||
bool = 7,
|
|
||||||
string = 8,
|
|
||||||
array = 9,
|
|
||||||
uint64 = 10,
|
|
||||||
int64 = 11,
|
|
||||||
float64 = 12,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Union of possible values.
|
|
||||||
pub const GgufValue = union(ValueType) {
|
|
||||||
uint8: u8,
|
|
||||||
int8: i8,
|
|
||||||
uint16: u16,
|
|
||||||
int16: i16,
|
|
||||||
uint32: u32,
|
|
||||||
int32: i32,
|
|
||||||
float32: f32,
|
|
||||||
bool: bool,
|
|
||||||
string: []const u8,
|
|
||||||
array: Array,
|
|
||||||
uint64: u64,
|
|
||||||
int64: i64,
|
|
||||||
float64: f64,
|
|
||||||
|
|
||||||
pub const Array = struct {
|
|
||||||
// Any value type is valid, including arrays.
|
|
||||||
child: ValueType,
|
|
||||||
// Number of elements, not bytes
|
|
||||||
len: usize,
|
|
||||||
data: []u8,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Header
|
|
||||||
const GgufHeader = extern struct {
|
|
||||||
// Magic number to announce that this is a GGUF file. Must be `GUFF`.
|
|
||||||
magic: [4]u8,
|
|
||||||
// The version of the format implemented.
|
|
||||||
// Must be `3` for version described in this spec.
|
|
||||||
version: u32,
|
|
||||||
// The number of tensors in the file.
|
|
||||||
// This is explicit, instead of being included in the metadata, to ensure
|
|
||||||
// it is always present for loading the tensors.
|
|
||||||
tensor_count: usize,
|
|
||||||
// The number of metadata key-value pairs.
|
|
||||||
metadata_kv_count: usize,
|
|
||||||
|
|
||||||
pub fn validate(self: GgufHeader) !void {
|
|
||||||
if (!std.mem.eql(u8, &self.magic, "GGUF")) {
|
|
||||||
log.err("Invalid GGUF file: wrong header {s}", .{self.magic});
|
|
||||||
return error.InvalidGguf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Key representation in this library API.
|
|
||||||
pub const GgufMetadataKv = struct {
|
|
||||||
name: []const u8,
|
|
||||||
type_: GgufValueType,
|
|
||||||
val: GgufValue,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tensor representation in this library API.
|
|
||||||
const GGUF_TENSOR_MAX_DIM: usize = 8; // Future-proof: actual limit is 4.
|
|
||||||
pub const GgufTensorInfo = struct {
|
|
||||||
name: []const u8,
|
|
||||||
t: TensorType, // Tensor type (enum TensorType).
|
|
||||||
rank: usize, // Number of dimensions of the tensor.
|
|
||||||
dims: [GGUF_TENSOR_MAX_DIM]i64, // Dimensions (Eg. [512, 1024, 1, 1]).
|
|
||||||
start: usize, // Offset from start of data section.
|
|
||||||
byte_len: usize, // Total size in bytes.
|
|
||||||
num_weights: usize, // Total number of parameters.
|
|
||||||
|
|
||||||
pub inline fn shape(info: GgufTensorInfo) []const i64 {
|
|
||||||
return info.dims[0..info.rank];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Return the value type name given the type ID.
|
|
||||||
fn getValueTypeName(t: u32) []const u8 {
|
|
||||||
if (@as(usize, @intCast(t)) >= GGUF_VALUE_NAME.len) return "unknown";
|
|
||||||
return GGUF_VALUE_NAME[@intCast(t)];
|
|
||||||
}
|
|
||||||
|
|
||||||
const GGUF_VALUE_NAME = [_][]const u8{
|
|
||||||
"uint8", "int8", "uint16", "int16", "uint32", "int32",
|
|
||||||
"float32", "bool", "string", "array", "uint64", "int64",
|
|
||||||
"float64",
|
|
||||||
};
|
|
||||||
|
|
||||||
/// GGUF file API
|
|
||||||
/// A memory-mapped view of a .gguf file.
|
|
||||||
/// Format used by GGML models: https://github.com/ggerganov/ggml/
|
|
||||||
pub const GgufFile = struct {
|
|
||||||
header: GgufHeader, // GUFF file header info.
|
|
||||||
size: usize, // Total file size.
|
|
||||||
file: zml.aio.MemoryMappedFile,
|
|
||||||
left_kv: usize, // Number of key-value pairs yet to read.
|
|
||||||
left_tensors: usize, // Number of tensors yet to read.
|
|
||||||
off: usize, // Offset of the next item to parse.
|
|
||||||
alignment: usize = 32, // File data alignment. Default: 32 bytes.
|
|
||||||
|
|
||||||
/// Open and memmap the given file.
|
|
||||||
pub fn open(path: []const u8) !GgufFile {
|
|
||||||
const file = try asynk.File.open(path, .{});
|
|
||||||
const header = try file.reader().readStruct(GgufHeader);
|
|
||||||
try header.validate();
|
|
||||||
return .{
|
|
||||||
.header = header,
|
|
||||||
.size = (try file.stat()).size,
|
|
||||||
.file = try zml.aio.MemoryMappedFile.init(file),
|
|
||||||
.off = @sizeOf(GgufHeader),
|
|
||||||
.left_kv = header.metadata_kv_count,
|
|
||||||
.left_tensors = header.tensor_count,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unmap the file memory and close the file handle.
|
|
||||||
pub fn close(self: *GgufFile) void {
|
|
||||||
self.file.deinit();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the context to read the first key-value entry in the GGUF
|
|
||||||
/// file and then all the rest. Is used when creating a new context
|
|
||||||
/// and also when you want to restart scanning the key-value
|
|
||||||
/// items in the file.
|
|
||||||
fn rewind(ctx: *GgufFile) void {
|
|
||||||
ctx.off = @sizeOf(GgufHeader);
|
|
||||||
ctx.left_kv = ctx.header.metadata_kv_count;
|
|
||||||
ctx.left_tensors = ctx.header.tensor_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn seek(self: *GgufFile, pos: usize) void {
|
|
||||||
assert(pos < self.size);
|
|
||||||
self.off = pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readInt(self: *GgufFile, comptime T: type) !T {
|
|
||||||
if (self.off + @sizeOf(T) >= self.size) return error.InvalidGguf;
|
|
||||||
const res = self.file.file.reader().readInt(T, .little);
|
|
||||||
self.off += @sizeOf(T);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readTensorType(self: *GgufFile) !TensorType {
|
|
||||||
const raw = try self.readInt(u32);
|
|
||||||
if (raw > TensorType.MAX_KNOWN_ENUM) {
|
|
||||||
log.err("Unsupported GGUF tensor type: {d}", .{raw});
|
|
||||||
return error.UnsupportedGgufType;
|
|
||||||
}
|
|
||||||
return @enumFromInt(raw);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readValueType(self: *GgufFile) !GgufValueType {
|
|
||||||
const raw = try self.readInt(u32);
|
|
||||||
const t: GgufValueType = @enumFromInt(raw);
|
|
||||||
switch (t) {
|
|
||||||
.uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .array, .uint64, .int64, .float64, .array_start, .array_end => {},
|
|
||||||
else => {
|
|
||||||
log.err("Unsupported GGUF value type: {s}", .{@tagName(t)});
|
|
||||||
return error.UnsupportedGgufType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn readAlloc(self: *GgufFile, allocator: std.mem.Allocator, len: usize) ![]u8 {
|
|
||||||
const data = try allocator.alloc(u8, len);
|
|
||||||
const read = try self.file.file.reader().readAll(data);
|
|
||||||
if (read != data.len) return error.InvalidGguf;
|
|
||||||
self.off += len;
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn skipBytes(self: *GgufFile, len: usize) !void {
|
|
||||||
try self.file.file.seekBy(@intCast(len));
|
|
||||||
self.off += len;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read the len then the actual bytes.
|
|
||||||
pub fn readString(self: *GgufFile, allocator: std.mem.Allocator) ![]u8 {
|
|
||||||
const len: usize = try self.readInt(u64);
|
|
||||||
return self.readAlloc(allocator, len);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn skipString(self: *GgufFile) !void {
|
|
||||||
const len: usize = try self.readInt(u64);
|
|
||||||
return self.skipBytes(len);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
|
|
||||||
const child = try self.readValueType();
|
|
||||||
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
|
|
||||||
return error.UnsupportedGgufType;
|
|
||||||
}
|
|
||||||
const len: usize = try self.readInt(u64);
|
|
||||||
const data = switch (child) {
|
|
||||||
// Since strings have variable lenghts, we need to read them one by one
|
|
||||||
.string => str: {
|
|
||||||
var data = try allocator.alloc([]u8, len);
|
|
||||||
for (0..len) |i| data[i] = try self.readString(allocator);
|
|
||||||
break :str std.mem.sliceAsBytes(data);
|
|
||||||
},
|
|
||||||
else => try self.readAlloc(allocator, len * child.sizeOf()),
|
|
||||||
};
|
|
||||||
return .{
|
|
||||||
.child = @enumFromInt(@intFromEnum(child)),
|
|
||||||
.len = len,
|
|
||||||
.data = data,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readTypedValue(self: *GgufFile, allocator: std.mem.Allocator, t: GgufValueType) !GgufValue {
|
|
||||||
return switch (t) {
|
|
||||||
.uint8 => .{ .uint8 = try self.readInt(u8) },
|
|
||||||
.int8 => .{ .int8 = try self.readInt(i8) },
|
|
||||||
.uint16 => .{ .uint16 = try self.readInt(u16) },
|
|
||||||
.int16 => .{ .int16 = try self.readInt(i16) },
|
|
||||||
.uint32 => .{ .uint32 = try self.readInt(u32) },
|
|
||||||
.int32 => .{ .int32 = try self.readInt(i32) },
|
|
||||||
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
|
|
||||||
.bool => .{ .bool = try self.readInt(u8) != 0 },
|
|
||||||
.string => .{ .string = try self.readString(allocator) },
|
|
||||||
.array => .{ .array = try self.readArrayHeader(allocator) },
|
|
||||||
.uint64 => .{ .uint64 = try self.readInt(u64) },
|
|
||||||
.int64 => .{ .int64 = try self.readInt(i64) },
|
|
||||||
.float64 => .{ .float64 = @bitCast(try self.readInt(u64)) },
|
|
||||||
else => error.UnsupportedGgufType,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses the next metadata entry.
|
|
||||||
/// Returns error.EndOfMetadata if there are no longer metadata to process in this GGUF file.
|
|
||||||
pub fn readMetadata(self: *GgufFile, allocator: std.mem.Allocator) !GgufMetadataKv {
|
|
||||||
if (self.left_kv == 0) return error.EndOfMetadata;
|
|
||||||
self.left_kv -= 1;
|
|
||||||
const name = try self.readString(allocator);
|
|
||||||
const type_ = try self.readValueType();
|
|
||||||
const val: GgufValue = try self.readTypedValue(allocator, type_);
|
|
||||||
return .{ .name = name, .type_ = type_, .val = val };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the data section offset. This function must be called exactly when
|
|
||||||
// all the key-values are consumed, in the context of the first call of
|
|
||||||
// ctx.getTensor(): this way we will be able to return tensor offsets
|
|
||||||
// as absolute positions and pointers to the mmapped file.
|
|
||||||
fn setDataOffset(self: *GgufFile) !void {
|
|
||||||
const base_off = self.off;
|
|
||||||
|
|
||||||
assert(self.left_kv == 0 and self.left_tensors == self.header.tensor_count);
|
|
||||||
|
|
||||||
for (0..self.left_tensors) |_| try self.skipTensor();
|
|
||||||
const padding: usize = getAlignmentPadding(self.alignment, self.off);
|
|
||||||
self.file.data_offset = self.off + padding;
|
|
||||||
|
|
||||||
try self.file.file.seekTo(base_off);
|
|
||||||
self.off = base_off;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn skipTensor(self: *GgufFile) !void {
|
|
||||||
try self.skipString(); // Skip name
|
|
||||||
const num_dim: u32 = try self.readInt(u32);
|
|
||||||
// dimensions, type, and offset.
|
|
||||||
try self.skipBytes(8 * num_dim + 4 + 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses the next tensor entry.
|
|
||||||
/// Returns error.EndOfMetadata if there are no longer tensor metadata to process in this GGUF file.
|
|
||||||
pub fn readTensorInfo(self: *GgufFile, allocator: std.mem.Allocator) !GgufTensorInfo {
|
|
||||||
if (self.left_tensors == 0 or self.left_kv != 0) {
|
|
||||||
return error.EndOfMetadata;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We want to return tensor data with offsets relative to the start
|
|
||||||
// of the file, so that the user of the API is able to access tensors
|
|
||||||
// as it iterates over them. To do so, we need to perform a full
|
|
||||||
// scan if this is the first tensor info we are reading.
|
|
||||||
// TODO: explicitly set the data offset in
|
|
||||||
if (self.file.data_offset == 0) try self.setDataOffset();
|
|
||||||
self.left_tensors -= 1;
|
|
||||||
const name = try self.readString(allocator);
|
|
||||||
const num_dim = try self.readInt(u32);
|
|
||||||
assert(@as(usize, @intCast(num_dim)) <= GGUF_TENSOR_MAX_DIM);
|
|
||||||
// Read the dimentions; unused dimensions are left `undefined`.
|
|
||||||
// Note: we reverse the order of the dimensions to match zml convention.
|
|
||||||
var dims: [GGUF_TENSOR_MAX_DIM]i64 = undefined;
|
|
||||||
var num_weights: usize = 1;
|
|
||||||
for (0..num_dim) |j| {
|
|
||||||
const d = try self.readInt(u64);
|
|
||||||
dims[num_dim - 1 - j] = @intCast(d);
|
|
||||||
num_weights *= d;
|
|
||||||
}
|
|
||||||
const t: TensorType = try self.readTensorType();
|
|
||||||
const start = try self.readInt(u64);
|
|
||||||
// To accurately calculate the bytes used by this tensor on the GGUF
|
|
||||||
// file, we need to take into account that quantization methods store
|
|
||||||
// tensors as block of N weights. So first of all we need to understand
|
|
||||||
// the number of padding weights (since the last block may have just
|
|
||||||
// fewer weights stored inside, but still requires to be stored to its full
|
|
||||||
// length). Then we can do the math to see how many blocks we need, and
|
|
||||||
// multiply by the block size to obtain the final total size.
|
|
||||||
const tf = t.getFeatures();
|
|
||||||
const byte_len: usize = (std.math.divCeil(usize, num_weights, tf.items_per_block) catch unreachable) * tf.bytes_per_block;
|
|
||||||
return .{
|
|
||||||
.name = name,
|
|
||||||
.t = t,
|
|
||||||
.rank = num_dim,
|
|
||||||
.dims = dims,
|
|
||||||
.start = start,
|
|
||||||
.byte_len = byte_len,
|
|
||||||
.num_weights = num_weights,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Given an offset or a length, returns the padding needed to align it to alignment.
|
|
||||||
fn getAlignmentPadding(alignment: usize, offset: usize) usize {
|
|
||||||
return @rem((alignment - @rem(offset, alignment)), alignment);
|
|
||||||
}
|
|
||||||
@ -1,9 +1,9 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const zml = @import("../zml.zig");
|
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
|
|
||||||
|
const zml = @import("../zml.zig");
|
||||||
const eval = @import("torch/eval.zig");
|
const eval = @import("torch/eval.zig");
|
||||||
const py = @import("torch/py.zig");
|
|
||||||
const File = @import("torch/file.zig").File;
|
const File = @import("torch/file.zig").File;
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||||
@ -12,7 +12,7 @@ const log = std.log.scoped(.@"zml/aio");
|
|||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(eval);
|
std.testing.refAllDecls(eval);
|
||||||
std.testing.refAllDecls(py);
|
std.testing.refAllDecls(@import("torch/py.zig"));
|
||||||
std.testing.refAllDecls(File);
|
std.testing.refAllDecls(File);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,13 +30,13 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
const tmp_alloc = arena.allocator();
|
const tmp_alloc = arena.allocator();
|
||||||
|
|
||||||
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
|
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
|
||||||
var torch_file = try File.init(tmp_alloc, mmap_file);
|
var torch_file = try asynk.callBlocking(File.init, .{ tmp_alloc, mmap_file });
|
||||||
|
|
||||||
const ops = try torch_file.parsePickle(tmp_alloc);
|
const ops = try torch_file.parsePickle(tmp_alloc);
|
||||||
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
||||||
|
|
||||||
// file ownership is transferred to the BufferStore
|
// file ownership is transferred to the BufferStore
|
||||||
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file});
|
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file});
|
||||||
try torch_file.parseModel(py_values, &res);
|
try torch_file.parseModel(py_values, &res);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -151,23 +151,23 @@ pub const PickleMemo = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any {
|
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any {
|
||||||
var stack = std.ArrayList(py.Any).init(arena);
|
var stack: std.ArrayList(py.Any) = .{};
|
||||||
var memo = PickleMemo.init(arena);
|
var memo = PickleMemo.init(arena);
|
||||||
|
|
||||||
for (x) |op| {
|
for (x) |op| {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
.mark => try stack.append(.{ .raw = op }),
|
.mark => try stack.append(arena, .{ .raw = op }),
|
||||||
.frame => {},
|
.frame => {},
|
||||||
.stop => break,
|
.stop => break,
|
||||||
.pop => _ = try pop(&stack),
|
.pop => _ = try pop(&stack),
|
||||||
.pop_mark => _ = try popMark(&stack),
|
.pop_mark => _ = try popMark(&stack),
|
||||||
.dup => if (stack.getLastOrNull()) |item|
|
.dup => if (stack.getLastOrNull()) |item|
|
||||||
try stack.append(try item.clone(arena))
|
try stack.append(arena, try item.clone(arena))
|
||||||
else
|
else
|
||||||
return error.CannotDupEmptyStack,
|
return error.CannotDupEmptyStack,
|
||||||
.persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
|
.persid => |v| try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
|
||||||
.binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
|
.binpersid => try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
|
||||||
.reduce => try stack.append(.{ .global = blk: {
|
.reduce => try stack.append(arena, .{ .global = blk: {
|
||||||
var args = try pop(&stack);
|
var args = try pop(&stack);
|
||||||
args = try memo.resolve(arena, args, true);
|
args = try memo.resolve(arena, args, true);
|
||||||
if (args != .seq) return error.InvalidInput;
|
if (args != .seq) return error.InvalidInput;
|
||||||
@ -175,23 +175,23 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
func = try memo.resolve(arena, func, true);
|
func = try memo.resolve(arena, func, true);
|
||||||
break :blk try py.Object.init(arena, func, args.seq.values, &.{});
|
break :blk try py.Object.init(arena, func, args.seq.values, &.{});
|
||||||
} }),
|
} }),
|
||||||
.build => try stack.append(blk: {
|
.build => try stack.append(arena, blk: {
|
||||||
const args = try memo.resolve(arena, try pop(&stack), true);
|
const args = try memo.resolve(arena, try pop(&stack), true);
|
||||||
const member = try memo.resolve(arena, try pop(&stack), true);
|
const member = try memo.resolve(arena, try pop(&stack), true);
|
||||||
break :blk .{ .set_state = try py.SetState.init(arena, member, args) };
|
break :blk .{ .set_state = try py.SetState.init(arena, member, args) };
|
||||||
}),
|
}),
|
||||||
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
|
.empty_dict => try stack.append(arena, .{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
|
||||||
.get => |v| try stack.append(.{ .ref = v }),
|
.get => |v| try stack.append(arena, .{ .ref = v }),
|
||||||
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
|
.empty_list => try stack.append(arena, .{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
|
||||||
.put => |v| {
|
.put => |v| {
|
||||||
try memo.insert(v, try pop(&stack));
|
try memo.insert(v, try pop(&stack));
|
||||||
try stack.append(.{ .ref = v });
|
try stack.append(arena, .{ .ref = v });
|
||||||
},
|
},
|
||||||
.tuple => try stack.append(blk: {
|
.tuple => try stack.append(arena, blk: {
|
||||||
const popped = try popMark(&stack);
|
const popped = try popMark(&stack);
|
||||||
break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } };
|
break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } };
|
||||||
}),
|
}),
|
||||||
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
|
.empty_tuple => try stack.append(arena, .{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
|
||||||
.setitem => {
|
.setitem => {
|
||||||
const v = try memo.resolve(arena, try pop(&stack), true);
|
const v = try memo.resolve(arena, try pop(&stack), true);
|
||||||
const k = try memo.resolve(arena, try pop(&stack), true);
|
const k = try memo.resolve(arena, try pop(&stack), true);
|
||||||
@ -228,17 +228,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
.proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
|
.proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
|
||||||
.tuple1 => try stack.append(blk: {
|
.tuple1 => try stack.append(arena, blk: {
|
||||||
const tup_values = try arena.alloc(py.Any, 1);
|
const tup_values = try arena.alloc(py.Any, 1);
|
||||||
tup_values[0] = try pop(&stack);
|
tup_values[0] = try pop(&stack);
|
||||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||||
}),
|
}),
|
||||||
.tuple2 => try stack.append(blk: {
|
.tuple2 => try stack.append(arena, blk: {
|
||||||
const tup_values = try arena.alloc(py.Any, 2);
|
const tup_values = try arena.alloc(py.Any, 2);
|
||||||
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
||||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||||
}),
|
}),
|
||||||
.tuple3 => try stack.append(blk: {
|
.tuple3 => try stack.append(arena, blk: {
|
||||||
const tup_values = try arena.alloc(py.Any, 3);
|
const tup_values = try arena.alloc(py.Any, 3);
|
||||||
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
||||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||||
@ -276,32 +276,32 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.dict => try stack.append(.{ .seq = .{
|
.dict => try stack.append(arena, .{ .seq = .{
|
||||||
.type = .dict,
|
.type = .dict,
|
||||||
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
||||||
} }),
|
} }),
|
||||||
.list => try stack.append(.{ .seq = .{
|
.list => try stack.append(arena, .{ .seq = .{
|
||||||
.type = .list,
|
.type = .list,
|
||||||
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
||||||
} }),
|
} }),
|
||||||
.inst => |v| try stack.append(.{ .object = try py.Object.init(
|
.inst => |v| try stack.append(arena, .{ .object = try py.Object.init(
|
||||||
arena,
|
arena,
|
||||||
try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena),
|
try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena),
|
||||||
try arena.dupe(py.Any, try popMark(&stack)),
|
try arena.dupe(py.Any, try popMark(&stack)),
|
||||||
&.{},
|
&.{},
|
||||||
) }),
|
) }),
|
||||||
.obj => try stack.append(blk: {
|
.obj => try stack.append(arena, blk: {
|
||||||
const mark = try findMark(&stack);
|
const mark = try findMark(&stack);
|
||||||
const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]);
|
const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]);
|
||||||
const member = stack.items[mark + 1];
|
const member = stack.items[mark + 1];
|
||||||
break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) };
|
break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) };
|
||||||
}),
|
}),
|
||||||
.newobj => try stack.append(blk: {
|
.newobj => try stack.append(arena, blk: {
|
||||||
const args = try arena.alloc(py.Any, 1);
|
const args = try arena.alloc(py.Any, 1);
|
||||||
args[0] = try pop(&stack);
|
args[0] = try pop(&stack);
|
||||||
break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) };
|
break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) };
|
||||||
}),
|
}),
|
||||||
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
|
.empty_set => try stack.append(arena, .{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
|
||||||
.additems => {
|
.additems => {
|
||||||
const postmark = try popMark(&stack);
|
const postmark = try popMark(&stack);
|
||||||
const top = try lastMut(&stack);
|
const top = try lastMut(&stack);
|
||||||
@ -316,15 +316,15 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.frozenset => try stack.append(.{ .seq = .{
|
.frozenset => try stack.append(arena, .{ .seq = .{
|
||||||
.type = .frozen_set,
|
.type = .frozen_set,
|
||||||
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
.values = try arena.dupe(py.Any, try popMark(&stack)),
|
||||||
} }),
|
} }),
|
||||||
.newobj_ex => try stack.append(blk: {
|
.newobj_ex => try stack.append(arena, blk: {
|
||||||
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
|
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
|
||||||
break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) };
|
break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) };
|
||||||
}),
|
}),
|
||||||
.stack_global => try stack.append(blk: {
|
.stack_global => try stack.append(arena, blk: {
|
||||||
const gn, const mn = .{
|
const gn, const mn = .{
|
||||||
try memo.resolve(arena, try pop(&stack), true),
|
try memo.resolve(arena, try pop(&stack), true),
|
||||||
try memo.resolve(arena, try pop(&stack), true),
|
try memo.resolve(arena, try pop(&stack), true),
|
||||||
@ -338,13 +338,13 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
};
|
};
|
||||||
try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
|
try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
|
||||||
},
|
},
|
||||||
else => try stack.append(.{ .raw = try op.clone(arena) }),
|
else => try stack.append(arena, .{ .raw = try op.clone(arena) }),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (resolve_refs) {
|
if (resolve_refs) {
|
||||||
return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
|
return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
|
||||||
}
|
}
|
||||||
return stack.toOwnedSlice();
|
return stack.toOwnedSlice(arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
|
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
|
||||||
@ -358,8 +358,9 @@ test evaluate {
|
|||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const allocator = arena.allocator();
|
const allocator = arena.allocator();
|
||||||
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
|
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
|
||||||
var buffered_reader = std.io.bufferedReader(file.reader());
|
var reader_buffer: [1024]u8 = undefined;
|
||||||
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
|
var reader = file.reader(&reader_buffer);
|
||||||
|
const ops = try pickle.parse(allocator, &reader.interface);
|
||||||
|
|
||||||
const vals = try evaluate(allocator, ops, true);
|
const vals = try evaluate(allocator, ops, true);
|
||||||
defer allocator.free(vals);
|
defer allocator.free(vals);
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const testing = std.testing;
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("../../zml.zig");
|
const zml = @import("../../zml.zig");
|
||||||
|
const HostBuffer = zml.HostBuffer;
|
||||||
|
const eval = @import("eval.zig");
|
||||||
const pickle = @import("pickle.zig");
|
const pickle = @import("pickle.zig");
|
||||||
const py = @import("py.zig");
|
const py = @import("py.zig");
|
||||||
const eval = @import("eval.zig");
|
|
||||||
const HostBuffer = zml.HostBuffer;
|
|
||||||
|
|
||||||
const testing = std.testing;
|
|
||||||
const log = std.log.scoped(.@"zml/aio");
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
||||||
@ -20,167 +21,82 @@ test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const File = struct {
|
pub const File = struct {
|
||||||
buffer_file: zml.aio.MemoryMappedFile,
|
mmap_file: zml.aio.MemoryMappedFile,
|
||||||
/// Map names to sub file
|
/// Map names to sub file
|
||||||
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
|
file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{},
|
||||||
tar_file: ?TarStream = null,
|
zip_prefix: []const u8,
|
||||||
is_zip_file: bool,
|
pickle_subfile: []const u8,
|
||||||
zip_prefix: []const u8 = &.{},
|
|
||||||
pickle_subfile: struct { start: u64 = 0, len: usize },
|
|
||||||
|
|
||||||
pub const FileEntry = struct {
|
|
||||||
version_needed_to_extract: u16,
|
|
||||||
flags: u16,
|
|
||||||
compression_method: std.zip.CompressionMethod,
|
|
||||||
last_modification_time: u16,
|
|
||||||
last_modification_date: u16,
|
|
||||||
header_zip_offset: u64,
|
|
||||||
crc32: u32,
|
|
||||||
filename_len: u32,
|
|
||||||
compressed_size: u64,
|
|
||||||
uncompressed_size: u64,
|
|
||||||
file_offset: u64,
|
|
||||||
|
|
||||||
pub fn init(entry: anytype) FileEntry {
|
|
||||||
return .{
|
|
||||||
.version_needed_to_extract = entry.version_needed_to_extract,
|
|
||||||
.flags = @as(u16, @bitCast(entry.flags)),
|
|
||||||
.compression_method = entry.compression_method,
|
|
||||||
.last_modification_time = entry.last_modification_time,
|
|
||||||
.last_modification_date = entry.last_modification_date,
|
|
||||||
.header_zip_offset = entry.header_zip_offset,
|
|
||||||
.crc32 = entry.crc32,
|
|
||||||
.filename_len = entry.filename_len,
|
|
||||||
.compressed_size = entry.compressed_size,
|
|
||||||
.uncompressed_size = entry.uncompressed_size,
|
|
||||||
.file_offset = entry.file_offset,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const magic = "PK\x03\x04";
|
const magic = "PK\x03\x04";
|
||||||
|
|
||||||
pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File {
|
|
||||||
const tar_file = try TarStream.init(file);
|
|
||||||
const file_magic = try tar_file.reader().readBytesNoEof(magic.len);
|
|
||||||
try tar_file.seekTo(0);
|
|
||||||
var res: File = .{
|
|
||||||
.buffer_file = mapped,
|
|
||||||
.tar_file = tar_file,
|
|
||||||
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
|
|
||||||
.pickle_subfile = .{ .len = try tar_file.getEndPos() },
|
|
||||||
};
|
|
||||||
if (res.is_zip_file) {
|
|
||||||
try res.parseZipHeaders(allocator, tar_file.seekableStream());
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File {
|
pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File {
|
||||||
const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len);
|
var pkl: []const u8 = mmap_file.data;
|
||||||
try mmap_file.file.seekTo(0);
|
var zip_prefix: []const u8 = &.{};
|
||||||
var res: File = .{
|
var file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{};
|
||||||
.buffer_file = mmap_file,
|
if (std.mem.eql(u8, mmap_file.data[0..magic.len], magic)) {
|
||||||
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
|
// We are dealing with a zip file.
|
||||||
.pickle_subfile = .{ .len = mmap_file.data.len },
|
// Let's look for the `data.pkl` file and keep a map of all other files.
|
||||||
};
|
// The other files will be the tensor storage and will be reference from `data.pkl`.
|
||||||
|
var header_parsing_buffer: [4096]u8 = undefined;
|
||||||
|
|
||||||
if (res.is_zip_file) {
|
// std.zip requires on a std.fs.File and don't leverage std.Io.Reader directly.
|
||||||
try res.parseZipHeaders(allocator, mmap_file.file.seekableStream());
|
// So we use the synchronous API to parse the headers,
|
||||||
|
// then we rely only on the memory map data to parse the pickle and load the buffers.
|
||||||
|
// To mitigate this we use `async.launchBlocking` in `torch.open`.
|
||||||
|
const raw_file: std.fs.File = .{ .handle = mmap_file.file._handle };
|
||||||
|
var reader = raw_file.reader(&header_parsing_buffer);
|
||||||
|
var it: std.zip.Iterator = try .init(&reader);
|
||||||
|
|
||||||
|
while (try it.next()) |header| {
|
||||||
|
if (header.filename_len == 0) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
return res;
|
if (header.compression_method != .store) {
|
||||||
|
return error.Unsupported;
|
||||||
|
}
|
||||||
|
|
||||||
|
const filename = mmap_file.data[header.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader) ..][0..header.filename_len];
|
||||||
|
|
||||||
|
var local_reader: std.Io.Reader = .fixed(mmap_file.data);
|
||||||
|
local_reader.discardAll(header.file_offset) catch return error.InvalidZipFile;
|
||||||
|
const local_header = local_reader.takeStruct(std.zip.LocalFileHeader, .little) catch return error.InvalidZipFile;
|
||||||
|
local_reader.discardAll(local_header.filename_len) catch return error.InvalidZipFile;
|
||||||
|
local_reader.discardAll(local_header.extra_len) catch return error.InvalidZipFile;
|
||||||
|
|
||||||
|
// normalize path separators
|
||||||
|
const file_content = mmap_file.data[local_reader.seek..][0..header.compressed_size];
|
||||||
|
const my_filename: []u8 = try allocator.dupe(u8, filename);
|
||||||
|
std.mem.replaceScalar(u8, my_filename, '\\', '/');
|
||||||
|
|
||||||
|
try file_map.put(allocator, my_filename, file_content);
|
||||||
|
|
||||||
|
if (std.mem.endsWith(u8, filename, "data.pkl")) {
|
||||||
|
pkl = file_content;
|
||||||
|
zip_prefix = filename[0 .. filename.len - "data.pkl".len];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pkl.len == 0) {
|
||||||
|
log.err("Could not find file ending in `data.pkl` in archive", .{});
|
||||||
|
return error.PickleNotFound;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.mmap_file = mmap_file,
|
||||||
|
.file_map = file_map,
|
||||||
|
.pickle_subfile = pkl,
|
||||||
|
.zip_prefix = zip_prefix,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn close(self: *File) void {
|
pub fn close(self: *File) void {
|
||||||
self.buffer_file.deinit();
|
self.mmap_file.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op {
|
pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op {
|
||||||
return if (self.tar_file) |tar_file| {
|
var reader: std.Io.Reader = .fixed(self.pickle_subfile);
|
||||||
try tar_file.seekTo(self.pickle_subfile.start);
|
return try pickle.parse(allocator, &reader);
|
||||||
var buffered = std.io.bufferedReader(tar_file.reader());
|
|
||||||
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
|
|
||||||
} else {
|
|
||||||
const file = self.buffer_file.file;
|
|
||||||
try file.seekTo(self.pickle_subfile.start);
|
|
||||||
var buffered = std.io.bufferedReader(file.reader());
|
|
||||||
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void {
|
|
||||||
var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{};
|
|
||||||
|
|
||||||
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
|
||||||
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
|
||||||
while (try iter.next()) |entry| {
|
|
||||||
const filename = filename_buf[0..entry.filename_len];
|
|
||||||
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
|
|
||||||
const len = try seekable_stream.context.reader().readAll(filename);
|
|
||||||
if (len != filename.len) return error.ZipBadFileOffset;
|
|
||||||
if (isBadFilename(filename)) return error.ZipBadFilename;
|
|
||||||
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
|
|
||||||
try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.file_map = file_map;
|
|
||||||
var file_iter = file_map.iterator();
|
|
||||||
while (file_iter.next()) |e| {
|
|
||||||
const entry = e.value_ptr.*;
|
|
||||||
const filename = e.key_ptr.*;
|
|
||||||
if (!std.mem.endsWith(u8, filename, "data.pkl")) continue;
|
|
||||||
|
|
||||||
self.zip_prefix = filename[0 .. filename.len - "data.pkl".len];
|
|
||||||
|
|
||||||
const local_data_header_offset: u64 = local_data_header_offset: {
|
|
||||||
switch (entry.compression_method) {
|
|
||||||
.store => {},
|
|
||||||
.deflate => {
|
|
||||||
// TODO(cryptodeal): handle decompress
|
|
||||||
@panic("TODO support use of `deflate`");
|
|
||||||
},
|
|
||||||
else => @panic("TODO support other modes of compression"),
|
|
||||||
}
|
|
||||||
const local_header = blk: {
|
|
||||||
try seekable_stream.seekTo(entry.file_offset);
|
|
||||||
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
|
|
||||||
};
|
|
||||||
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
|
|
||||||
return error.ZipBadFileOffset;
|
|
||||||
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
|
|
||||||
return error.ZipMismatchVersionNeeded;
|
|
||||||
if (local_header.last_modification_time != entry.last_modification_time)
|
|
||||||
return error.ZipMismatchModTime;
|
|
||||||
if (local_header.last_modification_date != entry.last_modification_date)
|
|
||||||
return error.ZipMismatchModDate;
|
|
||||||
|
|
||||||
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
|
|
||||||
return error.ZipMismatchFlags;
|
|
||||||
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
|
|
||||||
return error.ZipMismatchCrc32;
|
|
||||||
if (local_header.compressed_size != 0 and
|
|
||||||
local_header.compressed_size != entry.compressed_size)
|
|
||||||
return error.ZipMismatchCompLen;
|
|
||||||
if (local_header.uncompressed_size != 0 and
|
|
||||||
local_header.uncompressed_size != entry.uncompressed_size)
|
|
||||||
return error.ZipMismatchUncompLen;
|
|
||||||
if (local_header.filename_len != entry.filename_len)
|
|
||||||
return error.ZipMismatchFilenameLen;
|
|
||||||
|
|
||||||
break :local_data_header_offset @as(u64, local_header.filename_len) +
|
|
||||||
@as(u64, local_header.extra_len);
|
|
||||||
};
|
|
||||||
|
|
||||||
const local_data_file_offset: u64 =
|
|
||||||
@as(u64, entry.file_offset) +
|
|
||||||
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
|
|
||||||
local_data_header_offset;
|
|
||||||
self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size };
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.err("Could not find file ending in `data.pkl` in archive", .{});
|
|
||||||
return error.PickleNotFound;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool {
|
fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool {
|
||||||
@ -286,7 +202,7 @@ pub const File = struct {
|
|||||||
if (prefix.items.len > 0) {
|
if (prefix.items.len > 0) {
|
||||||
new_prefix.appendAssumeCapacity('.');
|
new_prefix.appendAssumeCapacity('.');
|
||||||
}
|
}
|
||||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||||
try self.parseValue(allocator, store, new_prefix, val);
|
try self.parseValue(allocator, store, new_prefix, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -303,7 +219,7 @@ pub const File = struct {
|
|||||||
if (prefix.items.len > 0) {
|
if (prefix.items.len > 0) {
|
||||||
new_prefix.appendAssumeCapacity('.');
|
new_prefix.appendAssumeCapacity('.');
|
||||||
}
|
}
|
||||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||||
const new_tag = switch (tag) {
|
const new_tag = switch (tag) {
|
||||||
.int64 => "int",
|
.int64 => "int",
|
||||||
.float64 => "float",
|
.float64 => "float",
|
||||||
@ -321,7 +237,7 @@ pub const File = struct {
|
|||||||
if (prefix.items.len > 0) {
|
if (prefix.items.len > 0) {
|
||||||
new_prefix.appendAssumeCapacity('.');
|
new_prefix.appendAssumeCapacity('.');
|
||||||
}
|
}
|
||||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||||
}
|
}
|
||||||
try self.parseValue(allocator, store, new_prefix, item);
|
try self.parseValue(allocator, store, new_prefix, item);
|
||||||
}
|
}
|
||||||
@ -353,7 +269,7 @@ pub const File = struct {
|
|||||||
if (prefix.items.len > 0) {
|
if (prefix.items.len > 0) {
|
||||||
new_prefix.appendAssumeCapacity('.');
|
new_prefix.appendAssumeCapacity('.');
|
||||||
}
|
}
|
||||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
|
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
|
||||||
try self.parseValue(allocator, store, new_prefix, val);
|
try self.parseValue(allocator, store, new_prefix, val);
|
||||||
},
|
},
|
||||||
inline else => |_, tag| {
|
inline else => |_, tag| {
|
||||||
@ -504,34 +420,10 @@ pub const File = struct {
|
|||||||
/// Given the name of one of the files in the .pt tarball,
|
/// Given the name of one of the files in the .pt tarball,
|
||||||
/// return the slice of the memory-mapped .pt corresponding to it.
|
/// return the slice of the memory-mapped .pt corresponding to it.
|
||||||
fn getStorage(self: File, filename: []const u8) ![]const u8 {
|
fn getStorage(self: File, filename: []const u8) ![]const u8 {
|
||||||
const maybe_entry = self.file_map.get(filename);
|
return self.file_map.get(filename) orelse {
|
||||||
if (maybe_entry == null) {
|
|
||||||
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
|
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
|
||||||
return error.TensorNotFound;
|
return error.TensorNotFound;
|
||||||
}
|
};
|
||||||
const entry = maybe_entry.?;
|
|
||||||
const base_offset: u64 = if (self.tar_file) |t| t.start else 0;
|
|
||||||
const file_offset: u64 = base_offset + entry.file_offset;
|
|
||||||
const file = self.buffer_file.file;
|
|
||||||
try file.seekTo(entry.file_offset);
|
|
||||||
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
|
|
||||||
|
|
||||||
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
|
|
||||||
return error.ZipBadFileOffset;
|
|
||||||
if (local_header.compressed_size != 0 and
|
|
||||||
local_header.compressed_size != entry.compressed_size)
|
|
||||||
return error.ZipMismatchCompLen;
|
|
||||||
if (local_header.uncompressed_size != 0 and
|
|
||||||
local_header.uncompressed_size != entry.uncompressed_size)
|
|
||||||
return error.ZipMismatchUncompLen;
|
|
||||||
if (local_header.filename_len != entry.filename_len)
|
|
||||||
return error.ZipMismatchFilenameLen;
|
|
||||||
|
|
||||||
const start = file_offset +
|
|
||||||
@sizeOf(std.zip.LocalFileHeader) +
|
|
||||||
@as(u64, local_header.filename_len) +
|
|
||||||
@as(u64, local_header.extra_len);
|
|
||||||
return self.buffer_file.mappedSlice(start, entry.uncompressed_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
|
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
|
||||||
@ -578,52 +470,6 @@ fn storageToDtype(storage_type: []const u8) !zml.DataType {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const TarStream = struct {
|
|
||||||
pub const SeekableStream = std.io.SeekableStream(
|
|
||||||
TarStream,
|
|
||||||
asynk.File.SeekError,
|
|
||||||
asynk.File.GetSeekPosError,
|
|
||||||
TarStream.seekTo,
|
|
||||||
TarStream.seekBy,
|
|
||||||
TarStream.getPos,
|
|
||||||
TarStream.getEndPos,
|
|
||||||
);
|
|
||||||
|
|
||||||
file: std.tar.Iterator(asynk.File.Reader).File,
|
|
||||||
start: usize,
|
|
||||||
|
|
||||||
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
|
|
||||||
return .{
|
|
||||||
.file = file,
|
|
||||||
.start = try file.parent_reader.context.getPos(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
|
|
||||||
return file.file.reader();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn seekTo(self: TarStream, offset: u64) !void {
|
|
||||||
return self.file.parent_reader.context.seekTo(self.start + offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn seekBy(self: TarStream, offset: i64) !void {
|
|
||||||
return self.file.parent_reader.context.seekBy(offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getPos(self: TarStream) !u64 {
|
|
||||||
return try self.file.parent_reader.context.getPos() - self.start;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getEndPos(self: TarStream) !u64 {
|
|
||||||
return self.file.size;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
|
|
||||||
return .{ .context = self };
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
test "Read pickle (zipped)" {
|
test "Read pickle (zipped)" {
|
||||||
// test file created with following python snippet:
|
// test file created with following python snippet:
|
||||||
//
|
//
|
||||||
@ -638,20 +484,19 @@ test "Read pickle (zipped)" {
|
|||||||
defer store.deinit();
|
defer store.deinit();
|
||||||
|
|
||||||
{
|
{
|
||||||
var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator);
|
var arena = std.heap.ArenaAllocator.init(testing.allocator);
|
||||||
defer tmp_arena.deinit();
|
defer arena.deinit();
|
||||||
const tmp_alloc = tmp_arena.allocator();
|
var torch_file = try File.init(arena.allocator(), mmap_file);
|
||||||
var torch_file = try File.init(tmp_alloc, mmap_file);
|
|
||||||
// We don't close the file directly, it will be closed by the store.
|
// We don't close the file directly, it will be closed by the store.
|
||||||
|
|
||||||
const ops = try torch_file.parsePickle(tmp_alloc);
|
const ops = try torch_file.parsePickle(arena.allocator());
|
||||||
try std.testing.expectEqual(302, ops.len);
|
try std.testing.expectEqual(302, ops.len);
|
||||||
|
|
||||||
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
const py_values = try eval.evaluate(arena.allocator(), ops, true);
|
||||||
try torch_file.parseModel(py_values, &store);
|
try torch_file.parseModel(py_values, &store);
|
||||||
}
|
}
|
||||||
|
|
||||||
// now we have freed the tmp_arena.
|
// now we have freed the arena.
|
||||||
// all data needed should have been copied into the store arena.
|
// all data needed should have been copied into the store arena.
|
||||||
try zml.testing.expectEqualShapes(
|
try zml.testing.expectEqualShapes(
|
||||||
zml.Shape.init(.{ 1, 4 }, .u8),
|
zml.Shape.init(.{ 1, 4 }, .u8),
|
||||||
|
|||||||
@ -763,54 +763,54 @@ pub const Op = union(enum) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Read a stream of bytes, and interpret it as a stream of Pickle operators.
|
/// Read a stream of bytes, and interpret it as a stream of Pickle operators.
|
||||||
pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) ![]const Op {
|
/// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies.
|
||||||
var results = std.ArrayList(Op).init(allocator);
|
pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
|
||||||
errdefer results.deinit();
|
// It's not very efficient to interleave the results with the data copied from the stream,
|
||||||
const len = max_line_len;
|
// because growth event in the results ArrayList will lead to fragmentation.
|
||||||
var _buf: std.BoundedArray(u8, 12) = .{};
|
// Trying to mitigate that by using a generous default size.
|
||||||
|
var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512);
|
||||||
|
errdefer results.deinit(arena);
|
||||||
|
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const b = try reader.readByte();
|
const code: OpCode = @enumFromInt(try reader.takeByte());
|
||||||
const code: OpCode = @enumFromInt(b);
|
|
||||||
const op: Op = switch (code) {
|
const op: Op = switch (code) {
|
||||||
.int => blk: {
|
.int => int: {
|
||||||
_buf.len = 0;
|
const bytes = try reader.takeDelimiterExclusive('\n');
|
||||||
try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1);
|
|
||||||
const buf = _buf.constSlice();
|
|
||||||
// Legacy hack, see OpCode.int documentation
|
// Legacy hack, see OpCode.int documentation
|
||||||
// We do this parsing right away to simplify downstream code.
|
// We do this parsing right away to simplify downstream code.
|
||||||
break :blk if (std.mem.eql(u8, "00", buf))
|
break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0')
|
||||||
.{ .bool = false }
|
.{ .bool = false }
|
||||||
else if (std.mem.eql(u8, "01", buf))
|
else if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '1')
|
||||||
.{ .bool = true }
|
.{ .bool = true }
|
||||||
else
|
else
|
||||||
.{ .int = try std.fmt.parseInt(i32, buf, 10) };
|
.{ .int = try std.fmt.parseInt(i32, bytes, 10) };
|
||||||
},
|
},
|
||||||
.binint => .{ .int = try reader.readInt(i32, .little) },
|
.binint => .{ .int = try reader.takeInt(i32, .little) },
|
||||||
.binint1 => .{ .int = try reader.readByte() },
|
.binint1 => .{ .int = try reader.takeByte() },
|
||||||
.binint2 => .{ .int = try reader.readInt(u16, .little) },
|
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
|
||||||
// TODO: long should handle the trailing 'L' -> add a test.
|
// TODO: long should handle the trailing 'L' -> add a test.
|
||||||
.long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
|
.long => .{ .long = try readLine(reader, &alloc_writer) },
|
||||||
.long1 => .{ .binlong = try _readSlice(reader, allocator, 1) },
|
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
|
||||||
.long4 => .{ .binlong = try _readSlice(reader, allocator, 4) },
|
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
|
||||||
.string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
|
.string => .{ .string = try readLine(reader, &alloc_writer) },
|
||||||
.binstring => .{ .string = try _readSlice(reader, allocator, 4) },
|
.binstring => .{ .string = try _readSlice(reader, arena, 4) },
|
||||||
.short_binstring => .{ .string = try _readSlice(reader, allocator, 1) },
|
.short_binstring => .{ .string = try _readSlice(reader, arena, 1) },
|
||||||
.binbytes => .{ .bytes = try _readSlice(reader, allocator, 4) },
|
.binbytes => .{ .bytes = try _readSlice(reader, arena, 4) },
|
||||||
.binbytes8 => .{ .bytes = try _readSlice(reader, allocator, 8) },
|
.binbytes8 => .{ .bytes = try _readSlice(reader, arena, 8) },
|
||||||
.short_binbytes => .{ .bytes = try _readSlice(reader, allocator, 1) },
|
.short_binbytes => .{ .bytes = try _readSlice(reader, arena, 1) },
|
||||||
.bytearray8 => .{ .bytearray = try _readSlice(reader, allocator, 8) },
|
.bytearray8 => .{ .bytearray = try _readSlice(reader, arena, 8) },
|
||||||
.next_buffer => .next_buffer,
|
.next_buffer => .next_buffer,
|
||||||
.readonly_buffer => .readonly_buffer,
|
.readonly_buffer => .readonly_buffer,
|
||||||
.none => .none,
|
.none => .none,
|
||||||
.newtrue => .{ .bool = true },
|
.newtrue => .{ .bool = true },
|
||||||
.newfalse => .{ .bool = false },
|
.newfalse => .{ .bool = false },
|
||||||
.unicode => .{ .unicode = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
|
.unicode => .{ .unicode = try readLine(reader, &alloc_writer) },
|
||||||
.short_binunicode => .{ .unicode = try _readSlice(reader, allocator, 1) },
|
.short_binunicode => .{ .unicode = try _readSlice(reader, arena, 1) },
|
||||||
.binunicode => .{ .unicode = try _readSlice(reader, allocator, 4) },
|
.binunicode => .{ .unicode = try _readSlice(reader, arena, 4) },
|
||||||
.binunicode8 => .{ .unicode = try _readSlice(reader, allocator, 8) },
|
.binunicode8 => .{ .unicode = try _readSlice(reader, arena, 8) },
|
||||||
.float => .{ .float = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
|
.float => .{ .float = try readLine(reader, &alloc_writer) },
|
||||||
.binfloat => .{ .binfloat = @bitCast(try reader.readInt(u64, .big)) },
|
.binfloat => .{ .binfloat = @bitCast(try reader.takeInt(u64, .big)) },
|
||||||
.empty_list => .empty_list,
|
.empty_list => .empty_list,
|
||||||
.append => .append,
|
.append => .append,
|
||||||
.appends => .appends,
|
.appends => .appends,
|
||||||
@ -832,74 +832,74 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
|
|||||||
.mark => .mark,
|
.mark => .mark,
|
||||||
.pop_mark => .pop_mark,
|
.pop_mark => .pop_mark,
|
||||||
// If we fail to parse delay the error to the evaluation.
|
// If we fail to parse delay the error to the evaluation.
|
||||||
.get => .{
|
.get => get: {
|
||||||
.get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32),
|
const digits = try reader.takeDelimiterExclusive('\n');
|
||||||
|
break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
|
||||||
},
|
},
|
||||||
.binget => .{ .get = try reader.readByte() },
|
.binget => .{ .get = try reader.takeByte() },
|
||||||
.long_binget => .{ .get = try reader.readInt(u32, .little) },
|
.long_binget => .{ .get = try reader.takeInt(u32, .little) },
|
||||||
.put => blk: {
|
.put => put: {
|
||||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
const digits = try reader.takeDelimiterExclusive('\n');
|
||||||
defer allocator.free(buf);
|
break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
|
||||||
const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32);
|
|
||||||
break :blk .{ .put = n };
|
|
||||||
},
|
},
|
||||||
.binput => .{ .put = try reader.readByte() },
|
.binput => .{ .put = try reader.takeByte() },
|
||||||
.long_binput => .{ .put = try reader.readInt(u32, .little) },
|
.long_binput => .{ .put = try reader.takeInt(u32, .little) },
|
||||||
.memoize => .memoize,
|
.memoize => .memoize,
|
||||||
.ext1 => .{ .ext1 = try reader.readByte() },
|
.ext1 => .{ .ext1 = try reader.takeByte() },
|
||||||
.ext2 => .{ .ext2 = try reader.readInt(i16, .little) },
|
.ext2 => .{ .ext2 = try reader.takeInt(i16, .little) },
|
||||||
.ext4 => .{ .ext4 = try reader.readInt(i32, .little) },
|
.ext4 => .{ .ext4 = try reader.takeInt(i32, .little) },
|
||||||
.global => .{ .global = .{
|
.global => .{ .global = .{
|
||||||
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
|
.module = try readLine(reader, &alloc_writer),
|
||||||
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
|
.class = try readLine(reader, &alloc_writer),
|
||||||
} },
|
} },
|
||||||
.stack_global => .stack_global,
|
.stack_global => .stack_global,
|
||||||
.reduce => .reduce,
|
.reduce => .reduce,
|
||||||
.build => .build,
|
.build => .build,
|
||||||
.inst => .{ .inst = .{
|
.inst => .{ .inst = .{
|
||||||
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
|
.module = try readLine(reader, &alloc_writer),
|
||||||
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
|
.class = try readLine(reader, &alloc_writer),
|
||||||
} },
|
} },
|
||||||
.obj => .obj,
|
.obj => .obj,
|
||||||
.newobj => .newobj,
|
.newobj => .newobj,
|
||||||
.newobj_ex => .newobj_ex,
|
.newobj_ex => .newobj_ex,
|
||||||
.proto => blk: {
|
.proto => blk: {
|
||||||
const version = try reader.readByte();
|
const version = try reader.takeByte();
|
||||||
if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version});
|
if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version});
|
||||||
break :blk .{ .proto = version };
|
break :blk .{ .proto = version };
|
||||||
},
|
},
|
||||||
.stop => .stop,
|
.stop => .stop,
|
||||||
|
.frame => frame: {
|
||||||
// This is not documented in pickletools but in https://peps.python.org/pep-3154/
|
// This is not documented in pickletools but in https://peps.python.org/pep-3154/
|
||||||
// The frame size is stored right after the frame header.
|
|
||||||
// The loader is allowed to prefetch framesize from the underlying reader,
|
// The loader is allowed to prefetch framesize from the underlying reader,
|
||||||
// and ops are not allowed to cross a frame boundary.
|
// and ops are not allowed to cross a frame boundary.
|
||||||
// We don't prefetch because we assume the reader is going to use some kind of buffered reader.
|
const frame_size = try reader.takeInt(u64, .little);
|
||||||
// We could try to enforce frame boundaries, but we would need to track
|
reader.fill(@min(frame_size, reader.buffer.len)) catch |err| switch (err) {
|
||||||
// how many bytes we are reading from the stream.
|
error.EndOfStream => {},
|
||||||
.frame => .{ .frame = try reader.readInt(u64, .little) },
|
else => return err,
|
||||||
.persid => .{ .persid = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
|
};
|
||||||
|
break :frame .{ .frame = frame_size };
|
||||||
|
},
|
||||||
|
.persid => .{ .persid = try readLine(reader, &alloc_writer) },
|
||||||
.binpersid => .binpersid,
|
.binpersid => .binpersid,
|
||||||
_ => |unk_tag| {
|
_ => |unk_tag| {
|
||||||
log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag});
|
log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag});
|
||||||
return error.NotSupported;
|
return error.NotSupported;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
try results.append(op);
|
try results.append(arena, op);
|
||||||
if (op == .stop) break;
|
if (op == .stop) break;
|
||||||
}
|
}
|
||||||
return results.toOwnedSlice();
|
return results.items;
|
||||||
}
|
}
|
||||||
|
|
||||||
test "parse protocol 4" {
|
test "parse protocol 4" {
|
||||||
const allocator = std.testing.allocator;
|
var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
|
||||||
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
|
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
|
||||||
var buffered_reader = std.io.bufferedReader(file.reader());
|
var read_buffer: [1024]u8 = undefined;
|
||||||
const ops = try parse(allocator, buffered_reader.reader(), 4096);
|
var reader = file.reader(&read_buffer);
|
||||||
defer {
|
const ops = try parse(arena.allocator(), &reader.interface);
|
||||||
// Test we are correctly freeing every allocation.
|
|
||||||
for (ops) |op| op.deinit(allocator);
|
|
||||||
allocator.free(ops);
|
|
||||||
}
|
|
||||||
|
|
||||||
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
|
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
|
||||||
var expected = [_]Op{
|
var expected = [_]Op{
|
||||||
@ -948,7 +948,9 @@ test "parse protocol 4" {
|
|||||||
|
|
||||||
test "parse protocol 0" {
|
test "parse protocol 0" {
|
||||||
// We also test protocol 0, cause it's more text oriented.
|
// We also test protocol 0, cause it's more text oriented.
|
||||||
const allocator = std.testing.allocator;
|
var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
|
||||||
const pickle_0 =
|
const pickle_0 =
|
||||||
\\(dp0
|
\\(dp0
|
||||||
\\Vhello
|
\\Vhello
|
||||||
@ -982,13 +984,8 @@ test "parse protocol 0" {
|
|||||||
\\s.
|
\\s.
|
||||||
;
|
;
|
||||||
|
|
||||||
var stream = std.io.fixedBufferStream(pickle_0);
|
var reader: std.Io.Reader = .fixed(pickle_0);
|
||||||
const ops = try parse(allocator, stream.reader(), 4096);
|
const ops = try parse(arena.allocator(), &reader);
|
||||||
defer {
|
|
||||||
// Test we are correctly freeing every allocation.
|
|
||||||
for (ops) |op| op.deinit(allocator);
|
|
||||||
allocator.free(ops);
|
|
||||||
}
|
|
||||||
|
|
||||||
var expected = [_]Op{
|
var expected = [_]Op{
|
||||||
.mark,
|
.mark,
|
||||||
@ -1043,18 +1040,11 @@ test "parse protocol 0" {
|
|||||||
try std.testing.expectEqualDeep(&expected, ops);
|
try std.testing.expectEqualDeep(&expected, ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T {
|
|
||||||
buffer.len = 0;
|
|
||||||
try reader.streamUntilDelimiter(buffer.writer(), '\n', 13);
|
|
||||||
return std.fmt.parseInt(T, buffer.constSlice(), 10);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
|
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
|
||||||
const T = std.meta.Int(.unsigned, 8 * len_bytes);
|
const T = std.meta.Int(.unsigned, 8 * len_bytes);
|
||||||
const str_len: u64 = try reader.readInt(T, .little);
|
const str_len: u64 = try reader.takeInt(T, .little);
|
||||||
const buf = try allocator.alloc(u8, str_len);
|
const buf = try allocator.alloc(u8, str_len);
|
||||||
errdefer allocator.free(buf);
|
_ = try reader.readSliceAll(buf);
|
||||||
_ = try reader.read(buf);
|
|
||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1063,3 +1053,14 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8)
|
|||||||
std.mem.writeInt(T, &res, value, .little);
|
std.mem.writeInt(T, &res, value, .little);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 {
|
||||||
|
const n = try reader.streamDelimiter(&alloc_writer.writer, '\n');
|
||||||
|
std.debug.assert(try reader.takeByte() == '\n');
|
||||||
|
const w = &alloc_writer.writer;
|
||||||
|
std.debug.assert(w.end == n);
|
||||||
|
const items = w.buffer[0..n];
|
||||||
|
w.buffer = w.buffer[n + 1 ..];
|
||||||
|
w.end = 0;
|
||||||
|
return items;
|
||||||
|
}
|
||||||
|
|||||||
@ -172,7 +172,7 @@ pub const HostBuffer = struct {
|
|||||||
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
||||||
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
|
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
|
||||||
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
|
const ptr: [*]const T = @ptrCast(@alignCast(self._data));
|
||||||
return ptr[0..self._shape.count()];
|
return ptr[0..self._shape.count()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -664,7 +664,7 @@ pub const CompilationContext = struct {
|
|||||||
// Create the result tensor object by combining the operand results,
|
// Create the result tensor object by combining the operand results,
|
||||||
// as well as the registered shapes and donations.
|
// as well as the registered shapes and donations.
|
||||||
// Note: this assume res can be stack-allocated.
|
// Note: this assume res can be stack-allocated.
|
||||||
var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*;
|
var res = @as(*const stdx.meta.FnResult(func), @ptrCast(@alignCast(function.res_tensors))).*;
|
||||||
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
|
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
|
||||||
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
|
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
|
|||||||
@ -113,7 +113,7 @@ const _CreateOptions = struct {
|
|||||||
/// "Best-Fit with Coalescing" algorithm
|
/// "Best-Fit with Coalescing" algorithm
|
||||||
bfc: Options,
|
bfc: Options,
|
||||||
/// use cudaMallocAsync
|
/// use cudaMallocAsync
|
||||||
@"async": Options,
|
async: Options,
|
||||||
/// use raw cuMalloc
|
/// use raw cuMalloc
|
||||||
platform,
|
platform,
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ const _CreateOptions = struct {
|
|||||||
.platform => {
|
.platform => {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
||||||
},
|
},
|
||||||
.bfc, .@"async" => |opt| {
|
.bfc, .async => |opt| {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
|
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
|
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
|
||||||
if (opt.memory_fraction > 0) {
|
if (opt.memory_fraction > 0) {
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
|
||||||
pub const Tracer = switch (builtin.os.tag) {
|
pub const Tracer = switch (builtin.os.tag) {
|
||||||
@ -11,15 +12,15 @@ const CudaTracer = struct {
|
|||||||
|
|
||||||
// Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
|
// Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
|
||||||
// They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
|
// They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
|
||||||
const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
|
const cuProfilerStart = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
|
||||||
const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
|
const cuProfilerStop = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
|
||||||
|
|
||||||
// Those symbols are defined in nvToolsExt.h which we don't want to provide.
|
// Those symbols are defined in nvToolsExt.h which we don't want to provide.
|
||||||
// However, we link with libnvToolsExt.so which provides them.
|
// However, we link with libnvToolsExt.so which provides them.
|
||||||
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
|
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
|
||||||
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
|
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.c) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
|
||||||
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
|
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.c) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
|
||||||
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
|
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.c) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
|
||||||
|
|
||||||
pub fn init(name: [:0]const u8) CudaTracer {
|
pub fn init(name: [:0]const u8) CudaTracer {
|
||||||
_ = name;
|
_ = name;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user