From d056fd35112aac7a3dc5776f448e7489cad45435 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 6 Oct 2025 15:29:57 +0000 Subject: [PATCH] zml: cleanup some todos --- zml/aio/torch/file.zig | 9 +++------ zml/aio/torch/pickle.zig | 5 ++--- zml/floats.zig | 6 +----- zml/hostbuffer.zig | 5 ----- zml/tensor.zig | 12 ++++-------- 5 files changed, 10 insertions(+), 27 deletions(-) diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig index c1588b4..42da119 100644 --- a/zml/aio/torch/file.zig +++ b/zml/aio/torch/file.zig @@ -12,9 +12,6 @@ const py = @import("py.zig"); const log = std.log.scoped(.@"zml/aio"); -// TODO(cryptodeal): use zml.aio.PrefixBuilder instead -const StringBuilder = std.ArrayList(u8); - test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(File); @@ -111,11 +108,11 @@ pub const File = struct { var prefix_buf: [1024]u8 = undefined; const allocator = store.arena.allocator(); for (values) |item| { - try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); + try self.parseValue(allocator, store, .initBuffer(&prefix_buf), item); } } - pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void { + pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !void { // log.warn("Parsing {}", .{v}); switch (v) { .app, .object, .global => |object| { @@ -303,7 +300,7 @@ pub const File = struct { } } - fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool { + fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !bool { return switch (v) { .global => |object| { if (try self.parseTensor(allocator, object)) |host_buffer| { diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 5724ab0..6ed3a1b 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -789,7 +789,6 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { .binint => .{ .int = try reader.takeInt(i32, .little) }, .binint1 => .{ .int = try reader.takeByte() }, .binint2 => .{ .int = try reader.takeInt(u16, .little) }, - // TODO: long should handle the trailing 'L' -> add a test. .long => .{ .long = try readLine(reader, &alloc_writer) }, .long1 => .{ .binlong = try _readSlice(reader, arena, 1) }, .long4 => .{ .binlong = try _readSlice(reader, arena, 4) }, @@ -902,7 +901,7 @@ test "parse protocol 4" { const ops = try parse(arena.allocator(), &reader.interface); // this can be obtained by running: `python -m pickletools simple_test_4.pickle` - var expected = [_]Op{ + const expected: []const Op = &.{ .{ .proto = 4 }, .{ .frame = 119 }, .empty_dict, @@ -943,7 +942,7 @@ test "parse protocol 4" { .setitems, .stop, }; - try std.testing.expectEqualDeep(&expected, ops); + try std.testing.expectEqualDeep(expected, ops); } test "parse protocol 0" { diff --git a/zml/floats.zig b/zml/floats.zig index 7d995c4..5424048 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -442,8 +442,7 @@ pub const Float4E2M1 = packed struct(u4) { pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 }; pub fn toF32(x: Float4E2M1) f32 { - // the baseline toF32 doesn't work correctly: - // 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon + // faster implementation return values[@as(u4, @bitCast(x))]; } @@ -457,9 +456,6 @@ pub const Float4E2M1 = packed struct(u4) { } test fromF32 { - // the baseline fromF32 doesn't work correctly: - // ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead. - // TODO: it probably affects other types. var from_f32_res: [16]Float4E2M1 = undefined; for (&from_f32_res, 0..) |*r, i| { r.* = .fromF32(Float4E2M1.values[i]); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 2052970..0469f1e 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -202,11 +202,6 @@ pub const HostBuffer = struct { return self._strides[0..self._shape.rank()]; } - // TODO: rename .data into ._data and make it a [*]u8 - // pub fn data(self: HostBuffer) []const u8 { - // return self.data; - // } - pub inline fn rank(self: HostBuffer) u4 { return self._shape.rank(); } diff --git a/zml/tensor.zig b/zml/tensor.zig index 58798a0..d5acc54 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -53,10 +53,7 @@ pub const Tensor = struct { } pub fn format(self: Tensor, writer: *std.Io.Writer) !void { - // TODO(0.15.0) handle format - // const bare_fmt = fmt.len == 1 and fmt[0] == '_'; - const bare_fmt = false; - try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape}); + try writer.print("Tensor({f})", .{self._shape}); } /// Returns the shape of a Tensor. @@ -4081,11 +4078,10 @@ test "Tensor.maxPool2d" { ); } +/// Return a clone of a type with Tensors replaced by Buffer. +/// Non-Tensor metadata is stripped out of the resulting struct. +/// Recursively descends into the type. pub fn Bufferized(comptime T: type) type { - // TODO: we should strip out the non-buffer fields. - // Currently it's confusing cause the Bufferized struct contains field that are never read. - // Also it will simplify the layout of the Bufferized struct. - // accelerating the calls to execute. return meta.MapRestrict(Tensor, Buffer).map(T); }