2023-01-02 14:28:25 +00:00
const std = @import ( " std " ) ;
const testing = std . testing ;
const meta = @import ( " meta.zig " ) ;
const pjrt = @import ( " pjrt " ) ;
const pjrtx = @import ( " pjrtx.zig " ) ;
const Context = @import ( " context.zig " ) . Context ;
const Data = @import ( " dtype.zig " ) . Data ;
const DataType = @import ( " dtype.zig " ) . DataType ;
2023-01-27 14:35:11 +00:00
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
const Platform = @import ( " platform.zig " ) . Platform ;
const Shape = @import ( " shape.zig " ) . Shape ;
test {
std . testing . refAllDecls ( Buffer ) ;
}
2023-01-02 14:28:25 +00:00
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
///
/// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU.
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
/// * can be created by calling `HostBuffer.toDevice(platform)`.
pub const Buffer = struct {
_shape : Shape ,
_shards : Shape = undefined ,
2023-01-27 14:35:11 +00:00
_platform : Platform ,
2023-01-02 14:28:25 +00:00
_data : * pjrtx . Buffer ,
/// Copies the content of the given buffer from host memory to the accelerator memory.
2023-01-27 14:35:11 +00:00
pub fn from ( platform : Platform , buf : HostBuffer ) ! Buffer {
const pjrt_buffer = try platform . pjrt_client . bufferFromHostBuffer ( platform . pjrt_api , . {
2023-01-02 14:28:25 +00:00
. data = buf . data ,
. buffer_type = pjrtx . Buffer . BufferTypeFromDType ( buf . shape ( ) . dtype ( ) ) ,
. dims = buf . shape ( ) . dims ( ) ,
2023-01-27 14:35:11 +00:00
. byte_strides = buf . strides ( ) ,
. device = platform . getDevices ( ) [ 0 ] ,
2023-01-02 14:28:25 +00:00
. host_buffer_semantics = . ImmutableUntilTransferCompletes ,
} ) ;
return . {
2023-01-27 14:35:11 +00:00
. _platform = platform ,
2023-01-02 14:28:25 +00:00
. _shape = buf . shape ( ) ,
. _data = pjrt_buffer ,
} ;
}
/// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`.
2023-01-27 14:35:11 +00:00
pub fn fromPjrtBuffer ( platform : Platform , pjrt_buffer : * pjrtx . Buffer ) Buffer {
2023-01-02 14:28:25 +00:00
return . {
2023-01-27 14:35:11 +00:00
. _platform = platform ,
. _shape = _shapeFromPjrtBuffer ( platform , pjrt_buffer ) ,
2023-01-02 14:28:25 +00:00
. _data = pjrt_buffer ,
} ;
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
2023-01-27 14:35:11 +00:00
pub fn fromSlice ( platform : Platform , dimz : anytype , s : anytype ) ! Buffer {
2023-01-02 14:28:25 +00:00
const sh = Shape . init ( dimz , DataType . fromSliceElementType ( s ) ) ;
2023-01-27 14:35:11 +00:00
return from ( platform , HostBuffer . fromBytes ( sh , std . mem . sliceAsBytes ( s ) ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape.
2023-01-27 14:35:11 +00:00
pub fn fromArray ( platform : Platform , arr : anytype ) ! Buffer {
2023-01-02 14:28:25 +00:00
const host_buffer = HostBuffer . fromArray ( & arr ) ;
2023-01-27 14:35:11 +00:00
return try from ( platform , host_buffer ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a Buffer with a single element.
2023-01-27 14:35:11 +00:00
pub fn scalar ( platform : Platform , val : anytype , dtype_ : DataType ) ! Buffer {
2023-01-02 14:28:25 +00:00
const x = dtype_ . constant ( val ) ;
const host_buffer = HostBuffer . fromBytes ( Shape . init ( . { } , dtype_ ) , x . constSlice ( ) ) ;
2023-01-27 14:35:11 +00:00
return try from ( platform , host_buffer ) ;
}
/// Creates a Buffer with a single element repeated manytime.
pub fn constant ( platform : Platform , shape_ : Shape , val : anytype ) ! Buffer {
const x = shape_ . dtype ( ) . constant ( val ) ;
const host_buffer : HostBuffer = . {
. _shape = shape_ ,
. _strides = [ 1 ] i64 { 0 } * * Shape . MAX_RANK ,
. data = x . constSlice ( ) ,
} ;
return try from ( platform , host_buffer ) ;
}
test constant {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const x = try constant ( platform , Shape . init ( . { 4 , 3 , 2 } , . u16 ) , 42 ) ;
const y = try x . getValue ( [ 4 * 3 * 2 ] u16 ) ;
try std . testing . expectEqual ( [ _ ] u16 { 42 } * * ( 4 * 3 * 2 ) , y ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a Buffer as a view of memory visible from the device,
/// thus avoiding a copy.
///
/// On CUDA, it also allows you to specify a host allocated slice as they seem to be accessible.
/// Be careful though, as it requires a specific alignment.
/// Also note that it might not work on all platforms,
/// could lead to crashes and is considerably slower.
2023-01-27 14:35:11 +00:00
pub fn asViewOf ( platform : Platform , buf : HostBuffer ) ! Buffer {
2023-01-02 14:28:25 +00:00
const minor_to_major : [ Shape . MAX_RANK ] i64 = comptime blk : {
var res : [ Shape . MAX_RANK ] i64 = undefined ;
for ( 0 . . Shape . MAX_RANK ) | i | {
res [ i ] = @intCast ( Shape . MAX_RANK - i - 1 ) ;
}
break : blk res ;
} ;
2023-01-27 14:35:11 +00:00
const pjrt_buffer = try platform . pjrt_client . createViewOfDeviceBuffer ( platform . pjrt_api , . {
2023-01-02 14:28:25 +00:00
. data = buf . data ,
. element_type = pjrtx . Buffer . BufferTypeFromDType ( buf . shape ( ) . dtype ( ) ) ,
. dims = buf . shape ( ) . dims ( ) ,
2023-01-27 14:35:11 +00:00
. device = platform . getDevices ( ) [ 0 ] ,
2023-01-02 14:28:25 +00:00
. layout = . {
. Tiled = . {
. minor_to_major = minor_to_major [ Shape . MAX_RANK - buf . shape ( ) . rank ( ) . . ] ,
. tile_dims = & . { } ,
. tile_dims_sizes = & . { } ,
} ,
} ,
} ) ;
return . {
2023-01-27 14:35:11 +00:00
. _platform = platform ,
2023-01-02 14:28:25 +00:00
. _shape = buf . shape ( ) ,
. _data = pjrt_buffer ,
} ;
}
/// Fetches the content of the given buffer into a stack variable of the given type.
pub fn getValue ( self : Buffer , T : type ) ! T {
meta . assert ( self . _shape . byteSize ( ) = = @sizeOf ( T ) , " Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes " , . { self , self . _shape . byteSize ( ) , @typeName ( T ) , @sizeOf ( T ) } ) ;
var res : T = undefined ;
try self . _data . toHostBuffer ( self . _platform . pjrt_api , std . mem . asBytes ( & res ) ) ;
return res ;
}
/// Copies the content of the Buffer back to host, in the given buffer,
/// and return a new `HostBuffer` object with the same shape.
/// The returned `HostBuffer` doesn't own the memory.
pub fn toHost ( self : Buffer , output : [ ] u8 ) ! HostBuffer {
try self . _data . toHostBuffer ( self . _platform . pjrt_api , output ) ;
return HostBuffer . fromBytes ( self . shape ( ) , output ) ;
}
/// Copies the content of the Buffer to the host.
/// The returned `HostBuffer` does own the memory.
pub fn toHostAlloc ( self : Buffer , allocator : std . mem . Allocator ) ! HostBuffer {
const output = try HostBuffer . empty ( allocator , self . shape ( ) ) ;
try self . _data . toHostBuffer ( self . _platform . pjrt_api , @constCast ( output . data ) ) ;
return output ;
}
/// Frees the accelerator memory.
/// Depending on the platform, the memory is typically not released to the OS
/// but just marked as available in the memory pool.
pub fn deinit ( self : * const Buffer ) void {
self . _data . deinit ( self . _platform . pjrt_api ) ;
}
/// This Buffer shape.
pub fn shape ( self : Buffer ) Shape {
return self . _shape ;
}
/// This Buffer shape as a slice of dims.
pub fn dims ( self : * const Buffer ) [ ] const i64 {
return self . _shape . dims ( ) ;
}
/// This Buffer element type.
pub fn dtype ( self : Buffer ) DataType {
return self . _shape . dtype ( ) ;
}
/// This Buffer rank.
pub fn rank ( self : Buffer ) u4 {
return self . _shape . rank ( ) ;
}
/// Test helper: returns a new Buffer with the given tags.
/// Allows to call `zml.testing.compileAndCall` when the tested
/// functions requires tagged tensors.
pub fn withTags ( self : Buffer , tags_ : anytype ) Buffer {
var res = self ;
res . _shape = self . _shape . withTags ( tags_ ) ;
return res ;
}
pub fn format (
self : Buffer ,
comptime fmt : [ ] const u8 ,
options : std . fmt . FormatOptions ,
writer : anytype ,
) ! void {
_ = fmt ;
_ = options ;
2023-01-27 14:35:11 +00:00
try writer . print ( " Buffer({_}) " , . { self . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
2023-01-27 14:35:11 +00:00
fn _shapeFromPjrtBuffer ( platform : Platform , buf : * pjrtx . Buffer ) Shape {
const dt : DataType = switch ( buf . getElementType ( platform . pjrt_api ) ) {
2023-01-02 14:28:25 +00:00
// Please keep the list exhaustive and in the same order than in DataType.
. PRED = > . bool ,
. F8E4M3B11FNUZ = > . f8e4m3b11fnuz ,
. F8E4M3FN = > . f8e4m3fn ,
. F8E4M3FNUZ = > . f8e4m3fnuz ,
. F8E5M2 = > . f8e5m2 ,
. F8E5M2FNUZ = > . f8e5m2fnuz ,
. BF16 = > . bf16 ,
. F16 = > . f16 ,
. F32 = > . f32 ,
. F64 = > . f64 ,
. S4 = > . i4 ,
. S8 = > . i8 ,
. S16 = > . i16 ,
. S32 = > . i32 ,
. S64 = > . i64 ,
. U4 = > . u4 ,
. U8 = > . u8 ,
. U16 = > . u16 ,
. U32 = > . u32 ,
. U64 = > . u64 ,
. C64 = > . c64 ,
. C128 = > . c128 ,
. INVALID = > @panic ( " Can't handle INVALID Pjrt buffers. " ) ,
} ;
2023-01-27 14:35:11 +00:00
return Shape . init ( buf . getDimensions ( platform . pjrt_api ) , dt ) ;
2023-01-02 14:28:25 +00:00
}
} ;