diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 2d28461..294243a 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -715,29 +715,29 @@ pub const LoadedExecutable = opaque { }; pub const BufferType = enum(c.PJRT_Buffer_Type) { - INVALID = c.PJRT_Buffer_Type_INVALID, - PRED = c.PJRT_Buffer_Type_PRED, - S8 = c.PJRT_Buffer_Type_S8, - S16 = c.PJRT_Buffer_Type_S16, - S32 = c.PJRT_Buffer_Type_S32, - S64 = c.PJRT_Buffer_Type_S64, - U8 = c.PJRT_Buffer_Type_U8, - U16 = c.PJRT_Buffer_Type_U16, - U32 = c.PJRT_Buffer_Type_U32, - U64 = c.PJRT_Buffer_Type_U64, - F16 = c.PJRT_Buffer_Type_F16, - F32 = c.PJRT_Buffer_Type_F32, - F64 = c.PJRT_Buffer_Type_F64, - BF16 = c.PJRT_Buffer_Type_BF16, - C64 = c.PJRT_Buffer_Type_C64, - C128 = c.PJRT_Buffer_Type_C128, - F8E5M2 = c.PJRT_Buffer_Type_F8E5M2, - F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN, - F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, - F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ, - F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ, - S4 = c.PJRT_Buffer_Type_S4, - U4 = c.PJRT_Buffer_Type_U4, + invalid = c.PJRT_Buffer_Type_INVALID, + bool = c.PJRT_Buffer_Type_PRED, + i4 = c.PJRT_Buffer_Type_S4, + i8 = c.PJRT_Buffer_Type_S8, + i16 = c.PJRT_Buffer_Type_S16, + i32 = c.PJRT_Buffer_Type_S32, + i64 = c.PJRT_Buffer_Type_S64, + u4 = c.PJRT_Buffer_Type_U4, + u8 = c.PJRT_Buffer_Type_U8, + u16 = c.PJRT_Buffer_Type_U16, + u32 = c.PJRT_Buffer_Type_U32, + u64 = c.PJRT_Buffer_Type_U64, + f16 = c.PJRT_Buffer_Type_F16, + f32 = c.PJRT_Buffer_Type_F32, + f64 = c.PJRT_Buffer_Type_F64, + bf16 = c.PJRT_Buffer_Type_BF16, + c64 = c.PJRT_Buffer_Type_C64, + c128 = c.PJRT_Buffer_Type_C128, + f8e5m2 = c.PJRT_Buffer_Type_F8E5M2, + f8e4m3fn = c.PJRT_Buffer_Type_F8E4M3FN, + f8e4m3b11fnuz = c.PJRT_Buffer_Type_F8E4M3B11FNUZ, + f8e5m2fnuz = c.PJRT_Buffer_Type_F8E5M2FNUZ, + f8e4m3fnuz = c.PJRT_Buffer_Type_F8E4M3FNUZ, }; pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) { diff --git a/zml/buffer.zig b/zml/buffer.zig index 9d79c0b..3ce95fb 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -411,56 +411,14 @@ pub const Buffer = struct { pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { return switch (dt) { - .bool => .PRED, - .f8e4m3b11fnuz => .F8E4M3B11FNUZ, - .f8e4m3fn => .F8E4M3FN, - .f8e4m3fnuz => .F8E4M3FNUZ, - .f8e5m2 => .F8E5M2, - .f8e5m2fnuz => .F8E5M2FNUZ, - .bf16 => .BF16, - .f16 => .F16, - .f32 => .F32, - .f64 => .F64, - .i8 => .S8, - .i4 => .S4, - .i16 => .S16, - .i32 => .S32, - .i64 => .S64, - .u4 => .U4, - .u8 => .U8, - .u16 => .U16, - .u32 => .U32, - .u64 => .U64, - .c64 => .C64, - .c128 => .C128, + inline else => |tag| @field(pjrt.BufferType, @tagName(tag)), }; } pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType { return switch (pjrt_type) { - .PRED => .bool, - .F8E4M3B11FNUZ => .f8e4m3b11fnuz, - .F8E4M3FN => .f8e4m3fn, - .F8E4M3FNUZ => .f8e4m3fnuz, - .F8E5M2 => .f8e5m2, - .F8E5M2FNUZ => .f8e5m2fnuz, - .BF16 => .bf16, - .F16 => .f16, - .F32 => .f32, - .F64 => .f64, - .S8 => .i8, - .S4 => .i4, - .S16 => .i16, - .S32 => .i32, - .S64 => .i64, - .U4 => .u4, - .U8 => .u8, - .U16 => .u16, - .U32 => .u32, - .U64 => .u64, - .C64 => .c64, - .C128 => .c128, - .INVALID => @panic("Found an invalid pjrt buffer"), + .invalid => @panic("Found an invalid pjrt buffer"), + inline else => |tag| @field(DataType, @tagName(tag)), }; } @@ -472,7 +430,7 @@ test bufferTypeFromDtype { inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| { const dt: pjrt.BufferType = @enumFromInt(field.value); - if (dt == .INVALID) continue; + if (dt == .invalid) continue; try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); } }