mlir, pjrt, zml: expose missing data types (u2, i1, i2, f4e2m1fn, f8e3m4, f8e4m3, f8e8m0fnu); fix Float32 conversion bug that truncated values

This commit is contained in:
Tarry Singh 2025-09-19 12:13:32 +00:00
parent 29bd1242ba
commit e641d05dd2
7 changed files with 355 additions and 94 deletions

View File

@ -414,9 +414,9 @@ pub const ArrayAttribute = struct {
pub fn IntegerAttribute(comptime it: IntegerTypes) type { pub fn IntegerAttribute(comptime it: IntegerTypes) type {
const ZigType, const getter = comptime switch (it) { const ZigType, const getter = comptime switch (it) {
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt }, .i1, .i2, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt }, .si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt }, .u2, .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
.unknown => @compileError("IntegerAttribute(unknown)"), .unknown => @compileError("IntegerAttribute(unknown)"),
}; };
@ -1249,6 +1249,7 @@ pub const IndexType = struct {
pub const IntegerTypes = enum { pub const IntegerTypes = enum {
i1, i1,
i2,
i4, i4,
i8, i8,
i16, i16,
@ -1259,6 +1260,7 @@ pub const IntegerTypes = enum {
si16, si16,
si32, si32,
si64, si64,
u2,
u4, u4,
u8, u8,
u16, u16,
@ -1271,6 +1273,7 @@ pub const IntegerTypes = enum {
pub fn IntegerType(comptime it: IntegerTypes) type { pub fn IntegerType(comptime it: IntegerTypes) type {
const Config = switch (it) { const Config = switch (it) {
.i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i2 => .{ 2, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
@ -1281,6 +1284,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
.si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, .si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, .si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, .si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.u2 => .{ 2, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
@ -1320,11 +1324,15 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
} }
pub const FloatTypes = enum { pub const FloatTypes = enum {
f4e2m1fn,
f8e3m4,
f8e4m3,
f8e4m3b11fnuz, f8e4m3b11fnuz,
f8e4m3fn, f8e4m3fn,
f8e4m3fnuz, f8e4m3fnuz,
f8e5m2, f8e5m2,
f8e5m2fnuz, f8e5m2fnuz,
f8e8m0fnu,
bf16, bf16,
f16, f16,
f32, f32,
@ -1338,12 +1346,19 @@ pub const FloatTypes = enum {
}; };
pub fn FloatType(comptime ft: FloatTypes) type { pub fn FloatType(comptime ft: FloatTypes) type {
const Config = switch (ft) { const Config: struct {
*const fn (c.MlirType) callconv(.c) bool,
*const fn (c.MlirContext) callconv(.c) c.MlirType,
} = switch (ft) {
.f4e2m1fn => .{ c.mlirTypeIsAFloat4E2M1FN, c.mlirFloat4E2M1FNTypeGet },
.f8e3m4 => .{ c.mlirTypeIsAFloat8E3M4, c.mlirFloat8E3M4TypeGet },
.f8e4m3 => .{ c.mlirTypeIsAFloat8E4M3, c.mlirFloat8E4M3TypeGet },
.f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet }, .f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet },
.f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet }, .f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet },
.f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet }, .f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet },
.f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet }, .f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet },
.f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet }, .f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet },
.f8e8m0fnu => .{ c.mlirTypeIsAFloat8E8M0FNU, c.mlirFloat8E8M0FNUTypeGet },
.bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet }, .bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet },
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet }, .f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet }, .f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },

View File

