Extend tests to handle all float types, preventing crashes with bfloat16 tensors.

This commit is contained in:
Tarry Singh 2023-04-27 10:34:27 +00:00
parent e0fd7f8d97
commit 021111d07d

View File

@ -70,15 +70,33 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
return error.TestUnexpectedResult; return error.TestUnexpectedResult;
} }
switch (left.dtype()) { switch (left.dtype()) {
inline .f16, .f32, .f64 => |t| { inline .bf16,
.f16,
.f32,
.f64,
.f8e4m3b11fnuz,
.f8e4m3fn,
.f8e4m3fnuz,
.f8e5m2,
.f8e5m2fnuz,
=> |t| {
const L = t.toZigType(); const L = t.toZigType();
const left_data = left.items(L); const left_data = left.items(L);
switch (right.dtype()) { switch (right.dtype()) {
inline .f16, .bf16, .f32, .f64, .f8e4m3fn => |rt| { inline .bf16,
.f16,
.f32,
.f64,
.f8e4m3b11fnuz,
.f8e4m3fn,
.f8e4m3fnuz,
.f8e5m2,
.f8e5m2fnuz,
=> |rt| {
const R = rt.toZigType(); const R = rt.toZigType();
const right_data = right.items(R); const right_data = right.items(R);
for (left_data, right_data, 0..) |l, r, i| { for (left_data, right_data, 0..) |l, r, i| {
if (!approxEq(L, l, zml.floats.floatCast(L, r), @floatCast(tolerance))) { if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
return error.TestUnexpectedResult; return error.TestUnexpectedResult;
} }
@ -87,16 +105,16 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
else => unreachable, else => unreachable,
} }
}, },
inline .u8, .u16, .u32, .i16, .i32, .i64 => |t| { inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| {
const T = t.toZigType(); const T = t.toZigType();
const left_data = left.items(T); const left_data = left.items(T);
const right_data = right.items(T); const right_data = right.items(T);
if (!std.mem.eql(T, left_data, right_data)) { if (!std.mem.eql(T, left_data, right_data)) {
log.err("left.data ({d}) != right.data ({d})", .{ left_data[0..10], right_data[0..10] }); log.err("left.data ({any}) != right.data ({any})", .{ left_data[0..10], right_data[0..10] });
return error.TestUnexpectedResult; return error.TestUnexpectedResult;
} }
}, },
else => unreachable, .c64, .c128 => @panic("TODO: support comparison of complex"),
} }
} }