Radix/zml/dtype.zig

345 lines
12 KiB
Zig

const std = @import("std");
const floats = @import("floats.zig");
const C64 = std.math.Complex(f32);
const C128 = std.math.Complex(f64);
test {
std.testing.refAllDecls(@This());
}
pub const DataType = enum(u8) {
bool,
f8e4m3b11fnuz,
f8e4m3fn,
f8e4m3fnuz,
f8e5m2,
f8e5m2fnuz,
bf16,
f16,
f32,
f64,
i4,
i8,
i16,
i32,
i64,
u4,
u8,
u16,
u32,
u64,
c64,
c128,
pub fn str(self: DataType) [:0]const u8 {
return switch (self) {
inline else => |tag| @tagName(tag),
};
}
pub const Class = enum(u8) {
bool,
float,
integer,
complex,
};
pub fn class(self: DataType) Class {
return switch (self) {
.bool => .bool,
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .float,
.i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => .integer,
.c64, .c128 => .complex,
};
}
pub fn isInteger(self: DataType) bool {
return self.class() == .integer;
}
pub fn isFloat(self: DataType) bool {
return self.class() == .float;
}
pub fn isComplex(self: DataType) bool {
return self.class() == .complex;
}
pub fn fromZigType(comptime T: type) DataType {
return switch (T) {
floats.Float8E4M3B11FNUZ => .f8e4m3b11fnuz,
floats.Float8E4M3FN => .f8e4m3fn,
floats.Float8E4M3FNUZ => .f8e4m3fnuz,
floats.Float8E5M2 => .f8e5m2,
floats.Float8E5M2FNUZ => .f8e5m2fnuz,
floats.BFloat16 => .bf16,
f16 => .f16,
f32 => .f32,
f64 => .f64,
bool => .bool,
i4 => .i4,
i8 => .i8,
i16 => .i16,
i32 => .i32,
i64 => .i64,
u4 => .u4,
u8 => .u8,
u16 => .u16,
u32 => .u32,
u64 => .u64,
C64 => .c64,
C128 => .c128,
else => @compileError("Unsupported Zig type: " ++ @typeName(T)),
};
}
pub fn fromSliceElementType(slice: anytype) DataType {
const type_info = @typeInfo(@TypeOf(slice));
if (type_info != .Pointer) {
@compileError("`initFromSlice` expects a slice, got " ++ @tagName(type_info));
}
return switch (type_info.Pointer.size) {
.Slice, .C, .Many => DataType.fromZigType(type_info.Pointer.child),
.One => b: {
const child_type_info = @typeInfo(type_info.Pointer.child);
break :b DataType.fromZigType(child_type_info.Array.child);
},
};
}
pub fn toZigType(comptime dtype: DataType) type {
return switch (dtype) {
inline else => |tag| std.meta.TagPayload(Data, tag),
};
}
pub fn isSignedInt(dtype: DataType) bool {
return switch (dtype) {
.i4, .i8, .i16, .i32, .i64 => true,
else => false,
};
}
pub fn sizeOf(self: DataType) u16 {
return switch (self) {
inline else => |tag| @sizeOf(std.meta.TagPayload(Data, tag)),
};
}
pub fn bitSizeOf(self: DataType) u16 {
return switch (self) {
inline else => |tag| @bitSizeOf(std.meta.TagPayload(Data, tag)),
};
}
pub fn alignOf(self: DataType) u29 {
return switch (self) {
inline else => |tag| @alignOf(std.meta.TagPayload(Data, tag)),
};
}
/// Try to find a type compatible with both dtype.
pub fn resolvePeerType(a: DataType, b: DataType) ?DataType {
if (a == b) {
return a;
}
// only resolve types in the same class
if (a.class() != b.class()) {
return null;
}
return if (a.sizeOf() >= b.sizeOf()) a else b;
}
test resolvePeerType {
try std.testing.expectEqual(DataType.f16.resolvePeerType(.f16), .f16);
try std.testing.expectEqual(DataType.f32.resolvePeerType(.f32), .f32);
try std.testing.expectEqual(DataType.f16.resolvePeerType(.f32), .f32);
try std.testing.expectEqual(DataType.f32.resolvePeerType(.f16), .f32);
try std.testing.expectEqual(DataType.f32.resolvePeerType(.f64), .f64);
try std.testing.expectEqual(DataType.f64.resolvePeerType(.f32), .f64);
try std.testing.expectEqual(DataType.f32.resolvePeerType(.i32), null);
try std.testing.expectEqual(DataType.c64.resolvePeerType(.c128), .c128);
try std.testing.expectEqual(DataType.c128.resolvePeerType(.i32), null);
try std.testing.expectEqual(DataType.c64.resolvePeerType(.f32), null);
}
pub fn zero(dtype: DataType) Data {
return Data.init(dtype, 0);
}
pub fn one(dtype: DataType) Data {
return Data.init(dtype, 1);
}
pub fn minValue(dtype: DataType) Data {
return switch (dtype) {
.bool => .{ .bool = false },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).zero()),
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).minusInf()),
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), -std.math.inf(std.meta.FieldType(Data, tag))),
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.minInt(std.meta.FieldType(Data, tag))),
inline else => |tag| @panic("Unsupported type: " ++ @tagName(tag)),
};
}
pub fn maxValue(dtype: DataType) Data {
return switch (dtype) {
.bool => .{ .bool = true },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).inf()),
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(std.meta.FieldType(Data, tag))),
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(std.meta.FieldType(Data, tag))),
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
};
}
pub fn constant(dtype: DataType, value: anytype) Data {
return Data.init(dtype, value);
}
};
pub const Data = union(DataType) {
bool: bool,
f8e4m3b11fnuz: floats.Float8E4M3B11FNUZ,
f8e4m3fn: floats.Float8E4M3FN,
f8e4m3fnuz: floats.Float8E4M3FNUZ,
f8e5m2: floats.Float8E5M2,
f8e5m2fnuz: floats.Float8E5M2FNUZ,
bf16: floats.BFloat16,
f16: f16,
f32: f32,
f64: f64,
i4: i4,
i8: i8,
i16: i16,
i32: i32,
i64: i64,
u4: u4,
u8: u8,
u16: u16,
u32: u32,
u64: u64,
c64: C64,
c128: C128,
/// Creates `Data` from a `value`.
///
/// If the `dtype` and `@TypeOf(value)` are incompatible
/// or a cast from `value` to `FieldType(dtype)` would
/// be lossy, a panic occurs.
pub fn init(dtype: DataType, value: anytype) Data {
const T = @TypeOf(value);
const Ti = @typeInfo(T);
return switch (dtype) {
.bool => switch (Ti) {
.Bool => .{ .bool = value },
.ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 },
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
},
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) {
.ComptimeInt, .Int => @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).fromF32(@floatFromInt(value))),
.ComptimeFloat, .Float => @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).fromF32(@floatCast(value))),
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
},
inline .f16, .f32, .f64 => |tag| switch (Ti) {
.ComptimeInt, .Int => @unionInit(Data, @tagName(tag), @floatFromInt(value)),
.ComptimeFloat, .Float => @unionInit(Data, @tagName(tag), @floatCast(value)),
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
},
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
.ComptimeInt => blk: {
const OutT = std.meta.FieldType(Data, tag);
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {
break :blk @unionInit(Data, @tagName(tag), @intCast(value));
} else {
@panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T));
}
},
.Int => @unionInit(Data, @tagName(tag), @intCast(value)),
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
},
.c64 => switch (T) {
C64 => .{ .c64 = value },
C128 => .{ .c64 = .{ .re = @floatCast(value.re), .im = @floatCast(value.im) } },
else => @panic("Could not create Data of type c64 from value of type " ++ @typeName(T)),
},
.c128 => switch (T) {
C64 => .{ .c128 = .{ .re = @floatCast(value.re), .im = @floatCast(value.im) } },
C128 => .{ .c128 = value },
else => @panic("Could not create Data of type c128 from value of type " ++ @typeName(T)),
},
};
}
test init {
try std.testing.expectEqual(20.0, Data.init(.f16, 20).f16);
try std.testing.expectEqual(20.5, Data.init(.f16, 20.5).f16);
try std.testing.expectEqual(20, Data.init(.f16, @as(u8, 20)).f16);
try std.testing.expectEqual(-20, Data.init(.f16, @as(i8, -20)).f16);
try std.testing.expectEqual(2000.5, Data.init(.f16, @as(f32, 2000.5)).f16);
try std.testing.expectEqual(true, Data.init(.bool, true).bool);
try std.testing.expectEqual(10, Data.init(.u8, 10).u8);
try std.testing.expectEqual(10, Data.init(.u8, @as(u16, 10)).u8);
try std.testing.expectEqual(10, Data.init(.i8, 10).i8);
try std.testing.expectEqual(10, Data.init(.i8, @as(u16, 10)).i8);
try std.testing.expectEqual(-10, Data.init(.i8, -10).i8);
try std.testing.expectEqual(-10, Data.init(.i8, @as(i16, -10)).i8);
try std.testing.expectEqual(C64.init(1, 2), Data.init(.c64, C64.init(1, 2)).c64);
try std.testing.expectEqual(C64.init(1, 2), Data.init(.c64, C128.init(1, 2)).c64);
try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C128.init(1, 2)).c128);
try std.testing.expectEqual(C128.init(1, 2), Data.init(.c128, C64.init(1, 2)).c128);
}
pub fn dataType(self: Data) DataType {
return std.meta.activeTag(self);
}
pub fn constSlice(data: *const Data) []const u8 {
return switch (data.*) {
inline else => |*value| std.mem.asBytes(value),
};
}
pub fn as(self: Data, comptime T: type) T {
// TODO allow more lossless conversions
switch (@typeInfo(T)) {
.Bool => return self.bool,
.Float => switch (self) {
inline .f16, .f32, .f64 => |v| return @floatCast(v),
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |v| return @floatCast(v.toF32()),
else => {},
},
.Int => switch (self) {
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |v| return @intCast(v),
else => {},
},
else => {},
}
std.debug.panic("Unsupported conversion {} -> {s}", .{ self.dataType(), @typeName(T) });
}
};
pub fn mantissaSize(dtype: DataType) usize {
return switch (dtype) {
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz => 3,
.f8e5m2, .f8e5m2fnuz => 2,
.f16 => 10,
.bf16 => 7,
.f32 => 23,
.f64 => 52,
else => @panic("Can't get mantissa size for a non-float dtype"),
};
}