zml: cleanup some todos
This commit is contained in:
parent
77cd21d2b2
commit
d056fd3511
@ -12,9 +12,6 @@ const py = @import("py.zig");
|
|||||||
|
|
||||||
const log = std.log.scoped(.@"zml/aio");
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
|
||||||
const StringBuilder = std.ArrayList(u8);
|
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(File);
|
std.testing.refAllDecls(File);
|
||||||
@ -111,11 +108,11 @@ pub const File = struct {
|
|||||||
var prefix_buf: [1024]u8 = undefined;
|
var prefix_buf: [1024]u8 = undefined;
|
||||||
const allocator = store.arena.allocator();
|
const allocator = store.arena.allocator();
|
||||||
for (values) |item| {
|
for (values) |item| {
|
||||||
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
|
try self.parseValue(allocator, store, .initBuffer(&prefix_buf), item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !void {
|
pub fn parseValue(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !void {
|
||||||
// log.warn("Parsing {}", .{v});
|
// log.warn("Parsing {}", .{v});
|
||||||
switch (v) {
|
switch (v) {
|
||||||
.app, .object, .global => |object| {
|
.app, .object, .global => |object| {
|
||||||
@ -303,7 +300,7 @@ pub const File = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, v: py.Any) !bool {
|
fn parseTorchGlobal(self: File, allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: std.ArrayList(u8), v: py.Any) !bool {
|
||||||
return switch (v) {
|
return switch (v) {
|
||||||
.global => |object| {
|
.global => |object| {
|
||||||
if (try self.parseTensor(allocator, object)) |host_buffer| {
|
if (try self.parseTensor(allocator, object)) |host_buffer| {
|
||||||
|
|||||||
@ -789,7 +789,6 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
|
|||||||
.binint => .{ .int = try reader.takeInt(i32, .little) },
|
.binint => .{ .int = try reader.takeInt(i32, .little) },
|
||||||
.binint1 => .{ .int = try reader.takeByte() },
|
.binint1 => .{ .int = try reader.takeByte() },
|
||||||
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
|
.binint2 => .{ .int = try reader.takeInt(u16, .little) },
|
||||||
// TODO: long should handle the trailing 'L' -> add a test.
|
|
||||||
.long => .{ .long = try readLine(reader, &alloc_writer) },
|
.long => .{ .long = try readLine(reader, &alloc_writer) },
|
||||||
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
|
.long1 => .{ .binlong = try _readSlice(reader, arena, 1) },
|
||||||
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
|
.long4 => .{ .binlong = try _readSlice(reader, arena, 4) },
|
||||||
@ -902,7 +901,7 @@ test "parse protocol 4" {
|
|||||||
const ops = try parse(arena.allocator(), &reader.interface);
|
const ops = try parse(arena.allocator(), &reader.interface);
|
||||||
|
|
||||||
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
|
// this can be obtained by running: `python -m pickletools simple_test_4.pickle`
|
||||||
var expected = [_]Op{
|
const expected: []const Op = &.{
|
||||||
.{ .proto = 4 },
|
.{ .proto = 4 },
|
||||||
.{ .frame = 119 },
|
.{ .frame = 119 },
|
||||||
.empty_dict,
|
.empty_dict,
|
||||||
@ -943,7 +942,7 @@ test "parse protocol 4" {
|
|||||||
.setitems,
|
.setitems,
|
||||||
.stop,
|
.stop,
|
||||||
};
|
};
|
||||||
try std.testing.expectEqualDeep(&expected, ops);
|
try std.testing.expectEqualDeep(expected, ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
test "parse protocol 0" {
|
test "parse protocol 0" {
|
||||||
|
|||||||
@ -442,8 +442,7 @@ pub const Float4E2M1 = packed struct(u4) {
|
|||||||
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
|
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
|
||||||
|
|
||||||
pub fn toF32(x: Float4E2M1) f32 {
|
pub fn toF32(x: Float4E2M1) f32 {
|
||||||
// the baseline toF32 doesn't work correctly:
|
// faster implementation
|
||||||
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
|
|
||||||
return values[@as(u4, @bitCast(x))];
|
return values[@as(u4, @bitCast(x))];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,9 +456,6 @@ pub const Float4E2M1 = packed struct(u4) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test fromF32 {
|
test fromF32 {
|
||||||
// the baseline fromF32 doesn't work correctly:
|
|
||||||
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
|
|
||||||
// TODO: it probably affects other types.
|
|
||||||
var from_f32_res: [16]Float4E2M1 = undefined;
|
var from_f32_res: [16]Float4E2M1 = undefined;
|
||||||
for (&from_f32_res, 0..) |*r, i| {
|
for (&from_f32_res, 0..) |*r, i| {
|
||||||
r.* = .fromF32(Float4E2M1.values[i]);
|
r.* = .fromF32(Float4E2M1.values[i]);
|
||||||
|
|||||||
@ -202,11 +202,6 @@ pub const HostBuffer = struct {
|
|||||||
return self._strides[0..self._shape.rank()];
|
return self._strides[0..self._shape.rank()];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rename .data into ._data and make it a [*]u8
|
|
||||||
// pub fn data(self: HostBuffer) []const u8 {
|
|
||||||
// return self.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub inline fn rank(self: HostBuffer) u4 {
|
pub inline fn rank(self: HostBuffer) u4 {
|
||||||
return self._shape.rank();
|
return self._shape.rank();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -53,10 +53,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
|
pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
|
||||||
// TODO(0.15.0) handle format
|
try writer.print("Tensor({f})", .{self._shape});
|
||||||
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
|
||||||
const bare_fmt = false;
|
|
||||||
try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the shape of a Tensor.
|
/// Returns the shape of a Tensor.
|
||||||
@ -4081,11 +4078,10 @@ test "Tensor.maxPool2d" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a clone of a type with Tensors replaced by Buffer.
|
||||||
|
/// Non-Tensor metadata is stripped out of the resulting struct.
|
||||||
|
/// Recursively descends into the type.
|
||||||
pub fn Bufferized(comptime T: type) type {
|
pub fn Bufferized(comptime T: type) type {
|
||||||
// TODO: we should strip out the non-buffer fields.
|
|
||||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
|
||||||
// Also it will simplify the layout of the Bufferized struct.
|
|
||||||
// accelerating the calls to execute.
|
|
||||||
return meta.MapRestrict(Tensor, Buffer).map(T);
|
return meta.MapRestrict(Tensor, Buffer).map(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user