Rename PJRT BufferType to follow Zig and ZML naming conventions.
This commit is contained in:
parent
7324a49da3
commit
99a2001e63
@ -715,29 +715,29 @@ pub const LoadedExecutable = opaque {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
||||||
INVALID = c.PJRT_Buffer_Type_INVALID,
|
invalid = c.PJRT_Buffer_Type_INVALID,
|
||||||
PRED = c.PJRT_Buffer_Type_PRED,
|
bool = c.PJRT_Buffer_Type_PRED,
|
||||||
S8 = c.PJRT_Buffer_Type_S8,
|
i4 = c.PJRT_Buffer_Type_S4,
|
||||||
S16 = c.PJRT_Buffer_Type_S16,
|
i8 = c.PJRT_Buffer_Type_S8,
|
||||||
S32 = c.PJRT_Buffer_Type_S32,
|
i16 = c.PJRT_Buffer_Type_S16,
|
||||||
S64 = c.PJRT_Buffer_Type_S64,
|
i32 = c.PJRT_Buffer_Type_S32,
|
||||||
U8 = c.PJRT_Buffer_Type_U8,
|
i64 = c.PJRT_Buffer_Type_S64,
|
||||||
U16 = c.PJRT_Buffer_Type_U16,
|
u4 = c.PJRT_Buffer_Type_U4,
|
||||||
U32 = c.PJRT_Buffer_Type_U32,
|
u8 = c.PJRT_Buffer_Type_U8,
|
||||||
U64 = c.PJRT_Buffer_Type_U64,
|
u16 = c.PJRT_Buffer_Type_U16,
|
||||||
F16 = c.PJRT_Buffer_Type_F16,
|
u32 = c.PJRT_Buffer_Type_U32,
|
||||||
F32 = c.PJRT_Buffer_Type_F32,
|
u64 = c.PJRT_Buffer_Type_U64,
|
||||||
F64 = c.PJRT_Buffer_Type_F64,
|
f16 = c.PJRT_Buffer_Type_F16,
|
||||||
BF16 = c.PJRT_Buffer_Type_BF16,
|
f32 = c.PJRT_Buffer_Type_F32,
|
||||||
C64 = c.PJRT_Buffer_Type_C64,
|
f64 = c.PJRT_Buffer_Type_F64,
|
||||||
C128 = c.PJRT_Buffer_Type_C128,
|
bf16 = c.PJRT_Buffer_Type_BF16,
|
||||||
F8E5M2 = c.PJRT_Buffer_Type_F8E5M2,
|
c64 = c.PJRT_Buffer_Type_C64,
|
||||||
F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN,
|
c128 = c.PJRT_Buffer_Type_C128,
|
||||||
F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
f8e5m2 = c.PJRT_Buffer_Type_F8E5M2,
|
||||||
F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
f8e4m3fn = c.PJRT_Buffer_Type_F8E4M3FN,
|
||||||
F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
||||||
S4 = c.PJRT_Buffer_Type_S4,
|
f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
||||||
U4 = c.PJRT_Buffer_Type_U4,
|
f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
||||||
|
|||||||
@ -411,56 +411,14 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||||
return switch (dt) {
|
return switch (dt) {
|
||||||
.bool => .PRED,
|
inline else => |tag| @field(pjrt.BufferType, @tagName(tag)),
|
||||||
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
|
|
||||||
.f8e4m3fn => .F8E4M3FN,
|
|
||||||
.f8e4m3fnuz => .F8E4M3FNUZ,
|
|
||||||
.f8e5m2 => .F8E5M2,
|
|
||||||
.f8e5m2fnuz => .F8E5M2FNUZ,
|
|
||||||
.bf16 => .BF16,
|
|
||||||
.f16 => .F16,
|
|
||||||
.f32 => .F32,
|
|
||||||
.f64 => .F64,
|
|
||||||
.i8 => .S8,
|
|
||||||
.i4 => .S4,
|
|
||||||
.i16 => .S16,
|
|
||||||
.i32 => .S32,
|
|
||||||
.i64 => .S64,
|
|
||||||
.u4 => .U4,
|
|
||||||
.u8 => .U8,
|
|
||||||
.u16 => .U16,
|
|
||||||
.u32 => .U32,
|
|
||||||
.u64 => .U64,
|
|
||||||
.c64 => .C64,
|
|
||||||
.c128 => .C128,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
|
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
|
||||||
return switch (pjrt_type) {
|
return switch (pjrt_type) {
|
||||||
.PRED => .bool,
|
.invalid => @panic("Found an invalid pjrt buffer"),
|
||||||
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
inline else => |tag| @field(DataType, @tagName(tag)),
|
||||||
.F8E4M3FN => .f8e4m3fn,
|
|
||||||
.F8E4M3FNUZ => .f8e4m3fnuz,
|
|
||||||
.F8E5M2 => .f8e5m2,
|
|
||||||
.F8E5M2FNUZ => .f8e5m2fnuz,
|
|
||||||
.BF16 => .bf16,
|
|
||||||
.F16 => .f16,
|
|
||||||
.F32 => .f32,
|
|
||||||
.F64 => .f64,
|
|
||||||
.S8 => .i8,
|
|
||||||
.S4 => .i4,
|
|
||||||
.S16 => .i16,
|
|
||||||
.S32 => .i32,
|
|
||||||
.S64 => .i64,
|
|
||||||
.U4 => .u4,
|
|
||||||
.U8 => .u8,
|
|
||||||
.U16 => .u16,
|
|
||||||
.U32 => .u32,
|
|
||||||
.U64 => .u64,
|
|
||||||
.C64 => .c64,
|
|
||||||
.C128 => .c128,
|
|
||||||
.INVALID => @panic("Found an invalid pjrt buffer"),
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,7 +430,7 @@ test bufferTypeFromDtype {
|
|||||||
|
|
||||||
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
|
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
|
||||||
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
||||||
if (dt == .INVALID) continue;
|
if (dt == .invalid) continue;
|
||||||
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user