bump runtimes/* code to Zig 0.15.1, restore PyTorch loader using std.fs.File, update CI zig fmt, remove stdx.io, note remaining issues with Neuron and CUDA debug builds

This commit is contained in:
Tarry Singh 2025-08-07 15:09:27 +00:00
parent 0ed7f5c907
commit 9e3cd6d616
24 changed files with 258 additions and 989 deletions

View File

@ -605,7 +605,7 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
pub fn items(self: Attr) []const dt.ZigType() { pub fn items(self: Attr) []const dt.ZigType() {
const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable; const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable;
const ptr: [*]const dt.ZigType() = @alignCast(@ptrCast(raw_bytes)); const ptr: [*]const dt.ZigType() = @ptrCast(@alignCast(raw_bytes));
// Note the mlir API returns us the number of elements, not the number of bytes, // Note the mlir API returns us the number of elements, not the number of bytes,
// that's why we track the element type at comptime to allow items to work. // that's why we track the element type at comptime to allow items to work.
return ptr[0..self.len()]; return ptr[0..self.len()];
@ -1743,7 +1743,7 @@ pub const helpers = struct {
writer: *std.Io.Writer, writer: *std.Io.Writer,
err: ?std.Io.Writer.Error = null, err: ?std.Io.Writer.Error = null,
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void { fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
if (ctx.err) |_| return; if (ctx.err) |_| return;
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
ctx.err = err; ctx.err = err;

View File

@ -359,7 +359,7 @@ pub const Attrs = extern struct {
value: *const anyopaque, value: *const anyopaque,
pub fn get(self: Scalar, T: type) T { pub fn get(self: Scalar, T: type) T {
const ptr: *const T = @alignCast(@ptrCast(self.value)); const ptr: *const T = @ptrCast(@alignCast(self.value));
return ptr.*; return ptr.*;
} }
}; };
@ -370,13 +370,13 @@ pub const Attrs = extern struct {
data: [*]const u8, data: [*]const u8,
pub fn slice(self: Array, T: type) []const T { pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data)); const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len]; return ptr[0..self.len];
} }
}; };
pub fn slice(self: Array, T: type) []const T { pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data)); const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len]; return ptr[0..self.len];
} }

View File

@ -58,7 +58,7 @@ pub const ApiError = error{
fn InnerMixin(comptime innerT: type) type { fn InnerMixin(comptime innerT: type) type {
return struct { return struct {
fn inner(self: anytype) *innerT { fn inner(self: anytype) *innerT {
return @ptrCast(@constCast(@alignCast(self))); return @ptrCast(@alignCast(@constCast(self)));
} }
}; };
} }
@ -125,10 +125,10 @@ pub const Api = struct {
} }
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT { pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start)); var cur: [*c]const c.PJRT_Extension_Base = @ptrCast(@alignCast(self.inner.extension_start));
while (cur != null) : (cur = cur.*.next) { while (cur != null) : (cur = cur.*.next) {
if (cur.*.type == ext_id) { if (cur.*.type == ext_id) {
return @alignCast(@ptrCast(cur)); return @ptrCast(@alignCast(cur));
} }
} }
@ -432,7 +432,7 @@ pub const Client = opaque {
.client = self.inner(), .client = self.inner(),
}) catch unreachable; }) catch unreachable;
if (ret.addressable_memories) |memories| { if (ret.addressable_memories) |memories| {
return @constCast(@ptrCast(memories[0..ret.num_addressable_memories])); return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
} }
return &.{}; return &.{};
} }

View File

@ -28,7 +28,7 @@ fn hasCudaPathInLDPath() bool {
fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void { fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch ""; const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }); const new_xla_flagsZ = try std.fmt.allocPrintSentinel(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }, 0);
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
} }

View File

@ -38,7 +38,7 @@ var module_def: c.PyModuleDef = .{
.{}, .{},
}), }),
.m_slots = @constCast(&[_]c.PyModuleDef_Slot{ .m_slots = @constCast(&[_]c.PyModuleDef_Slot{
.{ .slot = c.Py_mod_exec, .value = @constCast(@ptrCast(&module_exec)) }, .{ .slot = c.Py_mod_exec, .value = @ptrCast(@constCast(&module_exec)) },
.{}, .{},
}), }),
.m_traverse = null, .m_traverse = null,

View File

@ -25,10 +25,10 @@ fn isRunningOnEC2() !bool {
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only }); var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only });
defer f.close() catch {}; defer f.close() catch {};
var buf: [AmazonEC2.len]u8 = undefined; var content: [AmazonEC2.len]u8 = undefined;
_ = try f.reader().readAll(&buf); const n_read = try f.pread(&content, 0);
return std.mem.eql(u8, &buf, AmazonEC2); return std.mem.eql(u8, content[0..n_read], AmazonEC2);
} }
pub fn load() !*const pjrt.Api { pub fn load() !*const pjrt.Api {
@ -45,7 +45,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit(); defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -37,7 +37,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit(); defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -1,12 +1,12 @@
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const builtin = @import("builtin");
const asynk = @import("async"); const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
const stdx = @import("stdx");
const bazel_builtin = @import("bazel_builtin"); const bazel_builtin = @import("bazel_builtin");
const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/runtime/tpu"); const log = std.log.scoped(.@"zml/runtime/tpu");
@ -25,10 +25,10 @@ fn isOnGCP() !bool {
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only }); var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
defer f.close() catch {}; defer f.close() catch {};
var buf = [_]u8{0} ** GoogleComputeEngine.len; var content: [GoogleComputeEngine.len]u8 = undefined;
_ = try f.reader().readAll(&buf); const n_read = try f.pread(&content, 0);
return std.mem.eql(u8, &buf, GoogleComputeEngine); return std.mem.eql(u8, content[0..n_read], GoogleComputeEngine);
} }
pub fn load() !*const pjrt.Api { pub fn load() !*const pjrt.Api {
@ -42,7 +42,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit(); defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -8,7 +8,6 @@ zig_library(
"flags.zig", "flags.zig",
"fmt.zig", "fmt.zig",
"fs.zig", "fs.zig",
"io.zig",
"json.zig", "json.zig",
"math.zig", "math.zig",
"meta.zig", "meta.zig",

View File

@ -1,4 +0,0 @@
const std = @import("std");
pub const BufferedAnyWriter = std.io.BufferedWriter(4096, std.io.AnyWriter);
pub const BufferedAnyReader = std.io.BufferedReader(4096, std.io.AnyReader);

View File

@ -4,7 +4,6 @@ pub const debug = @import("debug.zig");
pub const flags = @import("flags.zig"); pub const flags = @import("flags.zig");
pub const fmt = @import("fmt.zig"); pub const fmt = @import("fmt.zig");
pub const fs = @import("fs.zig"); pub const fs = @import("fs.zig");
pub const io = @import("io.zig");
pub const json = @import("json.zig"); pub const json = @import("json.zig");
pub const math = @import("math.zig"); pub const math = @import("math.zig");
pub const meta = @import("meta.zig"); pub const meta = @import("meta.zig");

View File

@ -95,7 +95,7 @@ pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 {
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType { pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
const obj = try new(UpbType, arena); const obj = try new(UpbType, arena);
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) { return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @ptrCast(@alignCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
c.kUpb_DecodeStatus_Ok => obj, c.kUpb_DecodeStatus_Ok => obj,
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed, c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory, c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,

View File

@ -24,6 +24,11 @@ zig_library(
"aio/json.zig", "aio/json.zig",
"aio/safetensors.zig", "aio/safetensors.zig",
"aio/tinyllama.zig", "aio/tinyllama.zig",
"aio/torch.zig",
"aio/torch/eval.zig",
"aio/torch/file.zig",
"aio/torch/pickle.zig",
"aio/torch/py.zig",
"buffer.zig", "buffer.zig",
"context.zig", "context.zig",
"dtype.zig", "dtype.zig",
@ -72,6 +77,10 @@ zig_library(
zig_test( zig_test(
name = "test", name = "test",
data = [
"aio/torch/simple.pt",
"aio/torch/simple_test_4.pickle",
],
test_runner = ":test_runner", test_runner = ":test_runner",
deps = [":zml"], deps = [":zml"],
) )

View File

@ -5,16 +5,16 @@ const c = @import("c");
const stdx = @import("stdx"); const stdx = @import("stdx");
pub const safetensors = @import("aio/safetensors.zig"); pub const safetensors = @import("aio/safetensors.zig");
pub const tinyllama = @import("aio/tinyllama.zig"); pub const torch = @import("aio/torch.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const posix = @import("posix.zig"); const posix = @import("posix.zig");
const zml = @import("zml.zig"); const zml = @import("zml.zig");
pub const log = std.log.scoped(.@"zml/aio"); pub const log = std.log.scoped(.@"zml/aio");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
std.testing.refAllDecls(safetensors); std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(torch);
} }
// TODO error set for weight loading // TODO error set for weight loading
@ -25,6 +25,12 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
try safetensors.open(allocator, model_path) try safetensors.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json")) else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
try safetensors.open(allocator, model_path) try safetensors.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".pt"))
try torch.open(allocator, model_path)
// else if (std.mem.endsWith(u8, model_path, ".gguf"))
// try gguf.open(allocator, model_path)
// else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
// try tinyllama.open(allocator, model_path)
else { else {
std.debug.panic("File extension not recognized: {s}", .{model_path}); std.debug.panic("File extension not recognized: {s}", .{model_path});
}; };
@ -384,7 +390,7 @@ fn _populateStruct(
partial_struct = partial_struct or field_found; partial_struct = partial_struct or field_found;
if (!field_found) { if (!field_found) {
if (field.default_value_ptr) |v| { if (field.default_value_ptr) |v| {
@field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*; @field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
} else { } else {
if (partial_struct) { if (partial_struct) {
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name }); log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });

View File

@ -1,83 +0,0 @@
const asynk = @import("async");
const core = @import("gguf/core.zig");
const std = @import("std");
const zml = @import("../zml.zig");
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
var file = try core.GgufFile.open(path);
errdefer file.close();
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit();
const arena = res.arena.allocator();
res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{file.file});
// metadata must be read in order to read tensors
try loadMetadata(arena, &res, &file);
try loadBuffers(arena, &res, &file);
if (res.buffers.count() != file.header.tensor_count) {
log.warn("Expected to find {d} tensors in {s}, only found {d}", .{ file.header.tensor_count, path, res.buffers.count() });
}
return res;
}
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));
while (file.readMetadata(allocator)) |entry| {
log.info("Loading MetaData: {s}", .{entry.name});
const res = store._metadata.getOrPutAssumeCapacity(entry.name);
if (res.found_existing) {
// This file seems invalid. Since most metadatas aren't required, continue ahead.
log.warn("Found duplicated metadata key: {s}", .{entry.name});
continue;
}
res.value_ptr.* = switch (entry.val) {
.array => |arr| switch (arr.child) {
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
const T = @FieldType(core.GgufValue, @tagName(tag));
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
},
else => blk: {
log.warn("ignoring array metadata", .{});
break :blk .null;
},
},
inline else => |v| zml.aio.Metadata.wrap(v),
};
} else |err| switch (err) {
error.EndOfMetadata => {},
else => return err,
}
}
fn loadBuffers(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
try store.buffers.ensureTotalCapacity(allocator, @intCast(file.header.tensor_count));
while (file.readTensorInfo(allocator)) |info| {
const res = store.buffers.getOrPutAssumeCapacity(info.name);
if (res.found_existing) {
// This file seems invalid. Try to continue anyway.
log.warn("Found duplicated tensor: {s}", .{info.name});
continue;
}
// TODO: handle quantized types
const dtype: zml.DataType = info.t.toDtype() orelse return error.UnsupportedGgufType;
const buffer = HostBuffer.fromBytes(zml.Shape.init(info.shape(), dtype), file.file.mappedSlice(info.start, info.byte_len));
res.value_ptr.* = buffer;
// store the info index.
} else |err| switch (err) {
error.EndOfMetadata => {},
else => return err,
}
}

View File

@ -1,505 +0,0 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../../zml.zig");
const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/io");
pub const GgufErrors = error{
ValueTypeMismatch,
InvalidGguf,
UnsupportedGgufType,
EndOfMetadata,
OutOfMemory,
};
// Enums and structures
pub const TensorType = enum(u32) {
f32 = 0,
f16 = 1,
q4_0 = 2,
q4_1 = 3,
deprecated_q4_2 = 4,
deprecated_q4_3 = 5,
q5_0 = 6,
q5_1 = 7,
q8_0 = 8,
q8_1 = 9,
// k-quantizations
q2_k = 10,
q3_k = 11,
q4_k = 12,
q5_k = 13,
q6_k = 14,
q8_k = 15,
i8 = 16,
i16 = 17,
i32 = 18,
const MAX_KNOWN_ENUM = 18;
pub fn canConvertQuant(self: TensorType) bool {
return switch (self) {
.q8_0, .q4_k, .q6_k, .q2_k, .q4_0, .q4_1 => true,
else => false,
};
}
pub fn toDtype(self: TensorType) ?zml.DataType {
return switch (self) {
.f32 => .f32,
.f16 => .f16,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
else => null,
};
}
pub fn sizeOf(self: TensorType) usize {
return self.toDtype().?.sizeOf();
}
/// Return the tensor type features
pub fn getFeatures(t: TensorType) TensorTypeFeatures {
return switch (t) {
inline else => |val| @field(TENSOR_TYPE_FEATURES, @tagName(val)),
};
}
};
/// GGUF tensor type to features lookup table.
pub const TensorTypeFeatures = struct {
items_per_block: u29,
bytes_per_block: u29,
pub fn alignment(features: TensorTypeFeatures) u8 {
return std.math.log2_int(u29, features.bytes_per_block);
}
};
pub const TENSOR_TYPE_FEATURES: std.enums.EnumFieldStruct(TensorType, TensorTypeFeatures, null) = .{
.f32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f32) },
.f16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f16) },
.q4_0 = .{ .items_per_block = 32, .bytes_per_block = 18 },
.q4_1 = .{ .items_per_block = 32, .bytes_per_block = 20 },
.deprecated_q4_2 = .{ .items_per_block = 0, .bytes_per_block = 0 },
.deprecated_q4_3 = .{ .items_per_block = 0, .bytes_per_block = 0 },
.q5_0 = .{ .items_per_block = 32, .bytes_per_block = 22 },
.q5_1 = .{ .items_per_block = 32, .bytes_per_block = 24 },
.q8_0 = .{ .items_per_block = 32, .bytes_per_block = 34 },
.q8_1 = .{ .items_per_block = 32, .bytes_per_block = 40 },
.q2_k = .{ .items_per_block = 256, .bytes_per_block = 82 },
.q3_k = .{ .items_per_block = 256, .bytes_per_block = 110 },
.q4_k = .{ .items_per_block = 256, .bytes_per_block = 144 },
.q5_k = .{ .items_per_block = 256, .bytes_per_block = 176 },
.q6_k = .{ .items_per_block = 256, .bytes_per_block = 210 },
.q8_k = .{ .items_per_block = 256, .bytes_per_block = 292 },
.i8 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i8) },
.i16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i16) },
.i32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i32) },
};
pub const GgufValueType = enum(u32) {
// The value is a 8-bit unsigned integer.
uint8 = 0,
// The value is a 8-bit signed integer.
int8 = 1,
// The value is a 16-bit unsigned little-endian integer.
uint16 = 2,
// The value is a 16-bit signed little-endian integer.
int16 = 3,
// The value is a 32-bit unsigned little-endian integer.
uint32 = 4,
// The value is a 32-bit signed little-endian integer.
int32 = 5,
// The value is a 32-bit IEEE754 floating point number.
float32 = 6,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model
// being invalid or the reader being buggy.
bool = 7,
// The value is a UTF-8 non-null-terminated string, with length prepended.
string = 8,
// The value is an array of other values, with the length and type
// prepended. Arrays can be nested, and the length of the array is the
// number of elements in the array, not the number of bytes.
array = 9,
// The value is a 64-bit unsigned little-endian integer.
uint64 = 10,
// The value is a 64-bit signed little-endian integer.
int64 = 11,
// The value is a 64-bit IEEE754 floating point number.
float64 = 12,
// Special values used by the callbacks of gguf_do_with_value().
array_start = 100,
array_end = 101,
// Allow other values in case GGUF add more types without us noticing
_,
pub fn sizeOf(self: GgufValueType) usize {
return switch (self) {
.uint8 => @sizeOf(u8),
.int8 => @sizeOf(i8),
.uint16 => @sizeOf(u16),
.int16 => @sizeOf(i16),
.uint32 => @sizeOf(u32),
.int32 => @sizeOf(i32),
.float32 => @sizeOf(f32),
.bool => @sizeOf(bool),
.uint64 => @sizeOf(u64),
.int64 => @sizeOf(i64),
.float64 => @sizeOf(f64),
.string => @sizeOf([]u8),
else => unreachable,
};
}
pub fn arrayTypeCheck(self: GgufValueType, comptime T: type) !void {
switch (self) {
.string => if (T != []u8 and T != []const u8) return error.ValueTypeMismatch,
.uint8 => if (T != u8) return error.ValueTypeMismatch,
.int8 => if (T != i8) return error.ValueTypeMismatch,
.uint16 => if (T != u16) return error.ValueTypeMismatch,
.int16 => if (T != i16) return error.ValueTypeMismatch,
.uint32 => if (T != u32) return error.ValueTypeMismatch,
.int32 => if (T != i32) return error.ValueTypeMismatch,
.float32 => if (T != f32) return error.ValueTypeMismatch,
.bool => if (T != bool) return error.ValueTypeMismatch,
.uint64 => if (T != u64) return error.ValueTypeMismatch,
.int64 => if (T != i64) return error.ValueTypeMismatch,
.float64 => if (T != f64) return error.ValueTypeMismatch,
else => {},
}
}
};
pub const ValueType = enum(u8) {
uint8 = 0,
int8 = 1,
uint16 = 2,
int16 = 3,
uint32 = 4,
int32 = 5,
float32 = 6,
bool = 7,
string = 8,
array = 9,
uint64 = 10,
int64 = 11,
float64 = 12,
};
// Union of possible values.
pub const GgufValue = union(ValueType) {
uint8: u8,
int8: i8,
uint16: u16,
int16: i16,
uint32: u32,
int32: i32,
float32: f32,
bool: bool,
string: []const u8,
array: Array,
uint64: u64,
int64: i64,
float64: f64,
pub const Array = struct {
// Any value type is valid, including arrays.
child: ValueType,
// Number of elements, not bytes
len: usize,
data: []u8,
};
};
// Header
const GgufHeader = extern struct {
// Magic number to announce that this is a GGUF file. Must be `GUFF`.
magic: [4]u8,
// The version of the format implemented.
// Must be `3` for version described in this spec.
version: u32,
// The number of tensors in the file.
// This is explicit, instead of being included in the metadata, to ensure
// it is always present for loading the tensors.
tensor_count: usize,
// The number of metadata key-value pairs.
metadata_kv_count: usize,
pub fn validate(self: GgufHeader) !void {
if (!std.mem.eql(u8, &self.magic, "GGUF")) {
log.err("Invalid GGUF file: wrong header {s}", .{self.magic});
return error.InvalidGguf;
}
}
};
// Key representation in this library API.
pub const GgufMetadataKv = struct {
name: []const u8,
type_: GgufValueType,
val: GgufValue,
};
// Tensor representation in this library API.
const GGUF_TENSOR_MAX_DIM: usize = 8; // Future-proof: actual limit is 4.
pub const GgufTensorInfo = struct {
name: []const u8,
t: TensorType, // Tensor type (enum TensorType).
rank: usize, // Number of dimensions of the tensor.
dims: [GGUF_TENSOR_MAX_DIM]i64, // Dimensions (Eg. [512, 1024, 1, 1]).
start: usize, // Offset from start of data section.
byte_len: usize, // Total size in bytes.
num_weights: usize, // Total number of parameters.
pub inline fn shape(info: GgufTensorInfo) []const i64 {
return info.dims[0..info.rank];
}
};
// Return the value type name given the type ID.
fn getValueTypeName(t: u32) []const u8 {
if (@as(usize, @intCast(t)) >= GGUF_VALUE_NAME.len) return "unknown";
return GGUF_VALUE_NAME[@intCast(t)];
}
const GGUF_VALUE_NAME = [_][]const u8{
"uint8", "int8", "uint16", "int16", "uint32", "int32",
"float32", "bool", "string", "array", "uint64", "int64",
"float64",
};
/// GGUF file API
/// A memory-mapped view of a .gguf file.
/// Format used by GGML models: https://github.com/ggerganov/ggml/
pub const GgufFile = struct {
header: GgufHeader, // GUFF file header info.
size: usize, // Total file size.
file: zml.aio.MemoryMappedFile,
left_kv: usize, // Number of key-value pairs yet to read.
left_tensors: usize, // Number of tensors yet to read.
off: usize, // Offset of the next item to parse.
alignment: usize = 32, // File data alignment. Default: 32 bytes.
/// Open and memmap the given file.
pub fn open(path: []const u8) !GgufFile {
const file = try asynk.File.open(path, .{});
const header = try file.reader().readStruct(GgufHeader);
try header.validate();
return .{
.header = header,
.size = (try file.stat()).size,
.file = try zml.aio.MemoryMappedFile.init(file),
.off = @sizeOf(GgufHeader),
.left_kv = header.metadata_kv_count,
.left_tensors = header.tensor_count,
};
}
/// Unmap the file memory and close the file handle.
pub fn close(self: *GgufFile) void {
self.file.deinit();
}
/// Set the context to read the first key-value entry in the GGUF
/// file and then all the rest. Is used when creating a new context
/// and also when you want to restart scanning the key-value
/// items in the file.
fn rewind(ctx: *GgufFile) void {
ctx.off = @sizeOf(GgufHeader);
ctx.left_kv = ctx.header.metadata_kv_count;
ctx.left_tensors = ctx.header.tensor_count;
}
pub fn seek(self: *GgufFile, pos: usize) void {
assert(pos < self.size);
self.off = pos;
}
fn readInt(self: *GgufFile, comptime T: type) !T {
if (self.off + @sizeOf(T) >= self.size) return error.InvalidGguf;
const res = self.file.file.reader().readInt(T, .little);
self.off += @sizeOf(T);
return res;
}
fn readTensorType(self: *GgufFile) !TensorType {
const raw = try self.readInt(u32);
if (raw > TensorType.MAX_KNOWN_ENUM) {
log.err("Unsupported GGUF tensor type: {d}", .{raw});
return error.UnsupportedGgufType;
}
return @enumFromInt(raw);
}
fn readValueType(self: *GgufFile) !GgufValueType {
const raw = try self.readInt(u32);
const t: GgufValueType = @enumFromInt(raw);
switch (t) {
.uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .array, .uint64, .int64, .float64, .array_start, .array_end => {},
else => {
log.err("Unsupported GGUF value type: {s}", .{@tagName(t)});
return error.UnsupportedGgufType;
},
}
return t;
}
pub fn readAlloc(self: *GgufFile, allocator: std.mem.Allocator, len: usize) ![]u8 {
const data = try allocator.alloc(u8, len);
const read = try self.file.file.reader().readAll(data);
if (read != data.len) return error.InvalidGguf;
self.off += len;
return data;
}
pub fn skipBytes(self: *GgufFile, len: usize) !void {
try self.file.file.seekBy(@intCast(len));
self.off += len;
}
/// Read the len then the actual bytes.
pub fn readString(self: *GgufFile, allocator: std.mem.Allocator) ![]u8 {
const len: usize = try self.readInt(u64);
return self.readAlloc(allocator, len);
}
pub fn skipString(self: *GgufFile) !void {
const len: usize = try self.readInt(u64);
return self.skipBytes(len);
}
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
const child = try self.readValueType();
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
return error.UnsupportedGgufType;
}
const len: usize = try self.readInt(u64);
const data = switch (child) {
// Since strings have variable lenghts, we need to read them one by one
.string => str: {
var data = try allocator.alloc([]u8, len);
for (0..len) |i| data[i] = try self.readString(allocator);
break :str std.mem.sliceAsBytes(data);
},
else => try self.readAlloc(allocator, len * child.sizeOf()),
};
return .{
.child = @enumFromInt(@intFromEnum(child)),
.len = len,
.data = data,
};
}
fn readTypedValue(self: *GgufFile, allocator: std.mem.Allocator, t: GgufValueType) !GgufValue {
return switch (t) {
.uint8 => .{ .uint8 = try self.readInt(u8) },
.int8 => .{ .int8 = try self.readInt(i8) },
.uint16 => .{ .uint16 = try self.readInt(u16) },
.int16 => .{ .int16 = try self.readInt(i16) },
.uint32 => .{ .uint32 = try self.readInt(u32) },
.int32 => .{ .int32 = try self.readInt(i32) },
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
.bool => .{ .bool = try self.readInt(u8) != 0 },
.string => .{ .string = try self.readString(allocator) },
.array => .{ .array = try self.readArrayHeader(allocator) },
.uint64 => .{ .uint64 = try self.readInt(u64) },
.int64 => .{ .int64 = try self.readInt(i64) },
.float64 => .{ .float64 = @bitCast(try self.readInt(u64)) },
else => error.UnsupportedGgufType,
};
}
/// Parses the next metadata entry.
/// Returns error.EndOfMetadata if there are no longer metadata to process in this GGUF file.
pub fn readMetadata(self: *GgufFile, allocator: std.mem.Allocator) !GgufMetadataKv {
if (self.left_kv == 0) return error.EndOfMetadata;
self.left_kv -= 1;
const name = try self.readString(allocator);
const type_ = try self.readValueType();
const val: GgufValue = try self.readTypedValue(allocator, type_);
return .{ .name = name, .type_ = type_, .val = val };
}
// Set the data section offset. This function must be called exactly when
// all the key-values are consumed, in the context of the first call of
// ctx.getTensor(): this way we will be able to return tensor offsets
// as absolute positions and pointers to the mmapped file.
fn setDataOffset(self: *GgufFile) !void {
const base_off = self.off;
assert(self.left_kv == 0 and self.left_tensors == self.header.tensor_count);
for (0..self.left_tensors) |_| try self.skipTensor();
const padding: usize = getAlignmentPadding(self.alignment, self.off);
self.file.data_offset = self.off + padding;
try self.file.file.seekTo(base_off);
self.off = base_off;
}
pub fn skipTensor(self: *GgufFile) !void {
try self.skipString(); // Skip name
const num_dim: u32 = try self.readInt(u32);
// dimensions, type, and offset.
try self.skipBytes(8 * num_dim + 4 + 8);
}
/// Parses the next tensor entry.
/// Returns error.EndOfMetadata if there are no longer tensor metadata to process in this GGUF file.
pub fn readTensorInfo(self: *GgufFile, allocator: std.mem.Allocator) !GgufTensorInfo {
if (self.left_tensors == 0 or self.left_kv != 0) {
return error.EndOfMetadata;
}
// We want to return tensor data with offsets relative to the start
// of the file, so that the user of the API is able to access tensors
// as it iterates over them. To do so, we need to perform a full
// scan if this is the first tensor info we are reading.
// TODO: explicitly set the data offset in
if (self.file.data_offset == 0) try self.setDataOffset();
self.left_tensors -= 1;
const name = try self.readString(allocator);
const num_dim = try self.readInt(u32);
assert(@as(usize, @intCast(num_dim)) <= GGUF_TENSOR_MAX_DIM);
// Read the dimentions; unused dimensions are left `undefined`.
// Note: we reverse the order of the dimensions to match zml convention.
var dims: [GGUF_TENSOR_MAX_DIM]i64 = undefined;
var num_weights: usize = 1;
for (0..num_dim) |j| {
const d = try self.readInt(u64);
dims[num_dim - 1 - j] = @intCast(d);
num_weights *= d;
}
const t: TensorType = try self.readTensorType();
const start = try self.readInt(u64);
// To accurately calculate the bytes used by this tensor on the GGUF
// file, we need to take into account that quantization methods store
// tensors as block of N weights. So first of all we need to understand
// the number of padding weights (since the last block may have just
// fewer weights stored inside, but still requires to be stored to its full
// length). Then we can do the math to see how many blocks we need, and
// multiply by the block size to obtain the final total size.
const tf = t.getFeatures();
const byte_len: usize = (std.math.divCeil(usize, num_weights, tf.items_per_block) catch unreachable) * tf.bytes_per_block;
return .{
.name = name,
.t = t,
.rank = num_dim,
.dims = dims,
.start = start,
.byte_len = byte_len,
.num_weights = num_weights,
};
}
};
/// Given an offset or a length, returns the padding needed to align it to alignment.
fn getAlignmentPadding(alignment: usize, offset: usize) usize {
return @rem((alignment - @rem(offset, alignment)), alignment);
}

View File

@ -1,9 +1,9 @@
const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const zml = @import("../zml.zig");
const asynk = @import("async");
const zml = @import("../zml.zig");
const eval = @import("torch/eval.zig"); const eval = @import("torch/eval.zig");
const py = @import("torch/py.zig");
const File = @import("torch/file.zig").File; const File = @import("torch/file.zig").File;
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
@ -12,7 +12,7 @@ const log = std.log.scoped(.@"zml/aio");
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
std.testing.refAllDecls(eval); std.testing.refAllDecls(eval);
std.testing.refAllDecls(py); std.testing.refAllDecls(@import("torch/py.zig"));
std.testing.refAllDecls(File); std.testing.refAllDecls(File);
} }
@ -30,13 +30,13 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const tmp_alloc = arena.allocator(); const tmp_alloc = arena.allocator();
const mmap_file = try zml.aio.MemoryMappedFile.init(file); const mmap_file = try zml.aio.MemoryMappedFile.init(file);
var torch_file = try File.init(tmp_alloc, mmap_file); var torch_file = try asynk.callBlocking(File.init, .{ tmp_alloc, mmap_file });
const ops = try torch_file.parsePickle(tmp_alloc); const ops = try torch_file.parsePickle(tmp_alloc);
const py_values = try eval.evaluate(tmp_alloc, ops, true); const py_values = try eval.evaluate(tmp_alloc, ops, true);
// file ownership is transferred to the BufferStore // file ownership is transferred to the BufferStore
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file}); var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file});
try torch_file.parseModel(py_values, &res); try torch_file.parseModel(py_values, &res);
return res; return res;
} }

View File

@ -151,23 +151,23 @@ pub const PickleMemo = struct {
}; };
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any { pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any {
var stack = std.ArrayList(py.Any).init(arena); var stack: std.ArrayList(py.Any) = .{};
var memo = PickleMemo.init(arena); var memo = PickleMemo.init(arena);
for (x) |op| { for (x) |op| {
switch (op) { switch (op) {
.mark => try stack.append(.{ .raw = op }), .mark => try stack.append(arena, .{ .raw = op }),
.frame => {}, .frame => {},
.stop => break, .stop => break,
.pop => _ = try pop(&stack), .pop => _ = try pop(&stack),
.pop_mark => _ = try popMark(&stack), .pop_mark => _ = try popMark(&stack),
.dup => if (stack.getLastOrNull()) |item| .dup => if (stack.getLastOrNull()) |item|
try stack.append(try item.clone(arena)) try stack.append(arena, try item.clone(arena))
else else
return error.CannotDupEmptyStack, return error.CannotDupEmptyStack,
.persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), .persid => |v| try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
.binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }), .binpersid => try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
.reduce => try stack.append(.{ .global = blk: { .reduce => try stack.append(arena, .{ .global = blk: {
var args = try pop(&stack); var args = try pop(&stack);
args = try memo.resolve(arena, args, true); args = try memo.resolve(arena, args, true);
if (args != .seq) return error.InvalidInput; if (args != .seq) return error.InvalidInput;
@ -175,23 +175,23 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
func = try memo.resolve(arena, func, true); func = try memo.resolve(arena, func, true);
break :blk try py.Object.init(arena, func, args.seq.values, &.{}); break :blk try py.Object.init(arena, func, args.seq.values, &.{});
} }), } }),
.build => try stack.append(blk: { .build => try stack.append(arena, blk: {
const args = try memo.resolve(arena, try pop(&stack), true); const args = try memo.resolve(arena, try pop(&stack), true);
const member = try memo.resolve(arena, try pop(&stack), true); const member = try memo.resolve(arena, try pop(&stack), true);
break :blk .{ .set_state = try py.SetState.init(arena, member, args) }; break :blk .{ .set_state = try py.SetState.init(arena, member, args) };
}), }),
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }), .empty_dict => try stack.append(arena, .{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
.get => |v| try stack.append(.{ .ref = v }), .get => |v| try stack.append(arena, .{ .ref = v }),
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }), .empty_list => try stack.append(arena, .{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
.put => |v| { .put => |v| {
try memo.insert(v, try pop(&stack)); try memo.insert(v, try pop(&stack));
try stack.append(.{ .ref = v }); try stack.append(arena, .{ .ref = v });
}, },
.tuple => try stack.append(blk: { .tuple => try stack.append(arena, blk: {
const popped = try popMark(&stack); const popped = try popMark(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } }; break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } };
}), }),
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }), .empty_tuple => try stack.append(arena, .{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
.setitem => { .setitem => {
const v = try memo.resolve(arena, try pop(&stack), true); const v = try memo.resolve(arena, try pop(&stack), true);
const k = try memo.resolve(arena, try pop(&stack), true); const k = try memo.resolve(arena, try pop(&stack), true);
@ -228,17 +228,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
} }
}, },
.proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.append(blk: { .tuple1 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 1); const tup_values = try arena.alloc(py.Any, 1);
tup_values[0] = try pop(&stack); tup_values[0] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple2 => try stack.append(blk: { .tuple2 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 2); const tup_values = try arena.alloc(py.Any, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple3 => try stack.append(blk: { .tuple3 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 3); const tup_values = try arena.alloc(py.Any, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
@ -276,32 +276,32 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
}, },
} }
}, },
.dict => try stack.append(.{ .seq = .{ .dict => try stack.append(arena, .{ .seq = .{
.type = .dict, .type = .dict,
.values = try arena.dupe(py.Any, try popMark(&stack)), .values = try arena.dupe(py.Any, try popMark(&stack)),
} }), } }),
.list => try stack.append(.{ .seq = .{ .list => try stack.append(arena, .{ .seq = .{
.type = .list, .type = .list,
.values = try arena.dupe(py.Any, try popMark(&stack)), .values = try arena.dupe(py.Any, try popMark(&stack)),
} }), } }),
.inst => |v| try stack.append(.{ .object = try py.Object.init( .inst => |v| try stack.append(arena, .{ .object = try py.Object.init(
arena, arena,
try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena), try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena),
try arena.dupe(py.Any, try popMark(&stack)), try arena.dupe(py.Any, try popMark(&stack)),
&.{}, &.{},
) }), ) }),
.obj => try stack.append(blk: { .obj => try stack.append(arena, blk: {
const mark = try findMark(&stack); const mark = try findMark(&stack);
const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]); const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]);
const member = stack.items[mark + 1]; const member = stack.items[mark + 1];
break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) }; break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) };
}), }),
.newobj => try stack.append(blk: { .newobj => try stack.append(arena, blk: {
const args = try arena.alloc(py.Any, 1); const args = try arena.alloc(py.Any, 1);
args[0] = try pop(&stack); args[0] = try pop(&stack);
break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) }; break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) };
}), }),
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }), .empty_set => try stack.append(arena, .{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
.additems => { .additems => {
const postmark = try popMark(&stack); const postmark = try popMark(&stack);
const top = try lastMut(&stack); const top = try lastMut(&stack);
@ -316,15 +316,15 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
}, },
} }
}, },
.frozenset => try stack.append(.{ .seq = .{ .frozenset => try stack.append(arena, .{ .seq = .{
.type = .frozen_set, .type = .frozen_set,
.values = try arena.dupe(py.Any, try popMark(&stack)), .values = try arena.dupe(py.Any, try popMark(&stack)),
} }), } }),
.newobj_ex => try stack.append(blk: { .newobj_ex => try stack.append(arena, blk: {
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) }; const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) }; break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) };
}), }),
.stack_global => try stack.append(blk: { .stack_global => try stack.append(arena, blk: {
const gn, const mn = .{ const gn, const mn = .{
try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true),
try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true),
@ -338,13 +338,13 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
}; };
try memo.insert(@intCast(memo.map.count()), try item.clone(arena)); try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
}, },
else => try stack.append(.{ .raw = try op.clone(arena) }), else => try stack.append(arena, .{ .raw = try op.clone(arena) }),
} }
} }
if (resolve_refs) { if (resolve_refs) {
return try memo.resolveAllRefsIter(arena, 0, stack.items, true); return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
} }
return stack.toOwnedSlice(); return stack.toOwnedSlice(arena);
} }
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void { fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
@ -358,8 +358,9 @@ test evaluate {
defer arena.deinit(); defer arena.deinit();
const allocator = arena.allocator(); const allocator = arena.allocator();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader()); var reader_buffer: [1024]u8 = undefined;
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096); var reader = file.reader(&reader_buffer);
const ops = try pickle.parse(allocator, &reader.interface);
const vals = try evaluate(allocator, ops, true); const vals = try evaluate(allocator, ops, true);
defer allocator.free(vals); defer allocator.free(vals);

View File

@ -1,14 +1,15 @@
const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const testing = std.testing;
const asynk = @import("async");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("../../zml.zig"); const zml = @import("../../zml.zig");
const HostBuffer = zml.HostBuffer;
const eval = @import("eval.zig");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
const py = @import("py.zig"); const py = @import("py.zig");
const eval = @import("eval.zig");
const HostBuffer = zml.HostBuffer;
const testing = std.testing;
const log = std.log.scoped(.@"zml/aio"); const log = std.log.scoped(.@"zml/aio");
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead // TODO(cryptodeal): use zml.aio.PrefixBuilder instead
@ -20,167 +21,82 @@ test {
} }
pub const File = struct { pub const File = struct {
buffer_file: zml.aio.MemoryMappedFile, mmap_file: zml.aio.MemoryMappedFile,
/// Map names to sub file /// Map names to sub file
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{},
tar_file: ?TarStream = null, zip_prefix: []const u8,
is_zip_file: bool, pickle_subfile: []const u8,
zip_prefix: []const u8 = &.{},
pickle_subfile: struct { start: u64 = 0, len: usize },
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
const magic = "PK\x03\x04"; const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File {
const tar_file = try TarStream.init(file);
const file_magic = try tar_file.reader().readBytesNoEof(magic.len);
try tar_file.seekTo(0);
var res: File = .{
.buffer_file = mapped,
.tar_file = tar_file,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.pickle_subfile = .{ .len = try tar_file.getEndPos() },
};
if (res.is_zip_file) {
try res.parseZipHeaders(allocator, tar_file.seekableStream());
}
return res;
}
pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File { pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File {
const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len); var pkl: []const u8 = mmap_file.data;
try mmap_file.file.seekTo(0); var zip_prefix: []const u8 = &.{};
var res: File = .{ var file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{};
.buffer_file = mmap_file, if (std.mem.eql(u8, mmap_file.data[0..magic.len], magic)) {
.is_zip_file = std.mem.eql(u8, &file_magic, magic), // We are dealing with a zip file.
.pickle_subfile = .{ .len = mmap_file.data.len }, // Let's look for the `data.pkl` file and keep a map of all other files.
}; // The other files will be the tensor storage and will be reference from `data.pkl`.
var header_parsing_buffer: [4096]u8 = undefined;
if (res.is_zip_file) { // std.zip requires on a std.fs.File and don't leverage std.Io.Reader directly.
try res.parseZipHeaders(allocator, mmap_file.file.seekableStream()); // So we use the synchronous API to parse the headers,
// then we rely only on the memory map data to parse the pickle and load the buffers.
// To mitigate this we use `async.launchBlocking` in `torch.open`.
const raw_file: std.fs.File = .{ .handle = mmap_file.file._handle };
var reader = raw_file.reader(&header_parsing_buffer);
var it: std.zip.Iterator = try .init(&reader);
while (try it.next()) |header| {
if (header.filename_len == 0) {
continue;
} }
return res; if (header.compression_method != .store) {
return error.Unsupported;
}
const filename = mmap_file.data[header.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader) ..][0..header.filename_len];
var local_reader: std.Io.Reader = .fixed(mmap_file.data);
local_reader.discardAll(header.file_offset) catch return error.InvalidZipFile;
const local_header = local_reader.takeStruct(std.zip.LocalFileHeader, .little) catch return error.InvalidZipFile;
local_reader.discardAll(local_header.filename_len) catch return error.InvalidZipFile;
local_reader.discardAll(local_header.extra_len) catch return error.InvalidZipFile;
// normalize path separators
const file_content = mmap_file.data[local_reader.seek..][0..header.compressed_size];
const my_filename: []u8 = try allocator.dupe(u8, filename);
std.mem.replaceScalar(u8, my_filename, '\\', '/');
try file_map.put(allocator, my_filename, file_content);
if (std.mem.endsWith(u8, filename, "data.pkl")) {
pkl = file_content;
zip_prefix = filename[0 .. filename.len - "data.pkl".len];
}
}
if (pkl.len == 0) {
log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
}
return .{
.mmap_file = mmap_file,
.file_map = file_map,
.pickle_subfile = pkl,
.zip_prefix = zip_prefix,
};
} }
pub fn close(self: *File) void { pub fn close(self: *File) void {
self.buffer_file.deinit(); self.mmap_file.deinit();
} }
pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op { pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op {
return if (self.tar_file) |tar_file| { var reader: std.Io.Reader = .fixed(self.pickle_subfile);
try tar_file.seekTo(self.pickle_subfile.start); return try pickle.parse(allocator, &reader);
var buffered = std.io.bufferedReader(tar_file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
} else {
const file = self.buffer_file.file;
try file.seekTo(self.pickle_subfile.start);
var buffered = std.io.bufferedReader(file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
};
}
fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void {
var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{};
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
const len = try seekable_stream.context.reader().readAll(filename);
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
self.file_map = file_map;
var file_iter = file_map.iterator();
while (file_iter.next()) |e| {
const entry = e.value_ptr.*;
const filename = e.key_ptr.*;
if (!std.mem.endsWith(u8, filename, "data.pkl")) continue;
self.zip_prefix = filename[0 .. filename.len - "data.pkl".len];
const local_data_header_offset: u64 = local_data_header_offset: {
switch (entry.compression_method) {
.store => {},
.deflate => {
// TODO(cryptodeal): handle decompress
@panic("TODO support use of `deflate`");
},
else => @panic("TODO support other modes of compression"),
}
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
break :local_data_header_offset @as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
};
const local_data_file_offset: u64 =
@as(u64, entry.file_offset) +
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
local_data_header_offset;
self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size };
return;
}
log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
} }
fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool { fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool {
@ -286,7 +202,7 @@ pub const File = struct {
if (prefix.items.len > 0) { if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val); try self.parseValue(allocator, store, new_prefix, val);
} }
} }
@ -303,7 +219,7 @@ pub const File = struct {
if (prefix.items.len > 0) { if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
const new_tag = switch (tag) { const new_tag = switch (tag) {
.int64 => "int", .int64 => "int",
.float64 => "float", .float64 => "float",
@ -321,7 +237,7 @@ pub const File = struct {
if (prefix.items.len > 0) { if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
} }
try self.parseValue(allocator, store, new_prefix, item); try self.parseValue(allocator, store, new_prefix, item);
} }
@ -353,7 +269,7 @@ pub const File = struct {
if (prefix.items.len > 0) { if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val); try self.parseValue(allocator, store, new_prefix, val);
}, },
inline else => |_, tag| { inline else => |_, tag| {
@ -504,34 +420,10 @@ pub const File = struct {
/// Given the name of one of the files in the .pt tarball, /// Given the name of one of the files in the .pt tarball,
/// return the slice of the memory-mapped .pt corresponding to it. /// return the slice of the memory-mapped .pt corresponding to it.
fn getStorage(self: File, filename: []const u8) ![]const u8 { fn getStorage(self: File, filename: []const u8) ![]const u8 {
const maybe_entry = self.file_map.get(filename); return self.file_map.get(filename) orelse {
if (maybe_entry == null) {
std.log.err("Could not find file ending in `{s}` in archive", .{filename}); std.log.err("Could not find file ending in `{s}` in archive", .{filename});
return error.TensorNotFound; return error.TensorNotFound;
} };
const entry = maybe_entry.?;
const base_offset: u64 = if (self.tar_file) |t| t.start else 0;
const file_offset: u64 = base_offset + entry.file_offset;
const file = self.buffer_file.file;
try file.seekTo(entry.file_offset);
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
const start = file_offset +
@sizeOf(std.zip.LocalFileHeader) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
return self.buffer_file.mappedSlice(start, entry.uncompressed_size);
} }
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray { fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
@ -578,52 +470,6 @@ fn storageToDtype(storage_type: []const u8) !zml.DataType {
}; };
} }
const TarStream = struct {
pub const SeekableStream = std.io.SeekableStream(
TarStream,
asynk.File.SeekError,
asynk.File.GetSeekPosError,
TarStream.seekTo,
TarStream.seekBy,
TarStream.getPos,
TarStream.getEndPos,
);
file: std.tar.Iterator(asynk.File.Reader).File,
start: usize,
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
return .{
.file = file,
.start = try file.parent_reader.context.getPos(),
};
}
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
return file.file.reader();
}
pub fn seekTo(self: TarStream, offset: u64) !void {
return self.file.parent_reader.context.seekTo(self.start + offset);
}
pub fn seekBy(self: TarStream, offset: i64) !void {
return self.file.parent_reader.context.seekBy(offset);
}
pub fn getPos(self: TarStream) !u64 {
return try self.file.parent_reader.context.getPos() - self.start;
}
pub fn getEndPos(self: TarStream) !u64 {
return self.file.size;
}
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
return .{ .context = self };
}
};
test "Read pickle (zipped)" { test "Read pickle (zipped)" {
// test file created with following python snippet: // test file created with following python snippet:
// //
@ -638,20 +484,19 @@ test "Read pickle (zipped)" {
defer store.deinit(); defer store.deinit();
{ {
var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator); var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer tmp_arena.deinit(); defer arena.deinit();
const tmp_alloc = tmp_arena.allocator(); var torch_file = try File.init(arena.allocator(), mmap_file);
var torch_file = try File.init(tmp_alloc, mmap_file);
// We don't close the file directly, it will be closed by the store. // We don't close the file directly, it will be closed by the store.
const ops = try torch_file.parsePickle(tmp_alloc); const ops = try torch_file.parsePickle(arena.allocator());
try std.testing.expectEqual(302, ops.len); try std.testing.expectEqual(302, ops.len);
const py_values = try eval.evaluate(tmp_alloc, ops, true); const py_values = try eval.evaluate(arena.allocator(), ops, true);
try torch_file.parseModel(py_values, &store); try torch_file.parseModel(py_values, &store);
} }
// now we have freed the tmp_arena. // now we have freed the arena.
// all data needed should have been copied into the store arena. // all data needed should have been copied into the store arena.
try zml.testing.expectEqualShapes( try zml.testing.expectEqualShapes(
zml.Shape.init(.{ 1, 4 }, .u8), zml.Shape.init(.{ 1, 4 }, .u8),

View File

@ -763,54 +763,54 @@ pub const Op = union(enum) {
}; };
/// Read a stream of bytes, and interpret it as a stream of Pickle operators. /// Read a stream of bytes, and interpret it as a stream of Pickle operators.
pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) ![]const Op { /// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies.
var results = std.ArrayList(Op).init(allocator); pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
errdefer results.deinit(); // It's not very efficient to interleave the results with the data copied from the stream,
const len = max_line_len; // because growth event in the results ArrayList will lead to fragmentation.
var _buf: std.BoundedArray(u8, 12) = .{}; // Trying to mitigate that by using a generous default size.
var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512);
errdefer results.deinit(arena);
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);
while (true) { while (true) {
const b = try reader.readByte(); const code: OpCode = @enumFromInt(try reader.takeByte());
const code: OpCode = @enumFromInt(b);
const op: Op = switch (code) { const op: Op = switch (code) {
.int => blk: { .int => int: {
_buf.len = 0; const bytes = try reader.takeDelimiterExclusive('\n');
try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1);
const buf = _buf.constSlice();
// Legacy hack, see OpCode.int documentation // Legacy hack, see OpCode.int documentation
// We do this parsing right away to simplify downstream code. // We do this parsing right away to simplify downstream code.
break :blk if (std.mem.eql(u8, "00", buf)) break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0')
.{ .bool = false } .{ .bool = false }
else if (std.mem.eql(u8, "01", buf)) else if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '1')
.{ .bool = true } .{ .bool = true }
else else
.{ .int = try std.fmt.parseInt(i32, buf, 10) }; .{ .int = try std.fmt.parseInt(i32, bytes, 10) };
}, },
.binint => .{ .int = try reader.readInt(i32, .little) }, .binint => .{ .int = try reader.takeInt(i32, .little) },
.binint1 => .{ .int = try reader.readByte() }, .binint1 => .{ .int = try reader.takeByte() },
.binint2 => .{ .int = try reader.readInt(u16, .little) }, .binint2 => .{ .int = try reader.takeInt(u16, .little) },
// TODO: long should handle the trailing 'L' -> add a test. // TODO: long should handle the trailing 'L' -> add a test.
.long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .long => .{ .long = try readLine(reader, &alloc_writer) },
.long1 => .{ .binlong = try _readSlice(reader, allocator, 1) }, .long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
.long4 => .{ .binlong = try _readSlice(reader, allocator, 4) }, .long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
.string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .string => .{ .string = try readLine(reader, &alloc_writer) },
.binstring => .{ .string = try _readSlice(reader, allocator, 4) }, .binstring => .{ .string = try _readSlice(reader, arena, 4) },
.short_binstring => .{ .string = try _readSlice(reader, allocator, 1) }, .short_binstring => .{ .string = try _readSlice(reader, arena, 1) },
.binbytes => .{ .bytes = try _readSlice(reader, allocator, 4) }, .binbytes => .{ .bytes = try _readSlice(reader, arena, 4) },
.binbytes8 => .{ .bytes = try _readSlice(reader, allocator, 8) }, .binbytes8 => .{ .bytes = try _readSlice(reader, arena, 8) },
.short_binbytes => .{ .bytes = try _readSlice(reader, allocator, 1) }, .short_binbytes => .{ .bytes = try _readSlice(reader, arena, 1) },
.bytearray8 => .{ .bytearray = try _readSlice(reader, allocator, 8) }, .bytearray8 => .{ .bytearray = try _readSlice(reader, arena, 8) },
.next_buffer => .next_buffer, .next_buffer => .next_buffer,
.readonly_buffer => .readonly_buffer, .readonly_buffer => .readonly_buffer,
.none => .none, .none => .none,
.newtrue => .{ .bool = true }, .newtrue => .{ .bool = true },
.newfalse => .{ .bool = false }, .newfalse => .{ .bool = false },
.unicode => .{ .unicode = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .unicode => .{ .unicode = try readLine(reader, &alloc_writer) },
.short_binunicode => .{ .unicode = try _readSlice(reader, allocator, 1) }, .short_binunicode => .{ .unicode = try _readSlice(reader, arena, 1) },
.binunicode => .{ .unicode = try _readSlice(reader, allocator, 4) }, .binunicode => .{ .unicode = try _readSlice(reader, arena, 4) },
.binunicode8 => .{ .unicode = try _readSlice(reader, allocator, 8) }, .binunicode8 => .{ .unicode = try _readSlice(reader, arena, 8) },
.float => .{ .float = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, .float => .{ .float = try readLine(reader, &alloc_writer) },
.binfloat => .{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }, .binfloat => .{ .binfloat = @bitCast(try reader.takeInt(u64, .big)) },
.empty_list => .empty_list, .empty_list => .empty_list,
.append => .append, .append => .append,
.appends => .appends, .appends => .appends,
@ -832,74 +832,74 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
.mark => .mark, .mark => .mark,
.pop_mark => .pop_mark, .pop_mark => .pop_mark,
// If we fail to parse delay the error to the evaluation. // If we fail to parse delay the error to the evaluation.
.get => .{ .get => get: {
.get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32), const digits = try reader.takeDelimiterExclusive('\n');
break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
}, },
.binget => .{ .get = try reader.readByte() }, .binget => .{ .get = try reader.takeByte() },
.long_binget => .{ .get = try reader.readInt(u32, .little) }, .long_binget => .{ .get = try reader.takeInt(u32, .little) },
.put => blk: { .put => put: {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); const digits = try reader.takeDelimiterExclusive('\n');
defer allocator.free(buf); break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32);
break :blk .{ .put = n };
}, },
.binput => .{ .put = try reader.readByte() }, .binput => .{ .put = try reader.takeByte() },
.long_binput => .{ .put = try reader.readInt(u32, .little) }, .long_binput => .{ .put = try reader.takeInt(u32, .little) },
.memoize => .memoize, .memoize => .memoize,
.ext1 => .{ .ext1 = try reader.readByte() }, .ext1 => .{ .ext1 = try reader.takeByte() },
.ext2 => .{ .ext2 = try reader.readInt(i16, .little) }, .ext2 => .{ .ext2 = try reader.takeInt(i16, .little) },
.ext4 => .{ .ext4 = try reader.readInt(i32, .little) }, .ext4 => .{ .ext4 = try reader.takeInt(i32, .little) },
.global => .{ .global = .{ .global => .{ .global = .{
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), .module = try readLine(reader, &alloc_writer),
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), .class = try readLine(reader, &alloc_writer),
} }, } },
.stack_global => .stack_global, .stack_global => .stack_global,
.reduce => .reduce, .reduce => .reduce,
.build => .build, .build => .build,
.inst => .{ .inst = .{ .inst => .{ .inst = .{
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), .module = try readLine(reader, &alloc_writer),
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), .class = try readLine(reader, &alloc_writer),
} }, } },
.obj => .obj, .obj => .obj,
.newobj => .newobj, .newobj => .newobj,
.newobj_ex => .newobj_ex, .newobj_ex => .newobj_ex,
.proto => blk: { .proto => blk: {
const version = try reader.readByte(); const version = try reader.takeByte();
if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version}); if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version});
break :blk .{ .proto = version }; break :blk .{ .proto = version };
}, },
.stop => .stop, .stop => .stop,
.frame => frame: {
// This is not documented in pickletools but in https://peps.python.org/pep-3154/ // This is not documented in pickletools but in https://peps.python.org/pep-3154/
// The frame size is stored right after the frame header.
// The loader is allowed to prefetch framesize from the underlying reader, // The loader is allowed to prefetch framesize from the underlying reader,
// and ops are not allowed to cross a frame boundary. // and ops are not allowed to cross a frame boundary.
// We don't prefetch because we assume the reader is going to use some kind of buffered reader. const frame_size = try reader.takeInt(u64, .little);
// We could try to enforce frame boundaries, but we would need to track reader.fill(@min(frame_size, reader.buffer.len)) catch |err| switch (err) {
// how many bytes we are reading from the stream. error.EndOfStream => {},
.frame => .{ .frame = try reader.readInt(u64, .little) }, else => return err,
.persid => .{ .persid = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, };
break :frame .{ .frame = frame_size };
},
.persid => .{ .persid = try readLine(reader, &alloc_writer) },
.binpersid => .binpersid, .binpersid => .binpersid,
_ => |unk_tag| { _ => |unk_tag| {
log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag}); log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag});
return error.NotSupported; return error.NotSupported;
}, },
}; };
try results.append(op); try results.append(arena, op);
if (op == .stop) break; if (op == .stop) break;
} }
return results.toOwnedSlice(); return results.items;
} }
test "parse protocol 4" { test "parse protocol 4" {
const allocator = std.testing.allocator; var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
defer arena.deinit();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader()); var read_buffer: [1024]u8 = undefined;
const ops = try parse(allocator, buffered_reader.reader(), 4096); var reader = file.reader(&read_buffer);
defer { const ops = try parse(arena.allocator(), &reader.interface);
// Test we are correctly freeing every allocation.
for (ops) |op| op.deinit(allocator);
allocator.free(ops);
}
// this can be obtained by running: `python -m pickletools simple_test_4.pickle` // this can be obtained by running: `python -m pickletools simple_test_4.pickle`
var expected = [_]Op{ var expected = [_]Op{
@ -948,7 +948,9 @@ test "parse protocol 4" {
test "parse protocol 0" { test "parse protocol 0" {
// We also test protocol 0, cause it's more text oriented. // We also test protocol 0, cause it's more text oriented.
const allocator = std.testing.allocator; var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
defer arena.deinit();
const pickle_0 = const pickle_0 =
\\(dp0 \\(dp0
\\Vhello \\Vhello
@ -982,13 +984,8 @@ test "parse protocol 0" {
\\s. \\s.
; ;
var stream = std.io.fixedBufferStream(pickle_0); var reader: std.Io.Reader = .fixed(pickle_0);
const ops = try parse(allocator, stream.reader(), 4096); const ops = try parse(arena.allocator(), &reader);
defer {
// Test we are correctly freeing every allocation.
for (ops) |op| op.deinit(allocator);
allocator.free(ops);
}
var expected = [_]Op{ var expected = [_]Op{
.mark, .mark,
@ -1043,18 +1040,11 @@ test "parse protocol 0" {
try std.testing.expectEqualDeep(&expected, ops); try std.testing.expectEqualDeep(&expected, ops);
} }
fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T {
buffer.len = 0;
try reader.streamUntilDelimiter(buffer.writer(), '\n', 13);
return std.fmt.parseInt(T, buffer.constSlice(), 10);
}
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
const T = std.meta.Int(.unsigned, 8 * len_bytes); const T = std.meta.Int(.unsigned, 8 * len_bytes);
const str_len: u64 = try reader.readInt(T, .little); const str_len: u64 = try reader.takeInt(T, .little);
const buf = try allocator.alloc(u8, str_len); const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf); _ = try reader.readSliceAll(buf);
_ = try reader.read(buf);
return buf; return buf;
} }
@ -1063,3 +1053,14 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8)
std.mem.writeInt(T, &res, value, .little); std.mem.writeInt(T, &res, value, .little);
return res; return res;
} }
fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 {
const n = try reader.streamDelimiter(&alloc_writer.writer, '\n');
std.debug.assert(try reader.takeByte() == '\n');
const w = &alloc_writer.writer;
std.debug.assert(w.end == n);
const items = w.buffer[0..n];
w.buffer = w.buffer[n + 1 ..];
w.end = 0;
return items;
}

View File

@ -172,7 +172,7 @@ pub const HostBuffer = struct {
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32. // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) }); stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self}); stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
const ptr: [*]const T = @alignCast(@ptrCast(self._data)); const ptr: [*]const T = @ptrCast(@alignCast(self._data));
return ptr[0..self._shape.count()]; return ptr[0..self._shape.count()];
} }

View File

@ -664,7 +664,7 @@ pub const CompilationContext = struct {
// Create the result tensor object by combining the operand results, // Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations. // as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated. // Note: this assume res can be stack-allocated.
var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*; var res = @as(*const stdx.meta.FnResult(func), @ptrCast(@alignCast(function.res_tensors))).*;
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation }; const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
var context: LocalContext = .{ .op = op, .function = function, .donations = donations }; var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
meta.visit((struct { meta.visit((struct {

View File

@ -113,7 +113,7 @@ const _CreateOptions = struct {
/// "Best-Fit with Coalescing" algorithm /// "Best-Fit with Coalescing" algorithm
bfc: Options, bfc: Options,
/// use cudaMallocAsync /// use cudaMallocAsync
@"async": Options, async: Options,
/// use raw cuMalloc /// use raw cuMalloc
platform, platform,
@ -129,7 +129,7 @@ const _CreateOptions = struct {
.platform => { .platform => {
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
}, },
.bfc, .@"async" => |opt| { .bfc, .async => |opt| {
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator)); values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate)); values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
if (opt.memory_fraction > 0) { if (opt.memory_fraction > 0) {

View File

@ -1,4 +1,5 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const c = @import("c"); const c = @import("c");
pub const Tracer = switch (builtin.os.tag) { pub const Tracer = switch (builtin.os.tag) {
@ -11,15 +12,15 @@ const CudaTracer = struct {
// Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so // Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
// They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so). // They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable; const cuProfilerStart = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable; const cuProfilerStop = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
// Those symbols are defined in nvToolsExt.h which we don't want to provide. // Those symbols are defined in nvToolsExt.h which we don't want to provide.
// However, we link with libnvToolsExt.so which provides them. // However, we link with libnvToolsExt.so which provides them.
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us). // They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable; const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.c) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable; const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.c) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable; const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.c) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
pub fn init(name: [:0]const u8) CudaTracer { pub fn init(name: [:0]const u8) CudaTracer {
_ = name; _ = name;