diff --git a/mlir/mlir.zig b/mlir/mlir.zig index a222c40..a672191 100755 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -605,7 +605,7 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { pub fn items(self: Attr) []const dt.ZigType() { const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable; - const ptr: [*]const dt.ZigType() = @alignCast(@ptrCast(raw_bytes)); + const ptr: [*]const dt.ZigType() = @ptrCast(@alignCast(raw_bytes)); // Note the mlir API returns us the number of elements, not the number of bytes, // that's why we track the element type at comptime to allow items to work. return ptr[0..self.len()]; @@ -1743,7 +1743,7 @@ pub const helpers = struct { writer: *std.Io.Writer, err: ?std.Io.Writer.Error = null, fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void { - var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); + var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx)); if (ctx.err) |_| return; _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { ctx.err = err; diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig index b79a789..d2517a0 100644 --- a/pjrt/ffi.zig +++ b/pjrt/ffi.zig @@ -359,7 +359,7 @@ pub const Attrs = extern struct { value: *const anyopaque, pub fn get(self: Scalar, T: type) T { - const ptr: *const T = @alignCast(@ptrCast(self.value)); + const ptr: *const T = @ptrCast(@alignCast(self.value)); return ptr.*; } }; @@ -370,13 +370,13 @@ pub const Attrs = extern struct { data: [*]const u8, pub fn slice(self: Array, T: type) []const T { - const ptr: [*]const T = @alignCast(@ptrCast(self.data)); + const ptr: [*]const T = @ptrCast(@alignCast(self.data)); return ptr[0..self.len]; } }; pub fn slice(self: Array, T: type) []const T { - const ptr: [*]const T = @alignCast(@ptrCast(self.data)); + const ptr: [*]const T = @ptrCast(@alignCast(self.data)); return ptr[0..self.len]; } diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 87f3d17..2458149 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -58,7 +58,7 @@ pub const ApiError = error{ fn InnerMixin(comptime innerT: type) type { return struct { fn inner(self: anytype) *innerT { - return @ptrCast(@constCast(@alignCast(self))); + return @ptrCast(@alignCast(@constCast(self))); } }; } @@ -125,10 +125,10 @@ pub const Api = struct { } pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT { - var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start)); + var cur: [*c]const c.PJRT_Extension_Base = @ptrCast(@alignCast(self.inner.extension_start)); while (cur != null) : (cur = cur.*.next) { if (cur.*.type == ext_id) { - return @alignCast(@ptrCast(cur)); + return @ptrCast(@alignCast(cur)); } } @@ -432,7 +432,7 @@ pub const Client = opaque { .client = self.inner(), }) catch unreachable; if (ret.addressable_memories) |memories| { - return @constCast(@ptrCast(memories[0..ret.num_addressable_memories])); + return @ptrCast(@constCast(memories[0..ret.num_addressable_memories])); } return &.{}; } diff --git a/runtimes/cuda/cuda.zig b/runtimes/cuda/cuda.zig index c07c0e9..42b3b5f 100644 --- a/runtimes/cuda/cuda.zig +++ b/runtimes/cuda/cuda.zig @@ -28,7 +28,7 @@ fn hasCudaPathInLDPath() bool { fn setupXlaGpuCudaDirFlag(allocator: std.mem.Allocator, sandbox: []const u8) !void { const xla_flags = std.process.getEnvVarOwned(allocator, "XLA_FLAGS") catch ""; - const new_xla_flagsZ = try std.fmt.allocPrintZ(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }); + const new_xla_flagsZ = try std.fmt.allocPrintSentinel(allocator, "{s} --xla_gpu_cuda_data_dir={s}", .{ xla_flags, sandbox }, 0); _ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1); } diff --git a/runtimes/neuron/libneuronxla.zig b/runtimes/neuron/libneuronxla.zig index ae5710e..0010d28 100644 --- a/runtimes/neuron/libneuronxla.zig +++ b/runtimes/neuron/libneuronxla.zig @@ -38,7 +38,7 @@ var module_def: c.PyModuleDef = .{ .{}, }), .m_slots = @constCast(&[_]c.PyModuleDef_Slot{ - .{ .slot = c.Py_mod_exec, .value = @constCast(@ptrCast(&module_exec)) }, + .{ .slot = c.Py_mod_exec, .value = @ptrCast(@constCast(&module_exec)) }, .{}, }), .m_traverse = null, diff --git a/runtimes/neuron/neuron.zig b/runtimes/neuron/neuron.zig index 09d6a21..4d2efa1 100644 --- a/runtimes/neuron/neuron.zig +++ b/runtimes/neuron/neuron.zig @@ -25,10 +25,10 @@ fn isRunningOnEC2() !bool { var f = try asynk.File.open("/sys/devices/virtual/dmi/id/sys_vendor", .{ .mode = .read_only }); defer f.close() catch {}; - var buf: [AmazonEC2.len]u8 = undefined; - _ = try f.reader().readAll(&buf); + var content: [AmazonEC2.len]u8 = undefined; + const n_read = try f.pread(&content, 0); - return std.mem.eql(u8, &buf, AmazonEC2); + return std.mem.eql(u8, content[0..n_read], AmazonEC2); } pub fn load() !*const pjrt.Api { @@ -45,7 +45,7 @@ pub fn load() !*const pjrt.Api { return error.Unavailable; } - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); defer arena.deinit(); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { diff --git a/runtimes/rocm/rocm.zig b/runtimes/rocm/rocm.zig index 0861f92..9be8b07 100644 --- a/runtimes/rocm/rocm.zig +++ b/runtimes/rocm/rocm.zig @@ -37,7 +37,7 @@ pub fn load() !*const pjrt.Api { return error.Unavailable; } - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); defer arena.deinit(); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { diff --git a/runtimes/tpu/tpu.zig b/runtimes/tpu/tpu.zig index dc12b43..96af0be 100644 --- a/runtimes/tpu/tpu.zig +++ b/runtimes/tpu/tpu.zig @@ -1,12 +1,12 @@ -const builtin = @import("builtin"); const std = @import("std"); +const builtin = @import("builtin"); const asynk = @import("async"); -const pjrt = @import("pjrt"); -const c = @import("c"); -const stdx = @import("stdx"); const bazel_builtin = @import("bazel_builtin"); +const c = @import("c"); +const pjrt = @import("pjrt"); const runfiles = @import("runfiles"); +const stdx = @import("stdx"); const log = std.log.scoped(.@"zml/runtime/tpu"); @@ -25,10 +25,10 @@ fn isOnGCP() !bool { var f = try asynk.File.open("/sys/devices/virtual/dmi/id/product_name", .{ .mode = .read_only }); defer f.close() catch {}; - var buf = [_]u8{0} ** GoogleComputeEngine.len; - _ = try f.reader().readAll(&buf); + var content: [GoogleComputeEngine.len]u8 = undefined; + const n_read = try f.pread(&content, 0); - return std.mem.eql(u8, &buf, GoogleComputeEngine); + return std.mem.eql(u8, content[0..n_read], GoogleComputeEngine); } pub fn load() !*const pjrt.Api { @@ -42,7 +42,7 @@ pub fn load() !*const pjrt.Api { return error.Unavailable; } - var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator); + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); defer arena.deinit(); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse { diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 90456d1..e44bda2 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -8,7 +8,6 @@ zig_library( "flags.zig", "fmt.zig", "fs.zig", - "io.zig", "json.zig", "math.zig", "meta.zig", diff --git a/stdx/io.zig b/stdx/io.zig deleted file mode 100644 index 390647b..0000000 --- a/stdx/io.zig +++ /dev/null @@ -1,4 +0,0 @@ -const std = @import("std"); - -pub const BufferedAnyWriter = std.io.BufferedWriter(4096, std.io.AnyWriter); -pub const BufferedAnyReader = std.io.BufferedReader(4096, std.io.AnyReader); diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 661dea6..9bfa532 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -4,7 +4,6 @@ pub const debug = @import("debug.zig"); pub const flags = @import("flags.zig"); pub const fmt = @import("fmt.zig"); pub const fs = @import("fs.zig"); -pub const io = @import("io.zig"); pub const json = @import("json.zig"); pub const math = @import("math.zig"); pub const meta = @import("meta.zig"); diff --git a/upb/upb.zig b/upb/upb.zig index 1a233ec..2a9b1d5 100644 --- a/upb/upb.zig +++ b/upb/upb.zig @@ -95,7 +95,7 @@ pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 { pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType { const obj = try new(UpbType, arena); - return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) { + return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @ptrCast(@alignCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) { c.kUpb_DecodeStatus_Ok => obj, c.kUpb_DecodeStatus_Malformed => ParseError.Malformed, c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory, diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index e0722f7..aaa44f0 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -24,6 +24,11 @@ zig_library( "aio/json.zig", "aio/safetensors.zig", "aio/tinyllama.zig", + "aio/torch.zig", + "aio/torch/eval.zig", + "aio/torch/file.zig", + "aio/torch/pickle.zig", + "aio/torch/py.zig", "buffer.zig", "context.zig", "dtype.zig", @@ -72,6 +77,10 @@ zig_library( zig_test( name = "test", + data = [ + "aio/torch/simple.pt", + "aio/torch/simple_test_4.pickle", + ], test_runner = ":test_runner", deps = [":zml"], ) diff --git a/zml/aio.zig b/zml/aio.zig index 67e1876..1b926a9 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -5,16 +5,16 @@ const c = @import("c"); const stdx = @import("stdx"); pub const safetensors = @import("aio/safetensors.zig"); -pub const tinyllama = @import("aio/tinyllama.zig"); +pub const torch = @import("aio/torch.zig"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; const posix = @import("posix.zig"); const zml = @import("zml.zig"); pub const log = std.log.scoped(.@"zml/aio"); - test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(safetensors); + std.testing.refAllDecls(torch); } // TODO error set for weight loading @@ -25,6 +25,12 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) try safetensors.open(allocator, model_path) else if (std.mem.endsWith(u8, model_path, ".safetensors.index.json")) try safetensors.open(allocator, model_path) + else if (std.mem.endsWith(u8, model_path, ".pt")) + try torch.open(allocator, model_path) + // else if (std.mem.endsWith(u8, model_path, ".gguf")) + // try gguf.open(allocator, model_path) + // else if (std.mem.endsWith(u8, model_path, ".tinyllama")) + // try tinyllama.open(allocator, model_path) else { std.debug.panic("File extension not recognized: {s}", .{model_path}); }; @@ -384,7 +390,7 @@ fn _populateStruct( partial_struct = partial_struct or field_found; if (!field_found) { if (field.default_value_ptr) |v| { - @field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*; + @field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*; } else { if (partial_struct) { log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name }); diff --git a/zml/aio/gguf.zig b/zml/aio/gguf.zig deleted file mode 100644 index 3134f96..0000000 --- a/zml/aio/gguf.zig +++ /dev/null @@ -1,83 +0,0 @@ -const asynk = @import("async"); -const core = @import("gguf/core.zig"); -const std = @import("std"); -const zml = @import("../zml.zig"); - -const HostBuffer = @import("../hostbuffer.zig").HostBuffer; - -const Allocator = std.mem.Allocator; -const assert = std.debug.assert; - -const log = std.log.scoped(.@"zml/io"); - -pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { - var file = try core.GgufFile.open(path); - errdefer file.close(); - - var res: zml.aio.BufferStore = .{ - .arena = std.heap.ArenaAllocator.init(allocator), - }; - errdefer res.arena.deinit(); - const arena = res.arena.allocator(); - - res.files = try arena.dupe(zml.aio.MemoryMappedFile, &.{file.file}); - - // metadata must be read in order to read tensors - try loadMetadata(arena, &res, &file); - try loadBuffers(arena, &res, &file); - if (res.buffers.count() != file.header.tensor_count) { - log.warn("Expected to find {d} tensors in {s}, only found {d}", .{ file.header.tensor_count, path, res.buffers.count() }); - } - return res; -} - -fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void { - try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count)); - - while (file.readMetadata(allocator)) |entry| { - log.info("Loading MetaData: {s}", .{entry.name}); - const res = store._metadata.getOrPutAssumeCapacity(entry.name); - if (res.found_existing) { - // This file seems invalid. Since most metadatas aren't required, continue ahead. - log.warn("Found duplicated metadata key: {s}", .{entry.name}); - continue; - } - res.value_ptr.* = switch (entry.val) { - .array => |arr| switch (arr.child) { - inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: { - const T = @FieldType(core.GgufValue, @tagName(tag)); - break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data)); - }, - else => blk: { - log.warn("ignoring array metadata", .{}); - break :blk .null; - }, - }, - inline else => |v| zml.aio.Metadata.wrap(v), - }; - } else |err| switch (err) { - error.EndOfMetadata => {}, - else => return err, - } -} - -fn loadBuffers(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void { - try store.buffers.ensureTotalCapacity(allocator, @intCast(file.header.tensor_count)); - while (file.readTensorInfo(allocator)) |info| { - const res = store.buffers.getOrPutAssumeCapacity(info.name); - if (res.found_existing) { - // This file seems invalid. Try to continue anyway. - log.warn("Found duplicated tensor: {s}", .{info.name}); - continue; - } - - // TODO: handle quantized types - const dtype: zml.DataType = info.t.toDtype() orelse return error.UnsupportedGgufType; - const buffer = HostBuffer.fromBytes(zml.Shape.init(info.shape(), dtype), file.file.mappedSlice(info.start, info.byte_len)); - res.value_ptr.* = buffer; - // store the info index. - } else |err| switch (err) { - error.EndOfMetadata => {}, - else => return err, - } -} diff --git a/zml/aio/gguf/core.zig b/zml/aio/gguf/core.zig deleted file mode 100644 index 27102a2..0000000 --- a/zml/aio/gguf/core.zig +++ /dev/null @@ -1,505 +0,0 @@ -const asynk = @import("async"); -const std = @import("std"); -const zml = @import("../../zml.zig"); - -const assert = std.debug.assert; -const log = std.log.scoped(.@"zml/io"); - -pub const GgufErrors = error{ - ValueTypeMismatch, - InvalidGguf, - UnsupportedGgufType, - EndOfMetadata, - OutOfMemory, -}; - -// Enums and structures -pub const TensorType = enum(u32) { - f32 = 0, - f16 = 1, - q4_0 = 2, - q4_1 = 3, - deprecated_q4_2 = 4, - deprecated_q4_3 = 5, - q5_0 = 6, - q5_1 = 7, - q8_0 = 8, - q8_1 = 9, - // k-quantizations - q2_k = 10, - q3_k = 11, - q4_k = 12, - q5_k = 13, - q6_k = 14, - q8_k = 15, - i8 = 16, - i16 = 17, - i32 = 18, - - const MAX_KNOWN_ENUM = 18; - - pub fn canConvertQuant(self: TensorType) bool { - return switch (self) { - .q8_0, .q4_k, .q6_k, .q2_k, .q4_0, .q4_1 => true, - else => false, - }; - } - - pub fn toDtype(self: TensorType) ?zml.DataType { - return switch (self) { - .f32 => .f32, - .f16 => .f16, - .i8 => .i8, - .i16 => .i16, - .i32 => .i32, - else => null, - }; - } - - pub fn sizeOf(self: TensorType) usize { - return self.toDtype().?.sizeOf(); - } - - /// Return the tensor type features - pub fn getFeatures(t: TensorType) TensorTypeFeatures { - return switch (t) { - inline else => |val| @field(TENSOR_TYPE_FEATURES, @tagName(val)), - }; - } -}; - -/// GGUF tensor type to features lookup table. -pub const TensorTypeFeatures = struct { - items_per_block: u29, - bytes_per_block: u29, - - pub fn alignment(features: TensorTypeFeatures) u8 { - return std.math.log2_int(u29, features.bytes_per_block); - } -}; - -pub const TENSOR_TYPE_FEATURES: std.enums.EnumFieldStruct(TensorType, TensorTypeFeatures, null) = .{ - .f32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f32) }, - .f16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(f16) }, - .q4_0 = .{ .items_per_block = 32, .bytes_per_block = 18 }, - .q4_1 = .{ .items_per_block = 32, .bytes_per_block = 20 }, - .deprecated_q4_2 = .{ .items_per_block = 0, .bytes_per_block = 0 }, - .deprecated_q4_3 = .{ .items_per_block = 0, .bytes_per_block = 0 }, - .q5_0 = .{ .items_per_block = 32, .bytes_per_block = 22 }, - .q5_1 = .{ .items_per_block = 32, .bytes_per_block = 24 }, - .q8_0 = .{ .items_per_block = 32, .bytes_per_block = 34 }, - .q8_1 = .{ .items_per_block = 32, .bytes_per_block = 40 }, - .q2_k = .{ .items_per_block = 256, .bytes_per_block = 82 }, - .q3_k = .{ .items_per_block = 256, .bytes_per_block = 110 }, - .q4_k = .{ .items_per_block = 256, .bytes_per_block = 144 }, - .q5_k = .{ .items_per_block = 256, .bytes_per_block = 176 }, - .q6_k = .{ .items_per_block = 256, .bytes_per_block = 210 }, - .q8_k = .{ .items_per_block = 256, .bytes_per_block = 292 }, - .i8 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i8) }, - .i16 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i16) }, - .i32 = .{ .items_per_block = 1, .bytes_per_block = @sizeOf(i32) }, -}; - -pub const GgufValueType = enum(u32) { - // The value is a 8-bit unsigned integer. - uint8 = 0, - // The value is a 8-bit signed integer. - int8 = 1, - // The value is a 16-bit unsigned little-endian integer. - uint16 = 2, - // The value is a 16-bit signed little-endian integer. - int16 = 3, - // The value is a 32-bit unsigned little-endian integer. - uint32 = 4, - // The value is a 32-bit signed little-endian integer. - int32 = 5, - // The value is a 32-bit IEEE754 floating point number. - float32 = 6, - // The value is a boolean. - // 1-byte value where 0 is false and 1 is true. - // Anything else is invalid, and should be treated as either the model - // being invalid or the reader being buggy. - bool = 7, - // The value is a UTF-8 non-null-terminated string, with length prepended. - string = 8, - // The value is an array of other values, with the length and type - // prepended. Arrays can be nested, and the length of the array is the - // number of elements in the array, not the number of bytes. - array = 9, - // The value is a 64-bit unsigned little-endian integer. - uint64 = 10, - // The value is a 64-bit signed little-endian integer. - int64 = 11, - // The value is a 64-bit IEEE754 floating point number. - float64 = 12, - // Special values used by the callbacks of gguf_do_with_value(). - array_start = 100, - array_end = 101, - - // Allow other values in case GGUF add more types without us noticing - _, - - pub fn sizeOf(self: GgufValueType) usize { - return switch (self) { - .uint8 => @sizeOf(u8), - .int8 => @sizeOf(i8), - .uint16 => @sizeOf(u16), - .int16 => @sizeOf(i16), - .uint32 => @sizeOf(u32), - .int32 => @sizeOf(i32), - .float32 => @sizeOf(f32), - .bool => @sizeOf(bool), - .uint64 => @sizeOf(u64), - .int64 => @sizeOf(i64), - .float64 => @sizeOf(f64), - .string => @sizeOf([]u8), - else => unreachable, - }; - } - - pub fn arrayTypeCheck(self: GgufValueType, comptime T: type) !void { - switch (self) { - .string => if (T != []u8 and T != []const u8) return error.ValueTypeMismatch, - .uint8 => if (T != u8) return error.ValueTypeMismatch, - .int8 => if (T != i8) return error.ValueTypeMismatch, - .uint16 => if (T != u16) return error.ValueTypeMismatch, - .int16 => if (T != i16) return error.ValueTypeMismatch, - .uint32 => if (T != u32) return error.ValueTypeMismatch, - .int32 => if (T != i32) return error.ValueTypeMismatch, - .float32 => if (T != f32) return error.ValueTypeMismatch, - .bool => if (T != bool) return error.ValueTypeMismatch, - .uint64 => if (T != u64) return error.ValueTypeMismatch, - .int64 => if (T != i64) return error.ValueTypeMismatch, - .float64 => if (T != f64) return error.ValueTypeMismatch, - else => {}, - } - } -}; - -pub const ValueType = enum(u8) { - uint8 = 0, - int8 = 1, - uint16 = 2, - int16 = 3, - uint32 = 4, - int32 = 5, - float32 = 6, - bool = 7, - string = 8, - array = 9, - uint64 = 10, - int64 = 11, - float64 = 12, -}; - -// Union of possible values. -pub const GgufValue = union(ValueType) { - uint8: u8, - int8: i8, - uint16: u16, - int16: i16, - uint32: u32, - int32: i32, - float32: f32, - bool: bool, - string: []const u8, - array: Array, - uint64: u64, - int64: i64, - float64: f64, - - pub const Array = struct { - // Any value type is valid, including arrays. - child: ValueType, - // Number of elements, not bytes - len: usize, - data: []u8, - }; -}; - -// Header -const GgufHeader = extern struct { - // Magic number to announce that this is a GGUF file. Must be `GUFF`. - magic: [4]u8, - // The version of the format implemented. - // Must be `3` for version described in this spec. - version: u32, - // The number of tensors in the file. - // This is explicit, instead of being included in the metadata, to ensure - // it is always present for loading the tensors. - tensor_count: usize, - // The number of metadata key-value pairs. - metadata_kv_count: usize, - - pub fn validate(self: GgufHeader) !void { - if (!std.mem.eql(u8, &self.magic, "GGUF")) { - log.err("Invalid GGUF file: wrong header {s}", .{self.magic}); - return error.InvalidGguf; - } - } -}; - -// Key representation in this library API. -pub const GgufMetadataKv = struct { - name: []const u8, - type_: GgufValueType, - val: GgufValue, -}; - -// Tensor representation in this library API. -const GGUF_TENSOR_MAX_DIM: usize = 8; // Future-proof: actual limit is 4. -pub const GgufTensorInfo = struct { - name: []const u8, - t: TensorType, // Tensor type (enum TensorType). - rank: usize, // Number of dimensions of the tensor. - dims: [GGUF_TENSOR_MAX_DIM]i64, // Dimensions (Eg. [512, 1024, 1, 1]). - start: usize, // Offset from start of data section. - byte_len: usize, // Total size in bytes. - num_weights: usize, // Total number of parameters. - - pub inline fn shape(info: GgufTensorInfo) []const i64 { - return info.dims[0..info.rank]; - } -}; - -// Return the value type name given the type ID. -fn getValueTypeName(t: u32) []const u8 { - if (@as(usize, @intCast(t)) >= GGUF_VALUE_NAME.len) return "unknown"; - return GGUF_VALUE_NAME[@intCast(t)]; -} - -const GGUF_VALUE_NAME = [_][]const u8{ - "uint8", "int8", "uint16", "int16", "uint32", "int32", - "float32", "bool", "string", "array", "uint64", "int64", - "float64", -}; - -/// GGUF file API -/// A memory-mapped view of a .gguf file. -/// Format used by GGML models: https://github.com/ggerganov/ggml/ -pub const GgufFile = struct { - header: GgufHeader, // GUFF file header info. - size: usize, // Total file size. - file: zml.aio.MemoryMappedFile, - left_kv: usize, // Number of key-value pairs yet to read. - left_tensors: usize, // Number of tensors yet to read. - off: usize, // Offset of the next item to parse. - alignment: usize = 32, // File data alignment. Default: 32 bytes. - - /// Open and memmap the given file. - pub fn open(path: []const u8) !GgufFile { - const file = try asynk.File.open(path, .{}); - const header = try file.reader().readStruct(GgufHeader); - try header.validate(); - return .{ - .header = header, - .size = (try file.stat()).size, - .file = try zml.aio.MemoryMappedFile.init(file), - .off = @sizeOf(GgufHeader), - .left_kv = header.metadata_kv_count, - .left_tensors = header.tensor_count, - }; - } - - /// Unmap the file memory and close the file handle. - pub fn close(self: *GgufFile) void { - self.file.deinit(); - } - - /// Set the context to read the first key-value entry in the GGUF - /// file and then all the rest. Is used when creating a new context - /// and also when you want to restart scanning the key-value - /// items in the file. - fn rewind(ctx: *GgufFile) void { - ctx.off = @sizeOf(GgufHeader); - ctx.left_kv = ctx.header.metadata_kv_count; - ctx.left_tensors = ctx.header.tensor_count; - } - - pub fn seek(self: *GgufFile, pos: usize) void { - assert(pos < self.size); - self.off = pos; - } - - fn readInt(self: *GgufFile, comptime T: type) !T { - if (self.off + @sizeOf(T) >= self.size) return error.InvalidGguf; - const res = self.file.file.reader().readInt(T, .little); - self.off += @sizeOf(T); - return res; - } - - fn readTensorType(self: *GgufFile) !TensorType { - const raw = try self.readInt(u32); - if (raw > TensorType.MAX_KNOWN_ENUM) { - log.err("Unsupported GGUF tensor type: {d}", .{raw}); - return error.UnsupportedGgufType; - } - return @enumFromInt(raw); - } - - fn readValueType(self: *GgufFile) !GgufValueType { - const raw = try self.readInt(u32); - const t: GgufValueType = @enumFromInt(raw); - switch (t) { - .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .array, .uint64, .int64, .float64, .array_start, .array_end => {}, - else => { - log.err("Unsupported GGUF value type: {s}", .{@tagName(t)}); - return error.UnsupportedGgufType; - }, - } - return t; - } - - pub fn readAlloc(self: *GgufFile, allocator: std.mem.Allocator, len: usize) ![]u8 { - const data = try allocator.alloc(u8, len); - const read = try self.file.file.reader().readAll(data); - if (read != data.len) return error.InvalidGguf; - self.off += len; - return data; - } - - pub fn skipBytes(self: *GgufFile, len: usize) !void { - try self.file.file.seekBy(@intCast(len)); - self.off += len; - } - - /// Read the len then the actual bytes. - pub fn readString(self: *GgufFile, allocator: std.mem.Allocator) ![]u8 { - const len: usize = try self.readInt(u64); - return self.readAlloc(allocator, len); - } - - pub fn skipString(self: *GgufFile) !void { - const len: usize = try self.readInt(u64); - return self.skipBytes(len); - } - - fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array { - const child = try self.readValueType(); - if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) { - return error.UnsupportedGgufType; - } - const len: usize = try self.readInt(u64); - const data = switch (child) { - // Since strings have variable lenghts, we need to read them one by one - .string => str: { - var data = try allocator.alloc([]u8, len); - for (0..len) |i| data[i] = try self.readString(allocator); - break :str std.mem.sliceAsBytes(data); - }, - else => try self.readAlloc(allocator, len * child.sizeOf()), - }; - return .{ - .child = @enumFromInt(@intFromEnum(child)), - .len = len, - .data = data, - }; - } - - fn readTypedValue(self: *GgufFile, allocator: std.mem.Allocator, t: GgufValueType) !GgufValue { - return switch (t) { - .uint8 => .{ .uint8 = try self.readInt(u8) }, - .int8 => .{ .int8 = try self.readInt(i8) }, - .uint16 => .{ .uint16 = try self.readInt(u16) }, - .int16 => .{ .int16 = try self.readInt(i16) }, - .uint32 => .{ .uint32 = try self.readInt(u32) }, - .int32 => .{ .int32 = try self.readInt(i32) }, - .float32 => .{ .float32 = @bitCast(try self.readInt(u32)) }, - .bool => .{ .bool = try self.readInt(u8) != 0 }, - .string => .{ .string = try self.readString(allocator) }, - .array => .{ .array = try self.readArrayHeader(allocator) }, - .uint64 => .{ .uint64 = try self.readInt(u64) }, - .int64 => .{ .int64 = try self.readInt(i64) }, - .float64 => .{ .float64 = @bitCast(try self.readInt(u64)) }, - else => error.UnsupportedGgufType, - }; - } - - /// Parses the next metadata entry. - /// Returns error.EndOfMetadata if there are no longer metadata to process in this GGUF file. - pub fn readMetadata(self: *GgufFile, allocator: std.mem.Allocator) !GgufMetadataKv { - if (self.left_kv == 0) return error.EndOfMetadata; - self.left_kv -= 1; - const name = try self.readString(allocator); - const type_ = try self.readValueType(); - const val: GgufValue = try self.readTypedValue(allocator, type_); - return .{ .name = name, .type_ = type_, .val = val }; - } - - // Set the data section offset. This function must be called exactly when - // all the key-values are consumed, in the context of the first call of - // ctx.getTensor(): this way we will be able to return tensor offsets - // as absolute positions and pointers to the mmapped file. - fn setDataOffset(self: *GgufFile) !void { - const base_off = self.off; - - assert(self.left_kv == 0 and self.left_tensors == self.header.tensor_count); - - for (0..self.left_tensors) |_| try self.skipTensor(); - const padding: usize = getAlignmentPadding(self.alignment, self.off); - self.file.data_offset = self.off + padding; - - try self.file.file.seekTo(base_off); - self.off = base_off; - } - - pub fn skipTensor(self: *GgufFile) !void { - try self.skipString(); // Skip name - const num_dim: u32 = try self.readInt(u32); - // dimensions, type, and offset. - try self.skipBytes(8 * num_dim + 4 + 8); - } - - /// Parses the next tensor entry. - /// Returns error.EndOfMetadata if there are no longer tensor metadata to process in this GGUF file. - pub fn readTensorInfo(self: *GgufFile, allocator: std.mem.Allocator) !GgufTensorInfo { - if (self.left_tensors == 0 or self.left_kv != 0) { - return error.EndOfMetadata; - } - - // We want to return tensor data with offsets relative to the start - // of the file, so that the user of the API is able to access tensors - // as it iterates over them. To do so, we need to perform a full - // scan if this is the first tensor info we are reading. - // TODO: explicitly set the data offset in - if (self.file.data_offset == 0) try self.setDataOffset(); - self.left_tensors -= 1; - const name = try self.readString(allocator); - const num_dim = try self.readInt(u32); - assert(@as(usize, @intCast(num_dim)) <= GGUF_TENSOR_MAX_DIM); - // Read the dimentions; unused dimensions are left `undefined`. - // Note: we reverse the order of the dimensions to match zml convention. - var dims: [GGUF_TENSOR_MAX_DIM]i64 = undefined; - var num_weights: usize = 1; - for (0..num_dim) |j| { - const d = try self.readInt(u64); - dims[num_dim - 1 - j] = @intCast(d); - num_weights *= d; - } - const t: TensorType = try self.readTensorType(); - const start = try self.readInt(u64); - // To accurately calculate the bytes used by this tensor on the GGUF - // file, we need to take into account that quantization methods store - // tensors as block of N weights. So first of all we need to understand - // the number of padding weights (since the last block may have just - // fewer weights stored inside, but still requires to be stored to its full - // length). Then we can do the math to see how many blocks we need, and - // multiply by the block size to obtain the final total size. - const tf = t.getFeatures(); - const byte_len: usize = (std.math.divCeil(usize, num_weights, tf.items_per_block) catch unreachable) * tf.bytes_per_block; - return .{ - .name = name, - .t = t, - .rank = num_dim, - .dims = dims, - .start = start, - .byte_len = byte_len, - .num_weights = num_weights, - }; - } -}; - -/// Given an offset or a length, returns the padding needed to align it to alignment. -fn getAlignmentPadding(alignment: usize, offset: usize) usize { - return @rem((alignment - @rem(offset, alignment)), alignment); -} diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index ff0d116..79487f6 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -1,9 +1,9 @@ -const asynk = @import("async"); const std = @import("std"); -const zml = @import("../zml.zig"); +const asynk = @import("async"); + +const zml = @import("../zml.zig"); const eval = @import("torch/eval.zig"); -const py = @import("torch/py.zig"); const File = @import("torch/file.zig").File; const StringBuilder = std.ArrayListUnmanaged(u8); @@ -12,7 +12,7 @@ const log = std.log.scoped(.@"zml/aio"); test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(eval); - std.testing.refAllDecls(py); + std.testing.refAllDecls(@import("torch/py.zig")); std.testing.refAllDecls(File); } @@ -30,13 +30,13 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore const tmp_alloc = arena.allocator(); const mmap_file = try zml.aio.MemoryMappedFile.init(file); - var torch_file = try File.init(tmp_alloc, mmap_file); + var torch_file = try asynk.callBlocking(File.init, .{ tmp_alloc, mmap_file }); const ops = try torch_file.parsePickle(tmp_alloc); const py_values = try eval.evaluate(tmp_alloc, ops, true); // file ownership is transferred to the BufferStore - var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.buffer_file}); + var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file}); try torch_file.parseModel(py_values, &res); return res; } diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index c7085a1..42565ae 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -151,23 +151,23 @@ pub const PickleMemo = struct { }; pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const py.Any { - var stack = std.ArrayList(py.Any).init(arena); + var stack: std.ArrayList(py.Any) = .{}; var memo = PickleMemo.init(arena); for (x) |op| { switch (op) { - .mark => try stack.append(.{ .raw = op }), + .mark => try stack.append(arena, .{ .raw = op }), .frame => {}, .stop => break, .pop => _ = try pop(&stack), .pop_mark => _ = try popMark(&stack), .dup => if (stack.getLastOrNull()) |item| - try stack.append(try item.clone(arena)) + try stack.append(arena, try item.clone(arena)) else return error.CannotDupEmptyStack, - .persid => |v| try stack.append(.{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), - .binpersid => try stack.append(.{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }), - .reduce => try stack.append(.{ .global = blk: { + .persid => |v| try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }), + .binpersid => try stack.append(arena, .{ .pers_id = try py.PersId.init(arena, try pop(&stack)) }), + .reduce => try stack.append(arena, .{ .global = blk: { var args = try pop(&stack); args = try memo.resolve(arena, args, true); if (args != .seq) return error.InvalidInput; @@ -175,23 +175,23 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo func = try memo.resolve(arena, func, true); break :blk try py.Object.init(arena, func, args.seq.values, &.{}); } }), - .build => try stack.append(blk: { + .build => try stack.append(arena, blk: { const args = try memo.resolve(arena, try pop(&stack), true); const member = try memo.resolve(arena, try pop(&stack), true); break :blk .{ .set_state = try py.SetState.init(arena, member, args) }; }), - .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }), - .get => |v| try stack.append(.{ .ref = v }), - .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }), + .empty_dict => try stack.append(arena, .{ .seq = .{ .type = .dict, .values = &[_]py.Any{} } }), + .get => |v| try stack.append(arena, .{ .ref = v }), + .empty_list => try stack.append(arena, .{ .seq = .{ .type = .list, .values = &[_]py.Any{} } }), .put => |v| { try memo.insert(v, try pop(&stack)); - try stack.append(.{ .ref = v }); + try stack.append(arena, .{ .ref = v }); }, - .tuple => try stack.append(blk: { + .tuple => try stack.append(arena, blk: { const popped = try popMark(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = try arena.dupe(py.Any, popped) } }; }), - .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }), + .empty_tuple => try stack.append(arena, .{ .seq = .{ .type = .tuple, .values = &[_]py.Any{} } }), .setitem => { const v = try memo.resolve(arena, try pop(&stack), true); const k = try memo.resolve(arena, try pop(&stack), true); @@ -228,17 +228,17 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo } }, .proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), - .tuple1 => try stack.append(blk: { + .tuple1 => try stack.append(arena, blk: { const tup_values = try arena.alloc(py.Any, 1); tup_values[0] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), - .tuple2 => try stack.append(blk: { + .tuple2 => try stack.append(arena, blk: { const tup_values = try arena.alloc(py.Any, 2); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; }), - .tuple3 => try stack.append(blk: { + .tuple3 => try stack.append(arena, blk: { const tup_values = try arena.alloc(py.Any, 3); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack); break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; @@ -276,32 +276,32 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo }, } }, - .dict => try stack.append(.{ .seq = .{ + .dict => try stack.append(arena, .{ .seq = .{ .type = .dict, .values = try arena.dupe(py.Any, try popMark(&stack)), } }), - .list => try stack.append(.{ .seq = .{ + .list => try stack.append(arena, .{ .seq = .{ .type = .list, .values = try arena.dupe(py.Any, try popMark(&stack)), } }), - .inst => |v| try stack.append(.{ .object = try py.Object.init( + .inst => |v| try stack.append(arena, .{ .object = try py.Object.init( arena, try py.tuple(&.{ .{ .string = v.module }, .{ .string = v.class } }).clone(arena), try arena.dupe(py.Any, try popMark(&stack)), &.{}, ) }), - .obj => try stack.append(blk: { + .obj => try stack.append(arena, blk: { const mark = try findMark(&stack); const args = try arena.dupe(py.Any, stack.items[mark + 2 ..]); const member = stack.items[mark + 1]; break :blk .{ .object = try py.Object.init(arena, member, args, &.{}) }; }), - .newobj => try stack.append(blk: { + .newobj => try stack.append(arena, blk: { const args = try arena.alloc(py.Any, 1); args[0] = try pop(&stack); break :blk .{ .object = try py.Object.init(arena, try pop(&stack), args, &.{}) }; }), - .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }), + .empty_set => try stack.append(arena, .{ .seq = .{ .type = .set, .values = &[_]py.Any{} } }), .additems => { const postmark = try popMark(&stack); const top = try lastMut(&stack); @@ -316,15 +316,15 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo }, } }, - .frozenset => try stack.append(.{ .seq = .{ + .frozenset => try stack.append(arena, .{ .seq = .{ .type = .frozen_set, .values = try arena.dupe(py.Any, try popMark(&stack)), } }), - .newobj_ex => try stack.append(blk: { + .newobj_ex => try stack.append(arena, blk: { const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) }; break :blk .{ .object = try py.Object.init(arena, cls, args.seq.values, kwargs.seq.values) }; }), - .stack_global => try stack.append(blk: { + .stack_global => try stack.append(arena, blk: { const gn, const mn = .{ try memo.resolve(arena, try pop(&stack), true), try memo.resolve(arena, try pop(&stack), true), @@ -338,13 +338,13 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo }; try memo.insert(@intCast(memo.map.count()), try item.clone(arena)); }, - else => try stack.append(.{ .raw = try op.clone(arena) }), + else => try stack.append(arena, .{ .raw = try op.clone(arena) }), } } if (resolve_refs) { return try memo.resolveAllRefsIter(arena, 0, stack.items, true); } - return stack.toOwnedSlice(); + return stack.toOwnedSlice(arena); } fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void { @@ -358,8 +358,9 @@ test evaluate { defer arena.deinit(); const allocator = arena.allocator(); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); - var buffered_reader = std.io.bufferedReader(file.reader()); - const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096); + var reader_buffer: [1024]u8 = undefined; + var reader = file.reader(&reader_buffer); + const ops = try pickle.parse(allocator, &reader.interface); const vals = try evaluate(allocator, ops, true); defer allocator.free(vals); diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig index b3794bb..0307d51 100644 --- a/zml/aio/torch/file.zig +++ b/zml/aio/torch/file.zig @@ -1,14 +1,15 @@ -const asynk = @import("async"); const std = @import("std"); +const testing = std.testing; + +const asynk = @import("async"); const stdx = @import("stdx"); const zml = @import("../../zml.zig"); +const HostBuffer = zml.HostBuffer; +const eval = @import("eval.zig"); const pickle = @import("pickle.zig"); const py = @import("py.zig"); -const eval = @import("eval.zig"); -const HostBuffer = zml.HostBuffer; -const testing = std.testing; const log = std.log.scoped(.@"zml/aio"); // TODO(cryptodeal): use zml.aio.PrefixBuilder instead @@ -20,167 +21,82 @@ test { } pub const File = struct { - buffer_file: zml.aio.MemoryMappedFile, + mmap_file: zml.aio.MemoryMappedFile, /// Map names to sub file - file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, - tar_file: ?TarStream = null, - is_zip_file: bool, - zip_prefix: []const u8 = &.{}, - pickle_subfile: struct { start: u64 = 0, len: usize }, - - pub const FileEntry = struct { - version_needed_to_extract: u16, - flags: u16, - compression_method: std.zip.CompressionMethod, - last_modification_time: u16, - last_modification_date: u16, - header_zip_offset: u64, - crc32: u32, - filename_len: u32, - compressed_size: u64, - uncompressed_size: u64, - file_offset: u64, - - pub fn init(entry: anytype) FileEntry { - return .{ - .version_needed_to_extract = entry.version_needed_to_extract, - .flags = @as(u16, @bitCast(entry.flags)), - .compression_method = entry.compression_method, - .last_modification_time = entry.last_modification_time, - .last_modification_date = entry.last_modification_date, - .header_zip_offset = entry.header_zip_offset, - .crc32 = entry.crc32, - .filename_len = entry.filename_len, - .compressed_size = entry.compressed_size, - .uncompressed_size = entry.uncompressed_size, - .file_offset = entry.file_offset, - }; - } - }; + file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{}, + zip_prefix: []const u8, + pickle_subfile: []const u8, const magic = "PK\x03\x04"; - pub fn fromTarFile(allocator: std.mem.Allocator, mapped: zml.aio.MemoryMappedFile, file: std.tar.Iterator(asynk.File.Reader).File) !File { - const tar_file = try TarStream.init(file); - const file_magic = try tar_file.reader().readBytesNoEof(magic.len); - try tar_file.seekTo(0); - var res: File = .{ - .buffer_file = mapped, - .tar_file = tar_file, - .is_zip_file = std.mem.eql(u8, &file_magic, magic), - .pickle_subfile = .{ .len = try tar_file.getEndPos() }, - }; - if (res.is_zip_file) { - try res.parseZipHeaders(allocator, tar_file.seekableStream()); - } - return res; - } - pub fn init(allocator: std.mem.Allocator, mmap_file: zml.aio.MemoryMappedFile) !File { - const file_magic = try mmap_file.file.reader().readBytesNoEof(magic.len); - try mmap_file.file.seekTo(0); - var res: File = .{ - .buffer_file = mmap_file, - .is_zip_file = std.mem.eql(u8, &file_magic, magic), - .pickle_subfile = .{ .len = mmap_file.data.len }, - }; + var pkl: []const u8 = mmap_file.data; + var zip_prefix: []const u8 = &.{}; + var file_map: std.StringArrayHashMapUnmanaged([]const u8) = .{}; + if (std.mem.eql(u8, mmap_file.data[0..magic.len], magic)) { + // We are dealing with a zip file. + // Let's look for the `data.pkl` file and keep a map of all other files. + // The other files will be the tensor storage and will be reference from `data.pkl`. + var header_parsing_buffer: [4096]u8 = undefined; - if (res.is_zip_file) { - try res.parseZipHeaders(allocator, mmap_file.file.seekableStream()); + // std.zip requires on a std.fs.File and don't leverage std.Io.Reader directly. + // So we use the synchronous API to parse the headers, + // then we rely only on the memory map data to parse the pickle and load the buffers. + // To mitigate this we use `async.launchBlocking` in `torch.open`. + const raw_file: std.fs.File = .{ .handle = mmap_file.file._handle }; + var reader = raw_file.reader(&header_parsing_buffer); + var it: std.zip.Iterator = try .init(&reader); + + while (try it.next()) |header| { + if (header.filename_len == 0) { + continue; + } + if (header.compression_method != .store) { + return error.Unsupported; + } + + const filename = mmap_file.data[header.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader) ..][0..header.filename_len]; + + var local_reader: std.Io.Reader = .fixed(mmap_file.data); + local_reader.discardAll(header.file_offset) catch return error.InvalidZipFile; + const local_header = local_reader.takeStruct(std.zip.LocalFileHeader, .little) catch return error.InvalidZipFile; + local_reader.discardAll(local_header.filename_len) catch return error.InvalidZipFile; + local_reader.discardAll(local_header.extra_len) catch return error.InvalidZipFile; + + // normalize path separators + const file_content = mmap_file.data[local_reader.seek..][0..header.compressed_size]; + const my_filename: []u8 = try allocator.dupe(u8, filename); + std.mem.replaceScalar(u8, my_filename, '\\', '/'); + + try file_map.put(allocator, my_filename, file_content); + + if (std.mem.endsWith(u8, filename, "data.pkl")) { + pkl = file_content; + zip_prefix = filename[0 .. filename.len - "data.pkl".len]; + } + } + + if (pkl.len == 0) { + log.err("Could not find file ending in `data.pkl` in archive", .{}); + return error.PickleNotFound; + } } - return res; + + return .{ + .mmap_file = mmap_file, + .file_map = file_map, + .pickle_subfile = pkl, + .zip_prefix = zip_prefix, + }; } pub fn close(self: *File) void { - self.buffer_file.deinit(); + self.mmap_file.deinit(); } pub fn parsePickle(self: *File, allocator: std.mem.Allocator) ![]const pickle.Op { - return if (self.tar_file) |tar_file| { - try tar_file.seekTo(self.pickle_subfile.start); - var buffered = std.io.bufferedReader(tar_file.reader()); - return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len); - } else { - const file = self.buffer_file.file; - try file.seekTo(self.pickle_subfile.start); - var buffered = std.io.bufferedReader(file.reader()); - return try pickle.parse(allocator, buffered.reader(), self.pickle_subfile.len); - }; - } - - fn parseZipHeaders(self: *File, allocator: std.mem.Allocator, seekable_stream: anytype) !void { - var file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}; - - var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); - var filename_buf: [std.fs.max_path_bytes]u8 = undefined; - while (try iter.next()) |entry| { - const filename = filename_buf[0..entry.filename_len]; - try seekable_stream.seekTo(entry.header_zip_offset + @sizeOf(std.zip.CentralDirectoryFileHeader)); - const len = try seekable_stream.context.reader().readAll(filename); - if (len != filename.len) return error.ZipBadFileOffset; - if (isBadFilename(filename)) return error.ZipBadFilename; - std.mem.replaceScalar(u8, filename, '\\', '/'); // normalize path separators - try file_map.put(allocator, try allocator.dupe(u8, filename), FileEntry.init(entry)); - } - - self.file_map = file_map; - var file_iter = file_map.iterator(); - while (file_iter.next()) |e| { - const entry = e.value_ptr.*; - const filename = e.key_ptr.*; - if (!std.mem.endsWith(u8, filename, "data.pkl")) continue; - - self.zip_prefix = filename[0 .. filename.len - "data.pkl".len]; - - const local_data_header_offset: u64 = local_data_header_offset: { - switch (entry.compression_method) { - .store => {}, - .deflate => { - // TODO(cryptodeal): handle decompress - @panic("TODO support use of `deflate`"); - }, - else => @panic("TODO support other modes of compression"), - } - const local_header = blk: { - try seekable_stream.seekTo(entry.file_offset); - break :blk try seekable_stream.context.reader().readStructEndian(std.zip.LocalFileHeader, .little); - }; - if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) - return error.ZipBadFileOffset; - if (local_header.version_needed_to_extract != entry.version_needed_to_extract) - return error.ZipMismatchVersionNeeded; - if (local_header.last_modification_time != entry.last_modification_time) - return error.ZipMismatchModTime; - if (local_header.last_modification_date != entry.last_modification_date) - return error.ZipMismatchModDate; - - if (@as(u16, @bitCast(local_header.flags)) != entry.flags) - return error.ZipMismatchFlags; - if (local_header.crc32 != 0 and local_header.crc32 != entry.crc32) - return error.ZipMismatchCrc32; - if (local_header.compressed_size != 0 and - local_header.compressed_size != entry.compressed_size) - return error.ZipMismatchCompLen; - if (local_header.uncompressed_size != 0 and - local_header.uncompressed_size != entry.uncompressed_size) - return error.ZipMismatchUncompLen; - if (local_header.filename_len != entry.filename_len) - return error.ZipMismatchFilenameLen; - - break :local_data_header_offset @as(u64, local_header.filename_len) + - @as(u64, local_header.extra_len); - }; - - const local_data_file_offset: u64 = - @as(u64, entry.file_offset) + - @as(u64, @sizeOf(std.zip.LocalFileHeader)) + - local_data_header_offset; - self.pickle_subfile = .{ .start = local_data_file_offset, .len = entry.uncompressed_size }; - return; - } - - log.err("Could not find file ending in `data.pkl` in archive", .{}); - return error.PickleNotFound; + var reader: std.Io.Reader = .fixed(self.pickle_subfile); + return try pickle.parse(allocator, &reader); } fn basicTypeCheck(object: *const py.Object, module: []const u8, class: []const u8) bool { @@ -286,7 +202,7 @@ pub const File = struct { if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); try self.parseValue(allocator, store, new_prefix, val); } } @@ -303,7 +219,7 @@ pub const File = struct { if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); const new_tag = switch (tag) { .int64 => "int", .float64 => "float", @@ -321,7 +237,7 @@ pub const File = struct { if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); } try self.parseValue(allocator, store, new_prefix, item); } @@ -353,7 +269,7 @@ pub const File = struct { if (prefix.items.len > 0) { new_prefix.appendAssumeCapacity('.'); } - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); + new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), int, 10, .lower, .{}); try self.parseValue(allocator, store, new_prefix, val); }, inline else => |_, tag| { @@ -504,34 +420,10 @@ pub const File = struct { /// Given the name of one of the files in the .pt tarball, /// return the slice of the memory-mapped .pt corresponding to it. fn getStorage(self: File, filename: []const u8) ![]const u8 { - const maybe_entry = self.file_map.get(filename); - if (maybe_entry == null) { + return self.file_map.get(filename) orelse { std.log.err("Could not find file ending in `{s}` in archive", .{filename}); return error.TensorNotFound; - } - const entry = maybe_entry.?; - const base_offset: u64 = if (self.tar_file) |t| t.start else 0; - const file_offset: u64 = base_offset + entry.file_offset; - const file = self.buffer_file.file; - try file.seekTo(entry.file_offset); - const local_header = try file.reader().readStructEndian(std.zip.LocalFileHeader, .little); - - if (!std.mem.eql(u8, &local_header.signature, &std.zip.local_file_header_sig)) - return error.ZipBadFileOffset; - if (local_header.compressed_size != 0 and - local_header.compressed_size != entry.compressed_size) - return error.ZipMismatchCompLen; - if (local_header.uncompressed_size != 0 and - local_header.uncompressed_size != entry.uncompressed_size) - return error.ZipMismatchUncompLen; - if (local_header.filename_len != entry.filename_len) - return error.ZipMismatchFilenameLen; - - const start = file_offset + - @sizeOf(std.zip.LocalFileHeader) + - @as(u64, local_header.filename_len) + - @as(u64, local_header.extra_len); - return self.buffer_file.mappedSlice(start, entry.uncompressed_size); + }; } fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray { @@ -578,52 +470,6 @@ fn storageToDtype(storage_type: []const u8) !zml.DataType { }; } -const TarStream = struct { - pub const SeekableStream = std.io.SeekableStream( - TarStream, - asynk.File.SeekError, - asynk.File.GetSeekPosError, - TarStream.seekTo, - TarStream.seekBy, - TarStream.getPos, - TarStream.getEndPos, - ); - - file: std.tar.Iterator(asynk.File.Reader).File, - start: usize, - - pub fn init(file: std.tar.Iterator(asynk.File.Reader).File) !TarStream { - return .{ - .file = file, - .start = try file.parent_reader.context.getPos(), - }; - } - - pub fn reader(file: TarStream) std.tar.Iterator(asynk.File.Reader).File.Reader { - return file.file.reader(); - } - - pub fn seekTo(self: TarStream, offset: u64) !void { - return self.file.parent_reader.context.seekTo(self.start + offset); - } - - pub fn seekBy(self: TarStream, offset: i64) !void { - return self.file.parent_reader.context.seekBy(offset); - } - - pub fn getPos(self: TarStream) !u64 { - return try self.file.parent_reader.context.getPos() - self.start; - } - - pub fn getEndPos(self: TarStream) !u64 { - return self.file.size; - } - - pub fn seekableStream(self: TarStream) TarStream.SeekableStream { - return .{ .context = self }; - } -}; - test "Read pickle (zipped)" { // test file created with following python snippet: // @@ -638,20 +484,19 @@ test "Read pickle (zipped)" { defer store.deinit(); { - var tmp_arena = std.heap.ArenaAllocator.init(testing.allocator); - defer tmp_arena.deinit(); - const tmp_alloc = tmp_arena.allocator(); - var torch_file = try File.init(tmp_alloc, mmap_file); + var arena = std.heap.ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + var torch_file = try File.init(arena.allocator(), mmap_file); // We don't close the file directly, it will be closed by the store. - const ops = try torch_file.parsePickle(tmp_alloc); + const ops = try torch_file.parsePickle(arena.allocator()); try std.testing.expectEqual(302, ops.len); - const py_values = try eval.evaluate(tmp_alloc, ops, true); + const py_values = try eval.evaluate(arena.allocator(), ops, true); try torch_file.parseModel(py_values, &store); } - // now we have freed the tmp_arena. + // now we have freed the arena. // all data needed should have been copied into the store arena. try zml.testing.expectEqualShapes( zml.Shape.init(.{ 1, 4 }, .u8), diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index d9d2bcd..9174eb1 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -763,54 +763,54 @@ pub const Op = union(enum) { }; /// Read a stream of bytes, and interpret it as a stream of Pickle operators. -pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) ![]const Op { - var results = std.ArrayList(Op).init(allocator); - errdefer results.deinit(); - const len = max_line_len; - var _buf: std.BoundedArray(u8, 12) = .{}; +/// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies. +pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { + // It's not very efficient to interleave the results with the data copied from the stream, + // because growth event in the results ArrayList will lead to fragmentation. + // Trying to mitigate that by using a generous default size. + var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512); + errdefer results.deinit(arena); + var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512); while (true) { - const b = try reader.readByte(); - const code: OpCode = @enumFromInt(b); + const code: OpCode = @enumFromInt(try reader.takeByte()); const op: Op = switch (code) { - .int => blk: { - _buf.len = 0; - try reader.streamUntilDelimiter(_buf.writer(), '\n', _buf.capacity() + 1); - const buf = _buf.constSlice(); + .int => int: { + const bytes = try reader.takeDelimiterExclusive('\n'); // Legacy hack, see OpCode.int documentation // We do this parsing right away to simplify downstream code. - break :blk if (std.mem.eql(u8, "00", buf)) + break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0') .{ .bool = false } - else if (std.mem.eql(u8, "01", buf)) + else if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '1') .{ .bool = true } else - .{ .int = try std.fmt.parseInt(i32, buf, 10) }; + .{ .int = try std.fmt.parseInt(i32, bytes, 10) }; }, - .binint => .{ .int = try reader.readInt(i32, .little) }, - .binint1 => .{ .int = try reader.readByte() }, - .binint2 => .{ .int = try reader.readInt(u16, .little) }, + .binint => .{ .int = try reader.takeInt(i32, .little) }, + .binint1 => .{ .int = try reader.takeByte() }, + .binint2 => .{ .int = try reader.takeInt(u16, .little) }, // TODO: long should handle the trailing 'L' -> add a test. - .long => .{ .long = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, - .long1 => .{ .binlong = try _readSlice(reader, allocator, 1) }, - .long4 => .{ .binlong = try _readSlice(reader, allocator, 4) }, - .string => .{ .string = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, - .binstring => .{ .string = try _readSlice(reader, allocator, 4) }, - .short_binstring => .{ .string = try _readSlice(reader, allocator, 1) }, - .binbytes => .{ .bytes = try _readSlice(reader, allocator, 4) }, - .binbytes8 => .{ .bytes = try _readSlice(reader, allocator, 8) }, - .short_binbytes => .{ .bytes = try _readSlice(reader, allocator, 1) }, - .bytearray8 => .{ .bytearray = try _readSlice(reader, allocator, 8) }, + .long => .{ .long = try readLine(reader, &alloc_writer) }, + .long1 => .{ .binlong = try _readSlice(reader, arena, 1) }, + .long4 => .{ .binlong = try _readSlice(reader, arena, 4) }, + .string => .{ .string = try readLine(reader, &alloc_writer) }, + .binstring => .{ .string = try _readSlice(reader, arena, 4) }, + .short_binstring => .{ .string = try _readSlice(reader, arena, 1) }, + .binbytes => .{ .bytes = try _readSlice(reader, arena, 4) }, + .binbytes8 => .{ .bytes = try _readSlice(reader, arena, 8) }, + .short_binbytes => .{ .bytes = try _readSlice(reader, arena, 1) }, + .bytearray8 => .{ .bytearray = try _readSlice(reader, arena, 8) }, .next_buffer => .next_buffer, .readonly_buffer => .readonly_buffer, .none => .none, .newtrue => .{ .bool = true }, .newfalse => .{ .bool = false }, - .unicode => .{ .unicode = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, - .short_binunicode => .{ .unicode = try _readSlice(reader, allocator, 1) }, - .binunicode => .{ .unicode = try _readSlice(reader, allocator, 4) }, - .binunicode8 => .{ .unicode = try _readSlice(reader, allocator, 8) }, - .float => .{ .float = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, - .binfloat => .{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }, + .unicode => .{ .unicode = try readLine(reader, &alloc_writer) }, + .short_binunicode => .{ .unicode = try _readSlice(reader, arena, 1) }, + .binunicode => .{ .unicode = try _readSlice(reader, arena, 4) }, + .binunicode8 => .{ .unicode = try _readSlice(reader, arena, 8) }, + .float => .{ .float = try readLine(reader, &alloc_writer) }, + .binfloat => .{ .binfloat = @bitCast(try reader.takeInt(u64, .big)) }, .empty_list => .empty_list, .append => .append, .appends => .appends, @@ -832,74 +832,74 @@ pub fn parse(allocator: std.mem.Allocator, reader: anytype, max_line_len: usize) .mark => .mark, .pop_mark => .pop_mark, // If we fail to parse delay the error to the evaluation. - .get => .{ - .get = _readDigits(u32, reader, &_buf) catch std.math.maxInt(u32), + .get => get: { + const digits = try reader.takeDelimiterExclusive('\n'); + break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) }; }, - .binget => .{ .get = try reader.readByte() }, - .long_binget => .{ .get = try reader.readInt(u32, .little) }, - .put => blk: { - const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len); - defer allocator.free(buf); - const n = std.fmt.parseInt(u32, buf, 10) catch std.math.maxInt(u32); - break :blk .{ .put = n }; + .binget => .{ .get = try reader.takeByte() }, + .long_binget => .{ .get = try reader.takeInt(u32, .little) }, + .put => put: { + const digits = try reader.takeDelimiterExclusive('\n'); + break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) }; }, - .binput => .{ .put = try reader.readByte() }, - .long_binput => .{ .put = try reader.readInt(u32, .little) }, + .binput => .{ .put = try reader.takeByte() }, + .long_binput => .{ .put = try reader.takeInt(u32, .little) }, .memoize => .memoize, - .ext1 => .{ .ext1 = try reader.readByte() }, - .ext2 => .{ .ext2 = try reader.readInt(i16, .little) }, - .ext4 => .{ .ext4 = try reader.readInt(i32, .little) }, + .ext1 => .{ .ext1 = try reader.takeByte() }, + .ext2 => .{ .ext2 = try reader.takeInt(i16, .little) }, + .ext4 => .{ .ext4 = try reader.takeInt(i32, .little) }, .global => .{ .global = .{ - .module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), - .class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + .module = try readLine(reader, &alloc_writer), + .class = try readLine(reader, &alloc_writer), } }, .stack_global => .stack_global, .reduce => .reduce, .build => .build, .inst => .{ .inst = .{ - .module = try reader.readUntilDelimiterAlloc(allocator, '\n', len), - .class = try reader.readUntilDelimiterAlloc(allocator, '\n', len), + .module = try readLine(reader, &alloc_writer), + .class = try readLine(reader, &alloc_writer), } }, .obj => .obj, .newobj => .newobj, .newobj_ex => .newobj_ex, .proto => blk: { - const version = try reader.readByte(); + const version = try reader.takeByte(); if (version > 5) log.warn("zml.aio.torch.pickle.parse expects a Python pickle object of version <=5, got version {}. Will try to interpret anyway, but this may lead to more errors.", .{version}); break :blk .{ .proto = version }; }, .stop => .stop, - // This is not documented in pickletools but in https://peps.python.org/pep-3154/ - // The frame size is stored right after the frame header. - // The loader is allowed to prefetch framesize from the underlying reader, - // and ops are not allowed to cross a frame boundary. - // We don't prefetch because we assume the reader is going to use some kind of buffered reader. - // We could try to enforce frame boundaries, but we would need to track - // how many bytes we are reading from the stream. - .frame => .{ .frame = try reader.readInt(u64, .little) }, - .persid => .{ .persid = try reader.readUntilDelimiterAlloc(allocator, '\n', len) }, + .frame => frame: { + // This is not documented in pickletools but in https://peps.python.org/pep-3154/ + // The loader is allowed to prefetch framesize from the underlying reader, + // and ops are not allowed to cross a frame boundary. + const frame_size = try reader.takeInt(u64, .little); + reader.fill(@min(frame_size, reader.buffer.len)) catch |err| switch (err) { + error.EndOfStream => {}, + else => return err, + }; + break :frame .{ .frame = frame_size }; + }, + .persid => .{ .persid = try readLine(reader, &alloc_writer) }, .binpersid => .binpersid, _ => |unk_tag| { log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag}); return error.NotSupported; }, }; - try results.append(op); + try results.append(arena, op); if (op == .stop) break; } - return results.toOwnedSlice(); + return results.items; } test "parse protocol 4" { - const allocator = std.testing.allocator; + var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); + defer arena.deinit(); + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); - var buffered_reader = std.io.bufferedReader(file.reader()); - const ops = try parse(allocator, buffered_reader.reader(), 4096); - defer { - // Test we are correctly freeing every allocation. - for (ops) |op| op.deinit(allocator); - allocator.free(ops); - } + var read_buffer: [1024]u8 = undefined; + var reader = file.reader(&read_buffer); + const ops = try parse(arena.allocator(), &reader.interface); // this can be obtained by running: `python -m pickletools simple_test_4.pickle` var expected = [_]Op{ @@ -948,7 +948,9 @@ test "parse protocol 4" { test "parse protocol 0" { // We also test protocol 0, cause it's more text oriented. - const allocator = std.testing.allocator; + var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); + defer arena.deinit(); + const pickle_0 = \\(dp0 \\Vhello @@ -982,13 +984,8 @@ test "parse protocol 0" { \\s. ; - var stream = std.io.fixedBufferStream(pickle_0); - const ops = try parse(allocator, stream.reader(), 4096); - defer { - // Test we are correctly freeing every allocation. - for (ops) |op| op.deinit(allocator); - allocator.free(ops); - } + var reader: std.Io.Reader = .fixed(pickle_0); + const ops = try parse(arena.allocator(), &reader); var expected = [_]Op{ .mark, @@ -1043,18 +1040,11 @@ test "parse protocol 0" { try std.testing.expectEqualDeep(&expected, ops); } -fn _readDigits(comptime T: type, reader: anytype, buffer: *std.BoundedArray(u8, 12)) !T { - buffer.len = 0; - try reader.streamUntilDelimiter(buffer.writer(), '\n', 13); - return std.fmt.parseInt(T, buffer.constSlice(), 10); -} - fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { const T = std.meta.Int(.unsigned, 8 * len_bytes); - const str_len: u64 = try reader.readInt(T, .little); + const str_len: u64 = try reader.takeInt(T, .little); const buf = try allocator.alloc(u8, str_len); - errdefer allocator.free(buf); - _ = try reader.read(buf); + _ = try reader.readSliceAll(buf); return buf; } @@ -1063,3 +1053,14 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8) std.mem.writeInt(T, &res, value, .little); return res; } + +fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 { + const n = try reader.streamDelimiter(&alloc_writer.writer, '\n'); + std.debug.assert(try reader.takeByte() == '\n'); + const w = &alloc_writer.writer; + std.debug.assert(w.end == n); + const items = w.buffer[0..n]; + w.buffer = w.buffer[n + 1 ..]; + w.end = 0; + return items; +} diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 117cb1e..c2099b0 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -172,7 +172,7 @@ pub const HostBuffer = struct { // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32. stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) }); stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self}); - const ptr: [*]const T = @alignCast(@ptrCast(self._data)); + const ptr: [*]const T = @ptrCast(@alignCast(self._data)); return ptr[0..self._shape.count()]; } diff --git a/zml/module.zig b/zml/module.zig index 4e40882..e57f942 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -664,7 +664,7 @@ pub const CompilationContext = struct { // Create the result tensor object by combining the operand results, // as well as the registered shapes and donations. // Note: this assume res can be stack-allocated. - var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*; + var res = @as(*const stdx.meta.FnResult(func), @ptrCast(@alignCast(function.res_tensors))).*; const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation }; var context: LocalContext = .{ .op = op, .function = function, .donations = donations }; meta.visit((struct { diff --git a/zml/platform.zig b/zml/platform.zig index 0bd8a2a..838cac4 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -113,7 +113,7 @@ const _CreateOptions = struct { /// "Best-Fit with Coalescing" algorithm bfc: Options, /// use cudaMallocAsync - @"async": Options, + async: Options, /// use raw cuMalloc platform, @@ -129,7 +129,7 @@ const _CreateOptions = struct { .platform => { values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); }, - .bfc, .@"async" => |opt| { + .bfc, .async => |opt| { values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator)); values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate)); if (opt.memory_fraction > 0) { diff --git a/zml/tools/tracer.zig b/zml/tools/tracer.zig index 06bcd6f..d369b1e 100644 --- a/zml/tools/tracer.zig +++ b/zml/tools/tracer.zig @@ -1,4 +1,5 @@ const builtin = @import("builtin"); + const c = @import("c"); pub const Tracer = switch (builtin.os.tag) { @@ -11,15 +12,15 @@ const CudaTracer = struct { // Those symbols are defined in cudaProfiler.h but their implementation is in libcuda.so // They will be bound at call time after libcuda.so is loaded (as a needed dependency of libpjrt_cuda.so). - const cuProfilerStart = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable; - const cuProfilerStop = @extern(*const fn () callconv(.C) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable; + const cuProfilerStart = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStart", .linkage = .weak }) orelse unreachable; + const cuProfilerStop = @extern(*const fn () callconv(.c) c_int, .{ .name = "cuProfilerStop", .linkage = .weak }) orelse unreachable; // Those symbols are defined in nvToolsExt.h which we don't want to provide. // However, we link with libnvToolsExt.so which provides them. // They will be bound at call time after libnvToolsExt.so is loaded (manually dlopen'ed by us). - const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.C) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable; - const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.C) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable; - const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.C) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable; + const nvtxMarkA = @extern(*const fn ([*:0]const u8) callconv(.c) void, .{ .name = "nvtxMarkA", .linkage = .weak }) orelse unreachable; + const nvtxRangeStartA = @extern(*const fn ([*:0]const u8) callconv(.c) c_int, .{ .name = "nvtxRangeStartA", .linkage = .weak }) orelse unreachable; + const nvtxRangeEnd = @extern(*const fn (c_int) callconv(.c) void, .{ .name = "nvtxRangeEnd", .linkage = .weak }) orelse unreachable; pub fn init(name: [:0]const u8) CudaTracer { _ = name;