Rename PJRT BufferType to follow Zig and ZML naming conventions.

This commit is contained in:
Tarry Singh 2025-01-16 13:00:47 +00:00
parent 7324a49da3
commit 99a2001e63
2 changed files with 27 additions and 69 deletions

View File

@ -715,29 +715,29 @@ pub const LoadedExecutable = opaque {
}; };
pub const BufferType = enum(c.PJRT_Buffer_Type) { pub const BufferType = enum(c.PJRT_Buffer_Type) {
INVALID = c.PJRT_Buffer_Type_INVALID, invalid = c.PJRT_Buffer_Type_INVALID,
PRED = c.PJRT_Buffer_Type_PRED, bool = c.PJRT_Buffer_Type_PRED,
S8 = c.PJRT_Buffer_Type_S8, i4 = c.PJRT_Buffer_Type_S4,
S16 = c.PJRT_Buffer_Type_S16, i8 = c.PJRT_Buffer_Type_S8,
S32 = c.PJRT_Buffer_Type_S32, i16 = c.PJRT_Buffer_Type_S16,
S64 = c.PJRT_Buffer_Type_S64, i32 = c.PJRT_Buffer_Type_S32,
U8 = c.PJRT_Buffer_Type_U8, i64 = c.PJRT_Buffer_Type_S64,
U16 = c.PJRT_Buffer_Type_U16, u4 = c.PJRT_Buffer_Type_U4,
U32 = c.PJRT_Buffer_Type_U32, u8 = c.PJRT_Buffer_Type_U8,
U64 = c.PJRT_Buffer_Type_U64, u16 = c.PJRT_Buffer_Type_U16,
F16 = c.PJRT_Buffer_Type_F16, u32 = c.PJRT_Buffer_Type_U32,
F32 = c.PJRT_Buffer_Type_F32, u64 = c.PJRT_Buffer_Type_U64,
F64 = c.PJRT_Buffer_Type_F64, f16 = c.PJRT_Buffer_Type_F16,
BF16 = c.PJRT_Buffer_Type_BF16, f32 = c.PJRT_Buffer_Type_F32,
C64 = c.PJRT_Buffer_Type_C64, f64 = c.PJRT_Buffer_Type_F64,
C128 = c.PJRT_Buffer_Type_C128, bf16 = c.PJRT_Buffer_Type_BF16,
F8E5M2 = c.PJRT_Buffer_Type_F8E5M2, c64 = c.PJRT_Buffer_Type_C64,
F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN, c128 = c.PJRT_Buffer_Type_C128,
F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, f8e5m2 = c.PJRT_Buffer_Type_F8E5M2,
F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ, f8e4m3fn = c.PJRT_Buffer_Type_F8E4M3FN,
F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ, f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
S4 = c.PJRT_Buffer_Type_S4, f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ,
U4 = c.PJRT_Buffer_Type_U4, f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ,
}; };
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) { pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {

View File

@ -411,56 +411,14 @@ pub const Buffer = struct {
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
return switch (dt) { return switch (dt) {
.bool => .PRED, inline else => |tag| @field(pjrt.BufferType, @tagName(tag)),
.f8e4m3b11fnuz => .F8E4M3B11FNUZ,
.f8e4m3fn => .F8E4M3FN,
.f8e4m3fnuz => .F8E4M3FNUZ,
.f8e5m2 => .F8E5M2,
.f8e5m2fnuz => .F8E5M2FNUZ,
.bf16 => .BF16,
.f16 => .F16,
.f32 => .F32,
.f64 => .F64,
.i8 => .S8,
.i4 => .S4,
.i16 => .S16,
.i32 => .S32,
.i64 => .S64,
.u4 => .U4,
.u8 => .U8,
.u16 => .U16,
.u32 => .U32,
.u64 => .U64,
.c64 => .C64,
.c128 => .C128,
}; };
} }
pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType { pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
return switch (pjrt_type) { return switch (pjrt_type) {
.PRED => .bool, .invalid => @panic("Found an invalid pjrt buffer"),
.F8E4M3B11FNUZ => .f8e4m3b11fnuz, inline else => |tag| @field(DataType, @tagName(tag)),
.F8E4M3FN => .f8e4m3fn,
.F8E4M3FNUZ => .f8e4m3fnuz,
.F8E5M2 => .f8e5m2,
.F8E5M2FNUZ => .f8e5m2fnuz,
.BF16 => .bf16,
.F16 => .f16,
.F32 => .f32,
.F64 => .f64,
.S8 => .i8,
.S4 => .i4,
.S16 => .i16,
.S32 => .i32,
.S64 => .i64,
.U4 => .u4,
.U8 => .u8,
.U16 => .u16,
.U32 => .u32,
.U64 => .u64,
.C64 => .c64,
.C128 => .c128,
.INVALID => @panic("Found an invalid pjrt buffer"),
}; };
} }
@ -472,7 +430,7 @@ test bufferTypeFromDtype {
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| { inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
const dt: pjrt.BufferType = @enumFromInt(field.value); const dt: pjrt.BufferType = @enumFromInt(field.value);
if (dt == .INVALID) continue; if (dt == .invalid) continue;
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
} }
} }