Rename PJRT BufferType to follow Zig and ZML naming conventions.
This commit is contained in:
parent
7324a49da3
commit
99a2001e63
@ -715,29 +715,29 @@ pub const LoadedExecutable = opaque {
|
||||
};
|
||||
|
||||
pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
||||
INVALID = c.PJRT_Buffer_Type_INVALID,
|
||||
PRED = c.PJRT_Buffer_Type_PRED,
|
||||
S8 = c.PJRT_Buffer_Type_S8,
|
||||
S16 = c.PJRT_Buffer_Type_S16,
|
||||
S32 = c.PJRT_Buffer_Type_S32,
|
||||
S64 = c.PJRT_Buffer_Type_S64,
|
||||
U8 = c.PJRT_Buffer_Type_U8,
|
||||
U16 = c.PJRT_Buffer_Type_U16,
|
||||
U32 = c.PJRT_Buffer_Type_U32,
|
||||
U64 = c.PJRT_Buffer_Type_U64,
|
||||
F16 = c.PJRT_Buffer_Type_F16,
|
||||
F32 = c.PJRT_Buffer_Type_F32,
|
||||
F64 = c.PJRT_Buffer_Type_F64,
|
||||
BF16 = c.PJRT_Buffer_Type_BF16,
|
||||
C64 = c.PJRT_Buffer_Type_C64,
|
||||
C128 = c.PJRT_Buffer_Type_C128,
|
||||
F8E5M2 = c.PJRT_Buffer_Type_F8E5M2,
|
||||
F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN,
|
||||
F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
||||
F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
||||
F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
||||
S4 = c.PJRT_Buffer_Type_S4,
|
||||
U4 = c.PJRT_Buffer_Type_U4,
|
||||
invalid = c.PJRT_Buffer_Type_INVALID,
|
||||
bool = c.PJRT_Buffer_Type_PRED,
|
||||
i4 = c.PJRT_Buffer_Type_S4,
|
||||
i8 = c.PJRT_Buffer_Type_S8,
|
||||
i16 = c.PJRT_Buffer_Type_S16,
|
||||
i32 = c.PJRT_Buffer_Type_S32,
|
||||
i64 = c.PJRT_Buffer_Type_S64,
|
||||
u4 = c.PJRT_Buffer_Type_U4,
|
||||
u8 = c.PJRT_Buffer_Type_U8,
|
||||
u16 = c.PJRT_Buffer_Type_U16,
|
||||
u32 = c.PJRT_Buffer_Type_U32,
|
||||
u64 = c.PJRT_Buffer_Type_U64,
|
||||
f16 = c.PJRT_Buffer_Type_F16,
|
||||
f32 = c.PJRT_Buffer_Type_F32,
|
||||
f64 = c.PJRT_Buffer_Type_F64,
|
||||
bf16 = c.PJRT_Buffer_Type_BF16,
|
||||
c64 = c.PJRT_Buffer_Type_C64,
|
||||
c128 = c.PJRT_Buffer_Type_C128,
|
||||
f8e5m2 = c.PJRT_Buffer_Type_F8E5M2,
|
||||
f8e4m3fn = c.PJRT_Buffer_Type_F8E4M3FN,
|
||||
f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
||||
f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
||||
f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
||||
};
|
||||
|
||||
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
||||
|
||||
@ -411,56 +411,14 @@ pub const Buffer = struct {
|
||||
|
||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||
return switch (dt) {
|
||||
.bool => .PRED,
|
||||
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
|
||||
.f8e4m3fn => .F8E4M3FN,
|
||||
.f8e4m3fnuz => .F8E4M3FNUZ,
|
||||
.f8e5m2 => .F8E5M2,
|
||||
.f8e5m2fnuz => .F8E5M2FNUZ,
|
||||
.bf16 => .BF16,
|
||||
.f16 => .F16,
|
||||
.f32 => .F32,
|
||||
.f64 => .F64,
|
||||
.i8 => .S8,
|
||||
.i4 => .S4,
|
||||
.i16 => .S16,
|
||||
.i32 => .S32,
|
||||
.i64 => .S64,
|
||||
.u4 => .U4,
|
||||
.u8 => .U8,
|
||||
.u16 => .U16,
|
||||
.u32 => .U32,
|
||||
.u64 => .U64,
|
||||
.c64 => .C64,
|
||||
.c128 => .C128,
|
||||
inline else => |tag| @field(pjrt.BufferType, @tagName(tag)),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
|
||||
return switch (pjrt_type) {
|
||||
.PRED => .bool,
|
||||
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
||||
.F8E4M3FN => .f8e4m3fn,
|
||||
.F8E4M3FNUZ => .f8e4m3fnuz,
|
||||
.F8E5M2 => .f8e5m2,
|
||||
.F8E5M2FNUZ => .f8e5m2fnuz,
|
||||
.BF16 => .bf16,
|
||||
.F16 => .f16,
|
||||
.F32 => .f32,
|
||||
.F64 => .f64,
|
||||
.S8 => .i8,
|
||||
.S4 => .i4,
|
||||
.S16 => .i16,
|
||||
.S32 => .i32,
|
||||
.S64 => .i64,
|
||||
.U4 => .u4,
|
||||
.U8 => .u8,
|
||||
.U16 => .u16,
|
||||
.U32 => .u32,
|
||||
.U64 => .u64,
|
||||
.C64 => .c64,
|
||||
.C128 => .c128,
|
||||
.INVALID => @panic("Found an invalid pjrt buffer"),
|
||||
.invalid => @panic("Found an invalid pjrt buffer"),
|
||||
inline else => |tag| @field(DataType, @tagName(tag)),
|
||||
};
|
||||
}
|
||||
|
||||
@ -472,7 +430,7 @@ test bufferTypeFromDtype {
|
||||
|
||||
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
|
||||
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
||||
if (dt == .INVALID) continue;
|
||||
if (dt == .invalid) continue;
|
||||
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user