diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 15dbdab..a3b9114 100755 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -414,9 +414,9 @@ pub const ArrayAttribute = struct { pub fn IntegerAttribute(comptime it: IntegerTypes) type { const ZigType, const getter = comptime switch (it) { - .i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt }, + .i1, .i2, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt }, .si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt }, - .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt }, + .u2, .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt }, .unknown => @compileError("IntegerAttribute(unknown)"), }; @@ -1249,6 +1249,7 @@ pub const IndexType = struct { pub const IntegerTypes = enum { i1, + i2, i4, i8, i16, @@ -1259,6 +1260,7 @@ pub const IntegerTypes = enum { si16, si32, si64, + u2, u4, u8, u16, @@ -1271,6 +1273,7 @@ pub const IntegerTypes = enum { pub fn IntegerType(comptime it: IntegerTypes) type { const Config = switch (it) { .i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, + .i2 => .{ 2, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, .i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless }, @@ -1281,6 +1284,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type { .si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, .si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, .si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned }, + .u2 => .{ 2, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, .u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned }, @@ -1320,11 +1324,15 @@ pub fn IntegerType(comptime it: IntegerTypes) type { } pub const FloatTypes = enum { + f4e2m1fn, + f8e3m4, + f8e4m3, f8e4m3b11fnuz, f8e4m3fn, f8e4m3fnuz, f8e5m2, f8e5m2fnuz, + f8e8m0fnu, bf16, f16, f32, @@ -1338,12 +1346,19 @@ pub const FloatTypes = enum { }; pub fn FloatType(comptime ft: FloatTypes) type { - const Config = switch (ft) { + const Config: struct { + *const fn (c.MlirType) callconv(.c) bool, + *const fn (c.MlirContext) callconv(.c) c.MlirType, + } = switch (ft) { + .f4e2m1fn => .{ c.mlirTypeIsAFloat4E2M1FN, c.mlirFloat4E2M1FNTypeGet }, + .f8e3m4 => .{ c.mlirTypeIsAFloat8E3M4, c.mlirFloat8E3M4TypeGet }, + .f8e4m3 => .{ c.mlirTypeIsAFloat8E4M3, c.mlirFloat8E4M3TypeGet }, .f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet }, .f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet }, .f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet }, .f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet }, .f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet }, + .f8e8m0fnu => .{ c.mlirTypeIsAFloat8E8M0FNU, c.mlirFloat8E8M0FNUTypeGet }, .bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet }, .f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet }, .f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet }, diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index a3ef83d..cad1ce6 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -800,11 +800,13 @@ pub const LoadedExecutable = opaque { pub const BufferType = enum(c.PJRT_Buffer_Type) { invalid = c.PJRT_Buffer_Type_INVALID, bool = c.PJRT_Buffer_Type_PRED, + i2 = c.PJRT_Buffer_Type_S2, i4 = c.PJRT_Buffer_Type_S4, i8 = c.PJRT_Buffer_Type_S8, i16 = c.PJRT_Buffer_Type_S16, i32 = c.PJRT_Buffer_Type_S32, i64 = c.PJRT_Buffer_Type_S64, + u2 = c.PJRT_Buffer_Type_U2, u4 = c.PJRT_Buffer_Type_U4, u8 = c.PJRT_Buffer_Type_U8, u16 = c.PJRT_Buffer_Type_U16, @@ -821,6 +823,10 @@ pub const BufferType = enum(c.PJRT_Buffer_Type) { f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ, f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ, + f8e4m3 = c.PJRT_Buffer_Type_F8E4M3, + f8e3m4 = c.PJRT_Buffer_Type_F8E3M4, + f8e8m0 = c.PJRT_Buffer_Type_F8E8M0FNU, + f4e2m1 = c.PJRT_Buffer_Type_F4E2M1FN, }; pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) { diff --git a/zml/dtype.zig b/zml/dtype.zig index b0deb8f..b14f483 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -12,20 +12,26 @@ test { pub const DataType = enum(u8) { bool, // Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda. + f4e2m1, + f8e3m4, + f8e4m3, f8e4m3b11fnuz, f8e4m3fn, f8e4m3fnuz, f8e5m2, f8e5m2fnuz, + f8e8m0, bf16, f16, f32, f64, + i2, i4, i8, i16, i32, i64, + u2, u4, u8, u16, @@ -50,8 +56,21 @@ pub const DataType = enum(u8) { pub fn class(self: DataType) Class { return switch (self) { .bool => .bool, - .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .float, - .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => .integer, + .f4e2m1, + .f8e3m4, + .f8e4m3, + .f8e4m3b11fnuz, + .f8e4m3fn, + .f8e4m3fnuz, + .f8e5m2, + .f8e5m2fnuz, + .f8e8m0, + .bf16, + .f16, + .f32, + .f64, + => .float, + .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => .integer, .c64, .c128 => .complex, }; } @@ -70,21 +89,27 @@ pub const DataType = enum(u8) { pub fn fromZigType(comptime T: type) DataType { return switch (T) { + floats.Float4E2M1 => .f4e2m1, + floats.Float8E3M4 => .f8e3m4, + floats.Float8E4M3 => .f8e4m3, floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz, floats.Float8E4M3FN => .f8e4m3fn, floats.Float8E4M3FNUZ => .f8e4m3fnuz, floats.Float8E5M2 => .f8e5m2, floats.Float8E5M2FNUZ => .f8e5m2fnuz, + floats.Float8E8M0 => .f8e8m0, floats.BFloat16 => .bf16, f16 => .f16, f32 => .f32, f64 => .f64, bool => .bool, + i2 => .i2, i4 => .i4, i8 => .i8, i16 => .i16, i32 => .i32, i64 => .i64, + u2 => .u2, u4 => .u4, u8 => .u8, u16 => .u16, @@ -192,10 +217,10 @@ pub const DataType = enum(u8) { pub fn maxValue(dtype: DataType) Data { return switch (dtype) { .bool => .{ .bool = true }, - inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), - inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf), + inline .f4e2m1, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz, .f8e8m0 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), + inline .f8e3m4, .f8e4m3, .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf), inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))), - inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))), + inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))), inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), }; } @@ -207,20 +232,26 @@ pub const DataType = enum(u8) { pub const Data = union(DataType) { bool: bool, + f4e2m1: floats.Float4E2M1, + f8e3m4: floats.Float8E3M4, + f8e4m3: floats.Float8E4M3, f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ, f8e4m3fn: floats.Float8E4M3FN, f8e4m3fnuz: floats.Float8E4M3FNUZ, f8e5m2: floats.Float8E5M2, f8e5m2fnuz: floats.Float8E5M2FNUZ, + f8e8m0: floats.Float8E8M0, bf16: floats.BFloat16, f16: f16, f32: f32, f64: f64, + i2: i2, i4: i4, i8: i8, i16: i16, i32: i32, i64: i64, + u2: u2, u4: u4, u8: u8, u16: u16, @@ -244,7 +275,7 @@ pub const Data = union(DataType) { .comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 }, else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)), }, - inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) { + inline .f4e2m1, .f8e3m4, .f8e4m3, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .f8e8m0, .bf16 => |tag| switch (Ti) { .comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))), .comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))), else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)), @@ -254,7 +285,7 @@ pub const Data = union(DataType) { .comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)), else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)), }, - inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) { + inline .i2, .i4, .i8, .i16, .i32, .i64, .u2, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) { .comptime_int => blk: { const OutT = @FieldType(Data, @tagName(tag)); if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) { diff --git a/zml/floats.zig b/zml/floats.zig index 57bf515..7d995c4 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -5,10 +5,6 @@ test { std.testing.refAllDecls(@This()); } -fn allBitsOne(v: anytype) bool { - return v == std.math.maxInt(@TypeOf(v)); -} - fn FloatHelpers(Float: type) type { const info = @typeInfo(Float); const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}"; @@ -23,9 +19,11 @@ fn FloatHelpers(Float: type) type { } return struct { - const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits; const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits; const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits; + const f32_mantissa_bits: u8 = @typeInfo(@FieldType(Float32, "mantissa")).int.bits; + const exp_bias: i16 = std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1)); + const exp_off: u8 = FloatHelpers(Float32).exp_bias - exp_bias; pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 }; @@ -39,52 +37,74 @@ fn FloatHelpers(Float: type) type { /// Lossy conversion from f32, similar to @floatCast pub fn fromF32(f: f32) Float { + @setRuntimeSafety(false); + const vf32: Float32 = @bitCast(f); - const exp_bias = comptime expBias(); - const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias(); + const exponent: i16 = @as(i16, vf32.exponent) - exp_off; const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent")); if (overflow) { - return if (@hasDecl(Float, "inf")) { - return if (vf32.sign == 0) Float.inf else Float.minus_inf; - } else Float.nan; + @branchHint(.unlikely); + return if (@hasDecl(Float, "inf")) + if (vf32.sign == 0) Float.inf else Float.minus_inf + else + Float.nan; } - return .{ - .sign = vf32.sign, - .exponent = @intCast(exponent), - .mantissa = truncMantissa(vf32.mantissa), - }; + + return if (exponent <= 0) + .{ + .sign = vf32.sign, + .exponent = 0, + .mantissa = shiftMantissa(vf32.mantissa, @intCast(-exponent)), + } + else + .{ + .sign = vf32.sign, + .exponent = @intCast(exponent), + .mantissa = truncMantissa(vf32.mantissa), + }; } /// Lossless conversion to f32. pub fn toF32(x: Float) f32 { - var vf32: Float32 = undefined; - if (@hasDecl(Float, "isInf") and x.isInf()) { + @setRuntimeSafety(false); + + if (x == zero) return 0.0; + if (isInf(x)) { + @branchHint(.unlikely); return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32); } - vf32 = .{ - .sign = x.sign, - .exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()), - .mantissa = f32Mantissa(x), - }; + + const vf32: Float32 = if (x.exponent > 0) + .{ + .sign = x.sign, + .exponent = @as(u8, x.exponent) + exp_off, + .mantissa = f32Mantissa(x), + } + else + .{ + .sign = x.sign, + .exponent = exp_off - @clz(x.mantissa), + .mantissa = @as(u23, x.mantissa) << @clz(x.mantissa), + }; return @bitCast(vf32); } - fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") { - @setRuntimeSafety(false); - const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits; - return @intCast(x >> off); + fn truncMantissa(f32_mantissa: u32) @FieldType(Float, "mantissa") { + const rounding_val: u32 = @as(u32, 1) << (f32_mantissa_bits - mantissa_bits - 1); + return @truncate((f32_mantissa + rounding_val) >> (f32_mantissa_bits - mantissa_bits)); + } + + fn shiftMantissa(f32_mantissa: u32, underflow: u8) @FieldType(Float, "mantissa") { + const upper_bit: u32 = @as(u32, 1) << f32_mantissa_bits; + const full_mant32: u32 = f32_mantissa | upper_bit; + // divide the mantissa proportionally to the exponent underflow + const shifted_mant: u32 = full_mant32 >> @truncate(underflow + 1); + return truncMantissa(shifted_mant); } fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") { - @setRuntimeSafety(false); - const Res = @FieldType(Float32, "mantissa"); - const f32_mantissa_bits = @bitSizeOf(Res); - - return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits); - } - - fn expBias() u8 { - return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1)); + const T = @FieldType(Float32, "mantissa"); + return @as(T, x.mantissa) << f32_mantissa_bits - mantissa_bits; } pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { @@ -101,6 +121,9 @@ pub const Float32 = packed struct(u32) { exponent: u8, sign: u1, + pub const inf: Float32 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 0 }; + pub const minus_inf = neg(inf); + const Helpers = FloatHelpers(@This()); pub const zero = Helpers.zero; pub const neg = Helpers.neg; @@ -129,11 +152,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) { exponent: u4, sign: u1, - pub const nan: Float8E4M3B11FNUZ = .{ - .sign = 1, - .exponent = 0, - .mantissa = 0, - }; + pub const nan: Float8E4M3B11FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 }; pub fn isNan(self: Float8E4M3B11FNUZ) bool { return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; @@ -155,7 +174,7 @@ pub const Float8E4M3FN = packed struct(u8) { pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) }; pub fn isNan(self: Float8E4M3FN) bool { - return allBitsOne(self.exponent) and allBitsOne(self.mantissa); + return self.exponent == nan.exponent and self.mantissa == nan.mantissa; } const Helpers = FloatHelpers(@This()); pub const zero = Helpers.zero; @@ -170,11 +189,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) { exponent: u4, sign: u1, - pub const nan: Float8E4M3FNUZ = .{ - .sign = 1, - .exponent = 0, - .mantissa = 0, - }; + pub const nan: Float8E4M3FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 }; pub fn isNan(self: Float8E4M3FNUZ) bool { return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; @@ -189,9 +204,10 @@ pub const Float8E4M3FNUZ = packed struct(u8) { }; test "Float8E4" { + // With 4 bits of exponents power of two can be represented exactly up to 64. const test_case_e4: TestCase = .{ - .lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 }, - .lossy = &[_]f32{3.02344107628}, + .lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128, -1.125 / 64.0 }, + .lossy = &[_]f32{ 3.02344107628, 1.0 / 128.0, 1.0 / 512.0 }, }; inline for (.{ @@ -200,7 +216,11 @@ test "Float8E4" { Float8E4M3FNUZ, }) |Float8T| { try testCustomFloat(Float8T, test_case_e4); - try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32()); + try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 2048.0).toF32()); + if (@hasDecl(Float8T, "inf")) { + try std.testing.expectEqual(Float8T.inf, Float8T.fromF32(128.0)); + try std.testing.expectEqual(Float8T.inf.neg(), Float8T.fromF32(-128.0)); + } } } @@ -209,31 +229,19 @@ pub const Float8E5M2 = packed struct(u8) { exponent: u5, sign: u1, - pub const nan: Float8E5M2 = .{ - .sign = 0, - .exponent = std.math.maxInt(u5), - .mantissa = 1, - }; + pub const nan: Float8E5M2 = .{ .sign = 0, .exponent = std.math.maxInt(u5), .mantissa = 1 }; pub fn isNan(self: Float8E5M2) bool { - return allBitsOne(self.exponent) and self.mantissa != 0; + return self.exponent == nan.exponent and self.mantissa != 0; } - pub const minus_inf: Float8E5M2 = .{ - .sign = 1, - .exponent = std.math.maxInt(u5), - .mantissa = 0, - }; - pub const inf: Float8E5M2 = .{ .sign = 0, .exponent = std.math.maxInt(u5), .mantissa = 0, }; - pub fn isInf(self: Float8E5M2) bool { - return allBitsOne(self.exponent) and self.mantissa == 0; - } + pub const minus_inf: Float8E5M2 = .neg(inf); const Helpers = FloatHelpers(@This()); pub const zero = Helpers.zero; @@ -283,22 +291,16 @@ pub const BFloat16 = packed struct(u16) { return allBitsOne(self.exponent) and self.mantissa != 0; } - pub const minus_inf: BFloat16 = .{ - .sign = 1, - .exponent = std.math.maxInt(u8), - .mantissa = 0, - }; - pub const inf: BFloat16 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 0, }; - pub fn isInf(self: BFloat16) bool { - return allBitsOne(self.exponent) and self.mantissa == 0; - } + pub const minus_inf: BFloat16 = .neg(inf); + // Specialized versions of to/from F32. Since BFloat16 has the same exponent range than F32, + // no overflow/underflow can happen, simplifiying conversion logic. pub fn toF32(self: BFloat16) f32 { // Pad the BF16 with zeros 0 return @bitCast([2]u16{ 0, @bitCast(self) }); @@ -333,11 +335,167 @@ test BFloat16 { }); } -pub fn floatCast(T: type, x: anytype) T { - return switch (@TypeOf(x)) { - f64, f32, f16 => @floatCast(x), - else => @floatCast(x.toF32()), +pub const Float8E4M3 = packed struct(u8) { + mantissa: u3, + exponent: u4, + sign: u1, + + pub const nan: Float8E4M3 = @bitCast(0xFF); + + pub fn isNan(self: Float8E4M3) bool { + return self == nan or self == comptime nan.neg(); + } + + pub const inf: Float8E4M3 = .{ + .sign = 0, + .exponent = std.math.maxInt(u4), + .mantissa = 0, }; + + pub const minus_inf = neg(inf); + + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const formatNumber = Helpers.formatNumber; +}; + +pub const Float8E3M4 = packed struct(u8) { + mantissa: u4, + exponent: u3, + sign: u1, + + pub const nan: Float8E3M4 = @bitCast(0xFF); + + pub fn isNan(self: Float8E3M4) bool { + return self == nan or self == comptime nan.neg(); + } + + pub const inf: Float8E3M4 = .{ + .sign = 0, + .exponent = std.math.maxInt(u3), + .mantissa = 0, + }; + + pub const minus_inf = neg(inf); + + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const formatNumber = Helpers.formatNumber; +}; + +pub const Float8E8M0 = packed struct(u8) { + mantissa: u0 = 0, + exponent: u8, + sign: u0 = 0, + + pub const min_scale: f32 = @bitCast(Float32{ .sign = 0, .exponent = 0, .mantissa = 0b1 << 22 }); + + /// Lossy conversion from f32, similar to @floatCast + pub fn fromF32(f: f32) Float8E8M0 { + const vf32: Float32 = @bitCast(f); + return .{ .exponent = @intCast(vf32.exponent) }; + } + + /// Lossless conversion to f32. + pub fn toF32(x: Float8E8M0) f32 { + if (x.exponent == 0) return min_scale; + const vf32: Float32 = .{ + .sign = 0, + .exponent = x.exponent, + .mantissa = 0, + }; + return @bitCast(vf32); + } + + const Helpers = FloatHelpers(@This()); + pub const formatNumber = Helpers.formatNumber; +}; + +test Float8E8M0 { + try std.testing.expectEqual(Float8E8M0{ .exponent = 127 }, Float8E8M0.fromF32(1.0)); + // try std.testing.expectEqual(5.877472e-39, Float8E8M0.toF32(.{ .exponent = 0})); + + try testCustomFloat(Float8E8M0, .{ + .lossless = &[_]f32{ Float8E8M0.min_scale, 1.0, 64.0, 1.0 / 128.0, std.math.pow(f32, 2.0, 127) }, + .lossy = &[_]f32{1.00001}, + }); +} + +pub const Float4E2M1 = packed struct(u4) { + mantissa: u1, + exponent: u2, + sign: u1, + + pub const nan: Float4E2M1 = @bitCast(@as(u4, 0xF)); + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const formatNumber = Helpers.formatNumber; + + pub const values = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 }; + + pub fn toF32(x: Float4E2M1) f32 { + // the baseline toF32 doesn't work correctly: + // 0b0001 and 0b1001 shoud map to ±0.5, but are mapped to ±epsilon + return values[@as(u4, @bitCast(x))]; + } + + test toF32 { + var to_f32_res: [16]f32 = undefined; + for (&to_f32_res, 0..) |*r, i| { + const x_f4: Float4E2M1 = @bitCast(@as(u4, @intCast(i))); + r.* = x_f4.toF32(); + } + try std.testing.expectEqualSlices(f32, &Float4E2M1.values, &to_f32_res); + } + + test fromF32 { + // the baseline fromF32 doesn't work correctly: + // ±0.5 should map to 0b0001/0b1001 but are map to ±0.0 instead. + // TODO: it probably affects other types. + var from_f32_res: [16]Float4E2M1 = undefined; + for (&from_f32_res, 0..) |*r, i| { + r.* = .fromF32(Float4E2M1.values[i]); + } + try std.testing.expectEqualSlices(u4, &.{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, @ptrCast(&from_f32_res)); + } +}; + +pub fn floatCast(T: type, x: anytype) T { + return switch (T) { + f64, f32, f16 => switch (@TypeOf(x)) { + f64, f32, f16 => @floatCast(x), + else => @floatCast(x.toF32()), + }, + else => switch (@TypeOf(x)) { + f64, f32, f16 => .fromF32(x), + else => .fromF32(x.toF32()), + }, + }; +} + +pub fn isInf(x: anytype) bool { + const Float = @TypeOf(x); + switch (Float) { + f64, f32, f16 => return std.math.isInf(x), + else => {}, + } + if (!@hasDecl(Float, "inf")) return false; + + const FBits = std.meta.Int(.unsigned, @bitSizeOf(Float)); + const remove_sign = ~@as(FBits, 0) >> 1; + return @as(FBits, @bitCast(x)) & remove_sign == @as(FBits, @bitCast(Float.inf)); +} + +fn allBitsOne(v: anytype) bool { + return v == std.math.maxInt(@TypeOf(v)); } const TestCase = struct { diff --git a/zml/mlirx.zig b/zml/mlirx.zig index f50a2ef..286166a 100644 --- a/zml/mlirx.zig +++ b/zml/mlirx.zig @@ -35,20 +35,26 @@ pub const Type = struct { pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { return switch (dt) { .bool => .int(ctx, .i1), + .f4e2m1 => .float(ctx, .f4e2m1fn), + .f8e3m4 => .float(ctx, .f8e3m4), + .f8e4m3 => .float(ctx, .f8e4m3), .f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz), .f8e4m3fn => .float(ctx, .f8e4m3fn), .f8e4m3fnuz => .float(ctx, .f8e4m3fnuz), .f8e5m2 => .float(ctx, .f8e5m2), .f8e5m2fnuz => .float(ctx, .f8e5m2fnuz), + .f8e8m0 => .float(ctx, .f8e8m0fnu), .bf16 => .float(ctx, .bf16), .f16 => .float(ctx, .f16), .f32 => .float(ctx, .f32), .f64 => .float(ctx, .f64), + .i2 => .int(ctx, .i2), .i4 => .int(ctx, .i4), .i8 => .int(ctx, .i8), .i16 => .int(ctx, .i16), .i32 => .int(ctx, .i32), .i64 => .int(ctx, .i64), + .u2 => .int(ctx, .u2), .u4 => .int(ctx, .u4), .u8 => .int(ctx, .u8), .u16 => .int(ctx, .u16), @@ -62,23 +68,28 @@ pub const Type = struct { pub fn toDType(mlir_type: mlir.Type) dtype.DataType { const mapping = .{ .{ .bool, mlir.IntegerType(.i1) }, - + .{ .f4e2m1, mlir.FloatType(.f4e2m1fn) }, + .{ .f8e3m4, mlir.FloatType(.f8e3m4) }, + .{ .f8e4m3, mlir.FloatType(.f8e4m3) }, .{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) }, .{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) }, .{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) }, .{ .f8e5m2, mlir.FloatType(.f8e5m2) }, .{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) }, + .{ .f8e8m0, mlir.FloatType(.f8e8m0fnu) }, .{ .bf16, mlir.FloatType(.bf16) }, .{ .f16, mlir.FloatType(.f16) }, .{ .f32, mlir.FloatType(.f32) }, .{ .f64, mlir.FloatType(.f64) }, + .{ .i2, mlir.IntegerType(.i2) }, .{ .i4, mlir.IntegerType(.i4) }, .{ .i8, mlir.IntegerType(.i8) }, .{ .i16, mlir.IntegerType(.i16) }, .{ .i32, mlir.IntegerType(.i32) }, .{ .i64, mlir.IntegerType(.i64) }, + .{ .u2, mlir.IntegerType(.u2) }, .{ .u4, mlir.IntegerType(.u4) }, .{ .u8, mlir.IntegerType(.u8) }, .{ .u16, mlir.IntegerType(.u16) }, @@ -89,6 +100,8 @@ pub const Type = struct { .{ .c128, mlir.ComplexType(.c128) }, }; + // TODO: this seems quite slow to have all of those functions calls. + // Maybe we should memoize the ptr of a set of mlir types when creating the context. inline for (mapping) |entry| { const dt, const mlirT = entry; if (mlirT.is_a_fn(mlir_type._inner)) { diff --git a/zml/tensor.zig b/zml/tensor.zig index ff12d91..373d5ef 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1042,6 +1042,40 @@ pub const Tensor = struct { return _result(self._shape.withDtype(to), op.result(0)); } + test convert { + const floats = @import("floats.zig"); + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + // f4e2m1 + { + const x = [_]f32{ 0.0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6 }; + var x_f4: [x.len]floats.Float4E2M1 = undefined; + for (&x_f4, &x) |*xi_f4, xi| xi_f4.* = .fromF32(xi); + + const x_d = try zml.Buffer.fromArray(platform, x); + const x_f4_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f4e2m1 }); + + const x_f4_xla = x_f4_xla_d.getValue(@TypeOf(x_f4)); + errdefer std.log.warn("convert(.f4e2m1) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f4, x_f4_xla }); + try std.testing.expectEqualDeep(x_f4, x_f4_xla); + } + + // f8e3m4 + { + const x = [_]f32{ 1.1 / 4.0, 1.1 / 8.0, 1.1 / 16.0, 1.1 / 32.0, 1.1 / 64.0, 1.1 / 128.0 }; + var x_f8e3: [x.len]floats.Float8E3M4 = undefined; + for (&x_f8e3, &x) |*xi_f8e3, xi| xi_f8e3.* = .fromF32(xi); + + const x_d = try zml.Buffer.fromArray(platform, x); + const x_f8e3_xla_d = try zml.testing.compileAndCall(platform, Tensor.convert, .{ x_d, .f8e3m4 }); + + const x_f8e3_xla = x_f8e3_xla_d.getValue(@TypeOf(x_f8e3)); + errdefer std.log.warn("convert(.f8e3m4) failed !\ninput f32:\n{e}\nzml.floats computed:\n{any}\nxla computed:\n{any}", .{ stdx.fmt.slice(&x), x_f8e3, x_f8e3_xla }); + try std.testing.expectEqualDeep(x_f8e3, x_f8e3_xla); + } + } + /// Returns a Tensor containing the element-wise rounding operation of the input Tensor. pub fn round(self: Tensor) Tensor { const loc = self.getContext().mlirCtx().location(@src()); @@ -3844,11 +3878,11 @@ fn getPoolResDims(dt: DataType, in_dims: []const i64, base_dilations: @Vector(Te } fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType { - return dialect.stablehlo.CompareType.init(ctx, switch (dtype) { - .i4, .i8, .i16, .i32, .i64 => .SIGNED, - .bool, .u4, .u8, .u16, .u32, .u64 => .UNSIGNED, - .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .FLOAT, - .c64, .c128 => @panic("Can't compare complex numbers"), + return dialect.stablehlo.CompareType.init(ctx, switch (dtype.class()) { + .bool => .UNSIGNED, + .integer => if (dtype.isSignedInt()) .SIGNED else .UNSIGNED, + .float => .FLOAT, + .complex => @panic("Can't compare complex numbers"), }); } diff --git a/zml/testing.zig b/zml/testing.zig index c9ad9ef..851bd8c 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -65,11 +65,15 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { .f16, .f32, .f64, + .f4e2m1, + .f8e3m4, + .f8e4m3, .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, + .f8e8m0, => |t| { const L = t.toZigType(); const left_data = left.items(L); @@ -96,7 +100,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { else => unreachable, } }, - inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| { + inline .bool, .u2, .u4, .u8, .u16, .u32, .u64, .i2, .i4, .i8, .i16, .i32, .i64 => |t| { const T = t.toZigType(); return std.testing.expectEqualSlices(T, left.items(T), right.items(T)); },