2023-01-02 14:28:25 +00:00
const std = @import ( " std " ) ;
2024-09-10 09:14:28 +00:00
const asynk = @import ( " async " ) ;
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
const DataType = @import ( " dtype.zig " ) . DataType ;
2023-01-27 14:35:11 +00:00
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
2024-09-10 09:14:28 +00:00
const pjrt = @import ( " pjrtx.zig " ) ;
2023-01-27 14:35:11 +00:00
const Platform = @import ( " platform.zig " ) . Platform ;
const Shape = @import ( " shape.zig " ) . Shape ;
test {
2023-02-24 17:33:14 +00:00
std . testing . refAllDecls ( @This ( ) ) ;
2023-01-27 14:35:11 +00:00
std . testing . refAllDecls ( Buffer ) ;
}
2023-01-02 14:28:25 +00:00
2023-02-24 17:33:14 +00:00
const log = std . log . scoped ( . zml ) ;
2023-01-02 14:28:25 +00:00
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
///
/// * contains a handle that the ZML runtime can use to convert into a physical address, but there is no guarantee this address is visible from the CPU.
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
/// * can be created by calling `HostBuffer.toDevice(platform)`.
pub const Buffer = struct {
2024-09-10 09:14:28 +00:00
pub const Memory = enum {
host ,
host_pinned ,
device ,
pub fn toPjrtMemory ( self : Memory ) pjrt . Memory . Kind {
return switch ( self ) {
. host = > . unpinned_host ,
. host_pinned = > . pinned_host ,
. device = > . device ,
} ;
}
pub fn pjrtName ( self : Memory ) [ ] const u8 {
return @tagName ( self . toPjrtMemory ( ) ) ;
}
2024-04-11 15:43:24 +00:00
} ;
2023-01-02 14:28:25 +00:00
_shape : Shape ,
2023-02-24 17:33:14 +00:00
_api : * const pjrt . Api ,
_shards : Shards ,
pub const MAX_NUM_SHARDS : u8 = Platform . MAX_NUM_DEVICES ;
pub const Shards = std . BoundedArray ( * pjrt . Buffer , MAX_NUM_SHARDS ) ;
2023-01-02 14:28:25 +00:00
2024-12-25 17:14:44 +00:00
pub const FromOptions = struct {
wait : bool = true ,
memory : ? pjrt . Memory . Kind = null ,
} ;
2023-01-02 14:28:25 +00:00
/// Copies the content of the given buffer from host memory to the accelerator memory.
2024-12-25 17:14:44 +00:00
pub fn from ( platform : Platform , host_buffer : HostBuffer , opts : FromOptions ) ! Buffer {
2023-02-24 17:33:14 +00:00
var res : Buffer = . {
. _api = platform . pjrt_api ,
. _shape = host_buffer . shape ( ) ,
. _shards = . { } ,
2023-01-02 14:28:25 +00:00
} ;
2023-02-24 17:33:14 +00:00
// We shard only on the first axis so that the chunks are still contiguous.
// TODO: support more advanced sharding specs
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( platform . sharding ( ) . num_replicas = = 1 , " ZML doesn't support num_replicas > 1 for now, got: {} " , . { platform . sharding ( ) } ) ;
2023-02-24 17:33:14 +00:00
const sharding_ax : ? u3 = std . simd . firstTrue ( host_buffer . shape ( ) . _sharding_info ) ;
const n_partitions = platform . sharding ( ) . num_partitions ;
const chunk_size = if ( sharding_ax ) | ax | cs : {
// This kind of sharding error should be detected earlier on.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( @rem ( host_buffer . dim ( ax ) , n_partitions ) = = 0 , " Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}). " , . { host_buffer , ax , n_partitions } ) ;
2023-02-24 17:33:14 +00:00
break : cs @divExact ( host_buffer . dim ( ax ) , n_partitions ) ;
} else 0 ;
const buffer_type = bufferTypeFromDtype ( host_buffer . shape ( ) . dtype ( ) ) ;
2024-10-28 11:21:46 +00:00
const byte_strides = host_buffer . strides ( ) ;
2023-02-24 17:33:14 +00:00
const devices = platform . getDevices ( ) ;
for ( 0 . . n_partitions ) | i | {
// If no sharding if found, the given buffer is replicated on all devices.
const buf = if ( sharding_ax ) | ax | buf : {
const start : i64 = @as ( i64 , @intCast ( i ) ) * chunk_size ;
break : buf host_buffer . slice1d ( ax , . { . start = start , . end = start + chunk_size } ) ;
} else host_buffer ;
2024-12-25 17:14:44 +00:00
var args = pjrt . Client . BufferFromHostBufferArgs {
. data = buf . _data ,
. buffer_type = buffer_type ,
. dims = buf . shape ( ) . dims ( ) ,
. byte_strides = byte_strides ,
. host_buffer_semantics = . ImmutableUntilTransferCompletes ,
} ;
if ( opts . memory ) | memory_kind | {
const memories = try devices [ i ] . addressableMemories ( platform . pjrt_api ) ;
const memory = for ( memories ) | m | {
const kind = m . kind ( platform . pjrt_api ) ;
if ( kind = = memory_kind ) break m ;
} else return error . NotFound ;
args . memory = memory ;
} else {
args . device = devices [ i ] ;
}
2023-02-24 17:33:14 +00:00
2024-12-25 17:14:44 +00:00
const pjrt_buffer , const event = try platform . pjrt_client . bufferFromHostBuffer ( platform . pjrt_api , args ) ;
if ( event ) | ev | {
ev . deinit ( platform . pjrt_api ) ;
}
2023-03-06 17:05:56 +00:00
2023-05-09 12:44:56 +00:00
res . _shards . appendAssumeCapacity ( pjrt_buffer ) ;
2023-03-06 17:05:56 +00:00
}
2024-12-25 17:14:44 +00:00
if ( opts . wait ) {
res = try res . awaitt ( ) ;
}
2023-02-24 17:33:14 +00:00
return res ;
2023-01-02 14:28:25 +00:00
}
2024-12-25 17:14:44 +00:00
pub fn awaitt ( self : Buffer ) ! Buffer {
for ( self . _shards . constSlice ( ) ) | buffer | {
if ( buffer . getReadyEvent ( self . _api ) ) | ev | {
try ev . await_ ( self . _api ) ;
}
}
return self ;
}
2023-02-24 17:33:14 +00:00
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
2023-03-21 10:50:39 +00:00
pub fn fromPjrtBuffers ( platform : Platform , shape_ : Shape , pjrt_buffers : [ ] const * pjrt . Buffer ) Buffer {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( pjrt_buffers . len < = MAX_NUM_SHARDS , " ZML doesn't support having more than {} shards. Received {} shards for one buffer. " , . { MAX_NUM_SHARDS , pjrt_buffers . len } ) ;
stdx . debug . assert ( pjrt_buffers . len > 0 , " fromPjrtBuffers expects at least one buffer, got 0. " , . { } ) ;
2023-02-24 17:33:14 +00:00
var shards : Shards = . { } ;
shards . appendSliceAssumeCapacity ( pjrt_buffers ) ;
2023-01-02 14:28:25 +00:00
return . {
2023-02-24 17:33:14 +00:00
. _api = platform . pjrt_api ,
2023-03-21 10:50:39 +00:00
. _shape = shape_ ,
2023-02-24 17:33:14 +00:00
. _shards = shards ,
2023-01-02 14:28:25 +00:00
} ;
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
2023-01-27 14:35:11 +00:00
pub fn fromSlice ( platform : Platform , dimz : anytype , s : anytype ) ! Buffer {
2023-01-02 14:28:25 +00:00
const sh = Shape . init ( dimz , DataType . fromSliceElementType ( s ) ) ;
2024-12-25 17:14:44 +00:00
return from ( platform , HostBuffer . fromBytes ( sh , std . mem . sliceAsBytes ( s ) ) , . { } ) ;
2023-01-02 14:28:25 +00:00
}
2024-05-15 17:54:52 +00:00
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromBytes ( platform : Platform , sh : Shape , data : [ ] const u8 ) ! Buffer {
2024-12-25 17:14:44 +00:00
return from ( platform , HostBuffer . fromBytes ( sh , data ) , . { } ) ;
2024-05-15 17:54:52 +00:00
}
2023-01-02 14:28:25 +00:00
/// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape.
2023-01-27 14:35:11 +00:00
pub fn fromArray ( platform : Platform , arr : anytype ) ! Buffer {
2023-01-02 14:28:25 +00:00
const host_buffer = HostBuffer . fromArray ( & arr ) ;
2024-12-25 17:14:44 +00:00
return try from ( platform , host_buffer , . { . wait = true } ) ;
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromSliceOpts ( platform : Platform , dimz : anytype , s : anytype , opts : FromOptions ) ! Buffer {
const sh = Shape . init ( dimz , DataType . fromSliceElementType ( s ) ) ;
return from ( platform , HostBuffer . fromBytes ( sh , std . mem . sliceAsBytes ( s ) ) , opts ) ;
}
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromBytesOpts ( platform : Platform , sh : Shape , data : [ ] const u8 , opts : FromOptions ) ! Buffer {
return from ( platform , HostBuffer . fromBytes ( sh , data ) , opts ) ;
}
/// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape.
pub fn fromArrayOpts ( platform : Platform , arr : anytype , opts : FromOptions ) ! Buffer {
const host_buffer = HostBuffer . fromArray ( & arr ) ;
return try from ( platform , host_buffer , opts ) ;
2023-01-02 14:28:25 +00:00
}
2024-10-28 11:21:46 +00:00
pub fn asPinnedHostBuffer ( self : Buffer ) HostBuffer {
// TODO restore assert
// const memory = self.getMemory().kind(self._api);
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
const ptr : [ * ] u8 = @ptrCast ( self . _shards . get ( 0 ) . getOpaqueDeviceMemoryDataPointer ( self . _api ) catch unreachable ) ;
return HostBuffer . fromBytes ( self . _shape , ptr [ 0 . . self . _shape . byteSize ( ) ] ) ;
}
2023-01-02 14:28:25 +00:00
/// Creates a Buffer with a single element.
2023-01-27 14:35:11 +00:00
pub fn scalar ( platform : Platform , val : anytype , dtype_ : DataType ) ! Buffer {
2023-01-02 14:28:25 +00:00
const x = dtype_ . constant ( val ) ;
const host_buffer = HostBuffer . fromBytes ( Shape . init ( . { } , dtype_ ) , x . constSlice ( ) ) ;
2024-12-25 17:14:44 +00:00
return try from ( platform , host_buffer , . { . wait = true } ) ;
2023-01-27 14:35:11 +00:00
}
/// Creates a Buffer with a single element repeated manytime.
pub fn constant ( platform : Platform , shape_ : Shape , val : anytype ) ! Buffer {
2024-02-19 12:34:18 +00:00
var start = try std . time . Timer . start ( ) ;
defer {
const duration_ms = stdx . math . divFloat ( f32 , start . read ( ) , std . time . ns_per_ms ) ;
if ( duration_ms > 100 ) {
const size_gb = stdx . math . divFloat ( f32 , shape_ . byteSize ( ) , 1024 * 1024 * 1024 ) ;
log . info ( " Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s " , . { shape_ , size_gb , duration_ms , size_gb / duration_ms * 1000 } ) ;
}
}
// Convert val to the requested dtype.
2023-01-27 14:35:11 +00:00
const x = shape_ . dtype ( ) . constant ( val ) ;
2024-02-19 12:34:18 +00:00
const byte_size = shape_ . dtype ( ) . sizeOf ( ) ;
const max_bytes = 1024 ;
// Naive version for scalars and buffers with long last axis.
if ( shape_ . rank ( ) < 1 or byte_size * shape_ . dim ( - 1 ) > max_bytes ) {
const host_buffer : HostBuffer = . {
. _shape = shape_ ,
2024-10-28 11:21:46 +00:00
. _strides = @splat ( 0 ) ,
. _data = x . constSlice ( ) . ptr ,
2024-02-19 12:34:18 +00:00
} ;
2024-12-25 17:14:44 +00:00
return try from ( platform , host_buffer , . { . wait = true } ) ;
2024-02-19 12:34:18 +00:00
}
// To speed up copies, duplicate the scalar value into a vector,
// so that PJRT can copy row by row.
// Because this is respecting the shape, it won't work if the last axis is too big.
// If this becomes an issue, we should create a new intermediary Buffer by splitting last axis into { n, max_bytes }
// so that the trick works, and then reshape it
// We could also handle sharded constant directly in this function to avoid having to create too big arrays.
var bytes : [ max_bytes ] u8 align ( 64 ) = undefined ;
var strides = [ 1 ] i64 { 0 } * * Shape . MAX_RANK ;
strides [ shape_ . rank ( ) - 1 ] = byte_size ;
switch ( byte_size ) {
inline 1 , 2 , 4 , 8 , 16 = > | b | {
const Int = std . meta . Int ( . unsigned , b * 8 ) ;
const x_as_int : Int = @bitCast ( x . constSlice ( ) [ 0 . . b ] . * ) ;
const bytes_as_int : [ * ] Int = @ptrCast ( & bytes ) ;
@memset ( bytes_as_int [ 0 . . @intCast ( shape_ . dim ( - 1 ) ) ] , x_as_int ) ;
} ,
else = > unreachable ,
}
2024-10-28 11:21:46 +00:00
const host_buffer : HostBuffer = . { . _shape = shape_ , . _strides = strides , . _data = & bytes } ;
2024-12-25 17:14:44 +00:00
return try from ( platform , host_buffer , . { . wait = true } ) ;
2023-01-27 14:35:11 +00:00
}
test constant {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const x = try constant ( platform , Shape . init ( . { 4 , 3 , 2 } , . u16 ) , 42 ) ;
const y = try x . getValue ( [ 4 * 3 * 2 ] u16 ) ;
try std . testing . expectEqual ( [ _ ] u16 { 42 } * * ( 4 * 3 * 2 ) , y ) ;
2023-01-02 14:28:25 +00:00
}
2023-06-19 15:29:29 +00:00
/// Creates a Buffer as a view of host memory visible from the device,
2023-01-02 14:28:25 +00:00
/// thus avoiding a copy.
///
2023-06-19 15:29:29 +00:00
/// Be careful though, as it requires a specific alignment
/// and it might not work on all platforms,
/// could lead to crashes and operations on the buffer will be slower.
/// Tested on Cuda 12.4.
2024-09-10 09:14:28 +00:00
pub fn asViewOfHostBuffer ( platform : Platform , buf : HostBuffer ) Buffer {
2024-10-28 11:21:46 +00:00
return asViewOfDeviceBuffer ( platform , buf . shape ( ) , null , @constCast ( buf . _data ) ) ;
2023-06-19 15:29:29 +00:00
}
/// Creates a Buffer from a pointer into device memory.
/// This allows to interface with other libraries producing buffers.
2024-12-10 09:36:37 +00:00
pub fn asViewOfDeviceBuffer ( platform : Platform , shape_ : Shape , stream : ? * const pjrt . Stream , device_data : * anyopaque ) Buffer {
2023-01-02 14:28:25 +00:00
const minor_to_major : [ Shape . MAX_RANK ] i64 = comptime blk : {
var res : [ Shape . MAX_RANK ] i64 = undefined ;
for ( 0 . . Shape . MAX_RANK ) | i | {
res [ i ] = @intCast ( Shape . MAX_RANK - i - 1 ) ;
}
break : blk res ;
} ;
2024-09-10 09:14:28 +00:00
const pjrt_buffer = platform . pjrt_client . createViewOfDeviceBuffer ( platform . pjrt_api , . {
. data = device_data ,
2023-06-19 15:29:29 +00:00
. element_type = bufferTypeFromDtype ( shape_ . dtype ( ) ) ,
. dims = shape_ . dims ( ) ,
// TODO: exposes sharding in the API.
2023-01-27 14:35:11 +00:00
. device = platform . getDevices ( ) [ 0 ] ,
2023-01-02 14:28:25 +00:00
. layout = . {
2024-07-02 14:19:04 +00:00
. tiled = . {
2023-06-19 15:29:29 +00:00
. minor_to_major = minor_to_major [ Shape . MAX_RANK - shape_ . rank ( ) . . ] ,
2023-01-02 14:28:25 +00:00
. tile_dims = & . { } ,
. tile_dims_sizes = & . { } ,
} ,
} ,
2024-10-28 11:21:46 +00:00
. stream = stream ,
2024-09-10 09:14:28 +00:00
} ) catch @panic ( " failed to createViewOfDeviceBuffer " ) ;
2023-01-02 14:28:25 +00:00
2023-02-24 17:33:14 +00:00
var shards : Shards = . { } ;
shards . appendAssumeCapacity ( pjrt_buffer ) ;
2023-01-02 14:28:25 +00:00
return . {
2023-02-24 17:33:14 +00:00
. _api = platform . pjrt_api ,
2023-06-19 15:29:29 +00:00
. _shape = shape_ ,
2023-02-24 17:33:14 +00:00
. _shards = shards ,
2023-01-02 14:28:25 +00:00
} ;
}
/// Fetches the content of the given buffer into a stack variable of the given type.
pub fn getValue ( self : Buffer , T : type ) ! T {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . _shape . byteSize ( ) = = @sizeOf ( T ) , " Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes " , . { self , self . _shape . byteSize ( ) , @typeName ( T ) , @sizeOf ( T ) } ) ;
2023-01-02 14:28:25 +00:00
var res : T = undefined ;
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( ! self . hasShardedAxis ( ) , " TODO: support sharded Buffer -> Host transfer " , . { } ) ;
2023-05-09 12:44:56 +00:00
const maybe_event = try self . _shards . get ( 0 ) . toHostBuffer ( self . _api , std . mem . asBytes ( & res ) ) ;
if ( maybe_event ) | event | {
try event . await_ ( self . _api ) ;
}
2023-01-02 14:28:25 +00:00
return res ;
}
/// Copies the content of the Buffer back to host, in the given buffer,
/// and return a new `HostBuffer` object with the same shape.
/// The returned `HostBuffer` doesn't own the memory.
pub fn toHost ( self : Buffer , output : [ ] u8 ) ! HostBuffer {
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( ! self . hasShardedAxis ( ) , " TODO: support sharded Buffer -> Host transfer " , . { } ) ;
2023-05-09 12:44:56 +00:00
const maybe_event = try self . _shards . get ( 0 ) . toHostBuffer ( self . _api , output ) ;
if ( maybe_event ) | event | {
try event . await_ ( self . _api ) ;
}
2023-01-02 14:28:25 +00:00
return HostBuffer . fromBytes ( self . shape ( ) , output ) ;
}
/// Copies the content of the Buffer to the host.
/// The returned `HostBuffer` does own the memory.
pub fn toHostAlloc ( self : Buffer , allocator : std . mem . Allocator ) ! HostBuffer {
const output = try HostBuffer . empty ( allocator , self . shape ( ) ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( ! self . hasShardedAxis ( ) , " TODO: support sharded Buffer -> Host transfer " , . { } ) ;
2024-10-28 11:21:46 +00:00
const maybe_event = try self . _shards . get ( 0 ) . toHostBuffer ( self . _api , @constCast ( output . bytes ( ) ) ) ;
2023-05-09 12:44:56 +00:00
if ( maybe_event ) | event | {
try event . await_ ( self . _api ) ;
}
2023-01-02 14:28:25 +00:00
return output ;
}
/// Frees the accelerator memory.
/// Depending on the platform, the memory is typically not released to the OS
/// but just marked as available in the memory pool.
pub fn deinit ( self : * const Buffer ) void {
2023-02-24 17:33:14 +00:00
for ( self . _shards . constSlice ( ) ) | buffer | {
buffer . deinit ( self . _api ) ;
}
2023-01-02 14:28:25 +00:00
}
/// This Buffer shape.
pub fn shape ( self : Buffer ) Shape {
return self . _shape ;
}
/// This Buffer shape as a slice of dims.
pub fn dims ( self : * const Buffer ) [ ] const i64 {
return self . _shape . dims ( ) ;
}
/// This Buffer element type.
pub fn dtype ( self : Buffer ) DataType {
return self . _shape . dtype ( ) ;
}
/// This Buffer rank.
pub fn rank ( self : Buffer ) u4 {
return self . _shape . rank ( ) ;
}
/// Test helper: returns a new Buffer with the given tags.
/// Allows to call `zml.testing.compileAndCall` when the tested
/// functions requires tagged tensors.
pub fn withTags ( self : Buffer , tags_ : anytype ) Buffer {
var res = self ;
res . _shape = self . _shape . withTags ( tags_ ) ;
return res ;
}
pub fn format (
self : Buffer ,
comptime fmt : [ ] const u8 ,
options : std . fmt . FormatOptions ,
writer : anytype ,
) ! void {
_ = fmt ;
_ = options ;
2023-01-27 14:35:11 +00:00
try writer . print ( " Buffer({_}) " , . { self . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
2024-09-10 09:14:28 +00:00
pub fn getMemory ( self : Buffer ) * const pjrt . Memory {
const shard = self . _shards . get ( 0 ) ;
return shard . memory ( self . _api ) ;
}
2023-02-24 17:33:14 +00:00
fn hasShardedAxis ( self : Buffer ) bool {
if ( self . _shards . len = = 1 ) return false ;
return @reduce ( . Or , self . _shape . _sharding_info ) ;
2023-01-02 14:28:25 +00:00
}
2024-12-25 17:14:44 +00:00
pub fn copyToMemory ( self : Buffer , memory : * const pjrt . Memory ) ! Buffer {
var new_shards : Buffer . Shards = . { } ;
for ( self . _shards . slice ( ) ) | shard | {
const new_shard = try shard . copyToMemory ( self . _api , memory ) ;
new_shards . appendAssumeCapacity ( new_shard ) ;
}
return Buffer { . _shape = self . _shape , . _shards = new_shards , . _api = self . _api } ;
}
2025-02-25 10:37:45 +00:00
pub const UnitializedOptions = struct {
memory : ? pjrt . Memory . Kind = null ,
} ;
pub fn uninitialized ( platform : Platform , shape_ : Shape , opts : UnitializedOptions ) ! Buffer {
var res : Buffer = . {
. _api = platform . pjrt_api ,
. _shape = shape_ ,
. _shards = . { } ,
} ;
errdefer for ( res . _shards . slice ( ) ) | shard | {
shard . deinit ( platform . pjrt_api ) ;
} ;
const minor_to_major : [ Shape . MAX_RANK ] i64 = comptime blk : {
var minor_to_major : [ Shape . MAX_RANK ] i64 = undefined ;
for ( 0 . . Shape . MAX_RANK ) | i | {
minor_to_major [ i ] = @intCast ( Shape . MAX_RANK - i - 1 ) ;
}
break : blk minor_to_major ;
} ;
// TODO: support more advanced sharding specs
stdx . debug . assert ( platform . sharding ( ) . num_replicas = = 1 , " ZML doesn't support num_replicas > 1 for now, got: {} " , . { platform . sharding ( ) } ) ;
const sharding_ax : ? u3 = std . simd . firstTrue ( shape_ . _sharding_info ) ;
const n_partitions = platform . sharding ( ) . num_partitions ;
const shard_shape = if ( sharding_ax ) | ax | s : {
// This kind of sharding error should be detected earlier on.
stdx . debug . assert ( @rem ( shape_ . dim ( ax ) , n_partitions ) = = 0 , " Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}). " , . { ax , n_partitions } ) ;
const shard_shape = shape_ . set ( ax , @divExact ( shape_ . dim ( ax ) , n_partitions ) ) ;
break : s shard_shape ;
} else shape_ ;
const buffer_type = bufferTypeFromDtype ( shape_ . dtype ( ) ) ;
const devices = platform . getDevices ( ) ;
for ( 0 . . n_partitions ) | i | {
var args = pjrt . Client . CreateUninitializedBufferArgs {
. dims = shard_shape . dims ( ) ,
. element_type = buffer_type ,
. layout = . {
. tiled = . {
. minor_to_major = minor_to_major [ Shape . MAX_RANK - shape_ . rank ( ) . . ] ,
. tile_dims = & . { } ,
. tile_dims_sizes = & . { } ,
} ,
} ,
} ;
if ( opts . memory ) | memory_kind | {
const memories = try devices [ i ] . addressableMemories ( platform . pjrt_api ) ;
const memory = for ( memories ) | m | {
const kind = m . kind ( platform . pjrt_api ) ;
if ( kind = = memory_kind ) break m ;
} else return error . NotFound ;
args . memory = memory ;
} else {
args . device = devices [ i ] ;
}
const pjrt_buffer = try platform . pjrt_client . createUnitializedBuffer ( platform . pjrt_api , args ) ;
res . _shards . appendAssumeCapacity ( pjrt_buffer ) ;
}
return res ;
}
2023-01-02 14:28:25 +00:00
} ;
2023-02-24 17:33:14 +00:00
pub fn bufferTypeFromDtype ( dt : DataType ) pjrt . BufferType {
return switch ( dt ) {
2025-01-16 13:00:47 +00:00
inline else = > | tag | @field ( pjrt . BufferType , @tagName ( tag ) ) ,
2023-02-24 17:33:14 +00:00
} ;
}
pub fn dtypeFromBufferType ( pjrt_type : pjrt . BufferType ) DataType {
return switch ( pjrt_type ) {
2025-01-16 13:00:47 +00:00
. invalid = > @panic ( " Found an invalid pjrt buffer " ) ,
inline else = > | tag | @field ( DataType , @tagName ( tag ) ) ,
2023-02-24 17:33:14 +00:00
} ;
}
test bufferTypeFromDtype {
2024-07-02 14:19:04 +00:00
inline for ( @typeInfo ( DataType ) . @ " enum " . fields ) | field | {
2023-02-24 17:33:14 +00:00
const dt : DataType = @enumFromInt ( field . value ) ;
try std . testing . expectEqual ( dt , dtypeFromBufferType ( bufferTypeFromDtype ( dt ) ) ) ;
}
2024-07-02 14:19:04 +00:00
inline for ( @typeInfo ( pjrt . BufferType ) . @ " enum " . fields ) | field | {
2023-02-24 17:33:14 +00:00
const dt : pjrt . BufferType = @enumFromInt ( field . value ) ;
2025-01-16 13:00:47 +00:00
if ( dt = = . invalid ) continue ;
2023-02-24 17:33:14 +00:00
try std . testing . expectEqual ( dt , bufferTypeFromDtype ( dtypeFromBufferType ( dt ) ) ) ;
}
}