mlir, pjrt, zml: expose missing data types (u2, i1, i2, f4e2m1fn, f8e3m4, f8e4m3, f8e8m0fnu); fix Float32 conversion bug that truncated values
This commit is contained in:
parent
29bd1242ba
commit
e641d05dd2
@ -414,9 +414,9 @@ pub const ArrayAttribute = struct {
|
||||
|
||||
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
||||
const ZigType, const getter = comptime switch (it) {
|
||||
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
|
||||
.i1, .i2, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
|
||||
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
|
||||
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
|
||||
.u2, .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
|
||||
.unknown => @compileError("IntegerAttribute(unknown)"),
|
||||
};
|
||||
|
||||
@ -1249,6 +1249,7 @@ pub const IndexType = struct {
|
||||
|
||||
pub const IntegerTypes = enum {
|
||||
i1,
|
||||
i2,
|
||||
i4,
|
||||
i8,
|
||||
i16,
|
||||
@ -1259,6 +1260,7 @@ pub const IntegerTypes = enum {
|
||||
si16,
|
||||
si32,
|
||||
si64,
|
||||
u2,
|
||||
u4,
|
||||
u8,
|
||||
u16,
|
||||
@ -1271,6 +1273,7 @@ pub const IntegerTypes = enum {
|
||||
pub fn IntegerType(comptime it: IntegerTypes) type {
|
||||
const Config = switch (it) {
|
||||
.i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
||||
.i2 => .{ 2, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
||||
.i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
||||
.i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
||||
.i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
||||
@ -1281,6 +1284,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
|
||||
.si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
||||
.si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
||||
.si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
||||
.u2 => .{ 2, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
||||
.u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
||||
.u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
||||
.u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
||||
@ -1320,11 +1324,15 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
|
||||
}
|
||||
|
||||
pub const FloatTypes = enum {
|
||||
f4e2m1fn,
|
||||
f8e3m4,
|
||||
f8e4m3,
|
||||
f8e4m3b11fnuz,
|
||||
f8e4m3fn,
|
||||
f8e4m3fnuz,
|
||||
f8e5m2,
|
||||
f8e5m2fnuz,
|
||||
f8e8m0fnu,
|
||||
bf16,
|
||||
f16,
|
||||
f32,
|
||||
@ -1338,12 +1346,19 @@ pub const FloatTypes = enum {
|
||||
};
|
||||
|
||||
pub fn FloatType(comptime ft: FloatTypes) type {
|
||||
const Config = switch (ft) {
|
||||
const Config: struct {
|
||||
*const fn (c.MlirType) callconv(.c) bool,
|
||||
*const fn (c.MlirContext) callconv(.c) c.MlirType,
|
||||
} = switch (ft) {
|
||||
.f4e2m1fn => .{ c.mlirTypeIsAFloat4E2M1FN, c.mlirFloat4E2M1FNTypeGet },
|
||||
.f8e3m4 => .{ c.mlirTypeIsAFloat8E3M4, c.mlirFloat8E3M4TypeGet },
|
||||
.f8e4m3 => .{ c.mlirTypeIsAFloat8E4M3, c.mlirFloat8E4M3TypeGet },
|
||||
.f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet },
|
||||
.f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet },
|
||||
.f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet },
|
||||
.f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet },
|
||||
.f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet },
|
||||
.f8e8m0fnu => .{ c.mlirTypeIsAFloat8E8M0FNU, c.mlirFloat8E8M0FNUTypeGet },
|
||||
.bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet },
|
||||
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
|
||||
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
|
||||
|
||||
@ -800,11 +800,13 @@ pub const LoadedExecutable = opaque {
|
||||
pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
||||
invalid = c.PJRT_Buffer_Type_INVALID,
|
||||
bool = c.PJRT_Buffer_Type_PRED,
|
||||
i2 = c.PJRT_Buffer_Type_S2,
|
||||
i4 = c.PJRT_Buffer_Type_S4,
|
||||
i8 = c.PJRT_Buffer_Type_S8,
|
||||
i16 = c.PJRT_Buffer_Type_S16,
|
||||
i32 = c.PJRT_Buffer_Type_S32,
|
||||
i64 = c.PJRT_Buffer_Type_S64,
|
||||
u2 = c.PJRT_Buffer_Type_U2,
|
||||
u4 = c.PJRT_Buffer_Type_U4,
|
||||
u8 = c.PJRT_Buffer_Type_U8,
|
||||
u16 = c.PJRT_Buffer_Type_U16,
|
||||
@ -821,6 +823,10 @@ pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
||||
f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
||||
f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
||||
f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
||||
f8e4m3 = c.PJRT_Buffer_Type_F8E4M3,
|
||||
f8e3m4 = c.PJRT_Buffer_Type_F8E3M4,
|
||||
f8e8m0 = c.PJRT_Buffer_Type_F8E8M0FNU,
|
||||
f4e2m1 = c.PJRT_Buffer_Type_F4E2M1FN,
|
||||
};
|
||||
|
||||
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
||||
|
||||
@ -12,20 +12,26 @@ test {
|
||||
pub const DataType = enum(u8) {
|
||||
bool,
|
||||
// Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda.
|
||||
f4e2m1,
|
||||
f8e3m4,
|
||||
f8e4m3,
|
||||
f8e4m3b11fnuz,
|
||||
f8e4m3fn,
|
||||
f8e4m3fnuz,
|
||||
f8e5m2,
|
||||
f8e5m2fnuz,
|
||||
f8e8m0,
|
||||
bf16,
|
||||
f16,
|
||||
f32,
|
||||
f64,
|
||||
i2,
|
||||
i4,
|
||||
i8,
|
||||
i16,
|
||||
i32,
|
||||
i64,
|
||||
u2,
|
||||
u4,
|
||||
u8,
|
||||
u16,
|
||||
@ -50,8 +56,21 @@ pub const DataType = enum(u8) {
|
||||
pub fn class(self: DataType) Class {
|
||||
return switch (self) {
|
||||
.bool => .bool,
|
||||
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .float,
|
||||
.i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => .integer,
|
||||
.f4e2m1,
|
||||
.f8e3m4,
|
||||
.f8e4m3,
|
||||
.f8e4m3b11fnuz,
|
||||
.f8e4m3fn,
|
||||
.f8e4m3fnuz,
|
||||
.f8e5m2,
|
||||
.f8e5m2fnuz,
|
||||
.f8e8m0,
|
||||
.bf16,
|
||||
.f16,
|
||||
.f32,
|
||||
.f64,
|
||||
=> .float,
|
||||
.i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => .integer,
|
||||
.c64, .c128 => .complex,
|
||||
};
|
||||
}
|
||||
@ -70,21 +89,27 @@ pub const DataType = enum(u8) {
|
||||
|
||||
pub fn fromZigType(comptime T: type) DataType {
|
||||
return switch (T) {
|
||||
floats.Float4E2M1 => .f4e2m1,
|
||||
floats.Float8E3M4 => .f8e3m4,
|
||||
floats.Float8E4M3 => .f8e4m3,
|
||||
floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
||||
floats.Float8E4M3FN => .f8e4m3fn,
|
||||
floats.Float8E4M3FNUZ => .f8e4m3fnuz,
|
||||
floats.Float8E5M2 => .f8e5m2,
|
||||
floats.Float8E5M2FNUZ => .f8e5m2fnuz,
|
||||
floats.Float8E8M0 => .f8e8m0,
|
||||
floats.BFloat16 => .bf16,
|
||||
f16 => .f16,
|
||||
f32 => .f32,
|
||||
f64 => .f64,
|
||||
bool => .bool,
|
||||
i2 => .i2,
|
||||
i4 => .i4,
|
||||
i8 => .i8,
|
||||
i16 => .i16,
|
||||
i32 => .i32,
|
||||
i64 => .i64,
|
||||
u2 => .u2,
|
||||
u4 => .u4,
|
||||
u8 => .u8,
|
||||
u16 => .u16,
|
||||
@ -192,10 +217,10 @@ pub const DataType = enum(u8) {
|
||||
pub fn maxValue(dtype: DataType) Data {
|
||||
return switch (dtype) {
|
||||
.bool => .{ .bool = true },
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
|
||||
inline .f4e2m1, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz, .f8e8m0 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||
inline .f8e3m4, .f8e4m3, .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
|
||||
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))),
|
||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
|
||||
inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
|
||||
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||
};
|
||||
}
|
||||
@ -207,20 +232,26 @@ pub const DataType = enum(u8) {
|
||||
|
||||
pub const Data = union(DataType) {
|
||||
bool: bool,
|
||||
f4e2m1: floats.Float4E2M1,
|
||||
f8e3m4: floats.Float8E3M4,
|
||||
f8e4m3: floats.Float8E4M3,
|
||||
f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ,
|
||||
f8e4m3fn: floats.Float8E4M3FN,
|
||||
f8e4m3fnuz: floats.Float8E4M3FNUZ,
|
||||
f8e5m2: floats.Float8E5M2,
|
||||
f8e5m2fnuz: floats.Float8E5M2FNUZ,
|
||||
f8e8m0: floats.Float8E8M0,
|
||||
bf16: floats.BFloat16,
|
||||
f16: f16,
|
||||
f32: f32,
|
||||
f64: f64,
|
||||
i2: i2,
|
||||
i4: i4,
|
||||
i8: i8,
|
||||
i16: i16,
|
||||
i32: i32,
|
||||
i64: i64,
|
||||
u2: u2,
|
||||
u4: u4,
|
||||
u8: u8,
|
||||
u16: u16,
|
||||
@ -244,7 +275,7 @@ pub const Data = union(DataType) {
|
||||
.comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 },
|
||||
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
|
||||
},
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) {
|
||||
inline .f4e2m1, .f8e3m4, .f8e4m3, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .f8e8m0, .bf16 => |tag| switch (Ti) {
|
||||
.comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))),
|
||||
.comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))),
|
||||
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
|
||||
@ -254,7 +285,7 @@ pub const Data = union(DataType) {
|
||||
.comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)),
|
||||
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
|
||||
},
|
||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
|
||||
inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
|
||||
.comptime_int => blk: {
|
||||
const OutT = @FieldType(Data, @tagName(tag));
|
||||
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {
|
||||
|
||||
312
zml/floats.zig
312
zml/floats.zig
@ -5,10 +5,6 @@ test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
fn allBitsOne(v: anytype) bool {
|
||||
return v == std.math.maxInt(@TypeOf(v));
|
||||
}
|
||||
|
||||
fn FloatHelpers(Float: type) type {
|
||||
const info = @typeInfo(Float);
|
||||
const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}";
|
||||
@ -23,9 +19,11 @@ fn FloatHelpers(Float: type) type {
|
||||
}
|
||||
|
||||
return struct {
|
||||
const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits;
|
||||
const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits;
|
||||
const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits;
|
||||
const f32_mantissa_bits: u8 = @typeInfo(@FieldType(Float32, "mantissa")).int.bits;
|
||||
const exp_bias: i16 = std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
|
||||
const exp_off: u8 = FloatHelpers(Float32).exp_bias - exp_bias;
|
||||
|
||||
pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 };
|
||||
|
||||
@ -39,52 +37,74 @@ fn FloatHelpers(Float: type) type {
|
||||
|
||||
/// Lossy conversion from f32, similar to @floatCast
|
||||
pub fn fromF32(f: f32) Float {
|
||||
@setRuntimeSafety(false);
|
||||
|
||||
const vf32: Float32 = @bitCast(f);
|
||||
const exp_bias = comptime expBias();
|
||||
const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias();
|
||||
const exponent: i16 = @as(i16, vf32.exponent) - exp_off;
|
||||
const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent"));
|
||||
if (overflow) {
|
||||
return if (@hasDecl(Float, "inf")) {
|
||||
return if (vf32.sign == 0) Float.inf else Float.minus_inf;
|
||||
} else Float.nan;
|
||||
@branchHint(.unlikely);
|
||||
return if (@hasDecl(Float, "inf"))
|
||||
if (vf32.sign == 0) Float.inf else Float.minus_inf
|
||||
else
|
||||
Float.nan;
|
||||
}
|
||||
return .{
|
||||
.sign = vf32.sign,
|
||||
.exponent = @intCast(exponent),
|
||||
.mantissa = truncMantissa(vf32.mantissa),
|
||||
};
|
||||
|
||||
return if (exponent <= 0)
|
||||
.{
|
||||
.sign = vf32.sign,
|
||||
.exponent = 0,
|
||||
.mantissa = shiftMantissa(vf32.mantissa, @intCast(-exponent)),
|
||||
}
|
||||
else
|
||||
.{
|
||||
.sign = vf32.sign,
|
||||
.exponent = @intCast(exponent),
|
||||
.mantissa = truncMantissa(vf32.mantissa),
|
||||
};
|
||||
}
|
||||
|
||||
/// Lossless conversion to f32.
|
||||
pub fn toF32(x: Float) f32 {
|
||||
var vf32: Float32 = undefined;
|
||||
if (@hasDecl(Float, "isInf") and x.isInf()) {
|
||||
@setRuntimeSafety(false);
|
||||
|
||||
if (x == zero) return 0.0;
|
||||
if (isInf(x)) {
|
||||
@branchHint(.unlikely);
|
||||
return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
|
||||
}
|
||||
vf32 = .{
|
||||
.sign = x.sign,
|
||||
.exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()),
|
||||
.mantissa = f32Mantissa(x),
|
||||
};
|
||||
|
||||
const vf32: Float32 = if (x.exponent > 0)
|
||||
.{
|
||||
.sign = x.sign,
|
||||
.exponent = @as(u8, x.exponent) + exp_off,
|
||||
.mantissa = f32Mantissa(x),
|
||||
}
|
||||
else
|
||||
.{
|
||||
.sign = x.sign,
|
||||
.exponent = exp_off - @clz(x.mantissa),
|
||||
.mantissa = @as(u23, x.mantissa) << @clz(x.mantissa),
|
||||
};
|
||||
return @bitCast(vf32);
|
||||
}
|
||||
|
||||
fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") {
|
||||
@setRuntimeSafety(false);
|
||||
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits;
|
||||
return @intCast(x >> off);
|
||||
fn truncMantissa(f32_mantissa: u32) @FieldType(Float, "mantissa") {
|
||||
const rounding_val: u32 = @as(u32, 1) << (f32_mantissa_bits - mantissa_bits - 1);
|
||||
return @truncate((f32_mantissa + rounding_val) >> (f32_mantissa_bits - mantissa_bits));
|
||||
}
|
||||
|
||||
fn shiftMantissa(f32_mantissa: u32, underflow: u8) @FieldType(Float, "mantissa") {
|
||||
const upper_bit: u32 = @as(u32, 1) << f32_mantissa_bits;
|
||||
const full_mant32: u32 = f32_mantissa | upper_bit;
|
||||
// divide the mantissa proportionally to the exponent underflow
|
||||
const shifted_mant: u32 = full_mant32 >> @truncate(underflow + 1);
|
||||
return truncMantissa(shifted_mant);
|
||||
}
|
||||
|
||||
fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") {
|
||||
@setRuntimeSafety(false);
|
||||
const Res = @FieldType(Float32, "mantissa");
|
||||
const f32_mantissa_bits = @bitSizeOf(Res);
|
||||
|
||||
return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits);
|
||||
}
|
||||
|
||||
fn expBias() u8 {
|
||||
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
|
||||
const T = @FieldType(Float32, "mantissa");
|
||||
return @as(T, x.mantissa) << f32_mantissa_bits - mantissa_bits;
|
||||
}
|
||||
|
||||
pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||
@ -101,6 +121,9 @@ pub const Float32 = packed struct(u32) {
|
||||
exponent: u8,
|
||||
sign: u1,
|
||||
|
||||
pub const inf: Float32 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 0 };
|
||||
pub const minus_inf = neg(inf);
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
@ -129,11 +152,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) {
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3B11FNUZ = .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
pub const nan: Float8E4M3B11FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
|
||||
|
||||
pub fn isNan(self: Float8E4M3B11FNUZ) bool {
|
||||
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
|
||||
@ -155,7 +174,7 @@ pub const Float8E4M3FN = packed struct(u8) {
|
||||
pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) };
|
||||
|
||||
pub fn isNan(self: Float8E4M3FN) bool {
|
||||
return allBitsOne(self.exponent) and allBitsOne(self.mantissa);
|
||||
return self.exponent == nan.exponent and self.mantissa == nan.mantissa;
|
||||
}
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
@ -170,11 +189,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3FNUZ = .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
pub const nan: Float8E4M3FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
|
||||
|
||||
pub fn isNan(self: Float8E4M3FNUZ) bool {
|
||||
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
|
||||
@ -189,9 +204,10 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
|
||||
};
|
||||
|
||||
test "Float8E4" {
|
||||
// With 4 bits of exponents power of two can be represented exactly up to 64.
|
||||
const test_case_e4: TestCase = .{
|
||||
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 },
|
||||
.lossy = &[_]f32{3.02344107628},
|
||||
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128, -1.125 / 64.0 },
|
||||
.lossy = &[_]f32{ 3.02344107628, 1.0 / 128.0, 1.0 / 512.0 },
|
||||
};
|
||||
|
||||
inline for (.{
|
||||
@ -200,7 +216,11 @@ test "Float8E4" {
|
||||
Float8E4M3FNUZ,
|
||||
}) |Float8T| {
|
||||
try testCustomFloat(Float8T, test_case_e4);
|
||||
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32());
|
||||
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 2048.0).toF32());
|
||||
if (@hasDecl(Float8T, "inf")) {
|
||||
try std.testing.expectEqual(Float8T.inf, Float8T.fromF32(128.0));
|
||||
try std.testing.expectEqual(Float8T.inf.neg(), Float8T.fromF32(-128.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,31 +229,19 @@ pub const Float8E5M2 = packed struct(u8) {
|
||||
exponent: u5,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E5M2 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 1,
|
||||
};
|
||||
pub const nan: Float8E5M2 = .{ .sign = 0, .exponent = std.math.maxInt(u5), .mantissa = 1 };
|
||||
|
||||
pub fn isNan(self: Float8E5M2) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa != 0;
|
||||
return self.exponent == nan.exponent and self.mantissa != 0;
|
||||
}
|
||||
|
||||
pub const minus_inf: Float8E5M2 = .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub const inf: Float8E5M2 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isInf(self: Float8E5M2) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||
}
|
||||
pub const minus_inf: Float8E5M2 = .neg(inf);
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
@ -283,22 +291,16 @@ pub const BFloat16 = packed struct(u16) {
|
||||
return allBitsOne(self.exponent) and self.mantissa != 0;
|
||||
}
|
||||
|
||||
pub const minus_inf: BFloat16 = .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub const inf: BFloat16 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isInf(self: BFloat16) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||
}
|
||||
pub const minus_inf: BFloat16 = .neg(inf);
|
||||
|
||||
// Specialized versions of to/from F32. Since BFloat16 has the same exponent range than F32,
|
||||
// no overflow/underflow can happen, simplifiying conversion logic.
|
||||
pub fn toF32(self: BFloat16) f32 {
|
||||
// Pad the BF16 with zeros 0
|
||||
return @bitCast([2]u16{ 0, @bitCast(self) });
|
||||
@ -333,11 +335,167 @@ test BFloat16 {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn floatCast(T: type, x: anytype) T {
|
||||
return switch (@TypeOf(x)) {
|
||||
f64, f32, f16 => @floatCast(x),
|
||||
else => @floatCast(x.toF32()),
|
||||
pub const Float8E4M3 = packed struct(u8) {
|
||||
mantissa: u3,
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3 = @bitCast(0xFF);
|
||||
|
||||
pub fn isNan(self: Float8E4M3) bool {
|
||||
return self == nan or self == comptime nan.neg();
|
||||
}
|
||||
|
||||
pub const inf: Float8E4M3 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u4),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub const minus_inf = neg(inf);
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const formatNumber = Helpers.formatNumber;
|
||||
};
|
||||
|
||||
pub const Float8E3M4 = packed struct(u8) {
|
||||
mantissa: u4,
|
||||
exponent: u3,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E3M4 = @bitCast(0xFF);
|
||||
|
||||
pub fn isNan(self: Float8E3M4) bool {
|
||||
return self == nan or self == comptime nan.neg();
|
||||
}
|
||||
|
||||
pub const inf: Float8E3M4 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u3),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub const minus_inf = neg(inf);
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const formatNumber = Helpers.formatNumber;
|
||||
};
|
||||
|
||||
pub const Float8E8M0 = packed struct(u8) {
|
||||
mantissa: u0 = 0,
|
||||
exponent: u8,
|
||||
sign: u0 = 0,
|
||||
|
||||
pub const min_scale: f32 = @bitCast(Float32{ .sign = 0, .exponent = 0, .mantissa = 0b1 << 22 });
|
||||
|
||||
/// Lossy conversion from f32, similar to @floatCast
|
||||
pub fn fromF32(f: f32) Float8E8M0 {
|
||||
const vf32: Float32 = @bitCast(f);
|
||||
return .{ .exponent = @intCast(vf32.exponent) };
|
||||
}
|
||||
|
||||
/// Lossless conversion to f32.
|
||||
pub fn toF32(x: Float8E8M0) f32 {
|
||||
if (x.exponent == 0) return min_scale;
|
||||
const vf32: Float32 = .{
|
||||
.sign = 0,
|
||||
.exponent = x.exponent,
|
||||
.mantissa = 0,
|
||||
};
|
||||
return @bitCast(vf32);
|
||||
}
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const formatNumber = Helpers.formatNumber;
|
||||
};
|
||||
|
||||
test Float8E8M0 {
|
||||
try std.testing.expectEqual(Float8E8M0{ .exponent = 127 }, Float8E8M0.fromF32(1.0));
|
||||
// try std.testing.expectEqual(5.877472e-39, Float8E8M0.toF32(.{ .exponent = 0}));
|
||||
|
||||
try testCustomFloat(Float8E8M0, .{
|
||||
.lossless = &[_]f32{ Float8E8M0.min_scale, 1.0, 64.0, 1.0 / 128.0, std.math.pow(f32, 2.0, 127) },
|
||||
.lossy = &[_]f32{1.00001},
|
||||
});
|
||||
}
|
||||
|
||||
pub const Float4E2M1 = packed struct(u4) {
|
||||
mantissa: u1,
|
||||
exponent: u2,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float4E2M1 = @bitCast(@as(u4, 0xF));
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const formatNumber = Helpers.formatNumber;
|
||||
|
||||
pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
|
||||
|
||||
pub fn toF32(x: Float4E2M1) f32 {
|
||||
// the baseline toF32 doesn't work correctly:
|
||||
// 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon
|
||||
return values[@as(u4, @bitCast(x))];
|
||||
}
|
||||
|
||||
test toF32 {
|
||||
var to_f32_res: [16]f32 = undefined;
|
||||
for (&to_f32_res, 0..) |*r, i| {
|
||||
const x_f4: Float4E2M1 = @bitCast(@as(u4, @intCast(i)));
|
||||
r.* = x_f4.toF32();
|
||||
}
|
||||
try std.testing.expectEqualSlices(f32, &Float4E2M1.values, &to_f32_res);
|
||||
}
|
||||
|
||||
test fromF32 {
|
||||
// the baseline fromF32 doesn't work correctly:
|
||||
// ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead.
|
||||
// TODO: it probably affects other types.
|
||||
var from_f32_res: [16]Float4E2M1 = undefined;
|
||||
for (&from_f32_res, 0..) |*r, i| {
|
||||
r.* = .fromF32(Float4E2M1.values[i]);
|
||||
}
|
||||
try std.testing.expectEqualSlices(u4, &.{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, @ptrCast(&from_f32_res));
|
||||
}
|
||||
};
|
||||
|
||||
pub fn floatCast(T: type, x: anytype) T {
|
||||
return switch (T) {
|
||||
f64, f32, f16 => switch (@TypeOf(x)) {
|
||||
f64, f32, f16 => @floatCast(x),
|
||||
else => @floatCast(x.toF32()),
|
||||
},
|
||||
else => switch (@TypeOf(x)) {
|
||||
f64, f32, f16 => .fromF32(x),
|
||||
else => .fromF32(x.toF32()),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
pub fn isInf(x: anytype) bool {
|
||||
const Float = @TypeOf(x);
|
||||
switch (Float) {
|
||||
f64, f32, f16 => return std.math.isInf(x),
|
||||
else => {},
|
||||
}
|
||||
if (!@hasDecl(Float, "inf")) return false;
|
||||
|
||||
const FBits = std.meta.Int(.unsigned, @bitSizeOf(Float));
|
||||
const remove_sign = ~@as(FBits, 0) >> 1;
|
||||
return @as(FBits, @bitCast(x)) & remove_sign == @as(FBits, @bitCast(Float.inf));
|
||||
}
|
||||
|
||||
fn allBitsOne(v: anytype) bool {
|
||||
return v == std.math.maxInt(@TypeOf(v));
|
||||
}
|
||||
|
||||
const TestCase = struct {
|
||||
|
||||
@ -35,20 +35,26 @@ pub const Type = struct {
|
||||
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
|
||||
return switch (dt) {
|
||||
.bool => .int(ctx, .i1),
|
||||
.f4e2m1 => .float(ctx, .f4e2m1fn),
|
||||
.f8e3m4 => .float(ctx, .f8e3m4),
|
||||
.f8e4m3 => .float(ctx, .f8e4m3),
|
||||
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
|
||||
.f8e4m3fn => .float(ctx, .f8e4m3fn),
|
||||
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
|
||||
.f8e5m2 => .float(ctx, .f8e5m2),
|
||||
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
|
||||
.f8e8m0 => .float(ctx, .f8e8m0fnu),
|
||||
.bf16 => .float(ctx, .bf16),
|
||||
.f16 => .float(ctx, .f16),
|
||||
.f32 => .float(ctx, .f32),
|
||||
.f64 => .float(ctx, .f64),
|
||||
.i2 => .int(ctx, .i2),
|
||||
.i4 => .int(ctx, .i4),
|
||||
.i8 => .int(ctx, .i8),
|
||||
.i16 => .int(ctx, .i16),
|
||||
.i32 => .int(ctx, .i32),
|
||||
.i64 => .int(ctx, .i64),
|
||||
.u2 => .int(ctx, .u2),
|
||||
.u4 => .int(ctx, .u4),
|
||||
.u8 => .int(ctx, .u8),
|
||||
.u16 => .int(ctx, .u16),
|
||||
@ -62,23 +68,28 @@ pub const Type = struct {
|
||||
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
|
||||
const mapping = .{
|
||||
.{ .bool, mlir.IntegerType(.i1) },
|
||||
|
||||
.{ .f4e2m1, mlir.FloatType(.f4e2m1fn) },
|
||||
.{ .f8e3m4, mlir.FloatType(.f8e3m4) },
|
||||
.{ .f8e4m3, mlir.FloatType(.f8e4m3) },
|
||||
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
|
||||
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
|
||||
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
|
||||
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
|
||||
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
|
||||
.{ .f8e8m0, mlir.FloatType(.f8e8m0fnu) },
|
||||
.{ .bf16, mlir.FloatType(.bf16) },
|
||||
.{ .f16, mlir.FloatType(.f16) },
|
||||
.{ .f32, mlir.FloatType(.f32) },
|
||||
.{ .f64, mlir.FloatType(.f64) },
|
||||
|
||||
.{ .i2, mlir.IntegerType(.i2) },
|
||||
.{ .i4, mlir.IntegerType(.i4) },
|
||||
.{ .i8, mlir.IntegerType(.i8) },
|
||||
.{ .i16, mlir.IntegerType(.i16) },
|
||||
.{ .i32, mlir.IntegerType(.i32) },
|
||||
.{ .i64, mlir.IntegerType(.i64) },
|
||||
|
||||
.{ .u2, mlir.IntegerType(.u2) },
|
||||
.{ .u4, mlir.IntegerType(.u4) },
|
||||
.{ .u8, mlir.IntegerType(.u8) },
|
||||
.{ .u16, mlir.IntegerType(.u16) },
|
||||
@ -89,6 +100,8 @@ pub const Type = struct {
|
||||
.{ .c128, mlir.ComplexType(.c128) },
|
||||
};
|
||||
|
||||
// TODO: this seems quite slow to have all of those functions calls.
|
||||
// Maybe we should memoize the ptr of a set of mlir types when creating the context.
|
||||
inline for (mapping) |entry| {
|
||||
const dt, const mlirT = entry;
|
||||
if (mlirT.is_a_fn(mlir_type._inner)) {
|
||||
|
||||
@ -1042,6 +1042,40 @@ pub const Tensor = struct {
|
||||
return _result(self._shape.withDtype(to), op.result(0));
|
||||
}
|
||||
|
||||
test convert {
|
||||
const floats = @import("floats.zig");
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// f4e2m1
|
||||
{
|
||||
const x = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 };
|
||||
var x_f4: [x.len]floats.Float4E2M1 = undefined;
|
||||
for (&x_f4, &x) |*xi_f4, xi| xi_f4.* = .fromF32(xi);
|
||||
|
||||
const x_d = try zml.Buffer.fromArray(platform, x);
|
||||
const x_f4_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f4e2m1 });
|
||||
|
||||
const x_f4_xla = x_f4_xla_d.getValue(@TypeOf(x_f4));
|
||||
errdefer std.log.warn("convert(.f4e2m1) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f4, x_f4_xla });
|
||||
try std.testing.expectEqualDeep(x_f4, x_f4_xla);
|
||||
}
|
||||
|
||||
// f8e3m4
|
||||
{
|
||||
const x = [_]f32{ 1.1 / 4.0, 1.1 / 8.0, 1.1 / 16.0, 1.1 / 32.0, 1.1 / 64.0, 1.1 / 128.0 };
|
||||
var x_f8e3: [x.len]floats.Float8E3M4 = undefined;
|
||||
for (&x_f8e3, &x) |*xi_f8e3, xi| xi_f8e3.* = .fromF32(xi);
|
||||
|
||||
const x_d = try zml.Buffer.fromArray(platform, x);
|
||||
const x_f8e3_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f8e3m4 });
|
||||
|
||||
const x_f8e3_xla = x_f8e3_xla_d.getValue(@TypeOf(x_f8e3));
|
||||
errdefer std.log.warn("convert(.f8e3m4) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f8e3, x_f8e3_xla });
|
||||
try std.testing.expectEqualDeep(x_f8e3, x_f8e3_xla);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
|
||||
pub fn round(self: Tensor) Tensor {
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
@ -3844,11 +3878,11 @@ fn getPoolResDims(dt: DataType, in_dims: []const i64, base_dilations: @Vector(Te
|
||||
}
|
||||
|
||||
fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType {
|
||||
return dialect.stablehlo.CompareType.init(ctx, switch (dtype) {
|
||||
.i4, .i8, .i16, .i32, .i64 => .SIGNED,
|
||||
.bool, .u4, .u8, .u16, .u32, .u64 => .UNSIGNED,
|
||||
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .FLOAT,
|
||||
.c64, .c128 => @panic("Can't compare complex numbers"),
|
||||
return dialect.stablehlo.CompareType.init(ctx, switch (dtype.class()) {
|
||||
.bool => .UNSIGNED,
|
||||
.integer => if (dtype.isSignedInt()) .SIGNED else .UNSIGNED,
|
||||
.float => .FLOAT,
|
||||
.complex => @panic("Can't compare complex numbers"),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -65,11 +65,15 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
.f16,
|
||||
.f32,
|
||||
.f64,
|
||||
.f4e2m1,
|
||||
.f8e3m4,
|
||||
.f8e4m3,
|
||||
.f8e4m3b11fnuz,
|
||||
.f8e4m3fn,
|
||||
.f8e4m3fnuz,
|
||||
.f8e5m2,
|
||||
.f8e5m2fnuz,
|
||||
.f8e8m0,
|
||||
=> |t| {
|
||||
const L = t.toZigType();
|
||||
const left_data = left.items(L);
|
||||
@ -96,7 +100,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
else => unreachable,
|
||||
}
|
||||
},
|
||||
inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| {
|
||||
inline .bool, .u2, .u4, .u8, .u16, .u32, .u64, .i2, .i4, .i8, .i16, .i32, .i64 => |t| {
|
||||
const T = t.toZigType();
|
||||
return std.testing.expectEqualSlices(T, left.items(T), right.items(T));
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user