zml: cleanup some todos
This commit is contained in:
parent
77cd21d2b2
commit
d056fd3511
@ -12,9 +12,6 @@ const py = @import("py.zig");
|
||||
|
||||
const log = std.log.scoped(.@"zml/aio");
|
||||
|
||||
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
||||
const StringBuilder = std.ArrayList(u8);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(File);
|
||||
@ -111,11 +108,11 @@ pub const File = struct {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
const allocator = store.arena.allocator();
|
||||
for (values) |item| {
|
||||
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
|
||||
try self.parseValue(allocator, store, .initBuffer(&prefix_buf), item);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void {
|
||||
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !void {
|
||||
// log.warn("Parsing {}", .{v});
|
||||
switch (v) {
|
||||
.app, .object, .global => |object| {
|
||||
@ -303,7 +300,7 @@ pub const File = struct {
|
||||
}
|
||||
}
|
||||
|
||||
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool {
|
||||
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !bool {
|
||||
return switch (v) {
|
||||
.global => |object| {
|
||||
if (try self.parseTensor(allocator, object)) |host_buffer| {
|
||||
|
||||
@ -789,7 +789,6 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
|
||||
.binint => .{ .int = try reader.takeInt(i32, .little) },
|
||||
.binint1 => .{ .int = try reader.takeByte() },
|
||||
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
|
||||
// TODO: long should handle the trailing 'L' -> add a test.
|
||||
.long => .{ .long = try readLine(reader, &alloc_writer) },
|
||||
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
|
||||
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
|
||||
@ -902,7 +901,7 @@ test "parse protocol 4" {
|
||||
const ops = try parse(arena.allocator(), &reader.interface);
|
||||
|
||||
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
|
||||
var expected = [_]Op{
|
||||
const expected: []const Op = &.{
|
||||
.{ .proto = 4 },
|
||||
.{ .frame = 119 },
|
||||
.empty_dict,
|
||||
@ -943,7 +942,7 @@ test "parse protocol 4" {
|
||||
.setitems,
|
||||
.stop,
|
||||
};
|
||||
try std.testing.expectEqualDeep(&expected, ops);
|
||||
try std.testing.expectEqualDeep(expected, ops);
|
||||
}
|
||||
|
||||
test "parse protocol 0" {
|
||||
|
||||
@ -442,8 +442,7 @@ pub const Float4E2M1 = packed struct(u4) {
|
||||
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
|
||||
|
||||
pub fn toF32(x: Float4E2M1) f32 {
|
||||
// the baseline toF32 doesn't work correctly:
|
||||
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
|
||||
// faster implementation
|
||||
return values[@as(u4, @bitCast(x))];
|
||||
}
|
||||
|
||||
@ -457,9 +456,6 @@ pub const Float4E2M1 = packed struct(u4) {
|
||||
}
|
||||
|
||||
test fromF32 {
|
||||
// the baseline fromF32 doesn't work correctly:
|
||||
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
|
||||
// TODO: it probably affects other types.
|
||||
var from_f32_res: [16]Float4E2M1 = undefined;
|
||||
for (&from_f32_res, 0..) |*r, i| {
|
||||
r.* = .fromF32(Float4E2M1.values[i]);
|
||||
|
||||
@ -202,11 +202,6 @@ pub const HostBuffer = struct {
|
||||
return self._strides[0..self._shape.rank()];
|
||||
}
|
||||
|
||||
// TODO: rename .data into ._data and make it a [*]u8
|
||||
// pub fn data(self: HostBuffer) []const u8 {
|
||||
// return self.data;
|
||||
// }
|
||||
|
||||
pub inline fn rank(self: HostBuffer) u4 {
|
||||
return self._shape.rank();
|
||||
}
|
||||
|
||||
@ -53,10 +53,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
|
||||
// TODO(0.15.0) handle format
|
||||
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||
const bare_fmt = false;
|
||||
try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape});
|
||||
try writer.print("Tensor({f})", .{self._shape});
|
||||
}
|
||||
|
||||
/// Returns the shape of a Tensor.
|
||||
@ -4081,11 +4078,10 @@ test "Tensor.maxPool2d" {
|
||||
);
|
||||
}
|
||||
|
||||
/// Return a clone of a type with Tensors replaced by Buffer.
|
||||
/// Non-Tensor metadata is stripped out of the resulting struct.
|
||||
/// Recursively descends into the type.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
// TODO: we should strip out the non-buffer fields.
|
||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||
// Also it will simplify the layout of the Bufferized struct.
|
||||
// accelerating the calls to execute.
|
||||
return meta.MapRestrict(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user