bump runtimes/* code to Zig 0.15.1, restore PyTorch loader using std.fs.File, update CI zig fmt, remove stdx.io, note remaining issues with Neuron and CUDA debug builds

This commit is contained in:
Tarry Singh 2025-08-07 15:09:27 +00:00
parent 0ed7f5c907
commit 9e3cd6d616
24 changed files with 258 additions and 989 deletions

View File

@ -605,7 +605,7 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
pub fn items(self: Attr) []const dt.ZigType() {
const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable;
const ptr: [*]const dt.ZigType() = @alignCast(@ptrCast(raw_bytes));
const ptr: [*]const dt.ZigType() = @ptrCast(@alignCast(raw_bytes));
// Note the mlir API returns us the number of elements, not the number of bytes,
// that's why we track the element type at comptime to allow items to work.
return ptr[0..self.len()];
@ -1743,7 +1743,7 @@ pub const helpers = struct {
writer: *std.Io.Writer,
err: ?std.Io.Writer.Error = null,
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
if (ctx.err) |_| return;
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
ctx.err = err;

View File

@ -359,7 +359,7 @@ pub const Attrs = extern struct {
value: *const anyopaque,
pub fn get(self: Scalar, T: type) T {
const ptr: *const T = @alignCast(@ptrCast(self.value));
const ptr: *const T = @ptrCast(@alignCast(self.value));
return ptr.*;
}
};
@ -370,13 +370,13 @@ pub const Attrs = extern struct {
data: [*]const u8,
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len];
}
};
pub fn slice(self: Array, T: type) []const T {
const ptr: [*]const T = @alignCast(@ptrCast(self.data));
const ptr: [*]const T = @ptrCast(@alignCast(self.data));
return ptr[0..self.len];
}

View File

@ -58,7 +58,7 @@ pub const ApiError = error{
fn InnerMixin(comptime innerT: type) type {
return struct {
fn inner(self: anytype) *innerT {
return @ptrCast(@constCast(@alignCast(self)));
return @ptrCast(@alignCast(@constCast(self)));
}
};
}
@ -125,10 +125,10 @@ pub const Api = struct {
}
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start));
var cur: [*c]const c.PJRT_Extension_Base = @ptrCast(@alignCast(self.inner.extension_start));
while (cur != null) : (cur = cur.*.next) {
if (cur.*.type == ext_id) {
return @alignCast(@ptrCast(cur));
return @ptrCast(@alignCast(cur));
}
}
@ -432,7 +432,7 @@ pub const Client = opaque {
.client = self.inner(),
}) catch unreachable;
if (ret.addressable_memories) |memories| {
return @constCast(@ptrCast(memories[0..ret.num_addressable_memories]));
return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
}
return &.{};
}

View File

@ -28,7 +28,7 @@ fn hasCudaPathInLDPath() bool {
fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void {
const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox });
const new_xla_flagsZ = try std.fmt.allocPrintSentinel(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }, 0);
_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
}

View File

@ -38,7 +38,7 @@ var module_def: c.PyModuleDef = .{
.{},
}),
.m_slots = @constCast(&[_]c.PyModuleDef_Slot{
.{ .slot = c.Py_mod_exec, .value = @constCast(@ptrCast(&module_exec)) },
.{ .slot = c.Py_mod_exec, .value = @ptrCast(@constCast(&module_exec)) },
.{},
}),
.m_traverse = null,

View File

@ -25,10 +25,10 @@ fn isRunningOnEC2() !bool {
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only });
defer f.close() catch {};
var buf: [AmazonEC2.len]u8 = undefined;
_ = try f.reader().readAll(&buf);
var content: [AmazonEC2.len]u8 = undefined;
const n_read = try f.pread(&content, 0);
return std.mem.eql(u8, &buf, AmazonEC2);
return std.mem.eql(u8, content[0..n_read], AmazonEC2);
}
pub fn load() !*const pjrt.Api {
@ -45,7 +45,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable;
}
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -37,7 +37,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable;
}
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -1,12 +1,12 @@
const builtin = @import("builtin");
const std = @import("std");
const builtin = @import("builtin");
const asynk = @import("async");
const pjrt = @import("pjrt");
const c = @import("c");
const stdx = @import("stdx");
const bazel_builtin = @import("bazel_builtin");
const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/runtime/tpu");
@ -25,10 +25,10 @@ fn isOnGCP() !bool {
var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only });
defer f.close() catch {};
var buf = [_]u8{0} ** GoogleComputeEngine.len;
_ = try f.reader().readAll(&buf);
var content: [GoogleComputeEngine.len]u8 = undefined;
const n_read = try f.pread(&content, 0);
return std.mem.eql(u8, &buf, GoogleComputeEngine);
return std.mem.eql(u8, content[0..n_read], GoogleComputeEngine);
}
pub fn load() !*const pjrt.Api {
@ -42,7 +42,7 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable;
}
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
defer arena.deinit();
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {

View File

@ -8,7 +8,6 @@ zig_library(
"flags.zig",
"fmt.zig",
"fs.zig",
"io.zig",
"json.zig",
"math.zig",
"meta.zig",

View File

@ -1,4 +0,0 @@
const std = @import("std");
pub const BufferedAnyWriter = std.io.BufferedWriter(4096, std.io.AnyWriter);
pub const BufferedAnyReader = std.io.BufferedReader(4096, std.io.AnyReader);

View File

@ -4,7 +4,6 @@ pub const debug = @import("debug.zig");
pub const flags = @import("flags.zig");
pub const fmt = @import("fmt.zig");
pub const fs = @import("fs.zig");
pub const io = @import("io.zig");
pub const json = @import("json.zig");
pub const math = @import("math.zig");
pub const meta = @import("meta.zig");

View File

@ -95,7 +95,7 @@ pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 {
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
const obj = try new(UpbType, arena);
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @ptrCast(@alignCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
c.kUpb_DecodeStatus_Ok => obj,
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,

View File

@ -24,6 +24,11 @@ zig_library(
"aio/json.zig",
"aio/safetensors.zig",
"aio/tinyllama.zig",
"aio/torch.zig",
"aio/torch/eval.zig",
"aio/torch/file.zig",
"aio/torch/pickle.zig",
"aio/torch/py.zig",
"buffer.zig",
"context.zig",
"dtype.zig",
@ -72,6 +77,10 @@ zig_library(
zig_test(
name = "test",
data = [
"aio/torch/simple.pt",
"aio/torch/simple_test_4.pickle",
],
test_runner = ":test_runner",
deps = [":zml"],
)

View File

@ -5,16 +5,16 @@ const c = @import("c");
const stdx = @import("stdx");
pub const safetensors = @import("aio/safetensors.zig");
pub const tinyllama = @import("aio/tinyllama.zig");
pub const torch = @import("aio/torch.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const posix = @import("posix.zig");
const zml = @import("zml.zig");
pub const log = std.log.scoped(.@"zml/aio");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(torch);
}
// TODO error set for weight loading
@ -25,6 +25,12 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
try safetensors.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json"))
try safetensors.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".pt"))
try torch.open(allocator, model_path)
// else if (std.mem.endsWith(u8, model_path, ".gguf"))
// try gguf.open(allocator, model_path)
// else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
// try tinyllama.open(allocator, model_path)
else {
std.debug.panic("File extension not recognized: {s}", .{model_path});
};
@ -384,7 +390,7 @@ fn _populateStruct(
partial_struct = partial_struct or field_found;
if (!field_found) {
if (field.default_value_ptr) |v| {
@field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*;
@field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
} else {
if (partial_struct) {
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });

View File

@ -1,83 +0,0 @@
const asynk = @import("async");
const core = @import("gguf/core.zig");
const std = @import("std");
const zml = @import("../zml.zig");
const HostBuffer = @import("../hostbuffer.zig").HostBuffer;
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
var file = try core.GgufFile.open(path);
errdefer file.close();
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit();
const arena = res.arena.allocator();
res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{file.file});
// metadata must be read in order to read tensors
try loadMetadata(arena, &res, &file);
try loadBuffers(arena, &res, &file);
if (res.buffers.count() != file.header.tensor_count) {
log.warn("Expected to find {d} tensors in {s}, only found {d}", .{ file.header.tensor_count, path, res.buffers.count() });
}
return res;
}
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));
while (file.readMetadata(allocator)) |entry| {
log.info("Loading MetaData: {s}", .{entry.name});
const res = store._metadata.getOrPutAssumeCapacity(entry.name);
if (res.found_existing) {
// This file seems invalid. Since most metadatas aren't required, continue ahead.
log.warn("Found duplicated metadata key: {s}", .{entry.name});
continue;
}
res.value_ptr.* = switch (entry.val) {
.array => |arr| switch (arr.child) {
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
const T = @FieldType(core.GgufValue, @tagName(tag));
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
},
else => blk: {
log.warn("ignoring array metadata", .{});
break :blk .null;
},
},
inline else => |v| zml.aio.Metadata.wrap(v),
};
} else |err| switch (err) {
error.EndOfMetadata => {},
else => return err,
}
}
fn loadBuffers(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
try store.buffers.ensureTotalCapacity(allocator, @intCast(file.header.tensor_count));
while (file.readTensorInfo(allocator)) |info| {
const res = store.buffers.getOrPutAssumeCapacity(info.name);
if (res.found_existing) {
// This file seems invalid. Try to continue anyway.
log.warn("Found duplicated tensor: {s}", .{info.name});
continue;
}
// TODO: handle quantized types
const dtype: zml.DataType = info.t.toDtype() orelse return error.UnsupportedGgufType;
const buffer = HostBuffer.fromBytes(zml.Shape.init(info.shape(), dtype), file.file.mappedSlice(info.start, info.byte_len));
res.value_ptr.* = buffer;
// store the info index.
} else |err| switch (err) {
error.EndOfMetadata => {},
else => return err,
}
}

View File

@ -1,505 +0,0 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../../zml.zig");
const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/io");
pub const GgufErrors = error{
ValueTypeMismatch,
InvalidGguf,
UnsupportedGgufType,
EndOfMetadata,
OutOfMemory,
};
// Enums and structures
pub const TensorType = enum(u32) {
f32 = 0,
f16 = 1,
q4_0 = 2,
q4_1 = 3,
deprecated_q4_2 = 4,
deprecated_q4_3 = 5,
q5_0 = 6,
q5_1 = 7,
q8_0 = 8,
q8_1 = 9,
// k-quantizations
q2_k = 10,
q3_k = 11,
q4_k = 12,
q5_k = 13,
q6_k = 14,
q8_k = 15,
i8 = 16,
i16 = 17,
i32 = 18,
const MAX_KNOWN_ENUM = 18;
pub fn canConvertQuant(self: TensorType) bool {
return switch (self) {
.q8_0, .q4_k, .q6_k, .q2_k, .q4_0, .q4_1 => true,
else => false,
};
}
pub fn toDtype(self: TensorType) ?zml.DataType {
return switch (self) {
.f32 => .f32,
.f16 => .f16,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
else => null,
};
}
pub fn sizeOf(self: TensorType) usize {
return self.toDtype().?.sizeOf();
}
/// Return the tensor type features
pub fn getFeatures(t: TensorType) TensorTypeFeatures {
return switch (t) {
inline else => |val| @field(TENSOR_TYPE_FEATURES, @tagName(val)),
};
}
};
/// GGUF tensor type to features lookup table.
pub const TensorTypeFeatures = struct {
items_per_block: u29,
bytes_per_block: u29,
pub fn alignment(features: TensorTypeFeatures) u8 {
return std.math.log2_int(u29, features.bytes_per_block);
}
};
pub const TENSOR_TYPE_FEATURES: std.enums.EnumFieldStruct(TensorType, TensorTypeFeatures, null) = .{
.f32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f32) },
.f16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f16) },
.q4_0 = .{ .items_per_block = 32, .bytes_per_block = 18 },
.q4_1 = .{ .items_per_block = 32, .bytes_per_block = 20 },
.deprecated_q4_2 = .{ .items_per_block = 0, .bytes_per_block = 0 },
.deprecated_q4_3 = .{ .items_per_block = 0, .bytes_per_block = 0 },
.q5_0 = .{ .items_per_block = 32, .bytes_per_block = 22 },
.q5_1 = .{ .items_per_block = 32, .bytes_per_block = 24 },
.q8_0 = .{ .items_per_block = 32, .bytes_per_block = 34 },
.q8_1 = .{ .items_per_block = 32, .bytes_per_block = 40 },
.q2_k = .{ .items_per_block = 256, .bytes_per_block = 82 },
.q3_k = .{ .items_per_block = 256, .bytes_per_block = 110 },
.q4_k = .{ .items_per_block = 256, .bytes_per_block = 144 },
.q5_k = .{ .items_per_block = 256, .bytes_per_block = 176 },
.q6_k = .{ .items_per_block = 256, .bytes_per_block = 210 },
.q8_k = .{ .items_per_block = 256, .bytes_per_block = 292 },
.i8 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i8) },
.i16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i16) },
.i32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i32) },
};
pub const GgufValueType = enum(u32) {
// The value is a 8-bit unsigned integer.
uint8 = 0,
// The value is a 8-bit signed integer.
int8 = 1,
// The value is a 16-bit unsigned little-endian integer.
uint16 = 2,
// The value is a 16-bit signed little-endian integer.
int16 = 3,
// The value is a 32-bit unsigned little-endian integer.
uint32 = 4,
// The value is a 32-bit signed little-endian integer.
int32 = 5,
// The value is a 32-bit IEEE754 floating point number.
float32 = 6,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model
// being invalid or the reader being buggy.
bool = 7,
// The value is a UTF-8 non-null-terminated string, with length prepended.
string = 8,
// The value is an array of other values, with the length and type
// prepended. Arrays can be nested, and the length of the array is the
// number of elements in the array, not the number of bytes.
array = 9,
// The value is a 64-bit unsigned little-endian integer.
uint64 = 10,
// The value is a 64-bit signed little-endian integer.
int64 = 11,
// The value is a 64-bit IEEE754 floating point number.
float64 = 12,
// Special values used by the callbacks of gguf_do_with_value().
array_start = 100,
array_end = 101,
// Allow other values in case GGUF add more types without us noticing
_,
pub fn sizeOf(self: GgufValueType) usize {
return switch (self) {
.uint8 => @sizeOf(u8),
.int8 => @sizeOf(i8),
.uint16 => @sizeOf(u16),
.int16 => @sizeOf(i16),
.uint32 => @sizeOf(u32),
.int32 => @sizeOf(i32),
.float32 => @sizeOf(f32),
.bool => @sizeOf(bool),
.uint64 => @sizeOf(u64),
.int64 => @sizeOf(i64),
.float64 => @sizeOf(f64),
.string => @sizeOf([]u8),
else => unreachable,
};
}
pub fn arrayTypeCheck(self: GgufValueType, comptime T: type) !void {
switch (self) {
.string => if (T != []u8 and T != []const u8) return error.ValueTypeMismatch,
.uint8 => if (T != u8) return error.ValueTypeMismatch,
.int8 => if (T != i8) return error.ValueTypeMismatch,
.uint16 => if (T != u16) return error.ValueTypeMismatch,
.int16 => if (T != i16) return error.ValueTypeMismatch,
.uint32 => if (T != u32) return error.ValueTypeMismatch,
.int32 => if (T != i32) return error.ValueTypeMismatch,
.float32 => if (T != f32) return error.ValueTypeMismatch,
.bool => if (T != bool) return error.ValueTypeMismatch,
.uint64 => if (T != u64) return error.ValueTypeMismatch,
.int64 => if (T != i64) return error.ValueTypeMismatch,
.float64 => if (T != f64) return error.ValueTypeMismatch,
else => {},
}
}
};
pub const ValueType = enum(u8) {
uint8 = 0,
int8 = 1,
uint16 = 2,
int16 = 3,
uint32 = 4,
int32 = 5,
float32 = 6,
bool = 7,
string = 8,
array = 9,
uint64 = 10,
int64 = 11,
float64 = 12,
};
// Union of possible values.
pub const GgufValue = union(ValueType) {
uint8: u8,
int8: i8,
uint16: u16,
int16: i16,
uint32: u32,
int32: i32,
float32: f32,
bool: bool,
string: []const u8,
array: Array,
uint64: u64,
int64: i64,
float64: f64,
pub const Array = struct {
// Any value type is valid, including arrays.
child: ValueType,
// Number of elements, not bytes
len: usize,
data: []u8,
};
};
// Header
const GgufHeader = extern struct {
// Magic number to announce that this is a GGUF file. Must be `GUFF`.
magic: [4]u8,
// The version of the format implemented.
// Must be `3` for version described in this spec.
version: u32,
// The number of tensors in the file.
// This is explicit, instead of being included in the metadata, to ensure
// it is always present for loading the tensors.
tensor_count: usize,
// The number of metadata key-value pairs.
metadata_kv_count: usize,
pub fn validate(self: GgufHeader) !void {
if (!std.mem.eql(u8, &self.magic, "GGUF")) {
log.err("Invalid GGUF file: wrong header {s}", .{self.magic});
return error.InvalidGguf;
}
}
};
// Key representation in this library API.
pub const GgufMetadataKv = struct {
name: []const u8,
type_: GgufValueType,
val: GgufValue,
};
// Tensor representation in this library API.
const GGUF_TENSOR_MAX_DIM: usize = 8; // Future-proof: actual limit is 4.
pub const GgufTensorInfo = struct {
name: []const u8,
t: TensorType, // Tensor type (enum TensorType).
rank: usize, // Number of dimensions of the tensor.
dims: [GGUF_TENSOR_MAX_DIM]i64, // Dimensions (Eg. [512, 1024, 1, 1]).
start: usize, // Offset from start of data section.
byte_len: usize, // Total size in bytes.
num_weights: usize, // Total number of parameters.
pub inline fn shape(info: GgufTensorInfo) []const i64 {
return info.dims[0..info.rank];
}
};
// Return the value type name given the type ID.
fn getValueTypeName(t: u32) []const u8 {
if (@as(usize, @intCast(t)) >= GGUF_VALUE_NAME.len) return "unknown";
return GGUF_VALUE_NAME[@intCast(t)];
}
const GGUF_VALUE_NAME = [_][]const u8{
"uint8", "int8", "uint16", "int16", "uint32", "int32",
"float32", "bool", "string", "array", "uint64", "int64",
"float64",
};
/// GGUF file API
/// A memory-mapped view of a .gguf file.
/// Format used by GGML models: https://github.com/ggerganov/ggml/
pub const GgufFile = struct {
header: GgufHeader, // GUFF file header info.
size: usize, // Total file size.
file: zml.aio.MemoryMappedFile,
left_kv: usize, // Number of key-value pairs yet to read.
left_tensors: usize, // Number of tensors yet to read.
off: usize, // Offset of the next item to parse.
alignment: usize = 32, // File data alignment. Default: 32 bytes.
/// Open and memmap the given file.
pub fn open(path: []const u8) !GgufFile {
const file = try asynk.File.open(path, .{});
const header = try file.reader().readStruct(GgufHeader);
try header.validate();
return .{
.header = header,
.size = (try file.stat()).size,
.file = try zml.aio.MemoryMappedFile.init(file),
.off = @sizeOf(GgufHeader),
.left_kv = header.metadata_kv_count,
.left_tensors = header.tensor_count,
};
}
/// Unmap the file memory and close the file handle.
pub fn close(self: *GgufFile) void {
self.file.deinit();
}
/// Set the context to read the first key-value entry in the GGUF
/// file and then all the rest. Is used when creating a new context
/// and also when you want to restart scanning the key-value
/// items in the file.
fn rewind(ctx: *GgufFile) void {
ctx.off = @sizeOf(GgufHeader);
ctx.left_kv = ctx.header.metadata_kv_count;
ctx.left_tensors = ctx.header.tensor_count;
}
pub fn seek(self: *GgufFile, pos: usize) void {
assert(pos < self.size);
self.off = pos;
}
fn readInt(self: *GgufFile, comptime T: type) !T {
if (self.off + @sizeOf(T) >= self.size) return error.InvalidGguf;
const res = self.file.file.reader().readInt(T, .little);
self.off += @sizeOf(T);
return res;
}
fn readTensorType(self: *GgufFile) !TensorType {
const raw = try self.readInt(u32);
if (raw > TensorType.MAX_KNOWN_ENUM) {
log.err("Unsupported GGUF tensor type: {d}", .{raw});
return error.UnsupportedGgufType;
}
return @enumFromInt(raw);
}
fn readValueType(self: *GgufFile) !GgufValueType {
const raw = try self.readInt(u32);
const t: GgufValueType = @enumFromInt(raw);
switch (t) {
.uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .array, .uint64, .int64, .float64, .array_start, .array_end => {},
else => {
log.err("Unsupported GGUF value type: {s}", .{@tagName(t)});
return error.UnsupportedGgufType;
},
}
return t;
}
pub fn readAlloc(self: *GgufFile, allocator: std.mem.Allocator, len: usize) ![]u8 {
const data = try allocator.alloc(u8, len);
const read = try self.file.file.reader().readAll(data);
if (read != data.len) return error.InvalidGguf;
self.off += len;
return data;
}
pub fn skipBytes(self: *GgufFile, len: usize) !void {
try self.file.file.seekBy(@intCast(len));
self.off += len;
}
/// Read the len then the actual bytes.
pub fn readString(self: *GgufFile, allocator: std.mem.Allocator) ![]u8 {
const len: usize = try self.readInt(u64);
return self.readAlloc(allocator, len);
}
pub fn skipString(self: *GgufFile) !void {
const len: usize = try self.readInt(u64);
return self.skipBytes(len);
}
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
const child = try self.readValueType();
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
return error.UnsupportedGgufType;
}
const len: usize = try self.readInt(u64);
const data = switch (child) {
// Since strings have variable lenghts, we need to read them one by one
.string => str: {
var data = try allocator.alloc([]u8, len);
for (0..len) |i| data[i] = try self.readString(allocator);
break :str std.mem.sliceAsBytes(data);
},
else => try self.readAlloc(allocator, len * child.sizeOf()),
};
return .{
.child = @enumFromInt(@intFromEnum(child)),
.len = len,
.data = data,
};
}
fn readTypedValue(self: *GgufFile, allocator: std.mem.Allocator, t: GgufValueType) !GgufValue {
return switch (t) {
.uint8 => .{ .uint8 = try self.readInt(u8) },
.int8 => .{ .int8 = try self.readInt(i8) },
.uint16 => .{ .uint16 = try self.readInt(u16) },
.int16 => .{ .int16 = try self.readInt(i16) },
.uint32 => .{ .uint32 = try self.readInt(u32) },
.int32 => .{ .int32 = try self.readInt(i32) },
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
.bool => .{ .bool = try self.readInt(u8) != 0 },
.string => .{ .string = try self.readString(allocator) },
.array => .{ .array = try self.readArrayHeader(allocator) },
.uint64 => .{ .uint64 = try self.readInt(u64) },
.int64 => .{ .int64 = try self.readInt(i64) },
.float64 => .{ .float64 = @bitCast(try self.readInt(u64)) },
else => error.UnsupportedGgufType,
};
}
/// Parses the next metadata entry.
/// Returns error.EndOfMetadata if there are no longer metadata to process in this GGUF file.
pub fn readMetadata(self: *GgufFile, allocator: std.mem.Allocator) !GgufMetadataKv {
if (self.left_kv == 0) return error.EndOfMetadata;
self.left_kv -= 1;
const name = try self.readString(allocator);
const type_ = try self.readValueType();
const val: GgufValue = try self.readTypedValue(allocator, type_);
return .{ .name = name, .type_ = type_, .val = val };
}
// Set the data section offset. This function must be called exactly when
// all the key-values are consumed, in the context of the first call of
// ctx.getTensor(): this way we will be able to return tensor offsets
// as absolute positions and pointers to the mmapped file.
fn setDataOffset(self: *GgufFile) !void {
const base_off = self.off;
assert(self.left_kv == 0 and self.left_tensors == self.header.tensor_count);
for (0..self.left_tensors) |_| try self.skipTensor();
const padding: usize = getAlignmentPadding(self.alignment, self.off);
self.file.data_offset = self.off + padding;
try self.file.file.seekTo(base_off);
self.off = base_off;
}
pub fn skipTensor(self: *GgufFile) !void {
try self.skipString(); // Skip name
const num_dim: u32 = try self.readInt(u32);
// dimensions, type, and offset.
try self.skipBytes(8 * num_dim + 4 + 8);
}
/// Parses the next tensor entry.
/// Returns error.EndOfMetadata if there are no longer tensor metadata to process in this GGUF file.
pub fn readTensorInfo(self: *GgufFile, allocator: std.mem.Allocator) !GgufTensorInfo {
if (self.left_tensors == 0 or self.left_kv != 0) {
return error.EndOfMetadata;
}
// We want to return tensor data with offsets relative to the start
// of the file, so that the user of the API is able to access tensors
// as it iterates over them. To do so, we need to perform a full
// scan if this is the first tensor info we are reading.
// TODO: explicitly set the data offset in
if (self.file.data_offset == 0) try self.setDataOffset();
self.left_tensors -= 1;
const name = try self.readString(allocator);
const num_dim = try self.readInt(u32);
assert(@as(usize, @intCast(num_dim)) <= GGUF_TENSOR_MAX_DIM);
// Read the dimentions; unused dimensions are left `undefined`.
// Note: we reverse the order of the dimensions to match zml convention.
var dims: [GGUF_TENSOR_MAX_DIM]i64 = undefined;
var num_weights: usize = 1;
for (0..num_dim) |j| {
const d = try self.readInt(u64);
dims[num_dim - 1 - j] = @intCast(d);
num_weights *= d;
}
const t: TensorType = try self.readTensorType();
const start = try self.readInt(u64);
// To accurately calculate the bytes used by this tensor on the GGUF
// file, we need to take into account that quantization methods store
// tensors as block of N weights. So first of all we need to understand
// the number of padding weights (since the last block may have just
// fewer weights stored inside, but still requires to be stored to its full
// length). Then we can do the math to see how many blocks we need, and
// multiply by the block size to obtain the final total size.
const tf = t.getFeatures();
const byte_len: usize = (std.math.divCeil(usize, num_weights, tf.items_per_block) catch unreachable) * tf.bytes_per_block;
return .{
.name = name,
.t = t,
.rank = num_dim,
.dims = dims,
.start = start,
.byte_len = byte_len,
.num_weights = num_weights,
};
}
};
/// Given an offset or a length, returns the padding needed to align it to alignment.
fn getAlignmentPadding(alignment: usize, offset: usize) usize {
return @rem((alignment - @rem(offset, alignment)), alignment);
}

View File

@ -1,9 +1,9 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../zml.zig");
const asynk = @import("async");
const zml = @import("../zml.zig");
const eval = @import("torch/eval.zig");
const py = @import("torch/py.zig");
const File = @import("torch/file.zig").File;
const StringBuilder = std.ArrayListUnmanaged(u8);
@ -12,7 +12,7 @@ const log = std.log.scoped(.@"zml/aio");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(eval);
std.testing.refAllDecls(py);
std.testing.refAllDecls(@import("torch/py.zig"));
std.testing.refAllDecls(File);
}
@ -30,13 +30,13 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const tmp_alloc = arena.allocator();
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
var torch_file = try File.init(tmp_alloc, mmap_file);
var torch_file = try asynk.callBlocking(File.init, .{ tmp_alloc, mmap_file });
const ops = try torch_file.parsePickle(tmp_alloc);
const py_values = try eval.evaluate(tmp_alloc, ops, true);
// file ownership is transferred to the BufferStore
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file});
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file});
try torch_file.parseModel(py_values, &res);
return res;
}

View File

@ -151,23 +151,23 @@ pub const PickleMemo = struct {
};
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any {
var stack = std.ArrayList(py.Any).init(arena);
var stack: std.ArrayList(py.Any) = .{};
var memo = PickleMemo.init(arena);
for (x) |op| {
switch (op) {
.mark => try stack.append(.{ .raw = op }),
.mark => try stack.append(arena, .{ .raw = op }),
.frame => {},
.stop => break,
.pop => _ = try pop(&stack),
.pop_mark => _ = try popMark(&stack),
.dup => if (stack.getLastOrNull()) |item|
try stack.append(try item.clone(arena))
try stack.append(arena, try item.clone(arena))
else
return error.CannotDupEmptyStack,
.persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
.binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
.reduce => try stack.append(.{ .global = blk: {
.persid => |v| try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
.binpersid => try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }),
.reduce => try stack.append(arena, .{ .global = blk: {
var args = try pop(&stack);
args = try memo.resolve(arena, args, true);
if (args != .seq) return error.InvalidInput;
@ -175,23 +175,23 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
func = try memo.resolve(arena, func, true);
break :blk try py.Object.init(arena, func, args.seq.values, &.{});
} }),
.build => try stack.append(blk: {
.build => try stack.append(arena, blk: {
const args = try memo.resolve(arena, try pop(&stack), true);
const member = try memo.resolve(arena, try pop(&stack), true);
break :blk .{ .set_state = try py.SetState.init(arena, member, args) };
}),
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
.get => |v| try stack.append(.{ .ref = v }),
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
.empty_dict => try stack.append(arena, .{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }),
.get => |v| try stack.append(arena, .{ .ref = v }),
.empty_list => try stack.append(arena, .{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }),
.put => |v| {
try memo.insert(v, try pop(&stack));
try stack.append(.{ .ref = v });
try stack.append(arena, .{ .ref = v });
},
.tuple => try stack.append(blk: {
.tuple => try stack.append(arena, blk: {
const popped = try popMark(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } };
}),
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
.empty_tuple => try stack.append(arena, .{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }),
.setitem => {
const v = try memo.resolve(arena, try pop(&stack), true);
const k = try memo.resolve(arena, try pop(&stack), true);
@ -228,17 +228,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
}
},
.proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.append(blk: {
.tuple1 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 1);
tup_values[0] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple2 => try stack.append(blk: {
.tuple2 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple3 => try stack.append(blk: {
.tuple3 => try stack.append(arena, blk: {
const tup_values = try arena.alloc(py.Any, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
@ -276,32 +276,32 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
},
}
},
.dict => try stack.append(.{ .seq = .{
.dict => try stack.append(arena, .{ .seq = .{
.type = .dict,
.values = try arena.dupe(py.Any, try popMark(&stack)),
} }),
.list => try stack.append(.{ .seq = .{
.list => try stack.append(arena, .{ .seq = .{
.type = .list,
.values = try arena.dupe(py.Any, try popMark(&stack)),
} }),
.inst => |v| try stack.append(.{ .object = try py.Object.init(
.inst => |v| try stack.append(arena, .{ .object = try py.Object.init(
arena,
try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena),
try arena.dupe(py.Any, try popMark(&stack)),
&.{},
) }),
.obj => try stack.append(blk: {
.obj => try stack.append(arena, blk: {
const mark = try findMark(&stack);
const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]);
const member = stack.items[mark + 1];
break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) };
}),
.newobj => try stack.append(blk: {
.newobj => try stack.append(arena, blk: {
const args = try arena.alloc(py.Any, 1);
args[0] = try pop(&stack);
break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) };
}),
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
.empty_set => try stack.append(arena, .{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }),
.additems => {
const postmark = try popMark(&stack);
const top = try lastMut(&stack);
@ -316,15 +316,15 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
},
}
},
.frozenset => try stack.append(.{ .seq = .{
.frozenset => try stack.append(arena, .{ .seq = .{
.type = .frozen_set,
.values = try arena.dupe(py.Any, try popMark(&stack)),
} }),
.newobj_ex => try stack.append(blk: {
.newobj_ex => try stack.append(arena, blk: {
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) };
}),
.stack_global => try stack.append(blk: {
.stack_global => try stack.append(arena, blk: {
const gn, const mn = .{
try memo.resolve(arena, try pop(&stack), true),
try memo.resolve(arena, try pop(&stack), true),
@ -338,13 +338,13 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
};
try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
},
else => try stack.append(.{ .raw = try op.clone(arena) }),
else => try stack.append(arena, .{ .raw = try op.clone(arena) }),
}
}
if (resolve_refs) {
return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
}
return stack.toOwnedSlice();
return stack.toOwnedSlice(arena);
}
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
@ -358,8 +358,9 @@ test evaluate {
defer arena.deinit();
const allocator = arena.allocator();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
var reader_buffer: [1024]u8 = undefined;
var reader = file.reader(&reader_buffer);
const ops = try pickle.parse(allocator, &reader.interface);
const vals = try evaluate(allocator, ops, true);
defer allocator.free(vals);

View File

@ -1,14 +1,15 @@
const asynk = @import("async");
const std = @import("std");
const testing = std.testing;
const asynk = @import("async");
const stdx = @import("stdx");
const zml = @import("../../zml.zig");
const HostBuffer = zml.HostBuffer;
const eval = @import("eval.zig");
const pickle = @import("pickle.zig");
const py = @import("py.zig");
const eval = @import("eval.zig");
const HostBuffer = zml.HostBuffer;
const testing = std.testing;
const log = std.log.scoped(.@"zml/aio");
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
@ -20,167 +21,82 @@ test {
}
pub const File = struct {
buffer_file: zml.aio.MemoryMappedFile,
mmap_file: zml.aio.MemoryMappedFile,
/// Map names to sub file
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
is_zip_file: bool,
zip_prefix: []const u8 = &.{},
pickle_subfile: struct { start: u64 = 0, len: usize },
pub const FileEntry = struct {
version_needed_to_extract: u16,
flags: u16,
compression_method: std.zip.CompressionMethod,
last_modification_time: u16,
last_modification_date: u16,
header_zip_offset: u64,
crc32: u32,
filename_len: u32,
compressed_size: u64,
uncompressed_size: u64,
file_offset: u64,
pub fn init(entry: anytype) FileEntry {
return .{
.version_needed_to_extract = entry.version_needed_to_extract,
.flags = @as(u16, @bitCast(entry.flags)),
.compression_method = entry.compression_method,
.last_modification_time = entry.last_modification_time,
.last_modification_date = entry.last_modification_date,
.header_zip_offset = entry.header_zip_offset,
.crc32 = entry.crc32,
.filename_len = entry.filename_len,
.compressed_size = entry.compressed_size,
.uncompressed_size = entry.uncompressed_size,
.file_offset = entry.file_offset,
};
}
};
file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{},
zip_prefix: []const u8,
pickle_subfile: []const u8,
const magic = "PK\x03\x04";
pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File {
const tar_file = try TarStream.init(file);
const file_magic = try tar_file.reader().readBytesNoEof(magic.len);
try tar_file.seekTo(0);
var res: File = .{
.buffer_file = mapped,
.tar_file = tar_file,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.pickle_subfile = .{ .len = try tar_file.getEndPos() },
};
if (res.is_zip_file) {
try res.parseZipHeaders(allocator, tar_file.seekableStream());
}
return res;
}
pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File {
const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len);
try mmap_file.file.seekTo(0);
var res: File = .{
.buffer_file = mmap_file,
.is_zip_file = std.mem.eql(u8, &file_magic, magic),
.pickle_subfile = .{ .len = mmap_file.data.len },
};
var pkl: []const u8 = mmap_file.data;
var zip_prefix: []const u8 = &.{};
var file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{};
if (std.mem.eql(u8, mmap_file.data[0..magic.len], magic)) {
// We are dealing with a zip file.
// Let's look for the `data.pkl` file and keep a map of all other files.
// The other files will be the tensor storage and will be reference from `data.pkl`.
var header_parsing_buffer: [4096]u8 = undefined;
if (res.is_zip_file) {
try res.parseZipHeaders(allocator, mmap_file.file.seekableStream());
// std.zip requires on a std.fs.File and don't leverage std.Io.Reader directly.
// So we use the synchronous API to parse the headers,
// then we rely only on the memory map data to parse the pickle and load the buffers.
// To mitigate this we use `async.launchBlocking` in `torch.open`.
const raw_file: std.fs.File = .{ .handle = mmap_file.file._handle };
var reader = raw_file.reader(&header_parsing_buffer);
var it: std.zip.Iterator = try .init(&reader);
while (try it.next()) |header| {
if (header.filename_len == 0) {
continue;
}
if (header.compression_method != .store) {
return error.Unsupported;
}
const filename = mmap_file.data[header.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader) ..][0..header.filename_len];
var local_reader: std.Io.Reader = .fixed(mmap_file.data);
local_reader.discardAll(header.file_offset) catch return error.InvalidZipFile;
const local_header = local_reader.takeStruct(std.zip.LocalFileHeader, .little) catch return error.InvalidZipFile;
local_reader.discardAll(local_header.filename_len) catch return error.InvalidZipFile;
local_reader.discardAll(local_header.extra_len) catch return error.InvalidZipFile;
// normalize path separators
const file_content = mmap_file.data[local_reader.seek..][0..header.compressed_size];
const my_filename: []u8 = try allocator.dupe(u8, filename);
std.mem.replaceScalar(u8, my_filename, '\\', '/');
try file_map.put(allocator, my_filename, file_content);
if (std.mem.endsWith(u8, filename, "data.pkl")) {
pkl = file_content;
zip_prefix = filename[0 .. filename.len - "data.pkl".len];
}
}
if (pkl.len == 0) {
log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
}
return res;
return .{
.mmap_file = mmap_file,
.file_map = file_map,
.pickle_subfile = pkl,
.zip_prefix = zip_prefix,
};
}
pub fn close(self: *File) void {
self.buffer_file.deinit();
self.mmap_file.deinit();
}
pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op {
return if (self.tar_file) |tar_file| {
try tar_file.seekTo(self.pickle_subfile.start);
var buffered = std.io.bufferedReader(tar_file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
} else {
const file = self.buffer_file.file;
try file.seekTo(self.pickle_subfile.start);
var buffered = std.io.bufferedReader(file.reader());
return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len);
};
}
fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void {
var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{};
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
const filename = filename_buf[0..entry.filename_len];
try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader));
const len = try seekable_stream.context.reader().readAll(filename);
if (len != filename.len) return error.ZipBadFileOffset;
if (isBadFilename(filename)) return error.ZipBadFilename;
std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators
try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry));
}
self.file_map = file_map;
var file_iter = file_map.iterator();
while (file_iter.next()) |e| {
const entry = e.value_ptr.*;
const filename = e.key_ptr.*;
if (!std.mem.endsWith(u8, filename, "data.pkl")) continue;
self.zip_prefix = filename[0 .. filename.len - "data.pkl".len];
const local_data_header_offset: u64 = local_data_header_offset: {
switch (entry.compression_method) {
.store => {},
.deflate => {
// TODO(cryptodeal): handle decompress
@panic("TODO support use of `deflate`");
},
else => @panic("TODO support other modes of compression"),
}
const local_header = blk: {
try seekable_stream.seekTo(entry.file_offset);
break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little);
};
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.version_needed_to_extract != entry.version_needed_to_extract)
return error.ZipMismatchVersionNeeded;
if (local_header.last_modification_time != entry.last_modification_time)
return error.ZipMismatchModTime;
if (local_header.last_modification_date != entry.last_modification_date)
return error.ZipMismatchModDate;
if (@as(u16, @bitCast(local_header.flags)) != entry.flags)
return error.ZipMismatchFlags;
if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32)
return error.ZipMismatchCrc32;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
break :local_data_header_offset @as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
};
const local_data_file_offset: u64 =
@as(u64, entry.file_offset) +
@as(u64, @sizeOf(std.zip.LocalFileHeader)) +
local_data_header_offset;
self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size };
return;
}
log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
var reader: std.Io.Reader = .fixed(self.pickle_subfile);
return try pickle.parse(allocator, &reader);
}
fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool {
@ -286,7 +202,7 @@ pub const File = struct {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
}
}
@ -303,7 +219,7 @@ pub const File = struct {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
const new_tag = switch (tag) {
.int64 => "int",
.float64 => "float",
@ -321,7 +237,7 @@ pub const File = struct {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
}
try self.parseValue(allocator, store, new_prefix, item);
}
@ -353,7 +269,7 @@ pub const File = struct {
if (prefix.items.len > 0) {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{});
try self.parseValue(allocator, store, new_prefix, val);
},
inline else => |_, tag| {
@ -504,34 +420,10 @@ pub const File = struct {
/// Given the name of one of the files in the .pt tarball,
/// return the slice of the memory-mapped .pt corresponding to it.
fn getStorage(self: File, filename: []const u8) ![]const u8 {
const maybe_entry = self.file_map.get(filename);
if (maybe_entry == null) {
return self.file_map.get(filename) orelse {
std.log.err("Could not find file ending in `{s}` in archive", .{filename});
return error.TensorNotFound;
}
const entry = maybe_entry.?;
const base_offset: u64 = if (self.tar_file) |t| t.start else 0;
const file_offset: u64 = base_offset + entry.file_offset;
const file = self.buffer_file.file;
try file.seekTo(entry.file_offset);
const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little);
if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig))
return error.ZipBadFileOffset;
if (local_header.compressed_size != 0 and
local_header.compressed_size != entry.compressed_size)
return error.ZipMismatchCompLen;
if (local_header.uncompressed_size != 0 and
local_header.uncompressed_size != entry.uncompressed_size)
return error.ZipMismatchUncompLen;
if (local_header.filename_len != entry.filename_len)
return error.ZipMismatchFilenameLen;
const start = file_offset +
@sizeOf(std.zip.LocalFileHeader) +
@as(u64, local_header.filename_len) +
@as(u64, local_header.extra_len);
return self.buffer_file.mappedSlice(start, entry.uncompressed_size);
};
}
fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray {
@ -578,52 +470,6 @@ fn storageToDtype(storage_type: []const u8) !zml.DataType {
};
}
const TarStream = struct {
pub const SeekableStream = std.io.SeekableStream(
TarStream,
asynk.File.SeekError,
asynk.File.GetSeekPosError,
TarStream.seekTo,
TarStream.seekBy,
TarStream.getPos,
TarStream.getEndPos,
);
file: std.tar.Iterator(asynk.File.Reader).File,
start: usize,
pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream {
return .{
.file = file,
.start = try file.parent_reader.context.getPos(),
};
}
pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader {
return file.file.reader();
}
pub fn seekTo(self: TarStream, offset: u64) !void {
return self.file.parent_reader.context.seekTo(self.start + offset);
}
pub fn seekBy(self: TarStream, offset: i64) !void {
return self.file.parent_reader.context.seekBy(offset);
}
pub fn getPos(self: TarStream) !u64 {
return try self.file.parent_reader.context.getPos() - self.start;
}
pub fn getEndPos(self: TarStream) !u64 {
return self.file.size;
}
pub fn seekableStream(self: TarStream) TarStream.SeekableStream {
return .{ .context = self };
}
};
test "Read pickle (zipped)" {
// test file created with following python snippet:
//
@ -638,20 +484,19 @@ test "Read pickle (zipped)" {
defer store.deinit();
{
var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator);
defer tmp_arena.deinit();
const tmp_alloc = tmp_arena.allocator();
var torch_file = try File.init(tmp_alloc, mmap_file);
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
var torch_file = try File.init(arena.allocator(), mmap_file);
// We don't close the file directly, it will be closed by the store.
const ops = try torch_file.parsePickle(tmp_alloc);
const ops = try torch_file.parsePickle(arena.allocator());
try std.testing.expectEqual(302, ops.len);
const py_values = try eval.evaluate(tmp_alloc, ops, true);
const py_values = try eval.evaluate(arena.allocator(), ops, true);
try torch_file.parseModel(py_values, &store);
}
// now we have freed the tmp_arena.
// now we have freed the arena.
// all data needed should have been copied into the store arena.
try zml.testing.expectEqualShapes(
zml.Shape.init(.{ 1, 4 }, .u8),

View File

@ -763,54 +763,54 @@ pub const Op = union(enum) {
};
/// Read a stream of bytes, and interpret it as a stream of Pickle operators.
pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) ![]const Op {
var results = std.ArrayList(Op).init(allocator);
errdefer results.deinit();
const len = max_line_len;
var _buf: std.BoundedArray(u8, 12) = .{};
/// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies.
pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
// It's not very efficient to interleave the results with the data copied from the stream,
// because growth event in the results ArrayList will lead to fragmentation.
// Trying to mitigate that by using a generous default size.
var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512);
errdefer results.deinit(arena);
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);
while (true) {
const b = try reader.readByte();
const code: OpCode = @enumFromInt(b);
const code: OpCode = @enumFromInt(try reader.takeByte());
const op: Op = switch (code) {
.int => blk: {
_buf.len = 0;
try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1);
const buf = _buf.constSlice();
.int => int: {
const bytes = try reader.takeDelimiterExclusive('\n');
// Legacy hack, see OpCode.int documentation
// We do this parsing right away to simplify downstream code.
break :blk if (std.mem.eql(u8, "00", buf))
break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0')
.{ .bool = false }
else if (std.mem.eql(u8, "01", buf))
else if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '1')
.{ .bool = true }
else
.{ .int = try std.fmt.parseInt(i32, buf, 10) };
.{ .int = try std.fmt.parseInt(i32, bytes, 10) };
},
.binint => .{ .int = try reader.readInt(i32, .little) },
.binint1 => .{ .int = try reader.readByte() },
.binint2 => .{ .int = try reader.readInt(u16, .little) },
.binint => .{ .int = try reader.takeInt(i32, .little) },
.binint1 => .{ .int = try reader.takeByte() },
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
// TODO: long should handle the trailing 'L' -> add a test.
.long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.long1 => .{ .binlong = try _readSlice(reader, allocator, 1) },
.long4 => .{ .binlong = try _readSlice(reader, allocator, 4) },
.string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.binstring => .{ .string = try _readSlice(reader, allocator, 4) },
.short_binstring => .{ .string = try _readSlice(reader, allocator, 1) },
.binbytes => .{ .bytes = try _readSlice(reader, allocator, 4) },
.binbytes8 => .{ .bytes = try _readSlice(reader, allocator, 8) },
.short_binbytes => .{ .bytes = try _readSlice(reader, allocator, 1) },
.bytearray8 => .{ .bytearray = try _readSlice(reader, allocator, 8) },
.long => .{ .long = try readLine(reader, &alloc_writer) },
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
.string => .{ .string = try readLine(reader, &alloc_writer) },
.binstring => .{ .string = try _readSlice(reader, arena, 4) },
.short_binstring => .{ .string = try _readSlice(reader, arena, 1) },
.binbytes => .{ .bytes = try _readSlice(reader, arena, 4) },
.binbytes8 => .{ .bytes = try _readSlice(reader, arena, 8) },
.short_binbytes => .{ .bytes = try _readSlice(reader, arena, 1) },
.bytearray8 => .{ .bytearray = try _readSlice(reader, arena, 8) },
.next_buffer => .next_buffer,
.readonly_buffer => .readonly_buffer,
.none => .none,
.newtrue => .{ .bool = true },
.newfalse => .{ .bool = false },
.unicode => .{ .unicode = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.short_binunicode => .{ .unicode = try _readSlice(reader, allocator, 1) },
.binunicode => .{ .unicode = try _readSlice(reader, allocator, 4) },
.binunicode8 => .{ .unicode = try _readSlice(reader, allocator, 8) },
.float => .{ .float = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.binfloat => .{ .binfloat = @bitCast(try reader.readInt(u64, .big)) },
.unicode => .{ .unicode = try readLine(reader, &alloc_writer) },
.short_binunicode => .{ .unicode = try _readSlice(reader, arena, 1) },
.binunicode => .{ .unicode = try _readSlice(reader, arena, 4) },
.binunicode8 => .{ .unicode = try _readSlice(reader, arena, 8) },
.float => .{ .float = try readLine(reader, &alloc_writer) },
.binfloat => .{ .binfloat = @bitCast(try reader.takeInt(u64, .big)) },
.empty_list => .empty_list,
.append => .append,
.appends => .appends,
@ -832,74 +832,74 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize)
.mark => .mark,
.pop_mark => .pop_mark,
// If we fail to parse delay the error to the evaluation.
.get => .{
.get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32),
.get => get: {
const digits = try reader.takeDelimiterExclusive('\n');
break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
},
.binget => .{ .get = try reader.readByte() },
.long_binget => .{ .get = try reader.readInt(u32, .little) },
.put => blk: {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
defer allocator.free(buf);
const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32);
break :blk .{ .put = n };
.binget => .{ .get = try reader.takeByte() },
.long_binget => .{ .get = try reader.takeInt(u32, .little) },
.put => put: {
const digits = try reader.takeDelimiterExclusive('\n');
break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
},
.binput => .{ .put = try reader.readByte() },
.long_binput => .{ .put = try reader.readInt(u32, .little) },
.binput => .{ .put = try reader.takeByte() },
.long_binput => .{ .put = try reader.takeInt(u32, .little) },
.memoize => .memoize,
.ext1 => .{ .ext1 = try reader.readByte() },
.ext2 => .{ .ext2 = try reader.readInt(i16, .little) },
.ext4 => .{ .ext4 = try reader.readInt(i32, .little) },
.ext1 => .{ .ext1 = try reader.takeByte() },
.ext2 => .{ .ext2 = try reader.takeInt(i16, .little) },
.ext4 => .{ .ext4 = try reader.takeInt(i32, .little) },
.global => .{ .global = .{
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
.module = try readLine(reader, &alloc_writer),
.class = try readLine(reader, &alloc_writer),
} },
.stack_global => .stack_global,
.reduce => .reduce,
.build => .build,
.inst => .{ .inst = .{
.module = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
.class = try reader.readUntilDelimiterAlloc(allocator, '\n', len),
.module = try readLine(reader, &alloc_writer),
.class = try readLine(reader, &alloc_writer),
} },
.obj => .obj,
.newobj => .newobj,
.newobj_ex => .newobj_ex,
.proto => blk: {
const version = try reader.readByte();
const version = try reader.takeByte();
if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version});
break :blk .{ .proto = version };
},
.stop => .stop,
// This is not documented in pickletools but in https://peps.python.org/pep-3154/
// The frame size is stored right after the frame header.
// The loader is allowed to prefetch framesize from the underlying reader,
// and ops are not allowed to cross a frame boundary.
// We don't prefetch because we assume the reader is going to use some kind of buffered reader.
// We could try to enforce frame boundaries, but we would need to track
// how many bytes we are reading from the stream.
.frame => .{ .frame = try reader.readInt(u64, .little) },
.persid => .{ .persid = try reader.readUntilDelimiterAlloc(allocator, '\n', len) },
.frame => frame: {
// This is not documented in pickletools but in https://peps.python.org/pep-3154/
// The loader is allowed to prefetch framesize from the underlying reader,
// and ops are not allowed to cross a frame boundary.
const frame_size = try reader.takeInt(u64, .little);
reader.fill(@min(frame_size, reader.buffer.len)) catch |err| switch (err) {
error.EndOfStream => {},
else => return err,
};
break :frame .{ .frame = frame_size };
},
.persid => .{ .persid = try readLine(reader, &alloc_writer) },
.binpersid => .binpersid,
_ => |unk_tag| {
log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag});
return error.NotSupported;
},
};
try results.append(op);
try results.append(arena, op);
if (op == .stop) break;
}
return results.toOwnedSlice();
return results.items;
}
test "parse protocol 4" {
const allocator = std.testing.allocator;
var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
defer arena.deinit();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try parse(allocator, buffered_reader.reader(), 4096);
defer {
// Test we are correctly freeing every allocation.
for (ops) |op| op.deinit(allocator);
allocator.free(ops);
}
var read_buffer: [1024]u8 = undefined;
var reader = file.reader(&read_buffer);
const ops = try parse(arena.allocator(), &reader.interface);
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
var expected = [_]Op{
@ -948,7 +948,9 @@ test "parse protocol 4" {
test "parse protocol 0" {
// We also test protocol 0, cause it's more text oriented.
const allocator = std.testing.allocator;
var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
defer arena.deinit();
const pickle_0 =
\\(dp0
\\Vhello
@ -982,13 +984,8 @@ test "parse protocol 0" {
\\s.
;
var stream = std.io.fixedBufferStream(pickle_0);
const ops = try parse(allocator, stream.reader(), 4096);
defer {
// Test we are correctly freeing every allocation.
for (ops) |op| op.deinit(allocator);
allocator.free(ops);
}
var reader: std.Io.Reader = .fixed(pickle_0);
const ops = try parse(arena.allocator(), &reader);
var expected = [_]Op{
.mark,
@ -1043,18 +1040,11 @@ test "parse protocol 0" {
try std.testing.expectEqualDeep(&expected, ops);
}
fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T {
buffer.len = 0;
try reader.streamUntilDelimiter(buffer.writer(), '\n', 13);
return std.fmt.parseInt(T, buffer.constSlice(), 10);
}
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
const T = std.meta.Int(.unsigned, 8 * len_bytes);
const str_len: u64 = try reader.readInt(T, .little);
const str_len: u64 = try reader.takeInt(T, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
_ = try reader.readSliceAll(buf);
return buf;
}
@ -1063,3 +1053,14 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8)
std.mem.writeInt(T, &res, value, .little);
return res;
}
fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 {
const n = try reader.streamDelimiter(&alloc_writer.writer, '\n');
std.debug.assert(try reader.takeByte() == '\n');
const w = &alloc_writer.writer;
std.debug.assert(w.end == n);
const items = w.buffer[0..n];
w.buffer = w.buffer[n + 1 ..];
w.end = 0;
return items;
}

View File

@ -172,7 +172,7 @@ pub const HostBuffer = struct {
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
const ptr: [*]const T = @ptrCast(@alignCast(self._data));
return ptr[0..self._shape.count()];
}

View File

@ -664,7 +664,7 @@ pub const CompilationContext = struct {
// Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated.
var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*;
var res = @as(*const stdx.meta.FnResult(func), @ptrCast(@alignCast(function.res_tensors))).*;
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
meta.visit((struct {

View File

@ -113,7 +113,7 @@ const _CreateOptions = struct {
/// "Best-Fit with Coalescing" algorithm
bfc: Options,
/// use cudaMallocAsync
@"async": Options,
async: Options,
/// use raw cuMalloc
platform,
@ -129,7 +129,7 @@ const _CreateOptions = struct {
.platform => {
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
},
.bfc, .@"async" => |opt| {
.bfc, .async => |opt| {
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
if (opt.memory_fraction > 0) {

View File

@ -1,4 +1,5 @@
const builtin = @import("builtin");
const c = @import("c");
pub const Tracer = switch (builtin.os.tag) {
@ -11,15 +12,15 @@ const CudaTracer = struct {
// Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so
// They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so).
const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
const cuProfilerStart = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable;
const cuProfilerStop = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable;
// Those symbols are defined in nvToolsExt.h which we don't want to provide.
// However, we link with libnvToolsExt.so which provides them.
// They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us).
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.c) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable;
const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.c) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable;
const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.c) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable;
pub fn init(name: [:0]const u8) CudaTracer {
_ = name;