mlir, pjrt, zml: expose missing data types (u2, i1, i2, f4e2m1fn, f8e3m4, f8e4m3, f8e8m0fnu); fix Float32 conversion bug that truncated values

This commit is contained in:
Tarry Singh 2025-09-19 12:13:32 +00:00
parent 29bd1242ba
commit e641d05dd2
7 changed files with 355 additions and 94 deletions

View File

@ -414,9 +414,9 @@ pub const ArrayAttribute = struct {
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
const ZigType, const getter = comptime switch (it) {
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
.i1, .i2, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
.u2, .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
.unknown => @compileError("IntegerAttribute(unknown)"),
};
@ -1249,6 +1249,7 @@ pub const IndexType = struct {
pub const IntegerTypes = enum {
i1,
i2,
i4,
i8,
i16,
@ -1259,6 +1260,7 @@ pub const IntegerTypes = enum {
si16,
si32,
si64,
u2,
u4,
u8,
u16,
@ -1271,6 +1273,7 @@ pub const IntegerTypes = enum {
pub fn IntegerType(comptime it: IntegerTypes) type {
const Config = switch (it) {
.i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i2 => .{ 2, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
.i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
@ -1281,6 +1284,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
.si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
.u2 => .{ 2, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
.u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
@ -1320,11 +1324,15 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
}
pub const FloatTypes = enum {
f4e2m1fn,
f8e3m4,
f8e4m3,
f8e4m3b11fnuz,
f8e4m3fn,
f8e4m3fnuz,
f8e5m2,
f8e5m2fnuz,
f8e8m0fnu,
bf16,
f16,
f32,
@ -1338,12 +1346,19 @@ pub const FloatTypes = enum {
};
pub fn FloatType(comptime ft: FloatTypes) type {
const Config = switch (ft) {
const Config: struct {
*const fn (c.MlirType) callconv(.c) bool,
*const fn (c.MlirContext) callconv(.c) c.MlirType,
} = switch (ft) {
.f4e2m1fn => .{ c.mlirTypeIsAFloat4E2M1FN, c.mlirFloat4E2M1FNTypeGet },
.f8e3m4 => .{ c.mlirTypeIsAFloat8E3M4, c.mlirFloat8E3M4TypeGet },
.f8e4m3 => .{ c.mlirTypeIsAFloat8E4M3, c.mlirFloat8E4M3TypeGet },
.f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet },
.f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet },
.f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet },
.f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet },
.f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet },
.f8e8m0fnu => .{ c.mlirTypeIsAFloat8E8M0FNU, c.mlirFloat8E8M0FNUTypeGet },
.bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet },
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },

View File

@ -800,11 +800,13 @@ pub const LoadedExecutable = opaque {
pub const BufferType = enum(c.PJRT_Buffer_Type) {
invalid = c.PJRT_Buffer_Type_INVALID,
bool = c.PJRT_Buffer_Type_PRED,
i2 = c.PJRT_Buffer_Type_S2,
i4 = c.PJRT_Buffer_Type_S4,
i8 = c.PJRT_Buffer_Type_S8,
i16 = c.PJRT_Buffer_Type_S16,
i32 = c.PJRT_Buffer_Type_S32,
i64 = c.PJRT_Buffer_Type_S64,
u2 = c.PJRT_Buffer_Type_U2,
u4 = c.PJRT_Buffer_Type_U4,
u8 = c.PJRT_Buffer_Type_U8,
u16 = c.PJRT_Buffer_Type_U16,
@ -821,6 +823,10 @@ pub const BufferType = enum(c.PJRT_Buffer_Type) {
f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
f8e4m3 = c.PJRT_Buffer_Type_F8E4M3,
f8e3m4 = c.PJRT_Buffer_Type_F8E3M4,
f8e8m0 = c.PJRT_Buffer_Type_F8E8M0FNU,
f4e2m1 = c.PJRT_Buffer_Type_F4E2M1FN,
};
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {

View File

@ -12,20 +12,26 @@ test {
pub const DataType = enum(u8) {
bool,
// Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda.
f4e2m1,
f8e3m4,
f8e4m3,
f8e4m3b11fnuz,
f8e4m3fn,
f8e4m3fnuz,
f8e5m2,
f8e5m2fnuz,
f8e8m0,
bf16,
f16,
f32,
f64,
i2,
i4,
i8,
i16,
i32,
i64,
u2,
u4,
u8,
u16,
@ -50,8 +56,21 @@ pub const DataType = enum(u8) {
pub fn class(self: DataType) Class {
return switch (self) {
.bool => .bool,
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .float,
.i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => .integer,
.f4e2m1,
.f8e3m4,
.f8e4m3,
.f8e4m3b11fnuz,
.f8e4m3fn,
.f8e4m3fnuz,
.f8e5m2,
.f8e5m2fnuz,
.f8e8m0,
.bf16,
.f16,
.f32,
.f64,
=> .float,
.i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => .integer,
.c64, .c128 => .complex,
};
}
@ -70,21 +89,27 @@ pub const DataType = enum(u8) {
pub fn fromZigType(comptime T: type) DataType {
return switch (T) {
floats.Float4E2M1 => .f4e2m1,
floats.Float8E3M4 => .f8e3m4,
floats.Float8E4M3 => .f8e4m3,
floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz,
floats.Float8E4M3FN => .f8e4m3fn,
floats.Float8E4M3FNUZ => .f8e4m3fnuz,
floats.Float8E5M2 => .f8e5m2,
floats.Float8E5M2FNUZ => .f8e5m2fnuz,
floats.Float8E8M0 => .f8e8m0,
floats.BFloat16 => .bf16,
f16 => .f16,
f32 => .f32,
f64 => .f64,
bool => .bool,
i2 => .i2,
i4 => .i4,
i8 => .i8,
i16 => .i16,
i32 => .i32,
i64 => .i64,
u2 => .u2,
u4 => .u4,
u8 => .u8,
u16 => .u16,
@ -192,10 +217,10 @@ pub const DataType = enum(u8) {
pub fn maxValue(dtype: DataType) Data {
return switch (dtype) {
.bool => .{ .bool = true },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
inline .f4e2m1, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz, .f8e8m0 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
inline .f8e3m4, .f8e4m3, .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))),
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
};
}
@ -207,20 +232,26 @@ pub const DataType = enum(u8) {
pub const Data = union(DataType) {
bool: bool,
f4e2m1: floats.Float4E2M1,
f8e3m4: floats.Float8E3M4,
f8e4m3: floats.Float8E4M3,
f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ,
f8e4m3fn: floats.Float8E4M3FN,
f8e4m3fnuz: floats.Float8E4M3FNUZ,
f8e5m2: floats.Float8E5M2,
f8e5m2fnuz: floats.Float8E5M2FNUZ,
f8e8m0: floats.Float8E8M0,
bf16: floats.BFloat16,
f16: f16,
f32: f32,
f64: f64,
i2: i2,
i4: i4,
i8: i8,
i16: i16,
i32: i32,
i64: i64,
u2: u2,
u4: u4,
u8: u8,
u16: u16,
@ -244,7 +275,7 @@ pub const Data = union(DataType) {
.comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 },
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
},
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) {
inline .f4e2m1, .f8e3m4, .f8e4m3, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .f8e8m0, .bf16 => |tag| switch (Ti) {
.comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))),
.comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))),
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
@ -254,7 +285,7 @@ pub const Data = union(DataType) {
.comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)),
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
},
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
.comptime_int => blk: {
const OutT = @FieldType(Data, @tagName(tag));
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {

View File

@ -5,10 +5,6 @@ test {
std.testing.refAllDecls(@This());
}
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
}
fn FloatHelpers(Float: type) type {
const info = @typeInfo(Float);
const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}";
@ -23,9 +19,11 @@ fn FloatHelpers(Float: type) type {
}
return struct {
const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits;
const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits;
const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits;
const f32_mantissa_bits: u8 = @typeInfo(@FieldType(Float32, "mantissa")).int.bits;
const exp_bias: i16 = std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
const exp_off: u8 = FloatHelpers(Float32).exp_bias - exp_bias;
pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 };
@ -39,52 +37,74 @@ fn FloatHelpers(Float: type) type {
/// Lossy conversion from f32, similar to @floatCast
pub fn fromF32(f: f32) Float {
@setRuntimeSafety(false);
const vf32: Float32 = @bitCast(f);
const exp_bias = comptime expBias();
const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias();
const exponent: i16 = @as(i16, vf32.exponent) - exp_off;
const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent"));
if (overflow) {
return if (@hasDecl(Float, "inf")) {
return if (vf32.sign == 0) Float.inf else Float.minus_inf;
} else Float.nan;
@branchHint(.unlikely);
return if (@hasDecl(Float, "inf"))
if (vf32.sign == 0) Float.inf else Float.minus_inf
else
Float.nan;
}
return .{
.sign = vf32.sign,
.exponent = @intCast(exponent),
.mantissa = truncMantissa(vf32.mantissa),
};
return if (exponent <= 0)
.{
.sign = vf32.sign,
.exponent = 0,
.mantissa = shiftMantissa(vf32.mantissa, @intCast(-exponent)),
}
else
.{
.sign = vf32.sign,
.exponent = @intCast(exponent),
.mantissa = truncMantissa(vf32.mantissa),
};
}
/// Lossless conversion to f32.
pub fn toF32(x: Float) f32 {
var vf32: Float32 = undefined;
if (@hasDecl(Float, "isInf") and x.isInf()) {
@setRuntimeSafety(false);
if (x == zero) return 0.0;
if (isInf(x)) {
@branchHint(.unlikely);
return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
}
vf32 = .{
.sign = x.sign,
.exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()),
.mantissa = f32Mantissa(x),
};
const vf32: Float32 = if (x.exponent > 0)
.{
.sign = x.sign,
.exponent = @as(u8, x.exponent) + exp_off,
.mantissa = f32Mantissa(x),
}
else
.{
.sign = x.sign,
.exponent = exp_off - @clz(x.mantissa),
.mantissa = @as(u23, x.mantissa) << @clz(x.mantissa),
};
return @bitCast(vf32);
}
fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") {
@setRuntimeSafety(false);
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits;
return @intCast(x >> off);
fn truncMantissa(f32_mantissa: u32) @FieldType(Float, "mantissa") {
const rounding_val: u32 = @as(u32, 1) << (f32_mantissa_bits - mantissa_bits - 1);
return @truncate((f32_mantissa + rounding_val) >> (f32_mantissa_bits - mantissa_bits));
}
fn shiftMantissa(f32_mantissa: u32, underflow: u8) @FieldType(Float, "mantissa") {
const upper_bit: u32 = @as(u32, 1) << f32_mantissa_bits;
const full_mant32: u32 = f32_mantissa | upper_bit;
// divide the mantissa proportionally to the exponent underflow
const shifted_mant: u32 = full_mant32 >> @truncate(underflow + 1);
return truncMantissa(shifted_mant);
}
fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") {
@setRuntimeSafety(false);
const Res = @FieldType(Float32, "mantissa");
const f32_mantissa_bits = @bitSizeOf(Res);
return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits);
}
fn expBias() u8 {
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
const T = @FieldType(Float32, "mantissa");
return @as(T, x.mantissa) << f32_mantissa_bits - mantissa_bits;
}
pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
@ -101,6 +121,9 @@ pub const Float32 = packed struct(u32) {
exponent: u8,
sign: u1,
pub const inf: Float32 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 0 };
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
@ -129,11 +152,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) {
exponent: u4,
sign: u1,
pub const nan: Float8E4M3B11FNUZ = .{
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub const nan: Float8E4M3B11FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
pub fn isNan(self: Float8E4M3B11FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
@ -155,7 +174,7 @@ pub const Float8E4M3FN = packed struct(u8) {
pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) };
pub fn isNan(self: Float8E4M3FN) bool {
return allBitsOne(self.exponent) and allBitsOne(self.mantissa);
return self.exponent == nan.exponent and self.mantissa == nan.mantissa;
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
@ -170,11 +189,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
exponent: u4,
sign: u1,
pub const nan: Float8E4M3FNUZ = .{
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub const nan: Float8E4M3FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
pub fn isNan(self: Float8E4M3FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
@ -189,9 +204,10 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
};
test "Float8E4" {
// With 4 bits of exponents power of two can be represented exactly up to 64.
const test_case_e4: TestCase = .{
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 },
.lossy = &[_]f32{3.02344107628},
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128, -1.125 / 64.0 },
.lossy = &[_]f32{ 3.02344107628, 1.0 / 128.0, 1.0 / 512.0 },
};
inline for (.{
@ -200,7 +216,11 @@ test "Float8E4" {
Float8E4M3FNUZ,
}) |Float8T| {
try testCustomFloat(Float8T, test_case_e4);
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32());
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 2048.0).toF32());
if (@hasDecl(Float8T, "inf")) {
try std.testing.expectEqual(Float8T.inf, Float8T.fromF32(128.0));
try std.testing.expectEqual(Float8T.inf.neg(), Float8T.fromF32(-128.0));
}
}
}
@ -209,31 +229,19 @@ pub const Float8E5M2 = packed struct(u8) {
exponent: u5,
sign: u1,
pub const nan: Float8E5M2 = .{
.sign = 0,
.exponent = std.math.maxInt(u5),
.mantissa = 1,
};
pub const nan: Float8E5M2 = .{ .sign = 0, .exponent = std.math.maxInt(u5), .mantissa = 1 };
pub fn isNan(self: Float8E5M2) bool {
return allBitsOne(self.exponent) and self.mantissa != 0;
return self.exponent == nan.exponent and self.mantissa != 0;
}
pub const minus_inf: Float8E5M2 = .{
.sign = 1,
.exponent = std.math.maxInt(u5),
.mantissa = 0,
};
pub const inf: Float8E5M2 = .{
.sign = 0,
.exponent = std.math.maxInt(u5),
.mantissa = 0,
};
pub fn isInf(self: Float8E5M2) bool {
return allBitsOne(self.exponent) and self.mantissa == 0;
}
pub const minus_inf: Float8E5M2 = .neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
@ -283,22 +291,16 @@ pub const BFloat16 = packed struct(u16) {
return allBitsOne(self.exponent) and self.mantissa != 0;
}
pub const minus_inf: BFloat16 = .{
.sign = 1,
.exponent = std.math.maxInt(u8),
.mantissa = 0,
};
pub const inf: BFloat16 = .{
.sign = 0,
.exponent = std.math.maxInt(u8),
.mantissa = 0,
};
pub fn isInf(self: BFloat16) bool {
return allBitsOne(self.exponent) and self.mantissa == 0;
}
pub const minus_inf: BFloat16 = .neg(inf);
// Specialized versions of to/from F32. Since BFloat16 has the same exponent range than F32,
// no overflow/underflow can happen, simplifiying conversion logic.
pub fn toF32(self: BFloat16) f32 {
// Pad the BF16 with zeros 0
return @bitCast([2]u16{ 0, @bitCast(self) });
@ -333,11 +335,167 @@ test BFloat16 {
});
}
pub fn floatCast(T: type, x: anytype) T {
return switch (@TypeOf(x)) {
f64, f32, f16 => @floatCast(x),
else => @floatCast(x.toF32()),
pub const Float8E4M3 = packed struct(u8) {
mantissa: u3,
exponent: u4,
sign: u1,
pub const nan: Float8E4M3 = @bitCast(0xFF);
pub fn isNan(self: Float8E4M3) bool {
return self == nan or self == comptime nan.neg();
}
pub const inf: Float8E4M3 = .{
.sign = 0,
.exponent = std.math.maxInt(u4),
.mantissa = 0,
};
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E3M4 = packed struct(u8) {
mantissa: u4,
exponent: u3,
sign: u1,
pub const nan: Float8E3M4 = @bitCast(0xFF);
pub fn isNan(self: Float8E3M4) bool {
return self == nan or self == comptime nan.neg();
}
pub const inf: Float8E3M4 = .{
.sign = 0,
.exponent = std.math.maxInt(u3),
.mantissa = 0,
};
pub const minus_inf = neg(inf);
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E8M0 = packed struct(u8) {
mantissa: u0 = 0,
exponent: u8,
sign: u0 = 0,
pub const min_scale: f32 = @bitCast(Float32{ .sign = 0, .exponent = 0, .mantissa = 0b1 << 22 });
/// Lossy conversion from f32, similar to @floatCast
pub fn fromF32(f: f32) Float8E8M0 {
const vf32: Float32 = @bitCast(f);
return .{ .exponent = @intCast(vf32.exponent) };
}
/// Lossless conversion to f32.
pub fn toF32(x: Float8E8M0) f32 {
if (x.exponent == 0) return min_scale;
const vf32: Float32 = .{
.sign = 0,
.exponent = x.exponent,
.mantissa = 0,
};
return @bitCast(vf32);
}
const Helpers = FloatHelpers(@This());
pub const formatNumber = Helpers.formatNumber;
};
test Float8E8M0 {
try std.testing.expectEqual(Float8E8M0{ .exponent = 127 }, Float8E8M0.fromF32(1.0));
// try std.testing.expectEqual(5.877472e-39, Float8E8M0.toF32(.{ .exponent = 0}));
try testCustomFloat(Float8E8M0, .{
.lossless = &[_]f32{ Float8E8M0.min_scale, 1.0, 64.0, 1.0 / 128.0, std.math.pow(f32, 2.0, 127) },
.lossy = &[_]f32{1.00001},
});
}
pub const Float4E2M1 = packed struct(u4) {
mantissa: u1,
exponent: u2,
sign: u1,
pub const nan: Float4E2M1 = @bitCast(@as(u4, 0xF));
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const formatNumber = Helpers.formatNumber;
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
pub fn toF32(x: Float4E2M1) f32 {
// the baseline toF32 doesn't work correctly:
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
return values[@as(u4, @bitCast(x))];
}
test toF32 {
var to_f32_res: [16]f32 = undefined;
for (&to_f32_res, 0..) |*r, i| {
const x_f4: Float4E2M1 = @bitCast(@as(u4, @intCast(i)));
r.* = x_f4.toF32();
}
try std.testing.expectEqualSlices(f32, &Float4E2M1.values, &to_f32_res);
}
test fromF32 {
// the baseline fromF32 doesn't work correctly:
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
// TODO: it probably affects other types.
var from_f32_res: [16]Float4E2M1 = undefined;
for (&from_f32_res, 0..) |*r, i| {
r.* = .fromF32(Float4E2M1.values[i]);
}
try std.testing.expectEqualSlices(u4, &.{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, @ptrCast(&from_f32_res));
}
};
pub fn floatCast(T: type, x: anytype) T {
return switch (T) {
f64, f32, f16 => switch (@TypeOf(x)) {
f64, f32, f16 => @floatCast(x),
else => @floatCast(x.toF32()),
},
else => switch (@TypeOf(x)) {
f64, f32, f16 => .fromF32(x),
else => .fromF32(x.toF32()),
},
};
}
pub fn isInf(x: anytype) bool {
const Float = @TypeOf(x);
switch (Float) {
f64, f32, f16 => return std.math.isInf(x),
else => {},
}
if (!@hasDecl(Float, "inf")) return false;
const FBits = std.meta.Int(.unsigned, @bitSizeOf(Float));
const remove_sign = ~@as(FBits, 0) >> 1;
return @as(FBits, @bitCast(x)) & remove_sign == @as(FBits, @bitCast(Float.inf));
}
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
}
const TestCase = struct {

View File

@ -35,20 +35,26 @@ pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => .int(ctx, .i1),
.f4e2m1 => .float(ctx, .f4e2m1fn),
.f8e3m4 => .float(ctx, .f8e3m4),
.f8e4m3 => .float(ctx, .f8e4m3),
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
.f8e4m3fn => .float(ctx, .f8e4m3fn),
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
.f8e5m2 => .float(ctx, .f8e5m2),
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
.f8e8m0 => .float(ctx, .f8e8m0fnu),
.bf16 => .float(ctx, .bf16),
.f16 => .float(ctx, .f16),
.f32 => .float(ctx, .f32),
.f64 => .float(ctx, .f64),
.i2 => .int(ctx, .i2),
.i4 => .int(ctx, .i4),
.i8 => .int(ctx, .i8),
.i16 => .int(ctx, .i16),
.i32 => .int(ctx, .i32),
.i64 => .int(ctx, .i64),
.u2 => .int(ctx, .u2),
.u4 => .int(ctx, .u4),
.u8 => .int(ctx, .u8),
.u16 => .int(ctx, .u16),
@ -62,23 +68,28 @@ pub const Type = struct {
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{
.{ .bool, mlir.IntegerType(.i1) },
.{ .f4e2m1, mlir.FloatType(.f4e2m1fn) },
.{ .f8e3m4, mlir.FloatType(.f8e3m4) },
.{ .f8e4m3, mlir.FloatType(.f8e4m3) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .f8e8m0, mlir.FloatType(.f8e8m0fnu) },
.{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) },
.{ .i2, mlir.IntegerType(.i2) },
.{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) },
.{ .u2, mlir.IntegerType(.u2) },
.{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) },
@ -89,6 +100,8 @@ pub const Type = struct {
.{ .c128, mlir.ComplexType(.c128) },
};
// TODO: this seems quite slow to have all of those functions calls.
// Maybe we should memoize the ptr of a set of mlir types when creating the context.
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlirT.is_a_fn(mlir_type._inner)) {

View File

@ -1042,6 +1042,40 @@ pub const Tensor = struct {
return _result(self._shape.withDtype(to), op.result(0));
}
test convert {
const floats = @import("floats.zig");
const zml = @import("zml.zig");
const platform = zml.testing.env();
// f4e2m1
{
const x = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
var x_f4: [x.len]floats.Float4E2M1 = undefined;
for (&x_f4, &x) |*xi_f4, xi| xi_f4.* = .fromF32(xi);
const x_d = try zml.Buffer.fromArray(platform, x);
const x_f4_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f4e2m1 });
const x_f4_xla = x_f4_xla_d.getValue(@TypeOf(x_f4));
errdefer std.log.warn("convert(.f4e2m1) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f4, x_f4_xla });
try std.testing.expectEqualDeep(x_f4, x_f4_xla);
}
// f8e3m4
{
const x = [_]f32{ 1.1 / 4.0, 1.1 / 8.0, 1.1 / 16.0, 1.1 / 32.0, 1.1 / 64.0, 1.1 / 128.0 };
var x_f8e3: [x.len]floats.Float8E3M4 = undefined;
for (&x_f8e3, &x) |*xi_f8e3, xi| xi_f8e3.* = .fromF32(xi);
const x_d = try zml.Buffer.fromArray(platform, x);
const x_f8e3_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f8e3m4 });
const x_f8e3_xla = x_f8e3_xla_d.getValue(@TypeOf(x_f8e3));
errdefer std.log.warn("convert(.f8e3m4) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f8e3, x_f8e3_xla });
try std.testing.expectEqualDeep(x_f8e3, x_f8e3_xla);
}
}
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
pub fn round(self: Tensor) Tensor {
const loc = self.getContext().mlirCtx().location(@src());
@ -3844,11 +3878,11 @@ fn getPoolResDims(dt: DataType, in_dims: []const i64, base_dilations: @Vector(Te
}
fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType {
return dialect.stablehlo.CompareType.init(ctx, switch (dtype) {
.i4, .i8, .i16, .i32, .i64 => .SIGNED,
.bool, .u4, .u8, .u16, .u32, .u64 => .UNSIGNED,
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .FLOAT,
.c64, .c128 => @panic("Can't compare complex numbers"),
return dialect.stablehlo.CompareType.init(ctx, switch (dtype.class()) {
.bool => .UNSIGNED,
.integer => if (dtype.isSignedInt()) .SIGNED else .UNSIGNED,
.float => .FLOAT,
.complex => @panic("Can't compare complex numbers"),
});
}

View File

@ -65,11 +65,15 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
.f16,
.f32,
.f64,
.f4e2m1,
.f8e3m4,
.f8e4m3,
.f8e4m3b11fnuz,
.f8e4m3fn,
.f8e4m3fnuz,
.f8e5m2,
.f8e5m2fnuz,
.f8e8m0,
=> |t| {
const L = t.toZigType();
const left_data = left.items(L);
@ -96,7 +100,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
else => unreachable,
}
},
inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| {
inline .bool, .u2, .u4, .u8, .u16, .u32, .u64, .i2, .i4, .i8, .i16, .i32, .i64 => |t| {
const T = t.toZigType();
return std.testing.expectEqualSlices(T, left.items(T), right.items(T));
},