From 3f36506f1ce65f990506ad8c554c471036e9aabc Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 23 Jul 2024 17:43:43 +0000 Subject: [PATCH] zml: remove `usingnamespace` from floats.zig and related dependencies; note that incremental compilation does not improve overall build time due to linking overhead --- zml/dtype.zig | 6 +- zml/floats.zig | 316 +++++++++++++++++++++++++++++-------------------- zml/nn.zig | 4 +- 3 files changed, 191 insertions(+), 135 deletions(-) diff --git a/zml/dtype.zig b/zml/dtype.zig index 5fdcb5e..1f1c046 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -182,8 +182,8 @@ pub const DataType = enum(u8) { pub fn minValue(dtype: DataType) Data { return switch (dtype) { .bool => .{ .bool = false }, - inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).zero()), - inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).minusInf()), + inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).zero), + inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).minus_inf), inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), -std.math.inf(@FieldType(Data, @tagName(tag)))), inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.minInt(@FieldType(Data, @tagName(tag)))), inline else => |tag| @panic("Unsupported type: " ++ @tagName(tag)), @@ -194,7 +194,7 @@ pub const DataType = enum(u8) { return switch (dtype) { .bool => .{ .bool = true }, inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), - inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf()), + inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf), inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))), inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))), inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)), diff --git a/zml/floats.zig b/zml/floats.zig index 230ae97..5fa7cde 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -9,43 +9,44 @@ fn allBitsOne(v: anytype) bool { return v == std.math.maxInt(@TypeOf(v)); } -fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) type { - const bit_size = sign_bits + exponent_bits + mantissa_bits; - if (bit_size % 8 != 0) @compileError("FloatType should have a number of bits divisible by 8"); - - return packed struct(std.meta.Int(.unsigned, bit_size)) { - const Self = @This(); - - mantissa: std.meta.Int(.unsigned, mantissa_bits), - exponent: std.meta.Int(.unsigned, exponent_bits), - sign: std.meta.Int(.unsigned, sign_bits), - - pub fn zero() Self { - return .{ - .sign = 0, - .exponent = 0, - .mantissa = 0, - }; +fn FloatHelpers(Float: type) type { + const info = @typeInfo(Float); + const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}"; + if (info != .@"struct" or info.@"struct".backing_integer == null) { + @compileError(err_msg); + } + comptime { + for (info.@"struct".fields, &.{ "mantissa", "exponent", "sign" }) |field, expected_name| { + if (!std.mem.eql(u8, field.name, expected_name)) + @compileError(err_msg); } + } - pub fn neg(self: Self) Self { + return struct { + const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits; + const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits; + const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits; + + pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 }; + + pub fn neg(x: Float) Float { return .{ - .sign = self.sign ^ 1, - .exponent = self.exponent, - .mantissa = self.mantissa, + .sign = x.sign ^ 1, + .exponent = x.exponent, + .mantissa = x.mantissa, }; } /// Lossy conversion from f32, similar to @floatCast - pub fn fromF32(f: f32) Self { + pub fn fromF32(f: f32) Float { const vf32: Float32 = @bitCast(f); - const exp_bias = comptime Self.expBias(); - const exponent = @as(u16, vf32.exponent) + exp_bias -| Float32.expBias(); - const overflow = exponent > std.math.maxInt(std.meta.Int(.unsigned, exponent_bits)); + const exp_bias = comptime expBias(); + const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias(); + const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent")); if (overflow) { - return if (@hasDecl(Self, "inf")) { - return if (vf32.sign == 0) Self.inf() else Self.minusInf(); - } else Self.nan(); + return if (@hasDecl(Float, "inf")) { + return if (vf32.sign == 0) Float.inf else Float.minus_inf; + } else Float.nan; } return .{ .sign = vf32.sign, @@ -55,31 +56,31 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) } /// Lossless conversion to f32. - pub fn toF32(self: Self) f32 { + pub fn toF32(x: Float) f32 { var vf32: Float32 = undefined; - if (@hasDecl(Self, "isInf") and self.isInf()) { - return if (self.sign == 0) std.math.inf(f32) else -std.math.inf(f32); + if (@hasDecl(Float, "isInf") and x.isInf()) { + return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32); } vf32 = .{ - .sign = self.sign, - .exponent = if (self.exponent == 0) 0 else @intCast(@as(i16, self.exponent) + Float32.expBias() - Self.expBias()), - .mantissa = self.f32Mantissa(), + .sign = x.sign, + .exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()), + .mantissa = f32Mantissa(x), }; return @bitCast(vf32); } - fn truncMantissa(x: anytype) std.meta.FieldType(Self, .mantissa) { + fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") { @setRuntimeSafety(false); const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits; return @intCast(x >> off); } - fn f32Mantissa(self: Self) std.meta.FieldType(Float32, .mantissa) { + fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") { @setRuntimeSafety(false); - const f32_mantissa_bits = @bitSizeOf(std.meta.FieldType(Float32, .mantissa)); + const Res = @FieldType(Float32, "mantissa"); + const f32_mantissa_bits = @bitSizeOf(Res); - const Res = std.meta.FieldType(Float32, .mantissa); - return @shlExact(@as(Res, self.mantissa), f32_mantissa_bits - mantissa_bits); + return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits); } fn expBias() u8 { @@ -87,67 +88,112 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) } pub fn format( - self: @This(), + float: Float, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype, ) !void { _ = options; if (fmt.len == 1 and fmt[0] == '_') { - try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ self.sign, self.exponent, self.mantissa }); + try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ float.sign, float.exponent, float.mantissa }); } else { - try writer.print("{" ++ fmt ++ "}", .{self.toF32()}); + try writer.print("{" ++ fmt ++ "}", .{float.toF32()}); } } - - pub usingnamespace innerT; }; } -const Float32 = FloatType(1, 8, 23, struct {}); -const Float64 = FloatType(1, 11, 52, struct {}); +pub const Float32 = packed struct(u32) { + mantissa: u23, + exponent: u8, + sign: u1, -pub const Float8E4M3B11FNUZ = FloatType(1, 4, 3, struct { - pub fn nan() Float8E4M3B11FNUZ { - return .{ - .sign = 1, - .exponent = 0, - .mantissa = 0, - }; - } + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; + +const f32_exp_bias = FloatHelpers(Float32).expBias(); + +pub const Float64 = packed struct(u64) { + mantissa: u52, + exponent: u11, + sign: u1, + + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; + +pub const Float8E4M3B11FNUZ = packed struct(u8) { + mantissa: u3, + exponent: u4, + sign: u1, + + pub const nan: Float8E4M3B11FNUZ = .{ + .sign = 1, + .exponent = 0, + .mantissa = 0, + }; pub fn isNan(self: Float8E4M3B11FNUZ) bool { return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; } -}); -pub const Float8E4M3FN = FloatType(1, 4, 3, struct { - pub fn nan() Float8E4M3FN { - return .{ - .sign = 0, - .exponent = std.math.maxInt(u4), - .mantissa = std.math.maxInt(u3), - }; - } + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; + +pub const Float8E4M3FN = packed struct(u8) { + mantissa: u3, + exponent: u4, + sign: u1, + + pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) }; pub fn isNan(self: Float8E4M3FN) bool { return allBitsOne(self.exponent) and allBitsOne(self.mantissa); } -}); + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; -pub const Float8E4M3FNUZ = FloatType(1, 4, 3, struct { - pub fn nan() Float8E4M3FNUZ { - return .{ - .sign = 1, - .exponent = 0, - .mantissa = 0, - }; - } +pub const Float8E4M3FNUZ = packed struct(u8) { + mantissa: u3, + exponent: u4, + sign: u1, + + pub const nan: Float8E4M3FNUZ = .{ + .sign = 1, + .exponent = 0, + .mantissa = 0, + }; pub fn isNan(self: Float8E4M3FNUZ) bool { return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; } -}); + + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; test "Float8E4" { const test_case_e4: TestCase = .{ @@ -165,53 +211,63 @@ test "Float8E4" { } } -pub const Float8E5M2 = FloatType(1, 5, 2, struct { - pub fn nan() Float8E5M2 { - return .{ - .sign = 0, - .exponent = std.math.maxInt(u5), - .mantissa = 1, - }; - } +pub const Float8E5M2 = packed struct(u8) { + mantissa: u2, + exponent: u5, + sign: u1, + + pub const nan: Float8E5M2 = .{ + .sign = 0, + .exponent = std.math.maxInt(u5), + .mantissa = 1, + }; pub fn isNan(self: Float8E5M2) bool { return allBitsOne(self.exponent) and self.mantissa != 0; } - pub fn minusInf() Float8E5M2 { - return .{ - .sign = 1, - .exponent = std.math.maxInt(u5), - .mantissa = 0, - }; - } + pub const minus_inf: Float8E5M2 = .{ + .sign = 1, + .exponent = std.math.maxInt(u5), + .mantissa = 0, + }; - pub fn inf() Float8E5M2 { - return .{ - .sign = 0, - .exponent = std.math.maxInt(u5), - .mantissa = 0, - }; - } + pub const inf: Float8E5M2 = .{ + .sign = 0, + .exponent = std.math.maxInt(u5), + .mantissa = 0, + }; pub fn isInf(self: Float8E5M2) bool { return allBitsOne(self.exponent) and self.mantissa == 0; } -}); -pub const Float8E5M2FNUZ = FloatType(1, 5, 2, struct { - pub fn nan() Float8E5M2FNUZ { - return .{ - .sign = 1, - .exponent = 0, - .mantissa = 0, - }; - } + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; + +pub const Float8E5M2FNUZ = packed struct(u8) { + mantissa: u2, + exponent: u5, + sign: u1, + + pub const nan: Float8E5M2FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 }; pub fn isNan(self: Float8E5M2FNUZ) bool { return self.sign == 1 and self.exponent == 0 and self.mantissa == 0; } -}); + + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; test "Float8E5" { const test_case_e5: TestCase = .{ @@ -223,39 +279,39 @@ test "Float8E5" { } } -pub const BFloat16 = FloatType(1, 8, 7, struct { - pub fn nan() BFloat16 { - return .{ - .sign = 0, - .exponent = std.math.maxInt(u8), - .mantissa = 1, - }; - } +pub const BFloat16 = packed struct(u16) { + mantissa: u7, + exponent: u8, + sign: u1, + + pub const nan: BFloat16 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 1 }; pub fn isNan(self: BFloat16) bool { return allBitsOne(self.exponent) and self.mantissa != 0; } - pub fn minusInf() BFloat16 { - return .{ - .sign = 1, - .exponent = std.math.maxInt(u8), - .mantissa = 0, - }; - } + pub const minus_inf: BFloat16 = .{ + .sign = 1, + .exponent = std.math.maxInt(u8), + .mantissa = 0, + }; - pub fn inf() BFloat16 { - return .{ - .sign = 0, - .exponent = std.math.maxInt(u8), - .mantissa = 0, - }; - } + pub const inf: BFloat16 = .{ + .sign = 0, + .exponent = std.math.maxInt(u8), + .mantissa = 0, + }; pub fn isInf(self: BFloat16) bool { return allBitsOne(self.exponent) and self.mantissa == 0; } -}); + const Helpers = FloatHelpers(@This()); + pub const zero = Helpers.zero; + pub const neg = Helpers.neg; + pub const fromF32 = Helpers.fromF32; + pub const toF32 = Helpers.toF32; + pub const format = Helpers.format; +}; test BFloat16 { // From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples @@ -263,8 +319,8 @@ test BFloat16 { try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 }); try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 }); try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 }); - try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff }); - try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32))); + try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff }); + try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32))); try testCustomFloat(BFloat16, .{ .lossless = &[_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) }, diff --git a/zml/nn.zig b/zml/nn.zig index 07f4c9f..c3120d2 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -1303,8 +1303,8 @@ test sampleTokensDynamic { const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform); defer mod_bf16.deinit(); - const boost = bf16.inf(); - const nerf = bf16.minusInf(); + const boost = bf16.inf; + const nerf = bf16.minus_inf; const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf }); const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) }); try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32));