diff --git a/zml/testing.zig b/zml/testing.zig index 5a3a3bc..2afecee 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -70,15 +70,33 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { return error.TestUnexpectedResult; } switch (left.dtype()) { - inline .f16, .f32, .f64 => |t| { + inline .bf16, + .f16, + .f32, + .f64, + .f8e4m3b11fnuz, + .f8e4m3fn, + .f8e4m3fnuz, + .f8e5m2, + .f8e5m2fnuz, + => |t| { const L = t.toZigType(); const left_data = left.items(L); switch (right.dtype()) { - inline .f16, .bf16, .f32, .f64, .f8e4m3fn => |rt| { + inline .bf16, + .f16, + .f32, + .f64, + .f8e4m3b11fnuz, + .f8e4m3fn, + .f8e4m3fnuz, + .f8e5m2, + .f8e5m2fnuz, + => |rt| { const R = rt.toZigType(); const right_data = right.items(R); for (left_data, right_data, 0..) |l, r, i| { - if (!approxEq(L, l, zml.floats.floatCast(L, r), @floatCast(tolerance))) { + if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) { log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); return error.TestUnexpectedResult; } @@ -87,16 +105,16 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { else => unreachable, } }, - inline .u8, .u16, .u32, .i16, .i32, .i64 => |t| { + inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| { const T = t.toZigType(); const left_data = left.items(T); const right_data = right.items(T); if (!std.mem.eql(T, left_data, right_data)) { - log.err("left.data ({d}) != right.data ({d})", .{ left_data[0..10], right_data[0..10] }); + log.err("left.data ({any}) != right.data ({any})", .{ left_data[0..10], right_data[0..10] }); return error.TestUnexpectedResult; } }, - else => unreachable, + .c64, .c128 => @panic("TODO: support comparison of complex"), } }