zml: cleanup some todos

This commit is contained in:
Tarry Singh 2025-10-06 15:29:57 +00:00
parent 77cd21d2b2
commit d056fd3511
5 changed files with 10 additions and 27 deletions

View File

@ -12,9 +12,6 @@ const py = @import("py.zig");
const log = std.log.scoped(.@"zml/aio");
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
const StringBuilder = std.ArrayList(u8);
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(File);
@ -111,11 +108,11 @@ pub const File = struct {
var prefix_buf: [1024]u8 = undefined;
const allocator = store.arena.allocator();
for (values) |item| {
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
try self.parseValue(allocator, store, .initBuffer(&prefix_buf), item);
}
}
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void {
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !void {
// log.warn("Parsing {}", .{v});
switch (v) {
.app, .object, .global => |object| {
@ -303,7 +300,7 @@ pub const File = struct {
}
}
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool {
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !bool {
return switch (v) {
.global => |object| {
if (try self.parseTensor(allocator, object)) |host_buffer| {

View File

@ -789,7 +789,6 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
.binint => .{ .int = try reader.takeInt(i32, .little) },
.binint1 => .{ .int = try reader.takeByte() },
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
// TODO: long should handle the trailing 'L' -> add a test.
.long => .{ .long = try readLine(reader, &alloc_writer) },
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
@ -902,7 +901,7 @@ test "parse protocol 4" {
const ops = try parse(arena.allocator(), &reader.interface);
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
var expected = [_]Op{
const expected: []const Op = &.{
.{ .proto = 4 },
.{ .frame = 119 },
.empty_dict,
@ -943,7 +942,7 @@ test "parse protocol 4" {
.setitems,
.stop,
};
try std.testing.expectEqualDeep(&expected, ops);
try std.testing.expectEqualDeep(expected, ops);
}
test "parse protocol 0" {

View File

@ -442,8 +442,7 @@ pub const Float4E2M1 = packed struct(u4) {
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
pub fn toF32(x: Float4E2M1) f32 {
// the baseline toF32 doesn't work correctly:
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
// faster implementation
return values[@as(u4, @bitCast(x))];
}
@ -457,9 +456,6 @@ pub const Float4E2M1 = packed struct(u4) {
}
test fromF32 {
// the baseline fromF32 doesn't work correctly:
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
// TODO: it probably affects other types.
var from_f32_res: [16]Float4E2M1 = undefined;
for (&from_f32_res, 0..) |*r, i| {
r.* = .fromF32(Float4E2M1.values[i]);

View File

@ -202,11 +202,6 @@ pub const HostBuffer = struct {
return self._strides[0..self._shape.rank()];
}
// TODO: rename .data into ._data and make it a [*]u8
// pub fn data(self: HostBuffer) []const u8 {
// return self.data;
// }
pub inline fn rank(self: HostBuffer) u4 {
return self._shape.rank();
}

View File

@ -53,10 +53,7 @@ pub const Tensor = struct {
}
pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
// TODO(0.15.0) handle format
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
const bare_fmt = false;
try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape});
try writer.print("Tensor({f})", .{self._shape});
}
/// Returns the shape of a Tensor.
@ -4081,11 +4078,10 @@ test "Tensor.maxPool2d" {
);
}
/// Return a clone of a type with Tensors replaced by Buffer.
/// Non-Tensor metadata is stripped out of the resulting struct.
/// Recursively descends into the type.
pub fn Bufferized(comptime T: type) type {
// TODO: we should strip out the non-buffer fields.
// Currently it's confusing cause the Bufferized struct contains field that are never read.
// Also it will simplify the layout of the Bufferized struct.
// accelerating the calls to execute.
return meta.MapRestrict(Tensor, Buffer).map(T);
}