@ -800,11 +800,13 @@ pub const LoadedExecutable = opaque {
pub const BufferType = enum(c.PJRT_Buffer_Type) { pub const BufferType = enum(c.PJRT_Buffer_Type) {
invalid = c.PJRT_Buffer_Type_INVALID, invalid = c.PJRT_Buffer_Type_INVALID,
bool = c.PJRT_Buffer_Type_PRED, bool = c.PJRT_Buffer_Type_PRED,
i2 = c.PJRT_Buffer_Type_S2,
i4 = c.PJRT_Buffer_Type_S4, i4 = c.PJRT_Buffer_Type_S4,
i8 = c.PJRT_Buffer_Type_S8, i8 = c.PJRT_Buffer_Type_S8,
i16 = c.PJRT_Buffer_Type_S16, i16 = c.PJRT_Buffer_Type_S16,
i32 = c.PJRT_Buffer_Type_S32, i32 = c.PJRT_Buffer_Type_S32,
i64 = c.PJRT_Buffer_Type_S64, i64 = c.PJRT_Buffer_Type_S64,
u2 = c.PJRT_Buffer_Type_U2,
u4 = c.PJRT_Buffer_Type_U4, u4 = c.PJRT_Buffer_Type_U4,
u8 = c.PJRT_Buffer_Type_U8, u8 = c.PJRT_Buffer_Type_U8,
u16 = c.PJRT_Buffer_Type_U16, u16 = c.PJRT_Buffer_Type_U16,
@ -821,6 +823,10 @@ pub const BufferType = enum(c.PJRT_Buffer_Type) {
f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ, f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ, f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
f8e4m3 = c.PJRT_Buffer_Type_F8E4M3,
f8e3m4 = c.PJRT_Buffer_Type_F8E3M4,
f8e8m0 = c.PJRT_Buffer_Type_F8E8M0FNU,
f4e2m1 = c.PJRT_Buffer_Type_F4E2M1FN,
}; };
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) { pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {

View File

@ -12,20 +12,26 @@ test {
pub const DataType = enum(u8) { pub const DataType = enum(u8) {
bool, bool,
// Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda. // Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda.
f4e2m1,
f8e3m4,
f8e4m3,
f8e4m3b11fnuz, f8e4m3b11fnuz,
f8e4m3fn, f8e4m3fn,
f8e4m3fnuz, f8e4m3fnuz,
f8e5m2, f8e5m2,
f8e5m2fnuz, f8e5m2fnuz,
f8e8m0,
bf16, bf16,
f16, f16,
f32, f32,
f64, f64,
i2,
i4, i4,
i8, i8,
i16, i16,
i32, i32,
i64, i64,
u2,
u4, u4,
u8, u8,
u16, u16,
@ -50,8 +56,21 @@ pub const DataType = enum(u8) {
pub fn class(self: DataType) Class { pub fn class(self: DataType) Class {
return switch (self) { return switch (self) {
.bool => .bool, .bool => .bool,
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .float, .f4e2m1,
.i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => .integer, .f8e3m4,
.f8e4m3,
.f8e4m3b11fnuz,
.f8e4m3fn,
.f8e4m3fnuz,
.f8e5m2,
.f8e5m2fnuz,
.f8e8m0,
.bf16,
.f16,
.f32,
.f64,
=> .float,
.i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => .integer,
.c64, .c128 => .complex, .c64, .c128 => .complex,
}; };
} }
@ -70,21 +89,27 @@ pub const DataType = enum(u8) {
pub fn fromZigType(comptime T: type) DataType { pub fn fromZigType(comptime T: type) DataType {
return switch (T) { return switch (T) {
floats.Float4E2M1 => .f4e2m1,
floats.Float8E3M4 => .f8e3m4,
floats.Float8E4M3 => .f8e4m3,
floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz, floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz,
floats.Float8E4M3FN => .f8e4m3fn, floats.Float8E4M3FN => .f8e4m3fn,
floats.Float8E4M3FNUZ => .f8e4m3fnuz, floats.Float8E4M3FNUZ => .f8e4m3fnuz,
floats.Float8E5M2 => .f8e5m2, floats.Float8E5M2 => .f8e5m2,
floats.Float8E5M2FNUZ => .f8e5m2fnuz, floats.Float8E5M2FNUZ => .f8e5m2fnuz,
floats.Float8E8M0 => .f8e8m0,
floats.BFloat16 => .bf16, floats.BFloat16 => .bf16,
f16 => .f16, f16 => .f16,
f32 => .f32, f32 => .f32,
f64 => .f64, f64 => .f64,
bool => .bool, bool => .bool,
i2 => .i2,
i4 => .i4, i4 => .i4,
i8 => .i8, i8 => .i8,
i16 => .i16, i16 => .i16,
i32 => .i32, i32 => .i32,
i64 => .i64, i64 => .i64,
u2 => .u2,
u4 => .u4, u4 => .u4,
u8 => .u8, u8 => .u8,
u16 => .u16, u16 => .u16,
@ -192,10 +217,10 @@ pub const DataType = enum(u8) {
pub fn maxValue(dtype: DataType) Data { pub fn maxValue(dtype: DataType) Data {
return switch (dtype) { return switch (dtype) {
.bool => .{ .bool = true }, .bool => .{ .bool = true },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), inline .f4e2m1, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz, .f8e8m0 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf), inline .f8e3m4, .f8e4m3, .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))), inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))),
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))), inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
}; };
} }
@ -207,20 +232,26 @@ pub const DataType = enum(u8) {
pub const Data = union(DataType) { pub const Data = union(DataType) {
bool: bool, bool: bool,
f4e2m1: floats.Float4E2M1,
f8e3m4: floats.Float8E3M4,
f8e4m3: floats.Float8E4M3,
f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ, f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ,
f8e4m3fn: floats.Float8E4M3FN, f8e4m3fn: floats.Float8E4M3FN,
f8e4m3fnuz: floats.Float8E4M3FNUZ, f8e4m3fnuz: floats.Float8E4M3FNUZ,
f8e5m2: floats.Float8E5M2, f8e5m2: floats.Float8E5M2,
f8e5m2fnuz: floats.Float8E5M2FNUZ, f8e5m2fnuz: floats.Float8E5M2FNUZ,
f8e8m0: floats.Float8E8M0,
bf16: floats.BFloat16, bf16: floats.BFloat16,
f16: f16, f16: f16,
f32: f32, f32: f32,
f64: f64, f64: f64,
i2: i2,
i4: i4, i4: i4,
i8: i8, i8: i8,
i16: i16, i16: i16,
i32: i32, i32: i32,
i64: i64, i64: i64,
u2: u2,
u4: u4, u4: u4,
u8: u8, u8: u8,
u16: u16, u16: u16,
@ -244,7 +275,7 @@ pub const Data = union(DataType) {
.comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 }, .comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 },
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)), else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
}, },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) { inline .f4e2m1, .f8e3m4, .f8e4m3, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .f8e8m0, .bf16 => |tag| switch (Ti) {
.comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))), .comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))),
.comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))), .comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))),
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)), else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
@ -254,7 +285,7 @@ pub const Data = union(DataType) {
.comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)), .comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)),
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)), else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
}, },
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) { inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
.comptime_int => blk: { .comptime_int => blk: {
const OutT = @FieldType(Data, @tagName(tag)); const OutT = @FieldType(Data, @tagName(tag));
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) { if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {

View File

@ -5,10 +5,6 @@ test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
}
fn FloatHelpers(Float: type) type { fn FloatHelpers(Float: type) type {
const info = @typeInfo(Float); const info = @typeInfo(Float);
const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}"; const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}";
@ -23,9 +19,11 @@ fn FloatHelpers(Float: type) type {
} }
return struct { return struct {
const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits;
const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits; const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits;
const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits; const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits;
const f32_mantissa_bits: u8 = @typeInfo(@FieldType(Float32, "mantissa")).int.bits;
const exp_bias: i16 = std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
const exp_off: u8 = FloatHelpers(Float32).exp_bias - exp_bias;
pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 }; pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 };
@ -39,52 +37,74 @@ fn FloatHelpers(Float: type) type {
/// Lossy conversion from f32, similar to @floatCast /// Lossy conversion from f32, similar to @floatCast
pub fn fromF32(f: f32) Float { pub fn fromF32(f: f32) Float {
@setRuntimeSafety(false);
const vf32: Float32 = @bitCast(f); const vf32: Float32 = @bitCast(f);
const exp_bias = comptime expBias(); const exponent: i16 = @as(i16, vf32.exponent) - exp_off;
const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias();
const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent")); const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent"));
if (overflow) { if (overflow) {
return if (@hasDecl(Float, "inf")) { @branchHint(.unlikely);
return if (vf32.sign == 0) Float.inf else Float.minus_inf; return if (@hasDecl(Float, "inf"))
} else Float.nan; if (vf32.sign == 0) Float.inf else Float.minus_inf
else
Float.nan;
} }
return .{
.sign = vf32.sign, return if (exponent <= 0)
.exponent = @intCast(exponent), .{
.mantissa = truncMantissa(vf32.mantissa), .sign = vf32.sign,
}; .exponent = 0,
.mantissa = shiftMantissa(vf32.mantissa, @intCast(-exponent)),
}
else
.{
.sign = vf32.sign,
.exponent = @intCast(exponent),
.mantissa = truncMantissa(vf32.mantissa),
};
} }
/// Lossless conversion to f32. /// Lossless conversion to f32.
pub fn toF32(x: Float) f32 { pub fn toF32(x: Float) f32 {
var vf32: Float32 = undefined; @setRuntimeSafety(false);
if (@hasDecl(Float, "isInf") and x.isInf()) {
if (x == zero) return 0.0;
if (isInf(x)) {
@branchHint(.unlikely);
return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32); return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
} }
vf32 = .{
.sign = x.sign, const vf32: Float32 = if (x.exponent > 0)
.exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()), .{
.mantissa = f32Mantissa(x), .sign = x.sign,
}; .exponent = @as(u8, x.exponent) + exp_off,
.mantissa = f32Mantissa(x),
}
else
.{
.sign = x.sign,
.exponent = exp_off - @clz(x.mantissa),
.mantissa = @as(u23, x.mantissa) << @clz(x.mantissa),
};
return @bitCast(vf32); return @bitCast(vf32);
} }
fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") { fn truncMantissa(f32_mantissa: u32) @FieldType(Float, "mantissa") {
@setRuntimeSafety(false); const rounding_val: u32 = @as(u32, 1) << (f32_mantissa_bits - mantissa_bits - 1);
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits; return @truncate((f32_mantissa + rounding_val) >> (f32_mantissa_bits - mantissa_bits));
return @intCast(x >> off); }
fn shiftMantissa(f32_mantissa: u32, underflow: u8) @FieldType(Float, "mantissa") {
const upper_bit: u32 = @as(u32, 1) << f32_mantissa_bits;
const full_mant32: u32 = f32_mantissa | upper_bit;
// divide the mantissa proportionally to the exponent underflow
const shifted_mant: u32 = full_mant32 >> @truncate(underflow + 1);
return truncMantissa(shifted_mant);
} }
fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") { fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") {
@setRuntimeSafety(false); const T = @FieldType(Float32, "mantissa");
const Res = @FieldType(Float32, "mantissa"); return @as(T, x.mantissa) << f32_mantissa_bits - mantissa_bits;
const f32_mantissa_bits = @bitSizeOf(Res);
return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits);
}
fn expBias() u8 {
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
} }
pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
@ -101,6 +121,9 @@ pub const Float32 = packed struct(u32) {
exponent: u8, exponent: u8,
sign: u1, sign: u1,
pub const inf: Float32 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 0 };
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This()); const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero; pub const zero = Helpers.zero;
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
@ -129,11 +152,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) {
exponent: u4, exponent: u4,
sign: u1, sign: u1,
pub const nan: Float8E4M3B11FNUZ = .{ pub const nan: Float8E4M3B11FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub fn isNan(self: Float8E4M3B11FNUZ) bool { pub fn isNan(self: Float8E4M3B11FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
@ -155,7 +174,7 @@ pub const Float8E4M3FN = packed struct(u8) {
pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) }; pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) };
pub fn isNan(self: Float8E4M3FN) bool { pub fn isNan(self: Float8E4M3FN) bool {
return allBitsOne(self.exponent) and allBitsOne(self.mantissa); return self.exponent == nan.exponent and self.mantissa == nan.mantissa;
} }
const Helpers = FloatHelpers(@This()); const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero; pub const zero = Helpers.zero;
@ -170,11 +189,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
exponent: u4, exponent: u4,
sign: u1, sign: u1,
pub const nan: Float8E4M3FNUZ = .{ pub const nan: Float8E4M3FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub fn isNan(self: Float8E4M3FNUZ) bool { pub fn isNan(self: Float8E4M3FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
@ -189,9 +204,10 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
}; };
test "Float8E4" { test "Float8E4" {
// With 4 bits of exponents power of two can be represented exactly up to 64.
const test_case_e4: TestCase = .{ const test_case_e4: TestCase = .{
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 }, .lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128, -1.125 / 64.0 },
.lossy = &[_]f32{3.02344107628}, .lossy = &[_]f32{ 3.02344107628, 1.0 / 128.0, 1.0 / 512.0 },
}; };
inline for (.{ inline for (.{
@ -200,7 +216,11 @@ test "Float8E4" {
Float8E4M3FNUZ, Float8E4M3FNUZ,
}) |Float8T| { }) |Float8T| {
try testCustomFloat(Float8T, test_case_e4); try testCustomFloat(Float8T, test_case_e4);
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32()); try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 2048.0).toF32());
if (@hasDecl(Float8T, "inf")) {
try std.testing.expectEqual(Float8T.inf, Float8T.fromF32(128.0));
try std.testing.expectEqual(Float8T.inf.neg(), Float8T.fromF32(-128.0));
}
} }
} }
@ -209,31 +229,19 @@ pub const Float8E5M2 = packed struct(u8) {
exponent: u5, exponent: u5,
sign: u1, sign: u1,
pub const nan: Float8E5M2 = .{ pub const nan: Float8E5M2 = .{ .sign = 0, .exponent = std.math.maxInt(u5), .mantissa = 1 };
.sign = 0,
.exponent = std.math.maxInt(u5),
.mantissa = 1,
};
pub fn isNan(self: Float8E5M2) bool { pub fn isNan(self: Float8E5M2) bool {
return allBitsOne(self.exponent) and self.mantissa != 0; return self.exponent == nan.exponent and self.mantissa != 0;
} }
pub const minus_inf: Float8E5M2 = .{
.sign = 1,
.exponent = std.math.maxInt(u5),
.mantissa = 0,
};
pub const inf: Float8E5M2 = .{ pub const inf: Float8E5M2 = .{
.sign = 0, .sign = 0,
.exponent = std.math.maxInt(u5), .exponent = std.math.maxInt(u5),
.mantissa = 0, .mantissa = 0,
}; };
pub fn isInf(self: Float8E5M2) bool { pub const minus_inf: Float8E5M2 = .neg(inf);
return allBitsOne(self.exponent) and self.mantissa == 0;
}
const Helpers = FloatHelpers(@This()); const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero; pub const zero = Helpers.zero;
@ -283,22 +291,16 @@ pub const BFloat16 = packed struct(u16) {
return allBitsOne(self.exponent) and self.mantissa != 0; return allBitsOne(self.exponent) and self.mantissa != 0;
} }
pub const minus_inf: BFloat16 = .{
.sign = 1,
.exponent = std.math.maxInt(u8),
.mantissa = 0,
};
pub const inf: BFloat16 = .{ pub const inf: BFloat16 = .{
.sign = 0, .sign = 0,
.exponent = std.math.maxInt(u8), .exponent = std.math.maxInt(u8),
.mantissa = 0, .mantissa = 0,
}; };
pub fn isInf(self: BFloat16) bool { pub const minus_inf: BFloat16 = .neg(inf);
return allBitsOne(self.exponent) and self.mantissa == 0;
}
// Specialized versions of to/from F32. Since BFloat16 has the same exponent range than F32,
// no overflow/underflow can happen, simplifiying conversion logic.
pub fn toF32(self: BFloat16) f32 { pub fn toF32(self: BFloat16) f32 {
// Pad the BF16 with zeros 0 // Pad the BF16 with zeros 0
return @bitCast([2]u16{ 0, @bitCast(self) }); return @bitCast([2]u16{ 0, @bitCast(self) });
@ -333,11 +335,167 @@ test BFloat16 {
}); });
} }
pub fn floatCast(T: type, x: anytype) T { pub const Float8E4M3 = packed struct(u8) {
return switch (@TypeOf(x)) { mantissa: u3,
f64, f32, f16 => @floatCast(x), exponent: u4,
else => @floatCast(x.toF32()), sign: u1,
pub const nan: Float8E4M3 = @bitCast(0xFF);
pub fn isNan(self: Float8E4M3) bool {
return self == nan or self == comptime nan.neg();
}
pub const inf: Float8E4M3 = .{
.sign = 0,
.exponent = std.math.maxInt(u4),
.mantissa = 0,
}; };
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E3M4 = packed struct(u8) {
mantissa: u4,
exponent: u3,
sign: u1,
pub const nan: Float8E3M4 = @bitCast(0xFF);
pub fn isNan(self: Float8E3M4) bool {
return self == nan or self == comptime nan.neg();
}
pub const inf: Float8E3M4 = .{
.sign = 0,
.exponent = std.math.maxInt(u3),
.mantissa = 0,
};
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E8M0 = packed struct(u8) {
mantissa: u0 = 0,
exponent: u8,
sign: u0 = 0,
pub const min_scale: f32 = @bitCast(Float32{ .sign = 0, .exponent = 0, .mantissa = 0b1 << 22 });
/// Lossy conversion from f32, similar to @floatCast
pub fn fromF32(f: f32) Float8E8M0 {
const vf32: Float32 = @bitCast(f);
return .{ .exponent = @intCast(vf32.exponent) };
}
/// Lossless conversion to f32.
pub fn toF32(x: Float8E8M0) f32 {
if (x.exponent == 0) return min_scale;
const vf32: Float32 = .{
.sign = 0,
.exponent = x.exponent,
.mantissa = 0,
};
return @bitCast(vf32);
}
const Helpers = FloatHelpers(@This());
pub const formatNumber = Helpers.formatNumber;
};
test Float8E8M0 {
try std.testing.expectEqual(Float8E8M0{ .exponent = 127 }, Float8E8M0.fromF32(1.0));
// try std.testing.expectEqual(5.877472e-39, Float8E8M0.toF32(.{ .exponent = 0}));
try testCustomFloat(Float8E8M0, .{
.lossless = &[_]f32{ Float8E8M0.min_scale, 1.0, 64.0, 1.0 / 128.0, std.math.pow(f32, 2.0, 127) },
.lossy = &[_]f32{1.00001},
});
}
pub const Float4E2M1 = packed struct(u4) {
mantissa: u1,
exponent: u2,
sign: u1,
pub const nan: Float4E2M1 = @bitCast(@as(u4, 0xF));
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const formatNumber = Helpers.formatNumber;
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
pub fn toF32(x: Float4E2M1) f32 {
// the baseline toF32 doesn't work correctly:
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
return values[@as(u4, @bitCast(x))];
}
test toF32 {
var to_f32_res: [16]f32 = undefined;
for (&to_f32_res, 0..) |*r, i| {
const x_f4: Float4E2M1 = @bitCast(@as(u4, @intCast(i)));
r.* = x_f4.toF32();
}
try std.testing.expectEqualSlices(f32, &Float4E2M1.values, &to_f32_res);
}
test fromF32 {
// the baseline fromF32 doesn't work correctly:
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
// TODO: it probably affects other types.
var from_f32_res: [16]Float4E2M1 = undefined;
for (&from_f32_res, 0..) |*r, i| {
r.* = .fromF32(Float4E2M1.values[i]);
}
try std.testing.expectEqualSlices(u4, &.{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, @ptrCast(&from_f32_res));
}
};
pub fn floatCast(T: type, x: anytype) T {
return switch (T) {
f64, f32, f16 => switch (@TypeOf(x)) {
f64, f32, f16 => @floatCast(x),
else => @floatCast(x.toF32()),
},
else => switch (@TypeOf(x)) {
f64, f32, f16 => .fromF32(x),
else => .fromF32(x.toF32()),
},
};
}
pub fn isInf(x: anytype) bool {
const Float = @TypeOf(x);
switch (Float) {
f64, f32, f16 => return std.math.isInf(x),
else => {},
}
if (!@hasDecl(Float, "inf")) return false;
const FBits = std.meta.Int(.unsigned, @bitSizeOf(Float));
const remove_sign = ~@as(FBits, 0) >> 1;
return @as(FBits, @bitCast(x)) & remove_sign == @as(FBits, @bitCast(Float.inf));
}
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
} }
const TestCase = struct { const TestCase = struct {

View File

@ -35,20 +35,26 @@ pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) { return switch (dt) {
.bool => .int(ctx, .i1), .bool => .int(ctx, .i1),
.f4e2m1 => .float(ctx, .f4e2m1fn),
.f8e3m4 => .float(ctx, .f8e3m4),
.f8e4m3 => .float(ctx, .f8e4m3),
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz), .f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
.f8e4m3fn => .float(ctx, .f8e4m3fn), .f8e4m3fn => .float(ctx, .f8e4m3fn),
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz), .f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
.f8e5m2 => .float(ctx, .f8e5m2), .f8e5m2 => .float(ctx, .f8e5m2),
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz), .f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
.f8e8m0 => .float(ctx, .f8e8m0fnu),
.bf16 => .float(ctx, .bf16), .bf16 => .float(ctx, .bf16),
.f16 => .float(ctx, .f16), .f16 => .float(ctx, .f16),
.f32 => .float(ctx, .f32), .f32 => .float(ctx, .f32),
.f64 => .float(ctx, .f64), .f64 => .float(ctx, .f64),
.i2 => .int(ctx, .i2),
.i4 => .int(ctx, .i4), .i4 => .int(ctx, .i4),
.i8 => .int(ctx, .i8), .i8 => .int(ctx, .i8),
.i16 => .int(ctx, .i16), .i16 => .int(ctx, .i16),
.i32 => .int(ctx, .i32), .i32 => .int(ctx, .i32),
.i64 => .int(ctx, .i64), .i64 => .int(ctx, .i64),
.u2 => .int(ctx, .u2),
.u4 => .int(ctx, .u4), .u4 => .int(ctx, .u4),
.u8 => .int(ctx, .u8), .u8 => .int(ctx, .u8),
.u16 => .int(ctx, .u16), .u16 => .int(ctx, .u16),
@ -62,23 +68,28 @@ pub const Type = struct {
pub fn toDType(mlir_type: mlir.Type) dtype.DataType { pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{ const mapping = .{
.{ .bool, mlir.IntegerType(.i1) }, .{ .bool, mlir.IntegerType(.i1) },
.{ .f4e2m1, mlir.FloatType(.f4e2m1fn) },
.{ .f8e3m4, mlir.FloatType(.f8e3m4) },
.{ .f8e4m3, mlir.FloatType(.f8e4m3) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) }, .{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) }, .{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) }, .{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) }, .{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) }, .{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .f8e8m0, mlir.FloatType(.f8e8m0fnu) },
.{ .bf16, mlir.FloatType(.bf16) }, .{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) }, .{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) }, .{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) }, .{ .f64, mlir.FloatType(.f64) },
.{ .i2, mlir.IntegerType(.i2) },
.{ .i4, mlir.IntegerType(.i4) }, .{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) }, .{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) }, .{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) }, .{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) }, .{ .i64, mlir.IntegerType(.i64) },
.{ .u2, mlir.IntegerType(.u2) },
.{ .u4, mlir.IntegerType(.u4) }, .{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) }, .{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) }, .{ .u16, mlir.IntegerType(.u16) },
@ -89,6 +100,8 @@ pub const Type = struct {
.{ .c128, mlir.ComplexType(.c128) }, .{ .c128, mlir.ComplexType(.c128) },
}; };
// TODO: this seems quite slow to have all of those functions calls.
// Maybe we should memoize the ptr of a set of mlir types when creating the context.
inline for (mapping) |entry| { inline for (mapping) |entry| {
const dt, const mlirT = entry; const dt, const mlirT = entry;
if (mlirT.is_a_fn(mlir_type._inner)) { if (mlirT.is_a_fn(mlir_type._inner)) {

View File

@ -1042,6 +1042,40 @@ pub const Tensor = struct {
return _result(self._shape.withDtype(to), op.result(0)); return _result(self._shape.withDtype(to), op.result(0));
} }
test convert {
const floats = @import("floats.zig");
const zml = @import("zml.zig");
const platform = zml.testing.env();
// f4e2m1
{
const x = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
var x_f4: [x.len]floats.Float4E2M1 = undefined;
for (&x_f4, &x) |*xi_f4, xi| xi_f4.* = .fromF32(xi);
const x_d = try zml.Buffer.fromArray(platform, x);
const x_f4_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f4e2m1 });
const x_f4_xla = x_f4_xla_d.getValue(@TypeOf(x_f4));
errdefer std.log.warn("convert(.f4e2m1) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f4, x_f4_xla });
try std.testing.expectEqualDeep(x_f4, x_f4_xla);
}
// f8e3m4
{
const x = [_]f32{ 1.1 / 4.0, 1.1 / 8.0, 1.1 / 16.0, 1.1 / 32.0, 1.1 / 64.0, 1.1 / 128.0 };
var x_f8e3: [x.len]floats.Float8E3M4 = undefined;
for (&x_f8e3, &x) |*xi_f8e3, xi| xi_f8e3.* = .fromF32(xi);
const x_d = try zml.Buffer.fromArray(platform, x);
const x_f8e3_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f8e3m4 });
const x_f8e3_xla = x_f8e3_xla_d.getValue(@TypeOf(x_f8e3));
errdefer std.log.warn("convert(.f8e3m4) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f8e3, x_f8e3_xla });
try std.testing.expectEqualDeep(x_f8e3, x_f8e3_xla);
}
}
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor. /// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
pub fn round(self: Tensor) Tensor { pub fn round(self: Tensor) Tensor {
const loc = self.getContext().mlirCtx().location(@src()); const loc = self.getContext().mlirCtx().location(@src());
@ -3844,11 +3878,11 @@ fn getPoolResDims(dt: DataType, in_dims: []const i64, base_dilations: @Vector(Te
} }
fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType { fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType {
return dialect.stablehlo.CompareType.init(ctx, switch (dtype) { return dialect.stablehlo.CompareType.init(ctx, switch (dtype.class()) {
.i4, .i8, .i16, .i32, .i64 => .SIGNED, .bool => .UNSIGNED,
.bool, .u4, .u8, .u16, .u32, .u64 => .UNSIGNED, .integer => if (dtype.isSignedInt()) .SIGNED else .UNSIGNED,
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .FLOAT, .float => .FLOAT,
.c64, .c128 => @panic("Can't compare complex numbers"), .complex => @panic("Can't compare complex numbers"),
}); });
} }

View File

@ -65,11 +65,15 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
.f16, .f16,
.f32, .f32,
.f64, .f64,
.f4e2m1,
.f8e3m4,
.f8e4m3,
.f8e4m3b11fnuz, .f8e4m3b11fnuz,
.f8e4m3fn, .f8e4m3fn,
.f8e4m3fnuz, .f8e4m3fnuz,
.f8e5m2, .f8e5m2,
.f8e5m2fnuz, .f8e5m2fnuz,
.f8e8m0,
=> |t| { => |t| {
const L = t.toZigType(); const L = t.toZigType();
const left_data = left.items(L); const left_data = left.items(L);
@ -96,7 +100,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
else => unreachable, else => unreachable,
} }
}, },
inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| { inline .bool, .u2, .u4, .u8, .u16, .u32, .u64, .i2, .i4, .i8, .i16, .i32, .i64 => |t| {
const T = t.toZigType(); const T = t.toZigType();
return std.testing.expectEqualSlices(T, left.items(T), right.items(T)); return std.testing.expectEqualSlices(T, left.items(T), right.items(T));
}, },