zml: remove usingnamespace from floats.zig and related dependencies; note that incremental compilation does not improve overall build time due to linking overhead
This commit is contained in:
parent
42dee5d0e0
commit
3f36506f1c
@ -182,8 +182,8 @@ pub const DataType = enum(u8) {
|
||||
pub fn minValue(dtype: DataType) Data {
|
||||
return switch (dtype) {
|
||||
.bool => .{ .bool = false },
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).zero()),
|
||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).minusInf()),
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).zero),
|
||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).minus_inf),
|
||||
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), -std.math.inf(@FieldType(Data, @tagName(tag)))),
|
||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.minInt(@FieldType(Data, @tagName(tag)))),
|
||||
inline else => |tag| @panic("Unsupported type: " ++ @tagName(tag)),
|
||||
@ -194,7 +194,7 @@ pub const DataType = enum(u8) {
|
||||
return switch (dtype) {
|
||||
.bool => .{ .bool = true },
|
||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf()),
|
||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf),
|
||||
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))),
|
||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
|
||||
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||
|
||||
316
zml/floats.zig
316
zml/floats.zig
@ -9,43 +9,44 @@ fn allBitsOne(v: anytype) bool {
|
||||
return v == std.math.maxInt(@TypeOf(v));
|
||||
}
|
||||
|
||||
fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) type {
|
||||
const bit_size = sign_bits + exponent_bits + mantissa_bits;
|
||||
if (bit_size % 8 != 0) @compileError("FloatType should have a number of bits divisible by 8");
|
||||
|
||||
return packed struct(std.meta.Int(.unsigned, bit_size)) {
|
||||
const Self = @This();
|
||||
|
||||
mantissa: std.meta.Int(.unsigned, mantissa_bits),
|
||||
exponent: std.meta.Int(.unsigned, exponent_bits),
|
||||
sign: std.meta.Int(.unsigned, sign_bits),
|
||||
|
||||
pub fn zero() Self {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
fn FloatHelpers(Float: type) type {
|
||||
const info = @typeInfo(Float);
|
||||
const err_msg = "FloatHelpers expect a packed struct { mantissa: uXX, exponent: uXX, sign: u1}";
|
||||
if (info != .@"struct" or info.@"struct".backing_integer == null) {
|
||||
@compileError(err_msg);
|
||||
}
|
||||
comptime {
|
||||
for (info.@"struct".fields, &.{ "mantissa", "exponent", "sign" }) |field, expected_name| {
|
||||
if (!std.mem.eql(u8, field.name, expected_name))
|
||||
@compileError(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neg(self: Self) Self {
|
||||
return struct {
|
||||
const sign_bits: u8 = @typeInfo(@FieldType(Float, "sign")).int.bits;
|
||||
const mantissa_bits: u8 = @typeInfo(@FieldType(Float, "mantissa")).int.bits;
|
||||
const exponent_bits: u8 = @typeInfo(@FieldType(Float, "exponent")).int.bits;
|
||||
|
||||
pub const zero: Float = .{ .sign = 0, .exponent = 0, .mantissa = 0 };
|
||||
|
||||
pub fn neg(x: Float) Float {
|
||||
return .{
|
||||
.sign = self.sign ^ 1,
|
||||
.exponent = self.exponent,
|
||||
.mantissa = self.mantissa,
|
||||
.sign = x.sign ^ 1,
|
||||
.exponent = x.exponent,
|
||||
.mantissa = x.mantissa,
|
||||
};
|
||||
}
|
||||
|
||||
/// Lossy conversion from f32, similar to @floatCast
|
||||
pub fn fromF32(f: f32) Self {
|
||||
pub fn fromF32(f: f32) Float {
|
||||
const vf32: Float32 = @bitCast(f);
|
||||
const exp_bias = comptime Self.expBias();
|
||||
const exponent = @as(u16, vf32.exponent) + exp_bias -| Float32.expBias();
|
||||
const overflow = exponent > std.math.maxInt(std.meta.Int(.unsigned, exponent_bits));
|
||||
const exp_bias = comptime expBias();
|
||||
const exponent = @as(u16, vf32.exponent) + exp_bias -| FloatHelpers(Float32).expBias();
|
||||
const overflow = exponent > std.math.maxInt(@FieldType(Float, "exponent"));
|
||||
if (overflow) {
|
||||
return if (@hasDecl(Self, "inf")) {
|
||||
return if (vf32.sign == 0) Self.inf() else Self.minusInf();
|
||||
} else Self.nan();
|
||||
return if (@hasDecl(Float, "inf")) {
|
||||
return if (vf32.sign == 0) Float.inf else Float.minus_inf;
|
||||
} else Float.nan;
|
||||
}
|
||||
return .{
|
||||
.sign = vf32.sign,
|
||||
@ -55,31 +56,31 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type)
|
||||
}
|
||||
|
||||
/// Lossless conversion to f32.
|
||||
pub fn toF32(self: Self) f32 {
|
||||
pub fn toF32(x: Float) f32 {
|
||||
var vf32: Float32 = undefined;
|
||||
if (@hasDecl(Self, "isInf") and self.isInf()) {
|
||||
return if (self.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
|
||||
if (@hasDecl(Float, "isInf") and x.isInf()) {
|
||||
return if (x.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
|
||||
}
|
||||
vf32 = .{
|
||||
.sign = self.sign,
|
||||
.exponent = if (self.exponent == 0) 0 else @intCast(@as(i16, self.exponent) + Float32.expBias() - Self.expBias()),
|
||||
.mantissa = self.f32Mantissa(),
|
||||
.sign = x.sign,
|
||||
.exponent = if (x.exponent == 0) 0 else @intCast(@as(i16, x.exponent) + f32_exp_bias - expBias()),
|
||||
.mantissa = f32Mantissa(x),
|
||||
};
|
||||
return @bitCast(vf32);
|
||||
}
|
||||
|
||||
fn truncMantissa(x: anytype) std.meta.FieldType(Self, .mantissa) {
|
||||
fn truncMantissa(x: anytype) @FieldType(Float, "mantissa") {
|
||||
@setRuntimeSafety(false);
|
||||
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits;
|
||||
return @intCast(x >> off);
|
||||
}
|
||||
|
||||
fn f32Mantissa(self: Self) std.meta.FieldType(Float32, .mantissa) {
|
||||
fn f32Mantissa(x: Float) @FieldType(Float32, "mantissa") {
|
||||
@setRuntimeSafety(false);
|
||||
const f32_mantissa_bits = @bitSizeOf(std.meta.FieldType(Float32, .mantissa));
|
||||
const Res = @FieldType(Float32, "mantissa");
|
||||
const f32_mantissa_bits = @bitSizeOf(Res);
|
||||
|
||||
const Res = std.meta.FieldType(Float32, .mantissa);
|
||||
return @shlExact(@as(Res, self.mantissa), f32_mantissa_bits - mantissa_bits);
|
||||
return @shlExact(@as(Res, x.mantissa), f32_mantissa_bits - mantissa_bits);
|
||||
}
|
||||
|
||||
fn expBias() u8 {
|
||||
@ -87,67 +88,112 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type)
|
||||
}
|
||||
|
||||
pub fn format(
|
||||
self: @This(),
|
||||
float: Float,
|
||||
comptime fmt: []const u8,
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = options;
|
||||
if (fmt.len == 1 and fmt[0] == '_') {
|
||||
try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ self.sign, self.exponent, self.mantissa });
|
||||
try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ float.sign, float.exponent, float.mantissa });
|
||||
} else {
|
||||
try writer.print("{" ++ fmt ++ "}", .{self.toF32()});
|
||||
try writer.print("{" ++ fmt ++ "}", .{float.toF32()});
|
||||
}
|
||||
}
|
||||
|
||||
pub usingnamespace innerT;
|
||||
};
|
||||
}
|
||||
|
||||
const Float32 = FloatType(1, 8, 23, struct {});
|
||||
const Float64 = FloatType(1, 11, 52, struct {});
|
||||
pub const Float32 = packed struct(u32) {
|
||||
mantissa: u23,
|
||||
exponent: u8,
|
||||
sign: u1,
|
||||
|
||||
pub const Float8E4M3B11FNUZ = FloatType(1, 4, 3, struct {
|
||||
pub fn nan() Float8E4M3B11FNUZ {
|
||||
return .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
const f32_exp_bias = FloatHelpers(Float32).expBias();
|
||||
|
||||
pub const Float64 = packed struct(u64) {
|
||||
mantissa: u52,
|
||||
exponent: u11,
|
||||
sign: u1,
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
pub const Float8E4M3B11FNUZ = packed struct(u8) {
|
||||
mantissa: u3,
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3B11FNUZ = .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isNan(self: Float8E4M3B11FNUZ) bool {
|
||||
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
|
||||
}
|
||||
});
|
||||
|
||||
pub const Float8E4M3FN = FloatType(1, 4, 3, struct {
|
||||
pub fn nan() Float8E4M3FN {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u4),
|
||||
.mantissa = std.math.maxInt(u3),
|
||||
};
|
||||
}
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
pub const Float8E4M3FN = packed struct(u8) {
|
||||
mantissa: u3,
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3FN = .{ .sign = 0, .exponent = std.math.maxInt(u4), .mantissa = std.math.maxInt(u3) };
|
||||
|
||||
pub fn isNan(self: Float8E4M3FN) bool {
|
||||
return allBitsOne(self.exponent) and allBitsOne(self.mantissa);
|
||||
}
|
||||
});
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
pub const Float8E4M3FNUZ = FloatType(1, 4, 3, struct {
|
||||
pub fn nan() Float8E4M3FNUZ {
|
||||
return .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
pub const Float8E4M3FNUZ = packed struct(u8) {
|
||||
mantissa: u3,
|
||||
exponent: u4,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E4M3FNUZ = .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isNan(self: Float8E4M3FNUZ) bool {
|
||||
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
|
||||
}
|
||||
});
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
test "Float8E4" {
|
||||
const test_case_e4: TestCase = .{
|
||||
@ -165,53 +211,63 @@ test "Float8E4" {
|
||||
}
|
||||
}
|
||||
|
||||
pub const Float8E5M2 = FloatType(1, 5, 2, struct {
|
||||
pub fn nan() Float8E5M2 {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 1,
|
||||
};
|
||||
}
|
||||
pub const Float8E5M2 = packed struct(u8) {
|
||||
mantissa: u2,
|
||||
exponent: u5,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E5M2 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 1,
|
||||
};
|
||||
|
||||
pub fn isNan(self: Float8E5M2) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa != 0;
|
||||
}
|
||||
|
||||
pub fn minusInf() Float8E5M2 {
|
||||
return .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
pub const minus_inf: Float8E5M2 = .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn inf() Float8E5M2 {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
pub const inf: Float8E5M2 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u5),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isInf(self: Float8E5M2) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||
}
|
||||
});
|
||||
|
||||
pub const Float8E5M2FNUZ = FloatType(1, 5, 2, struct {
|
||||
pub fn nan() Float8E5M2FNUZ {
|
||||
return .{
|
||||
.sign = 1,
|
||||
.exponent = 0,
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
pub const Float8E5M2FNUZ = packed struct(u8) {
|
||||
mantissa: u2,
|
||||
exponent: u5,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: Float8E5M2FNUZ = .{ .sign = 1, .exponent = 0, .mantissa = 0 };
|
||||
|
||||
pub fn isNan(self: Float8E5M2FNUZ) bool {
|
||||
return self.sign == 1 and self.exponent == 0 and self.mantissa == 0;
|
||||
}
|
||||
});
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
test "Float8E5" {
|
||||
const test_case_e5: TestCase = .{
|
||||
@ -223,39 +279,39 @@ test "Float8E5" {
|
||||
}
|
||||
}
|
||||
|
||||
pub const BFloat16 = FloatType(1, 8, 7, struct {
|
||||
pub fn nan() BFloat16 {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 1,
|
||||
};
|
||||
}
|
||||
pub const BFloat16 = packed struct(u16) {
|
||||
mantissa: u7,
|
||||
exponent: u8,
|
||||
sign: u1,
|
||||
|
||||
pub const nan: BFloat16 = .{ .sign = 0, .exponent = std.math.maxInt(u8), .mantissa = 1 };
|
||||
|
||||
pub fn isNan(self: BFloat16) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa != 0;
|
||||
}
|
||||
|
||||
pub fn minusInf() BFloat16 {
|
||||
return .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
pub const minus_inf: BFloat16 = .{
|
||||
.sign = 1,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn inf() BFloat16 {
|
||||
return .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
}
|
||||
pub const inf: BFloat16 = .{
|
||||
.sign = 0,
|
||||
.exponent = std.math.maxInt(u8),
|
||||
.mantissa = 0,
|
||||
};
|
||||
|
||||
pub fn isInf(self: BFloat16) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||
}
|
||||
});
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
test BFloat16 {
|
||||
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
|
||||
@ -263,8 +319,8 @@ test BFloat16 {
|
||||
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
|
||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff });
|
||||
try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32)));
|
||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
|
||||
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));
|
||||
|
||||
try testCustomFloat(BFloat16, .{
|
||||
.lossless = &[_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) },
|
||||
|
||||
@ -1303,8 +1303,8 @@ test sampleTokensDynamic {
|
||||
|
||||
const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform);
|
||||
defer mod_bf16.deinit();
|
||||
const boost = bf16.inf();
|
||||
const nerf = bf16.minusInf();
|
||||
const boost = bf16.inf;
|
||||
const nerf = bf16.minus_inf;
|
||||
const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf });
|
||||
const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) });
|
||||
try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user