Radix/zml/floats.zig

364 lines
11 KiB
Zig

///! Conversion utilities between different Floating point formats.
const std = @import("std");
test {
std.testing.refAllDecls(@This());
}
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
}
fn FloatHelpers(Float: type) type {
const info = @typeInfo(Float);
const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}";
if (info != .@"struct" or info.@"struct".backing_integer == null) {
@compileError(err_msg);
}
comptime {
for (info.@"struct".fields, &.{ "mantissa", "exponent", "sign" }) |field, expected_name| {
if (!std.mem.eql(u8, field.name, expected_name))
@compileError(err_msg);
}
}
return struct {
const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits;
const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits;
const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits;
pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 };
pub fn neg(x: Float) Float {
return .{
.sign = x.sign ^ 1,
.exponent = x.exponent,
.mantissa = x.mantissa,
};
}
/// Lossy conversion from f32, similar to @floatCast
pub fn fromF32(f: f32) Float {
const vf32: Float32 = @bitCast(f);
const exp_bias = comptime expBias();
const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias();
const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent"));
if (overflow) {
return if (@hasDecl(Float, "inf")) {
return if (vf32.sign == 0) Float.inf else Float.minus_inf;
} else Float.nan;
}
return .{
.sign = vf32.sign,
.exponent = @intCast(exponent),
.mantissa = truncMantissa(vf32.mantissa),
};
}
/// Lossless conversion to f32.
pub fn toF32(x: Float) f32 {
var vf32: Float32 = undefined;
if (@hasDecl(Float, "isInf") and x.isInf()) {
return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
}
vf32 = .{
.sign = x.sign,
.exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()),
.mantissa = f32Mantissa(x),
};
return @bitCast(vf32);
}
fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") {
@setRuntimeSafety(false);
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits;
return @intCast(x >> off);
}
fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") {
@setRuntimeSafety(false);
const Res = @FieldType(Float32, "mantissa");
const f32_mantissa_bits = @bitSizeOf(Res);
return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits);
}
fn expBias() u8 {
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
}
pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
switch (n.mode) {
.binary, .octal, .hex => try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ x.sign, x.exponent, x.mantissa }),
else => try writer.printFloat(x.toF32(), n),
}
}
};
}
pub const Float32 = packed struct(u32) {
mantissa: u23,
exponent: u8,
sign: u1,
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
const f32_exp_bias = FloatHelpers(Float32).expBias();
pub const Float64 = packed struct(u64) {
mantissa: u52,
exponent: u11,
sign: u1,
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E4M3B11FNUZ = packed struct(u8) {
mantissa: u3,
exponent: u4,
sign: u1,
pub const nan: Float8E4M3B11FNUZ = .{
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub fn isNan(self: Float8E4M3B11FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E4M3FN = packed struct(u8) {
mantissa: u3,
exponent: u4,
sign: u1,
pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) };
pub fn isNan(self: Float8E4M3FN) bool {
return allBitsOne(self.exponent) and allBitsOne(self.mantissa);
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E4M3FNUZ = packed struct(u8) {
mantissa: u3,
exponent: u4,
sign: u1,
pub const nan: Float8E4M3FNUZ = .{
.sign = 1,
.exponent = 0,
.mantissa = 0,
};
pub fn isNan(self: Float8E4M3FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
test "Float8E4" {
const test_case_e4: TestCase = .{
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 },
.lossy = &[_]f32{3.02344107628},
};
inline for (.{
Float8E4M3B11FNUZ,
Float8E4M3FN,
Float8E4M3FNUZ,
}) |Float8T| {
try testCustomFloat(Float8T, test_case_e4);
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32());
}
}
pub const Float8E5M2 = packed struct(u8) {
mantissa: u2,
exponent: u5,
sign: u1,
pub const nan: Float8E5M2 = .{
.sign = 0,
.exponent = std.math.maxInt(u5),
.mantissa = 1,
};
pub fn isNan(self: Float8E5M2) bool {
return allBitsOne(self.exponent) and self.mantissa != 0;
}
pub const minus_inf: Float8E5M2 = .{
.sign = 1,
.exponent = std.math.maxInt(u5),
.mantissa = 0,
};
pub const inf: Float8E5M2 = .{
.sign = 0,
.exponent = std.math.maxInt(u5),
.mantissa = 0,
};
pub fn isInf(self: Float8E5M2) bool {
return allBitsOne(self.exponent) and self.mantissa == 0;
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
pub const Float8E5M2FNUZ = packed struct(u8) {
mantissa: u2,
exponent: u5,
sign: u1,
pub const nan: Float8E5M2FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
pub fn isNan(self: Float8E5M2FNUZ) bool {
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const formatNumber = Helpers.formatNumber;
};
test "Float8E5" {
const test_case_e5: TestCase = .{
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 128.0, -128 },
.lossy = &[_]f32{3.02344107628},
};
inline for (.{ Float8E5M2, Float8E5M2FNUZ }) |Float8T| {
try testCustomFloat(Float8T, test_case_e5);
}
}
pub const BFloat16 = packed struct(u16) {
mantissa: u7,
exponent: u8,
sign: u1,
pub const nan: BFloat16 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 1 };
pub fn isNan(self: BFloat16) bool {
return allBitsOne(self.exponent) and self.mantissa != 0;
}
pub const minus_inf: BFloat16 = .{
.sign = 1,
.exponent = std.math.maxInt(u8),
.mantissa = 0,
};
pub const inf: BFloat16 = .{
.sign = 0,
.exponent = std.math.maxInt(u8),
.mantissa = 0,
};
pub fn isInf(self: BFloat16) bool {
return allBitsOne(self.exponent) and self.mantissa == 0;
}
pub fn toF32(self: BFloat16) f32 {
// Pad the BF16 with zeros 0
return @bitCast([2]u16{ 0, @bitCast(self) });
}
pub fn fromF32(float32: f32) BFloat16 {
var int: u32 = @bitCast(float32);
// Round up if needed.
int += 0x8000;
const parts: [2]u16 = @bitCast(int);
return @bitCast(parts[1]);
}
const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero;
pub const neg = Helpers.neg;
pub const formatNumber = Helpers.formatNumber;
};
test BFloat16 {
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 });
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 66 });
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));
try testCustomFloat(BFloat16, .{
.lossless = &[_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) },
.lossy = &[_]f32{3.02344107628},
});
}
pub fn floatCast(T: type, x: anytype) T {
return switch (@TypeOf(x)) {
f64, f32, f16 => @floatCast(x),
else => @floatCast(x.toF32()),
};
}
const TestCase = struct {
lossless: []const f32,
lossy: []const f32,
tolerance: f32 = 1e-2,
};
fn testCustomFloat(FloatT: type, test_case: TestCase) !void {
for (test_case.lossless) |x| {
try std.testing.expectEqual(x, FloatT.fromF32(x).toF32());
}
for (test_case.lossy) |x| {
try expectApproxEqRel(f32, x, FloatT.fromF32(x).toF32(), test_case.tolerance);
}
}
fn expectApproxEqRel(FloatT: type, x: FloatT, y: FloatT, tolerance: FloatT) !void {
if (!std.math.approxEqRel(f32, x, y, tolerance)) {
std.log.err("expected ~{d}, got {d}", .{ x, y });
return error.TestUnexpectedResult;
}
}