2023-01-02 14:28:25 +00:00
const builtin = @import ( " builtin " ) ;
const std = @import ( " std " ) ;
2023-06-21 14:45:14 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
const meta = @import ( " meta.zig " ) ;
const mlir = @import ( " mlir.zig " ) ;
const ops = @import ( " ops.zig " ) ;
const module = @import ( " module.zig " ) ;
const Location = mlir . Location ;
const CompilationContext = module . CompilationContext ;
const Shape = @import ( " shape.zig " ) . Shape ;
const Buffer = @import ( " buffer.zig " ) . Buffer ;
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
const Data = @import ( " dtype.zig " ) . Data ;
const DataType = @import ( " dtype.zig " ) . DataType ;
const Platform = @import ( " platform.zig " ) . Platform ;
const EnumLiteral = @TypeOf ( . enum_literal ) ;
const dialect = struct {
const stablehlo = @import ( " mlir/dialects " ) . stablehlo ;
} ;
2023-06-21 14:45:14 +00:00
const assert = std . debug . assert ;
const testing = std . testing ;
const scoped_log = std . log . scoped ( . @ " zml/tensor " ) ;
2023-01-27 14:35:11 +00:00
2023-01-23 16:28:19 +00:00
test {
2023-01-27 14:35:11 +00:00
std . testing . refAllDecls ( Tensor ) ;
2023-01-23 16:28:19 +00:00
}
2023-01-02 14:28:25 +00:00
/// Represents an abstract Tensor object, which can be the input,
/// output, weight or activations of a neural network.
/// Tensor are abstract in the sense they only represent a computation,
/// but not a specific memory buffer.
/// Tensor namespace contains most of linear algebra needed to
/// represent mathematical operations.
/// More operations are available in `zml.nn` and `zml.torch` namespaces.
pub const Tensor = struct {
_shape : Shape ,
_id : _Id ,
_donation : _Donation = . no_buffer ,
pub const _Donation = union ( enum ) { no_buffer , input_buffer , arg : u16 } ;
pub const _Id = union ( enum ) { mlir : mlir . Value , buffer_id : u64 , arg_id : u64 } ;
pub const MAX_RANK = Shape . MAX_RANK ;
/// Returns the current compilation context.
pub fn getContext ( self : Tensor ) * CompilationContext {
_ = self ;
return CompilationContext . current ( ) ;
}
pub fn format (
self : Tensor ,
comptime fmt : [ ] const u8 ,
options : std . fmt . FormatOptions ,
writer : anytype ,
) ! void {
_ = options ;
2024-01-01 15:31:41 +00:00
const bare_fmt = fmt . len = = 1 and fmt [ 0 ] = = '_' ;
try writer . print ( if ( bare_fmt ) " {_} " else " Tensor({_}) " , . { self . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns the shape of a Tensor.
pub fn shape ( self : Tensor ) Shape {
return self . _shape ;
}
/// Returns the datatype of a Tensor.
pub fn dtype ( self : Tensor ) DataType {
return self . _shape . dtype ( ) ;
}
/// Returns the rank of a Tensor.
pub inline fn rank ( self : Tensor ) u4 {
return self . _shape . rank ( ) ;
}
/// Returns the number of element of a Tensor.
pub fn count ( self : Tensor ) usize {
return self . _shape . count ( ) ;
}
/// Returns the size in bytes of a Tensor.
pub fn byteSize ( self : Tensor ) usize {
return self . _shape . byteSize ( ) ;
}
/// Internal use
///
/// Creates a tensor from a Shape and an mlir.Value.
pub fn _result ( sh : Shape , val : mlir . Value ) Tensor {
const res : Tensor = . {
. _shape = sh ,
. _id = . { . mlir = val } ,
} ;
if ( builtin . mode = = . Debug ) {
// Check that the MLIR value actually have the same shape.
const other = fromMlirValue ( val ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( sh . eql ( other . _shape ) , " Created a {} from Mlir value but expected {} " , . { other . _shape , res . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
return res ;
}
/// Creates a Tensor from a mlir.Value
///
/// The shape is derived from the type of the mlir.Value.
pub fn fromMlirValue ( val : mlir . Value ) Tensor {
2024-07-15 12:32:24 +00:00
const ranked_tensor = val . getType ( ) . as ( mlir . RankedTensorType ) ;
2023-01-02 14:28:25 +00:00
const n = ranked_tensor . getRank ( ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( n < = MAX_RANK , " Can't represent MLIR tensor of rank {}, max supported rank is {}. " , . { n , MAX_RANK } ) ;
2023-01-02 14:28:25 +00:00
var sh : Shape = . { . _dtype = mlir . ext . Type . toDType ( ranked_tensor . getElementType ( ) ) } ;
for ( 0 . . n ) | i | {
sh . _dims . appendAssumeCapacity ( ranked_tensor . getDimension ( i ) ) ;
}
sh . _tags . resize ( n ) catch unreachable ;
return . { . _shape = sh , . _id = . { . mlir = val } } ;
}
/// Returns the dimension of axis 'axis_'.
///
2023-11-16 15:11:23 +00:00
/// 'axis_' can be an integer or a tag.
2023-01-02 14:28:25 +00:00
pub fn dim ( self : Tensor , axis_ : anytype ) i64 {
return self . _shape . dim ( axis_ ) ;
}
/// Returns the dimensions of a Tensor as a slice.
pub fn dims ( self : * const Tensor ) [ ] const i64 {
return self . _shape . dims ( ) ;
}
/// Returns the index of axis 'axis_'.
///
2023-11-16 15:11:23 +00:00
/// 'axis_' can be an integer or a tag.
2023-01-02 14:28:25 +00:00
pub fn axis ( self : Tensor , axis_ : anytype ) u3 {
return self . _shape . axis ( axis_ ) ;
}
2023-11-16 15:11:23 +00:00
/// Returns the indices of each of the given axes.
///
/// 'axis_' can be an integer or a tag.
pub fn axes ( self : Tensor , axes_ : anytype ) std . BoundedArray ( u3 , Tensor . MAX_RANK ) {
return self . _shape . axes ( axes_ ) ;
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor tagged with the tags in 'tagz'.
pub fn withTags ( self : Tensor , tagz : anytype ) Tensor {
var res = self ;
res . _shape = self . _shape . withTags ( tagz ) ;
return res ;
}
/// Returns a Tensor tagged partially with the tags in 'tagz'.
///
/// If 'tagz' is of length n, the n last dimensions of the Tensor will be tagged.
pub fn withPartialTags ( self : Tensor , tagz : anytype ) Tensor {
var res = self ;
res . _shape = self . _shape . withPartialTags ( tagz ) ;
return res ;
}
2023-02-24 17:33:14 +00:00
pub fn withSharding ( self : Tensor , axes_ : anytype ) Tensor {
2023-03-21 10:50:39 +00:00
return switch ( self . _id ) {
. arg_id , . mlir = > {
const ctx = self . getContext ( ) ;
var res = self ;
res . _shape = self . _shape . withSharding ( axes_ ) ;
const sharding = ctx . getShardingAttr ( res . _shape ) ;
2024-07-16 13:23:07 +00:00
const op = dialect . stablehlo . custom_call (
2023-03-21 10:50:39 +00:00
ctx . mlirCtx ( ) ,
& . { self . value ( ) } ,
2024-07-16 13:23:07 +00:00
. {
. call_target_name = " Sharding " ,
. has_side_effect = false ,
. addional_attributes = & . { . { " mhlo.sharding " , sharding . asAttr ( ) } } ,
. api_version = . original ,
} ,
2023-03-21 10:50:39 +00:00
& . { self . value ( ) . getType ( ) } ,
ctx . mlirCtx ( ) . location ( @src ( ) ) ,
) ;
return _result ( res . _shape , op . result ( 0 ) ) ;
} ,
. buffer_id = > {
var res = self ;
res . _shape = self . _shape . withSharding ( axes_ ) ;
return res ;
} ,
} ;
2023-02-24 17:33:14 +00:00
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor with new tag names.
pub fn rename ( self : Tensor , renames : anytype ) Tensor {
var res = self ;
res . _shape = self . _shape . rename ( renames ) ;
return res ;
}
2023-01-27 14:35:11 +00:00
pub fn renameAxis ( self : Tensor , ax : i8 , name : EnumLiteral ) Tensor {
var res = self ;
res . _shape . _tags . set ( self . axis ( ax ) , @tagName ( name ) . ptr ) ;
return res ;
}
2023-01-02 14:28:25 +00:00
/// Returns the mlir.Value associated with the Tensor.
///
/// This will fail if used outside of a compilation context.
pub fn value ( self : Tensor ) mlir . Value {
return self . getContext ( ) . getValueAndDonation ( self ) [ 0 ] ;
}
2023-09-21 11:15:50 +00:00
2023-01-02 14:28:25 +00:00
/// Tell PJRT compiler that memory should be reuse between the two tensors.
/// The compiler is already aggressively reusing tensors for intermediate results,
/// but this API allows to reuse buffer between input and output arguments
/// of a given function.
/// Note this is visible from the outside. The caller of a function with donations
/// is not allowed to reuse the donated input buffer after the call.
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
pub fn reuseBuffer ( self : Tensor , origin : Tensor ) Tensor {
// Note: check donation docs, this may be too permissive.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . byteSize ( ) = = origin . byteSize ( ) , " Can't reuse buffers between tensors of different size: {} and {} " , . { self , origin } ) ;
2023-01-02 14:28:25 +00:00
// TODO: should we store all donations inside the context ?
var res = self ;
res . _donation = self . getContext ( ) . getValueAndDonation ( origin ) [ 1 ] ;
return res ;
}
var _global_tensor_counter : u64 = 0 ;
/// Internal use
2023-11-16 15:11:23 +00:00
pub fn _reserveIdRange ( len : u32 ) u64 {
2023-01-02 14:28:25 +00:00
return @atomicRmw ( u64 , & _global_tensor_counter , . Add , len , . seq_cst ) ;
}
/// Internal use
pub fn setUniqueId ( self : * Tensor ) void {
2023-11-16 15:11:23 +00:00
self . _id = . { . buffer_id = _reserveIdRange ( 1 ) } ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
pub fn abs ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . abs ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
const dt = switch ( self . dtype ( ) ) {
. c64 = > . f32 ,
. c128 = > . f64 ,
else = > self . dtype ( ) ,
} ;
return _result ( self . _shape . withDtype ( dt ) , op . result ( 0 ) ) ;
}
/// Returns a Tensor whose elements have been bitcast to a target datatype.
///
/// The Tensor shape needs to be compatible with the target datatype.
pub fn bitCast ( self : Tensor , dt : DataType ) Tensor {
const src_bit_size = self . dtype ( ) . bitSizeOf ( ) ;
const tgt_bit_size = dt . bitSizeOf ( ) ;
var res_shape = if ( src_bit_size = = tgt_bit_size )
self . _shape
else if ( src_bit_size > tgt_bit_size ) gt : {
const new_dim = std . math . divExact ( u16 , src_bit_size , tgt_bit_size ) catch std . debug . panic ( " bitcast expects target datatype to be a multiple of source datatype when upcasting, got {} (bitsize of {}) and {} (bitsize of {}) " , . { self . dtype ( ) , src_bit_size , dt , tgt_bit_size } ) ;
var res = self . _shape ;
res = res . append ( . { . bitcast = new_dim } ) ;
break : gt res ;
} else lt : {
// several contiguous elements of self maps to one element of the result
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dim ( - 1 ) * src_bit_size = = tgt_bit_size , " bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4}) " , . { self . dim ( - 1 ) , src_bit_size , self . dim ( - 1 ) * src_bit_size , dt , tgt_bit_size } ) ;
2023-01-02 14:28:25 +00:00
break : lt self . _shape . remove ( - 1 ) ;
} ;
res_shape = res_shape . withDtype ( dt ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " bitCast({s}) " , . { @tagName ( dt ) } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . bitcast_convert (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
2024-07-15 12:32:24 +00:00
mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , res_shape ) . as ( mlir . Type ) ,
2023-01-02 14:28:25 +00:00
loc ,
) ;
return _result ( res_shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
pub fn countLeadingZeros ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . count_leading_zeros ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing booleans indicating if each element of the input Tensor is finite.
pub fn isFinite ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . is_finite ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape . withDtype ( . bool ) , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
pub fn popcnt ( self : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isInteger ( ) , " popcnt expects tensor type to be an integer, got {} " , . { self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . popcnt ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the sign of the input Tensor element-wise.
pub fn sign ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . sign ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
///
/// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details.
pub fn fmod ( self : Tensor , divisor : f32 ) Tensor {
return self . remainder ( Tensor . scalar ( divisor , . f32 ) . broadcast ( self . _shape , & . { } ) ) ;
}
2023-03-08 14:10:11 +00:00
test fmod {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const inputs : [ 2 ] [ 6 ] f32 = . { . { - 3.0 , - 2 , - 1 , 1 , 2 , 3 } , . { 1 , 2 , 3 , 4 , 5 , - 5 } } ;
const expectations : [ 2 ] [ 6 ] f32 = . { . { - 1.0 , - 0.0 , - 1.0 , 1.0 , 0.0 , 1.0 } , . { 1.0000 , 0.5000 , 0.0000 , 1.0000 , 0.5000 , - 0.5000 } } ;
const divisors : [ 2 ] f32 = . { 2 , - 1.5 } ;
inline for ( inputs , expectations , divisors ) | i , e , d | {
const input = try zml . Buffer . fromSlice ( platform , . { 6 } , & i ) ;
const output = try zml . testing . compileAndCall ( platform , Tensor . fmod , . { input , d } ) ;
try zml . testing . expectClose ( zml . HostBuffer . fromSlice ( . { 6 } , & e ) , output , 1e-4 ) ;
}
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
pub fn shiftLeft ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " shiftLeft " , dialect . stablehlo . shift_left ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'.
pub fn shiftRightArithmetic ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " shiftRightArithmetic " , dialect . stablehlo . shift_right_arithmetic ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'.
pub fn shiftRightLogical ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " shiftRightLogical " , dialect . stablehlo . shift_right_logical ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns the Cholesky decomposition of the input Tensor.
///
2023-11-16 15:11:23 +00:00
/// 'lower' controls the form of the output Tensor. The output will be lower-triangular if 'lower' is true
2023-01-02 14:28:25 +00:00
/// and upper-triangular otherwise.
pub fn cholesky ( self : Tensor , lower : bool ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) < = 2 , " cholesky expects tensor rank to be <= 2, got {} " , . { self . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " lower={} " , . { lower } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . cholesky ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , lower , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Solves the system of linear equations formed by the input tensors.
pub fn triangularSolve ( self : Tensor , other : Tensor , opts : dialect . stablehlo . TriangularSolveOpts ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) = = other . dtype ( ) , " triangularSolve expects tensors to be of the same type, got {} and {} " , . { self . dtype ( ) , other . dtype ( ) } ) ;
stdx . debug . assert ( self . rank ( ) < = 2 and self . rank ( ) = = other . rank ( ) , " triangularSolve expects tensors to have the same rank and be <= 2, got {} and {} " , . { self . rank ( ) , other . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " triangularSolve({_}, {}) " , . { self , opts } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . triangular_solve ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , other . value ( ) , loc , opts ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor.
pub fn roundNearestAfz ( self : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isFloat ( ) , " roundNearestAfz expects tensor type to be a float, got {} " , . { self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . round_nearest_afz ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor.
pub fn roundNearestEven ( self : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isFloat ( ) , " roundNearestEven expects tensor type to be a float, got {} " , . { self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . round_nearest_even ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
pub fn complex ( re : Tensor , im : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( re . _shape . eql ( im . _shape ) , " complex expects tensor shapes to match, got {} and {} " , . { re . _shape , im . _shape } ) ;
stdx . debug . assert ( re . dtype ( ) = = . f32 or re . dtype ( ) = = . f64 , " complex expects tensors type to be f32 or f64, got {} " , . { re . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const loc = re . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . complex ( re . getContext ( ) . mlirCtx ( ) , re . value ( ) , im . value ( ) , loc ) ;
const dt : DataType = if ( re . dtype ( ) = = . f32 ) . c64 else . c128 ;
return _result ( re . _shape . withDtype ( dt ) , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise real part of the input Tensor.
///
/// Tensor type can float or complex.
pub fn real ( self : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isComplex ( ) or self . dtype ( ) . isFloat ( ) , " real expects tensor type to be a float or a complex, got {} " , . { self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
if ( self . dtype ( ) . isFloat ( ) ) {
return self ;
}
const dt : DataType = switch ( self . dtype ( ) ) {
. c64 = > . f32 ,
. c128 = > . f64 ,
else = > unreachable ,
} ;
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . real ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape . withDtype ( dt ) , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise imaginary part of the input Tensor.
///
/// Tensor type can float or complex.
pub fn imag ( self : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isFloat ( ) or self . dtype ( ) . isComplex ( ) , " imag expects tensor type to be a float or a complex, got {} " , . { self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
// Real tensors don't have imaginary part.
if ( self . dtype ( ) . isFloat ( ) ) {
return Tensor . constant ( self . _shape , self . dtype ( ) . zero ( ) ) ;
}
const dt : DataType = switch ( self . dtype ( ) ) {
. bf16 , . f16 , . f32 , . f64 = > self . dtype ( ) ,
. c64 = > . f32 ,
. c128 = > . f64 ,
else = > unreachable ,
} ;
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . imag ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape . withDtype ( dt ) , op . result ( 0 ) ) ;
}
/// Returns the Fast Fourier Transform (FFT) of the input Tensor.
pub fn fft ( self : Tensor , opts : dialect . stablehlo . FftOpts ) Tensor {
// TODO: support tagged API.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( 1 < = opts . length . len and opts . length . len < = 3 , " fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {} " , . { opts . length . len } ) ;
stdx . debug . assert ( opts . length . len < = self . rank ( ) , " fft expects 'opts.length' length to be less than tensor rank, got {} and {} " , . { opts . length . len , self . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
const sh = switch ( opts . kind ) {
. FFT , . IFFT = > blk : {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isComplex ( ) , " fft({any}) expects tensor type to be complex, got {} " , . { opts , self . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
break : blk self . _shape ;
} ,
. RFFT = > blk : {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) = = . f32 or self . dtype ( ) = = . f64 , " fft({}) expects tensor type to be f32 or f64, got {} " , . { opts , self . dtype ( ) } ) ;
stdx . debug . assert ( std . mem . eql ( i64 , self . dims ( ) [ self . rank ( ) - opts . length . len . . ] , opts . length ) , " fft({}) expects tensor last dimensions to match given lengths, got {} and {} " , . { opts , self . dims ( ) [ self . rank ( ) - opts . length . len . . ] . len , opts . length . len } ) ;
2023-01-02 14:28:25 +00:00
const dt : DataType = switch ( self . dtype ( ) ) {
. f32 = > . c64 ,
else = > . c128 ,
} ;
const shape_ = self . _shape . setDim ( - 1 , @divExact ( self . dim ( - 1 ) , 2 ) + 1 ) ;
break : blk shape_ . withDtype ( dt ) ;
} ,
. IRFFT = > blk : {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isComplex ( ) , " fft({any}) expects tensor type to be complex, got {} " , . { opts , self . dtype ( ) } ) ;
stdx . debug . assert ( std . mem . eql ( i64 , self . dims ( ) [ self . rank ( ) - opts . length . len . . ] , opts . length ) , " fft({any}) expects tensor last dimensions to match given lengths, got {} and {} " , . { opts , self . dims ( ) [ self . rank ( ) - opts . length . len . . ] . len , opts . length . len } ) ;
2023-01-02 14:28:25 +00:00
const dt : DataType = switch ( self . dtype ( ) ) {
. c64 = > . f32 ,
else = > . f64 ,
} ;
const shape_ = self . _shape . setDim ( - 1 , @divExact ( self . dim ( - 1 ) - 1 , 2 ) ) ;
break : blk shape_ . withDtype ( dt ) ;
} ,
} ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " fft({_},{}) " , . { self , opts } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . fft ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc , opts ) ;
return _result ( sh , op . result ( 0 ) ) ;
}
pub const Rng = struct {
_state : Tensor ,
algorithm : dialect . stablehlo . RngAlgorithm . Type = . DEFAULT ,
pub fn shape ( ) ShapeOf ( Rng ) {
return . {
. _state = Shape . init ( . { 2 } , . u64 ) ,
} ;
}
2023-01-27 14:35:11 +00:00
pub fn init ( platform : Platform , seed : u128 ) ! Bufferized ( Rng ) {
2023-01-02 14:28:25 +00:00
const bits : [ 2 ] u64 = @bitCast ( seed ) ;
return . {
. _state = try Buffer . fromSlice ( platform , Shape . init ( . { 2 } , . u64 ) , & bits ) ,
. algorithm = undefined ,
} ;
}
/// Returns a Tensor of the given shape, filled with uniform random bits, and a new Rng state.
///
/// The given Rng state should not be used anymore (or you'll get the same numbers again).
/// The output is guaranteed to be deterministic function of `self` Rng state,
/// but it is not guaranteed to be deterministic between implementations.
pub fn bitGenerator ( self : Rng , sh : Shape ) struct { Rng , Tensor } {
const ctx = CompilationContext . current ( ) ;
2024-01-01 15:31:41 +00:00
const loc = ctx . location ( @src ( ) , " rand.bitGen({_}) " , . { sh } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . rng_bit_generator (
ctx . mlirCtx ( ) ,
self . algorithm ,
self . _state . value ( ) ,
mlir . ext . mlirType ( ctx . mlirCtx ( ) , self . _state . _shape ) ,
mlir . ext . mlirType ( ctx . mlirCtx ( ) , sh ) ,
loc ,
) ;
return . { self . update ( op . result ( 0 ) ) , _result ( sh , op . result ( 1 ) ) } ;
}
fn update ( self : Rng , new_state : mlir . Value ) Rng {
return . {
. _state = _result ( self . _state . _shape , new_state ) . reuseBuffer ( self . _state ) ,
. algorithm = self . algorithm ,
} ;
}
/// Returns a Tensor of the given shape, filled with uniformly sampled floating point numbers from an interval,
/// and a new Rng state.
///
/// https://en.wikipedia.org/wiki/Continuous_uniform_distribution
pub fn uniform (
self : Rng ,
shape_ : Shape ,
opts : struct { min : f64 = 0 , max : f64 = 1 } ,
) struct { Rng , Tensor } {
const dt = if ( shape_ . dtype ( ) . isFloat ( ) ) shape_ . dtype ( ) else . f32 ;
const mantissa_bit_count = @import ( " dtype.zig " ) . mantissaSize ( dt ) ;
const bit_count : usize = dt . bitSizeOf ( ) ;
const rng_bit_count = if ( mantissa_bit_count < 8 ) 8 else bit_count ;
const uint_dtype : DataType = switch ( bit_count ) {
8 = > . u8 ,
16 = > . u16 ,
32 = > . u32 ,
64 = > . u64 ,
2023-06-21 14:45:14 +00:00
else = > stdx . debug . panic ( " uniform don't support non-byte aligned dtype. Got: {} " , . { shape_ } ) ,
2023-01-02 14:28:25 +00:00
} ;
const rng , const bits = self . bitGenerator ( shape_ . withDtype ( uint_dtype ) ) ;
// Erase bits outside of mantissa.
var float_bits = bits . shiftRightLogical ( scalar ( rng_bit_count - mantissa_bit_count , uint_dtype ) ) ;
// Set exponent bits to represent e^0 (eg 127 for f32).
float_bits = float_bits . logical ( . OR , scalar ( 1 , dt ) . bitCast ( uint_dtype ) ) ;
// float_bits now uniformly represents number in [1, 2[ range.
2023-01-19 12:20:40 +00:00
// Let's convert to floats, and subtract one to go to [0, 1[ range.
2023-01-02 14:28:25 +00:00
var floats = float_bits . bitCast ( dt ) . sub ( scalar ( 1 , dt ) ) ;
floats = floats . mul ( scalar ( opts . max - opts . min , dt ) ) . addConstant ( opts . min ) ;
// Convert back to integer if needed.
return . { rng , floats . convert ( shape_ . dtype ( ) ) } ;
}
test uniform {
const zml = @import ( " zml.zig " ) ;
const Stats = struct {
const Stats = @This ( ) ;
mean : Tensor ,
variance : Tensor ,
min : Tensor ,
max : Tensor ,
pub fn uniformStats (
rand : Rng ,
shape_ : Shape ,
opts : struct { min : f64 , max : f64 } ,
) struct { Rng , Stats } {
const rng , const data = rand . uniform ( shape_ , . { . min = opts . min , . max = opts . max } ) ;
const mean_ = data . mean ( 0 ) ;
const variance = data . sub ( mean_ . broad ( data . shape ( ) ) ) . pow ( Tensor . scalar ( 2 , . f32 ) ) . mean ( 0 ) ;
return . { rng , . {
. mean = mean_ ,
. variance = variance ,
. min = data . min ( 0 ) ,
. max = data . max ( 0 ) ,
} } ;
}
} ;
const platform = zml . testing . env ( ) ;
// Compute stats over a uniform distribution on [-2, 10].
const rand , const stats = try zml . testing . compileAndCall (
platform ,
Stats . uniformStats ,
. {
try Rng . init ( platform , 1234 ) ,
Shape . init ( . { 1024 } , . f32 ) ,
. { . min = - 2 , . max = 10 } ,
} ,
) ;
// Check the Rng state has been modified.
try std . testing . expect ( try rand . _state . getValue ( i128 ) ! = 1234 ) ;
// Check the mean and variance are close to theoritical values.
const mean_ = try stats . mean . getValue ( f32 ) ;
try std . testing . expectApproxEqAbs ( 4 , mean_ , 0.03 ) ;
const variance = try stats . variance . getValue ( f32 ) ;
try std . testing . expectApproxEqAbs ( 12.0 * 12.0 / 12.0 , variance , 0.01 ) ;
// Check that no value is outside of the interval
// and we have samples close to the edges.
const min_ = try stats . min . getValue ( f32 ) ;
try std . testing . expect ( min_ > = - 2 ) ;
try std . testing . expectApproxEqAbs ( - 2 , min_ , 0.05 ) ;
const max_ = try stats . max . getValue ( f32 ) ;
try std . testing . expect ( max_ < 10 ) ;
try std . testing . expectApproxEqAbs ( 10 , max_ , 0.05 ) ;
}
/// Returns a Tensor of the given shape, filled with floating point numbers sampled from a normal distribution.
///
/// Note: this uses stablehlo.rng which is deprecated.
/// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md
pub fn normal ( sh : Shape , opts : struct { mean : f64 = 0 , stddev : f64 = 1 } ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( sh . dtype ( ) . isFloat ( ) , " normal expects tensor type to be a float, got {} " , . { sh . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const ctx = CompilationContext . current ( ) ;
const loc = ctx . location ( @src ( ) , " rand.normal({_}, mean={},stddev={}) " , . { sh , opts . mean , opts . stddev } ) ;
2023-01-02 14:28:25 +00:00
const a = Tensor . constant ( . { } , Data . init ( sh . dtype ( ) , opts . mean ) ) ;
const b = Tensor . constant ( . { } , Data . init ( sh . dtype ( ) , opts . stddev ) ) ;
const res_shape = Tensor . constantTensor ( HostBuffer . fromSlice ( . { sh . rank ( ) } , sh . dims ( ) ) ) ;
2024-01-01 15:31:41 +00:00
const op = dialect . stablehlo . rng ( ctx . mlirCtx ( ) , a . value ( ) , b . value ( ) , res_shape . value ( ) , . NORMAL , loc ) ;
2023-01-02 14:28:25 +00:00
return _result ( sh , op . result ( 0 ) ) ;
}
/// Returns a Tensor of the given shape, filled with floating point numbers sampled from a Gumbel distribution, and a new Rng state.
///
/// Often used in ML because of the reparametrization tricks.
/// Sampling from a gumbel distribution is equivalent to sample
/// from a softmax distribution, but doesn't require to compute the sum of exponentials.
/// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
/// See `sampleTokens` for a practical use case.
/// Note: we only implement the μ=0, β=1 version.
pub fn gumbel ( self : Rng , shape_ : Shape ) struct { Rng , Tensor } {
const rand , const u = self . uniform (
2023-06-16 14:34:18 +00:00
// Always use .f32 to have a big enough mantissa.
shape_ . withDtype ( . f32 ) ,
2023-01-02 14:28:25 +00:00
// We don't want 0 to be sampled otherwise `log` will return -inf.
2023-06-16 14:34:18 +00:00
. { . min = std . math . floatEps ( f32 ) , . max = 1 } ,
2023-01-02 14:28:25 +00:00
) ;
2023-06-16 14:34:18 +00:00
return . { rand , u . log ( ) . scale ( - 1 ) . log ( ) . scale ( - 1 ) . convert ( shape_ . dtype ( ) ) } ;
2023-01-02 14:28:25 +00:00
}
test gumbel {
const zml = @import ( " zml.zig " ) ;
const Stats = struct {
const Stats = @This ( ) ;
mean : Tensor ,
variance : Tensor ,
actual_dist : Tensor ,
pub fn gumbelStats ( rand : Rng , target_dist : Tensor ) struct { Rng , Stats } {
const s = Shape . init ( . { . n = 1024 , . d = 4 } , . f32 ) ;
const rng , const data = rand . gumbel ( s ) ;
const flat = data . flattenAll ( ) ;
const mean_ = flat . mean ( 0 ) ;
const variance = flat . sub ( mean_ . broad ( flat . shape ( ) ) ) . pow ( Tensor . scalar ( 2 , . f32 ) ) . mean ( 0 ) ;
// Test out the gumbel reparametrization trick
var x = target_dist . log ( ) . withTags ( . { . d } ) . broad ( s ) ;
x = x . add ( data ) ;
2024-01-01 15:31:41 +00:00
const samples = x . argMax ( . d ) . indices . squeeze ( . d ) ;
2023-01-02 14:28:25 +00:00
// count 0, 1, 2 and 3 in samples:
// - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58
// - sum in u64
// - split to [4]u16
const powers = blk : {
var powers : [ 4 ] u64 = undefined ;
for ( & powers , 0 . . ) | * p , i | p . * = std . math . pow ( u64 , 2 , i * 16 ) ;
break : blk powers ;
} ;
const values = Tensor . constantTensor ( HostBuffer . fromArray ( & powers ) ) . withTags ( . { . d } ) ;
2023-01-23 16:28:19 +00:00
const counts = values . gatherValues ( . d , samples , . { } ) . sum ( . n ) . bitCast ( . u16 ) ;
2023-01-02 14:28:25 +00:00
const actual_dist = counts . reshape ( target_dist . shape ( ) ) . convert ( target_dist . dtype ( ) ) . divByConst ( s . dim ( . n ) ) ;
return . { rng , . { . mean = mean_ , . variance = variance , . actual_dist = actual_dist } } ;
}
} ;
const platform = zml . testing . env ( ) ;
const tgt_dist = [ _ ] f32 { 2.0 , 1.0 , 4.0 , 3.0 } ;
const rand , const stats = try zml . testing . compileAndCall ( platform , Stats . gumbelStats , . {
try Rng . init ( platform , 1234 ) , try HostBuffer . fromArray ( & tgt_dist ) . toDevice ( platform ) ,
} ) ;
// Check the Rng state has been modified.
try std . testing . expect ( try rand . _state . getValue ( i128 ) ! = 1234 ) ;
// Check the mean and variance are close to theoritical values.
const mean_ = try stats . mean . getValue ( f32 ) ;
try std . testing . expectApproxEqAbs ( 0.5772 , mean_ , 0.02 ) ;
const variance = try stats . variance . getValue ( f32 ) ;
const pi = std . math . pi ;
try std . testing . expectApproxEqAbs ( pi * pi / 6.0 , variance , 0.03 ) ;
// Check the distribution obtained with the gumbel trick matches the target distribution.
const actual_dist = try stats . actual_dist . getValue ( [ 4 ] f32 ) ;
scoped_log . debug ( " tgt_dist: {d}, actual_dist: {d} " , . { tgt_dist , actual_dist } ) ;
for ( tgt_dist , actual_dist ) | tgt , actual | {
// We normalize tgt_dist to make it a well formed distribution.
// We didn't do it before calling gumbel, because the gumbel trick
// doesn't require normalized distributions as input.
try std . testing . expectApproxEqAbs ( tgt / 10.0 , actual , 0.05 ) ;
}
}
} ;
/// Returns a Tensor containing the element-wise conversion to another floating point type.
pub fn reducePrecision ( self : Tensor , exponent_bits : i32 , mantissa_bits : i32 ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) . isFloat ( ) , " reducePrecision expects tensor type to be a float, got {} " , . { self . dtype ( ) } ) ;
stdx . debug . assert ( 1 < = exponent_bits , " reducePrecision expects 'exponent_bits' to be >= 1, got {} " , . { exponent_bits } ) ;
stdx . debug . assert ( 0 < = mantissa_bits , " reducePrecision expects 'mantissa_bits' to be positive, got {} " , . { mantissa_bits } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " reducePrecision(exponent_bits={}, mantissa_bits={}) " , . { exponent_bits , mantissa_bits } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . reduce_precision ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , exponent_bits , mantissa_bits , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
inline fn convolution ( self : Tensor , other : Tensor , opts : dialect . stablehlo . ConvolutionOpts , loc : mlir . Location ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) = = other . rank ( ) , " convolution expects tensor ranks to match, got {} and {} " , . { self . rank ( ) , other . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
const N = self . rank ( ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( opts . window_strides . len = = N - 2 , @src ( ) ) ;
for ( opts . window_strides ) | s | stdx . debug . guard ( 0 < s , @src ( ) ) ;
stdx . debug . guard ( opts . lhs_dilation . len = = N - 2 , @src ( ) ) ;
for ( opts . lhs_dilation ) | d | stdx . debug . guard ( 0 < d , @src ( ) ) ;
stdx . debug . guard ( opts . rhs_dilation . len = = N - 2 , @src ( ) ) ;
for ( opts . rhs_dilation ) | d | stdx . debug . guard ( 0 < d , @src ( ) ) ;
stdx . debug . guard ( opts . window_reversal . len = = N - 2 , @src ( ) ) ;
stdx . debug . guard ( @rem ( self . dim ( opts . input_batch_dimension ) , opts . batch_group_count ) = = 0 , @src ( ) ) ;
stdx . debug . guard ( @rem ( self . dim ( opts . input_feature_dimension ) , opts . feature_group_count ) = = 0 , @src ( ) ) ;
stdx . debug . guard ( opts . input_spatial_dimensions . len = = N - 2 , @src ( ) ) ;
stdx . debug . guard ( opts . input_batch_dimension ! = opts . input_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . input_batch_dimension and opts . input_batch_dimension < N , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . input_feature_dimension and opts . input_feature_dimension < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
for ( opts . input_spatial_dimensions , 0 . . ) | d , i | {
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( d ! = opts . input_batch_dimension , @src ( ) ) ;
stdx . debug . guard ( d ! = opts . input_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = d and d < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
if ( i < opts . input_spatial_dimensions . len - 1 ) continue ;
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( std . mem . indexOfScalar ( i64 , opts . input_spatial_dimensions [ i + 1 . . ] , d ) = = null , @src ( ) ) ;
}
stdx . debug . guard ( other . dim ( opts . kernel_input_feature_dimension ) = = @divTrunc ( self . dim ( opts . input_feature_dimension ) , opts . feature_group_count ) , @src ( ) ) ;
stdx . debug . guard ( @rem ( other . dim ( opts . kernel_output_feature_dimension ) , opts . batch_group_count ) = = 0 , @src ( ) ) ;
stdx . debug . guard ( @rem ( other . dim ( opts . kernel_output_feature_dimension ) , opts . feature_group_count ) = = 0 , @src ( ) ) ;
stdx . debug . guard ( opts . kernel_spatial_dimensions . len = = N - 2 , @src ( ) ) ;
stdx . debug . guard ( opts . kernel_input_feature_dimension ! = opts . kernel_output_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . kernel_input_feature_dimension and opts . kernel_input_feature_dimension < N , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . kernel_output_feature_dimension and opts . kernel_output_feature_dimension < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
for ( opts . kernel_spatial_dimensions , 0 . . ) | d , i | {
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( d ! = opts . kernel_input_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( d ! = opts . kernel_output_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = d and d < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
if ( i < opts . kernel_spatial_dimensions . len - 1 ) continue ;
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( std . mem . indexOfScalar ( i64 , opts . kernel_spatial_dimensions [ i + 1 . . ] , d ) = = null , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
}
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( opts . output_spatial_dimensions . len = = N - 2 , @src ( ) ) ;
stdx . debug . guard ( opts . output_batch_dimension ! = opts . output_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . output_batch_dimension and opts . output_batch_dimension < N , @src ( ) ) ;
stdx . debug . guard ( 0 < = opts . output_feature_dimension and opts . output_feature_dimension < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
for ( opts . output_spatial_dimensions , 0 . . ) | d , i | {
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( d ! = opts . output_batch_dimension , @src ( ) ) ;
stdx . debug . guard ( d ! = opts . output_feature_dimension , @src ( ) ) ;
stdx . debug . guard ( 0 < = d and d < N , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
if ( i < opts . output_spatial_dimensions . len - 1 ) continue ;
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( std . mem . indexOfScalar ( i64 , opts . output_spatial_dimensions [ i + 1 . . ] , d ) = = null , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
}
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( 0 < opts . feature_group_count , @src ( ) ) ;
stdx . debug . guard ( 0 < opts . batch_group_count , @src ( ) ) ;
stdx . debug . guard ( opts . feature_group_count = = 1 or opts . batch_group_count = = 1 , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
var used_opts = opts ;
used_opts . pad_shape = & . { @intCast ( N - 2 ) , 2 } ;
used_opts . precision_config = & . { . DEFAULT , . DEFAULT } ;
var new_shape = self . _shape ;
var res_dim : i64 = undefined ;
for ( 0 . . N ) | i | {
if ( i = = @as ( usize , @intCast ( opts . output_batch_dimension ) ) ) {
res_dim = @divTrunc ( self . dim ( opts . input_batch_dimension ) , opts . batch_group_count ) ;
} else if ( i = = @as ( usize , @intCast ( opts . output_feature_dimension ) ) ) {
res_dim = other . dim ( opts . kernel_output_feature_dimension ) ;
} else {
// calculate spatial dimension value
const spatial_dim : usize = std . mem . indexOfScalar ( i64 , opts . output_spatial_dimensions , @as ( i64 , @intCast ( i ) ) ) . ? ;
const lhs_dim = opts . input_spatial_dimensions [ spatial_dim ] ;
const rhs_dim = opts . kernel_spatial_dimensions [ spatial_dim ] ;
const dilated_input_shape_lhs_dim : i64 = if ( self . dim ( lhs_dim ) = = 0 ) 0 else ( self . dim ( lhs_dim ) - 1 ) * opts . lhs_dilation [ spatial_dim ] + 1 ;
const left_pad_value , const right_pad_value = if ( opts . pad_value . len = = 1 )
. { opts . pad_value [ 0 ] , opts . pad_value [ 0 ] }
else
. { opts . pad_value [ 2 * spatial_dim ] , opts . pad_value [ 2 * spatial_dim + 1 ] } ;
const padded_input_shape_lhs_dim = left_pad_value + dilated_input_shape_lhs_dim + right_pad_value ;
const dilated_window_shape_lhs_dim : i64 = if ( other . dim ( rhs_dim ) = = 0 ) 0 else ( other . dim ( rhs_dim ) - 1 ) * opts . rhs_dilation [ spatial_dim ] + 1 ;
const is_empty_window_lhs_dim = padded_input_shape_lhs_dim = = 0 or dilated_window_shape_lhs_dim > padded_input_shape_lhs_dim ;
res_dim = if ( is_empty_window_lhs_dim ) 0 else @divTrunc ( padded_input_shape_lhs_dim - dilated_window_shape_lhs_dim , opts . window_strides [ spatial_dim ] ) + 1 ;
}
new_shape = new_shape . set ( i , res_dim ) ;
}
// inferred shape '[1, 256, 1, 12008]' is incompatible with return type of operation '[1, 256, 1, 11978]'
const op = dialect . stablehlo . convolution (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
other . value ( ) ,
used_opts ,
2024-07-15 12:32:24 +00:00
mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , new_shape ) . as ( mlir . Type ) ,
2023-01-02 14:28:25 +00:00
loc ,
) ;
return _result ( new_shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the result of the 1D convolution of 'input' by 'kernel'.
pub fn conv1d (
input : Tensor ,
kernel : Tensor ,
opts : struct {
window_strides : i64 = 1 ,
padding : [ ] const i64 = & . { 0 , 0 } ,
lhs_dilation : i64 = 1 ,
rhs_dilation : i64 = 1 ,
window_reversal : bool = false ,
input_batch_dimension : i64 = 0 ,
input_feature_dimension : i64 = 1 ,
input_spatial_dimensions : i64 = 2 ,
kernel_output_feature_dimension : i64 = 0 ,
kernel_input_feature_dimension : i64 = 1 ,
kernel_spatial_dimensions : i64 = 2 ,
output_batch_dimension : i64 = 0 ,
output_feature_dimension : i64 = 1 ,
output_spatial_dimensions : i64 = 2 ,
feature_group_count : i64 = 1 ,
batch_group_count : i64 = 1 ,
} ,
) Tensor {
2024-01-01 15:31:41 +00:00
const loc = input . getContext ( ) . location ( @src ( ) , " opts={} " , . { opts } ) ;
2023-01-02 14:28:25 +00:00
return input . convolution ( kernel , . {
. window_strides = & . { opts . window_strides } ,
. pad_value = opts . padding ,
. lhs_dilation = & . { opts . lhs_dilation } ,
. rhs_dilation = & . { opts . rhs_dilation } ,
. window_reversal = & . { opts . window_reversal } ,
. input_batch_dimension = opts . input_batch_dimension ,
. input_feature_dimension = opts . input_feature_dimension ,
. input_spatial_dimensions = & . { opts . input_spatial_dimensions } ,
. kernel_input_feature_dimension = opts . kernel_input_feature_dimension ,
. kernel_output_feature_dimension = opts . kernel_output_feature_dimension ,
. kernel_spatial_dimensions = & . { opts . kernel_spatial_dimensions } ,
. output_batch_dimension = opts . output_batch_dimension ,
. output_feature_dimension = opts . output_feature_dimension ,
. output_spatial_dimensions = & . { opts . output_spatial_dimensions } ,
. feature_group_count = opts . feature_group_count ,
. batch_group_count = opts . batch_group_count ,
} , loc ) ;
}
/// Returns a Tensor containing the result of the 2D convolution of 'input' by 'kernel'.
/// Defaults values correspond to a (B, C_in, W, H) image, (C_out, C_in, W, H) kernel weights and (B, C_out, W, H) output.
pub fn conv2d (
input : Tensor ,
kernel : Tensor ,
opts : struct {
window_strides : [ ] const i64 = & . { 1 , 1 } ,
padding : [ ] const i64 = & . { 0 , 0 , 0 , 0 } ,
lhs_dilation : [ ] const i64 = & . { 1 , 1 } ,
rhs_dilation : [ ] const i64 = & . { 1 , 1 } ,
window_reversal : [ ] const bool = & . { false , false } ,
input_batch_dimension : i64 = 0 ,
input_feature_dimension : i64 = 1 ,
input_spatial_dimensions : [ ] const i64 = & . { 2 , 3 } ,
kernel_input_feature_dimension : i64 = 1 ,
kernel_output_feature_dimension : i64 = 0 ,
kernel_spatial_dimensions : [ ] const i64 = & . { 2 , 3 } ,
output_batch_dimension : i64 = 0 ,
output_feature_dimension : i64 = 1 ,
output_spatial_dimensions : [ ] const i64 = & . { 2 , 3 } ,
feature_group_count : i64 = 1 ,
batch_group_count : i64 = 1 ,
} ,
) Tensor {
2024-01-01 15:31:41 +00:00
const loc = input . getContext ( ) . location ( @src ( ) , " opts={} " , . { opts } ) ;
2023-01-02 14:28:25 +00:00
return input . convolution ( kernel , . {
. window_strides = opts . window_strides ,
. pad_value = opts . padding ,
. lhs_dilation = opts . lhs_dilation ,
. rhs_dilation = opts . rhs_dilation ,
. window_reversal = opts . window_reversal ,
. input_batch_dimension = opts . input_batch_dimension ,
. input_feature_dimension = opts . input_feature_dimension ,
. input_spatial_dimensions = opts . input_spatial_dimensions ,
. kernel_input_feature_dimension = opts . kernel_input_feature_dimension ,
. kernel_output_feature_dimension = opts . kernel_output_feature_dimension ,
. kernel_spatial_dimensions = opts . kernel_spatial_dimensions ,
. output_batch_dimension = opts . output_batch_dimension ,
. output_feature_dimension = opts . output_feature_dimension ,
. output_spatial_dimensions = opts . output_spatial_dimensions ,
. feature_group_count = opts . feature_group_count ,
. batch_group_count = opts . batch_group_count ,
} , loc ) ;
}
/// Returns a Tensor containing the element-wise addition of the input Tensors.
pub fn add ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " add " , dialect . stablehlo . add ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise subtraction of the input Tensors.
pub fn sub ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " subtract " , dialect . stablehlo . subtract ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise multiplication of the input Tensors.
pub fn mul ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " mul " , dialect . stablehlo . multiply ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise division of the input Tensors.
pub fn div ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " div " , dialect . stablehlo . divide ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise exponentiation of the input Tensors.
pub fn pow ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " pow " , dialect . stablehlo . power ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise maximum operation of the input Tensors.
pub fn maximum ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " maximum " , dialect . stablehlo . maximum ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise minimum operation of the input Tensors.
pub fn minimum ( self : Tensor , other : Tensor ) Tensor {
2024-01-01 15:31:41 +00:00
return binaryOp ( @src ( ) , " minimum " , dialect . stablehlo . minimum ) ( self , other ) ;
}
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
pub fn remainder ( self : Tensor , other : Tensor ) Tensor {
return binaryOp ( @src ( ) , " remainder " , dialect . stablehlo . remainder ) ( self , other ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
pub fn addConstant ( self : Tensor , b : anytype ) Tensor {
return self . add ( Tensor . scalar ( b , self . dtype ( ) ) ) ;
}
/// Returns a Tensor containing the element-wise division of the input Tensor by a constant.
pub fn divByConst ( self : Tensor , b : anytype ) Tensor {
return self . div ( Tensor . scalar ( b , self . dtype ( ) ) ) ;
}
/// Returns a Tensor containing the element-wise multiplication of the input Tensor by a constant.
pub inline fn scale ( self : Tensor , val : anytype ) Tensor {
return self . mul ( Tensor . scalar ( val , self . dtype ( ) ) ) ;
}
pub const LogicalOp = enum { OR , XOR , AND } ;
/// Returns a Tensor containing the element-wise logical operation of the input Tensors.
pub fn logical ( self : Tensor , comptime logical_op : LogicalOp , other : Tensor ) Tensor {
return switch ( logical_op ) {
2024-01-01 15:31:41 +00:00
. OR = > binaryOp ( @src ( ) , " or " , dialect . stablehlo . or_ ) ( self , other ) ,
. XOR = > binaryOp ( @src ( ) , " xor " , dialect . stablehlo . xor ) ( self , other ) ,
. AND = > binaryOp ( @src ( ) , " and " , dialect . stablehlo . and_ ) ( self , other ) ,
2023-01-02 14:28:25 +00:00
} ;
}
/// Returns a Tensor containing the element-wise floor operation of the input Tensor.
pub fn floor ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
return _result ( self . _shape , dialect . stablehlo . floor ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise ceil operation of the input Tensor.
pub fn ceil ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
return _result ( self . _shape , dialect . stablehlo . ceil ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise conversion to another type.
2024-01-01 15:31:41 +00:00
pub fn convert ( self : Tensor , to : DataType ) Tensor {
if ( to = = self . dtype ( ) ) {
2023-01-02 14:28:25 +00:00
return self ;
}
2024-07-15 12:32:24 +00:00
const res_type = mlir . RankedTensorType . init ( self . dims ( ) , mlir . ext . Type . fromDType ( self . getContext ( ) . mlirCtx ( ) , to ) ) . as ( mlir . Type ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " convert({_},to={s}) " , . { self , @tagName ( to ) } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . convert ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , res_type , loc ) ;
2024-01-01 15:31:41 +00:00
return _result ( self . _shape . withDtype ( to ) , op . result ( 0 ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
pub fn round ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const sine_op = dialect . stablehlo . round_nearest_even ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , sine_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise clamping operation of the input Tensor.
pub fn clamp ( self : Tensor , min_ : Tensor , max_ : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . clamp ( self . getContext ( ) . mlirCtx ( ) , min_ . value ( ) , self . value ( ) , max_ . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// See torch.matmul
pub fn matmul ( lhs : Tensor , rhs : Tensor ) Tensor {
return @import ( " torch.zig " ) . matmul ( lhs , rhs ) ;
}
/// Matrix multiplication, where contracting axes are specified using their tags.
/// eg dot(.{ .a, .b, .c }, .{ .a, .c, .d }, .{ .c }) -> .{ .a, .c, .d }
/// Axes with the same tag on both sides, and which aren't contracting,
/// are considered "batching axes".
pub fn dot ( lhs : Tensor , rhs : Tensor , comptime contracting : anytype ) Tensor {
var contracting_axes : [ contracting . len ] [ 2 ] i8 = undefined ;
inline for ( contracting , 0 . . ) | c , i | {
contracting_axes [ i ] = . { lhs . axis ( c ) , rhs . axis ( c ) } ;
}
var batching_axes : [ MAX_RANK ] [ 2 ] i8 = undefined ;
var n_batching : u8 = 0 ;
for ( lhs . _shape . tags ( ) , 0 . . ) | l , li | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( l ! = Shape . TagUnknown , " Can't use `dot(..., {any})` on {any}, it need to be explictily tagged. " , . { contracting , lhs } ) ;
2023-01-02 14:28:25 +00:00
for ( rhs . _shape . tags ( ) , 0 . . ) | r , ri | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( r ! = Shape . TagUnknown , " Can't use `dot(..., {any})` on {any}, it need to be explictily tagged. " , . { contracting , rhs } ) ;
2023-01-02 14:28:25 +00:00
if ( l = = r ) {
for ( contracting_axes ) | ct | {
if ( l = = lhs . _shape . tag ( ct [ 0 ] ) ) {
break ;
}
} else {
// tag is both in lhs and rhs but not in contracting -> it's a batching dim.
batching_axes [ n_batching ] = . { @intCast ( li ) , @intCast ( ri ) } ;
n_batching + = 1 ;
}
}
}
}
return dotGeneral ( lhs , rhs , contracting_axes [ 0 . . ] , batching_axes [ 0 . . n_batching ] ) ;
}
test dot {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-01-02 14:28:25 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . c = 20 } , . { . c = 20 } , . { . c } , . { } } ,
. {
. { . a = 20 , . b = 21 , . c = 22 } ,
. { . a = 20 , . d = 23 , . c = 22 } ,
. { . c } ,
. { . a = 20 , . b = 21 , . d = 23 } ,
} ,
. {
. { . a = 20 , . b = 21 , . c = 22 } ,
. { . c = 22 , . d = 23 , . e = 24 } ,
. { . c } ,
. { . a = 20 , . b = 21 , . d = 23 , . e = 24 } ,
} ,
. {
. { . a = 20 , . b = 21 , . c = 22 } ,
. { . c = 22 , . d = 23 , . a = 20 } ,
. { . c , . a } ,
. { . b = 21 , . d = 23 } ,
} ,
} ) | testcase | {
const x_shape , const y_shape , const ctr , const z_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f32 = 0.0 } ) ;
const y = Tensor . constant ( y_shape , . { . f32 = 0.0 } ) ;
const z = x . dot ( y , ctr ) ;
try zml . testing . expectEqualShapes ( Shape . init ( z_shape , . f32 ) , z . shape ( ) ) ;
}
}
/// Generalized matrix multiplication of two tensors along the specified axes.
/// In this version batching dimensions need to be explicitly specified.
/// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes.
/// Where "result axes" are non-contracting, non-batching axes of each input tensor.
2023-01-27 14:35:11 +00:00
pub fn dotGeneral (
lhs : Tensor ,
rhs : Tensor ,
contracting_axes : [ ] const [ 2 ] i8 ,
batching_axes : [ ] const [ 2 ] i8 ,
) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( lhs . dtype ( ) = = rhs . dtype ( ) , " dotGeneral expects tensors to be of the same type, got {} and {} " , . { lhs . dtype ( ) , rhs . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const Axes = std . BoundedArray ( i64 , MAX_RANK ) ;
var res_shape : Shape = . { . _dtype = lhs . dtype ( ) } ;
// Validate batching axes
var lhs_batching_axes : Axes = . { } ;
var rhs_batching_axes : Axes = . { } ;
for ( batching_axes ) | b_axes | {
const l , const r = b_axes ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( lhs . _shape . dim ( l ) = = rhs . _shape . dim ( r ) , " dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {} " , . { l , r , lhs , rhs } ) ;
2023-01-02 14:28:25 +00:00
var t = lhs . _shape . tag ( l ) ;
if ( t = = Shape . TagUnknown ) t = rhs . _shape . tag ( r ) ;
res_shape = res_shape . appendDim ( lhs . _shape . dim ( l ) , t ) ;
lhs_batching_axes . appendAssumeCapacity ( lhs . _shape . axis ( l ) ) ;
rhs_batching_axes . appendAssumeCapacity ( rhs . _shape . axis ( r ) ) ;
}
// Validate contracting axes
var lhs_contracting_axes : Axes = . { } ;
var rhs_contracting_axes : Axes = . { } ;
for ( contracting_axes ) | c_axes | {
const l , const r = c_axes ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( lhs . _shape . dim ( l ) = = rhs . _shape . dim ( r ) , " dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {} " , . { l , r , lhs , rhs } ) ;
2023-01-02 14:28:25 +00:00
lhs_contracting_axes . appendAssumeCapacity ( lhs . _shape . axis ( l ) ) ;
rhs_contracting_axes . appendAssumeCapacity ( rhs . _shape . axis ( r ) ) ;
}
// Result shape is obtained by concatenating batching dimensions, (already done)
// then dimensions from lhs axes that aren't contracting nor batching,
// then dimensions from rhs axes that aren't contracting nor batching.
for ( 0 . . lhs . rank ( ) ) | l | {
if ( std . mem . indexOfScalar ( i64 , lhs_contracting_axes . constSlice ( ) , @intCast ( l ) ) ) | _ | {
continue ;
}
if ( std . mem . indexOfScalar ( i64 , lhs_batching_axes . constSlice ( ) , @intCast ( l ) ) ) | _ | {
continue ;
}
res_shape = res_shape . appendDim ( lhs . _shape . dim ( l ) , lhs . _shape . tag ( l ) ) ;
}
for ( 0 . . rhs . rank ( ) ) | r | {
if ( std . mem . indexOfScalar ( i64 , rhs_contracting_axes . constSlice ( ) , @intCast ( r ) ) ) | _ | {
continue ;
}
if ( std . mem . indexOfScalar ( i64 , rhs_batching_axes . constSlice ( ) , @intCast ( r ) ) ) | _ | {
continue ;
}
res_shape = res_shape . appendDim ( rhs . _shape . dim ( r ) , rhs . _shape . tag ( r ) ) ;
}
2023-06-07 11:20:25 +00:00
const mlir_ctx = lhs . getContext ( ) . mlirCtx ( ) ;
2024-01-01 15:31:41 +00:00
const loc = lhs . getContext ( ) . location ( @src ( ) , " dot({_},{_},contracting={any},batching={any} " , . { lhs , rhs , contracting_axes , batching_axes } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . dot_general (
2023-06-07 11:20:25 +00:00
mlir_ctx ,
2023-01-02 14:28:25 +00:00
lhs . value ( ) ,
rhs . value ( ) ,
2023-06-07 11:20:25 +00:00
mlir . ext . mlirType ( mlir_ctx , res_shape ) ,
2023-01-02 14:28:25 +00:00
loc ,
. {
. lhs_batching_dimensions = lhs_batching_axes . constSlice ( ) ,
. rhs_batching_dimensions = rhs_batching_axes . constSlice ( ) ,
. lhs_contracting_dimensions = lhs_contracting_axes . constSlice ( ) ,
. rhs_contracting_dimensions = rhs_contracting_axes . constSlice ( ) ,
2023-06-07 11:20:25 +00:00
. precision = . fast ,
2023-01-02 14:28:25 +00:00
} ,
) ;
return _result ( res_shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
pub fn sigmoid ( self : Tensor ) Tensor {
2023-05-01 10:40:50 +00:00
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . logistic ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
2023-01-02 14:28:25 +00:00
}
2023-05-01 10:40:50 +00:00
pub const logistic = sigmoid ;
2023-01-02 14:28:25 +00:00
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
pub fn relu ( self : Tensor ) Tensor {
return self . maximum ( Tensor . constant ( self . dims ( ) , self . dtype ( ) . zero ( ) ) ) ;
}
/// Returns a Tensor containing the leaky-ReLU activation function applied to each element of the input Tensor.
///
/// LeakyReLU(x) = max(0,x) + negative_slope * min(0,x)
/// ref: https://paperswithcode.com/method/leaky-relu
pub fn leakyReLU ( self : Tensor , negative_slope : f32 ) Tensor {
const below_zero = self . scale ( negative_slope ) . minimum ( Tensor . scalar ( 0 , self . dtype ( ) ) ) ;
return self . relu ( ) . add ( below_zero ) ;
}
test leakyReLU {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const input = try zml . Buffer . fromSlice ( platform , . { 2 } , & [ _ ] f32 { - 0.6884 , 1.6795 } ) ;
const res = try zml . testing . compileAndCall ( platform , leakyReLU , . { input , 0.1 } ) ;
const expectation = zml . HostBuffer . fromArray ( & [ 2 ] f32 { - 0.0688 , 1.6795 } ) ;
try zml . testing . expectClose ( expectation , res , 1e-4 ) ;
}
/// Returns a Tensor containing the SwiGLU activation function applied to the input Tensor.
pub fn swiglu ( self : Tensor , beta : f32 , w : Tensor , b : Tensor ) Tensor {
const sigmoid_tensor = self . mul ( Tensor . constant ( self . _shape , Data . init ( self . dtype ( ) , beta ) ) ) . sigmoid ( ) ;
const one_minus_sigmoid_tensor = Tensor . constant ( self . _shape , Data . init ( self . dtype ( ) , 1 ) ) . sub ( sigmoid_tensor ) ;
return self . mul ( sigmoid_tensor ) . add ( one_minus_sigmoid_tensor . mul ( w . matmul ( self ) . add ( b ) ) ) ;
}
/// Returns a Tensor containing the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
///
/// We use an approximation of the erf function using tanh:
/// gelu(x) ≃ 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
/// see: https://paperswithcode.com/method/gelu
pub fn gelu ( x : Tensor ) Tensor {
const scaled_x_cube = x . mul ( x ) . mul ( x ) . scale ( 0.044715 ) ;
const one = Tensor . constant ( x . _shape , x . dtype ( ) . one ( ) ) ;
const one_plus_tanh = Tensor . add ( x , scaled_x_cube ) . scale ( std . math . sqrt ( 2.0 / std . math . pi ) ) . tanh ( ) . add ( one ) ;
return one_plus_tanh . mul ( x ) . scale ( 0.5 ) ;
}
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
///
/// It's an even more crude approximation than gelu.
pub fn quickGelu ( x : Tensor ) Tensor {
return x . scale ( 1.702 ) . sigmoid ( ) . mul ( x ) ;
}
/// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor.
///
/// silu(x) = x σ (x)
/// https://paperswithcode.com/method/silu
pub fn silu ( x : Tensor ) Tensor {
return x . mul ( x . sigmoid ( ) ) ;
}
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
pub fn softmax ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
2024-06-14 15:27:06 +00:00
const max_val = self . max ( a ) ;
const row_mask = max_val . cmp ( . GT , Tensor . scalar ( - std . math . inf ( f64 ) , self . dtype ( ) ) ) ;
2023-01-02 14:28:25 +00:00
const exp_diff_max = self . sub ( self . max ( a ) . broad ( self . _shape ) ) . exp ( ) ;
2024-06-14 15:27:06 +00:00
const res = exp_diff_max . div ( exp_diff_max . sum ( a ) . broad ( self . _shape ) ) ;
// If a row is full -inf return full 0 instead of full nan,
// this fix attention when mask hides a full row.
return row_mask . broad ( self . shape ( ) ) . select ( res , Tensor . scalar ( 0 , self . dtype ( ) ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
pub fn logSumExp ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
// stabilization: shift `self` by it's max value before passing to exponent.
const M = self . max ( a ) ;
const log_sum_exp = log ( sum ( exp ( self . sub ( M . broad ( self . _shape ) ) ) , a ) ) ;
// restore the shift again
return M . add ( log_sum_exp ) ;
}
/// Returns a Tensor containing the sum of elements over the given axis.
2023-05-17 09:01:27 +00:00
/// Output shape is the input shape with the axis_ dim set to 1.
2023-01-02 14:28:25 +00:00
pub fn sum ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
return ops . reduce (
struct {
pub fn acc ( x : Tensor , res : Tensor ) Tensor {
return res . add ( x . convert ( res . dtype ( ) ) ) ;
}
} . acc ,
self ,
Tensor . scalar ( 0 , self . dtype ( ) ) ,
& . { a } ,
) ;
}
/// Returns a Tensor containing the mean of elements over the given axis.
2023-05-17 09:01:27 +00:00
/// Output shape is the input shape with the axis_ dim set to 1.
2023-01-02 14:28:25 +00:00
pub fn mean ( self : Tensor , axis_ : anytype ) Tensor {
return self . sum ( axis_ ) . divByConst ( self . dim ( axis_ ) ) ;
}
2023-05-17 09:01:27 +00:00
/// Returns a Tensor containing the cumulative sum of elements over the given axis.
/// Output shape is the same as input shape.
/// [0, 1, 0, 1, 0, 0, 1, 1].cumulativeSum(0) -> [0, 1, 1, 2, 2, 2, 3, 4]
/// The last value contains the sum of all element in the array.
pub fn cumulativeSum ( self : Tensor , axis_ : anytype ) Tensor {
const rk = self . rank ( ) ;
const a = self . axis ( axis_ ) ;
const ones = [ _ ] i64 { 1 } * * MAX_RANK ;
var window_dimensions = ones ;
window_dimensions [ a ] = self . dim ( a ) ;
var padding = [ _ ] [ 2 ] i64 { . { 0 , 0 } } * * MAX_RANK ;
padding [ a ] = . { self . dim ( a ) - 1 , 0 } ;
2023-07-21 09:01:01 +00:00
return ops . reduceWindow (
2023-05-17 09:01:27 +00:00
Tensor . add ,
self ,
Tensor . scalar ( 0 , self . dtype ( ) ) ,
. {
. base_dilations = ones [ 0 . . rk ] ,
. window_dilations = ones [ 0 . . rk ] ,
. window_strides = ones [ 0 . . rk ] ,
. window_dimensions = window_dimensions [ 0 . . rk ] ,
. padding = padding [ 0 . . rk ] ,
} ,
) ;
}
test cumulativeSum {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
pub fn _cumsum ( input : Tensor ) Tensor {
2023-07-21 09:01:01 +00:00
const x = input . withPartialTags ( . { . n } ) ;
const y = x . cumulativeSum ( . n ) ;
// Check that tags are propagated
std . debug . assert ( y . shape ( ) . eqlWithTags ( x . shape ( ) ) ) ;
return y ;
2023-05-17 09:01:27 +00:00
}
} ;
const x = try zml . Buffer . fromArray (
platform ,
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 1 , 0 , 1 } , . { 3 , 1 , 0 , 2 , 1 } } ,
) ;
const res = try zml . testing . compileAndCall ( platform , Local . _cumsum , . { x } ) ;
try testing . expectEqual (
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 2 , 3 } , . { 3 , 4 , 4 , 6 , 7 } } ,
try res . getValue ( [ 2 ] [ 5 ] f32 ) ,
) ;
}
2023-01-02 14:28:25 +00:00
/// Returns a transposed Tensor computed using the given axes.
pub fn transpose ( self : Tensor , axes_ : anytype ) Tensor {
const axes__ = self . axes ( axes_ ) . constSlice ( ) ;
const default_perm = [ MAX_RANK ] i64 { 7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 } ;
const no_op = [ MAX_RANK ] i64 { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 } ;
const permutation : [ ] const i64 = if ( axes__ . len = = 0 )
default_perm [ MAX_RANK - self . rank ( ) . . ]
else
toI64 ( axes__ ) ;
2023-12-25 13:01:17 +00:00
stdx . debug . assert ( permutation . len = = self . rank ( ) , " transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d} " , . { self , permutation [ 0 . . @min ( permutation . len , MAX_RANK + 2 ) ] } ) ;
2023-01-02 14:28:25 +00:00
if ( std . mem . eql ( i64 , permutation , no_op [ 0 . . self . rank ( ) ] ) ) {
return self ;
}
const res_shape = self . _shape . transpose ( permutation ) ;
2023-12-18 13:56:45 +00:00
if ( transposeIsJustAReshape ( self . shape ( ) , permutation ) ) {
return self . reshape ( res_shape ) ;
}
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " transpose({_}, {d}) " , . { self , permutation } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . transpose (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
mlir . ext . mlirType ( self . getContext ( ) . mlirCtx ( ) , res_shape ) ,
loc ,
. { . permutation = toI64 ( permutation ) } ,
) ;
return _result ( res_shape , op . result ( 0 ) ) ;
}
2023-01-27 14:35:11 +00:00
pub fn swapAxes ( self : Tensor , a : anytype , b : anytype ) Tensor {
if ( self . axis ( a ) = = self . axis ( b ) ) return self ;
var perm : Shape . AxesArray = . { } ;
for ( 0 . . self . rank ( ) ) | i | {
perm . appendAssumeCapacity ( @intCast ( i ) ) ;
}
perm . set ( self . axis ( a ) , self . axis ( b ) ) ;
perm . set ( self . axis ( b ) , self . axis ( a ) ) ;
return self . transpose ( perm . constSlice ( ) ) ;
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor with the given axis unflattened.
///
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
pub fn unflatten ( self : Tensor , axis_ : i8 , n : i64 ) Tensor {
2023-11-16 15:11:23 +00:00
// TODO: move to torch.zig, this equivalent to `spitAxis`
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) < Tensor . MAX_RANK , " unflatten expects input tensor rank to be less than {}, got {} " , . { Tensor . MAX_RANK , self . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
const a = if ( axis_ > = 0 ) self . axis ( axis_ ) else self . axis ( axis_ ) + 1 ;
const new_dim = std . math . divExact ( i64 , self . dim ( a ) , n ) catch std . debug . panic ( " unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {} " , . { self . dim ( a ) , n } ) ;
const new_shape = self . _shape . set ( a , n ) . insert ( a + 1 , . { . _ = new_dim } ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " axis={}, n={} " , . { axis_ , n } ) ;
2023-01-02 14:28:25 +00:00
const reshaped_val = dialect . stablehlo . reshape (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , new_shape ) ,
loc ,
) ;
return _result ( new_shape , reshaped_val . result ( 0 ) ) ;
}
/// Splits the given axis in several axes.
/// eg: `Tensor.init(.{ .a = 10, .b = 3 }).split(.a, .{.a1 = 5, .a2 = 2});`
/// The number of elements in the split shape must match the number of element
/// in the target axis.
pub fn splitAxis ( self : Tensor , ax : anytype , split_shape : anytype ) Tensor {
const new_shape = self . _shape . splitAxis ( ax , split_shape ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " splitAxis({}, {any}) " , . { ax , split_shape } ) ;
2023-01-02 14:28:25 +00:00
const reshaped_val = dialect . stablehlo . reshape (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , new_shape ) ,
loc ,
) ;
return _result ( new_shape , reshaped_val . result ( 0 ) ) ;
}
/// Merges two or more contiguous axes into one axis.
pub fn merge ( self : Tensor , merges_ : anytype ) Tensor {
return self . reshape ( self . _shape . mergeAxes ( merges_ ) ) ;
}
/// Merges two or more non-contiguous axes into one axis.
/// Will make a transpose if needed.
/// .{ .a, .b, .c }.mergeTranspose(.{ .a, .c }, .ac) -> .{ .b, .ac }
pub fn mergeTranspose ( self : Tensor , axes_ : anytype , merged : EnumLiteral ) Tensor {
const cont = self . contiguous ( axes_ ) ;
return cont . reshape ( cont . _shape . mergeAxis ( axes_ , merged ) ) ;
}
/// Transposes the input Tensor, such has the given axes end up in contiguous position.
/// .{ .a, .b, .c, .d }.contiguous(.{ .c, .a }) -> .{ .b, .d, .c, .a }
pub fn contiguous ( self : Tensor , axes_ : anytype ) Tensor {
const perm = self . _shape . contiguousPerm ( axes_ ) ;
return self . transpose ( perm . constSlice ( ) ) ;
}
/// Flattens the given axis and the next one, into one new axis.
pub fn flatten ( self : Tensor , axis_ : anytype ) Tensor {
2023-11-16 15:11:23 +00:00
// TODO: move to torch.zig, this is equivalent to merge
2023-01-02 14:28:25 +00:00
const old_shape = self . _shape ;
const a = self . axis ( axis_ ) ;
2023-06-21 14:45:14 +00:00
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
2023-01-02 14:28:25 +00:00
const new_shape = old_shape . remove ( a + 1 ) . set ( a , old_shape . dim ( a ) * old_shape . dim ( a + 1 ) ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " flatten({_},{}) " , . { self , axis_ } ) ;
2023-01-02 14:28:25 +00:00
const reshaped_val = dialect . stablehlo . reshape (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , new_shape ) ,
loc ,
) ;
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
return _result ( new_shape , reshaped_val . result ( 0 ) ) ;
}
pub inline fn flattenAll ( self : Tensor ) Tensor {
2023-11-16 15:11:23 +00:00
// TODO: rename to just flatten, once flatten is moved to torch
2023-01-02 14:28:25 +00:00
return self . reshape ( . { self . count ( ) } ) ;
}
pub const Slice = struct {
start : i64 = 0 ,
end : ? i64 = null ,
step : i64 = 1 ,
} ;
/// Slices the input Tensor over the given axis using the given parameters.
pub fn slice1d ( self : Tensor , axis_ : anytype , s : Slice ) Tensor {
var slices = [ _ ] Slice { . { } } * * MAX_RANK ;
slices [ self . axis ( axis_ ) ] = s ;
return self . slice ( slices [ 0 . . self . rank ( ) ] ) ;
}
/// Slices the input Tensor using the given parameters.
pub fn slice ( self : Tensor , slices : [ ] const Slice ) Tensor {
var start_indices : [ MAX_RANK ] i64 = undefined ;
var strides : [ MAX_RANK ] i64 = undefined ;
var limit_indices : [ MAX_RANK ] i64 = undefined ;
var res_shape : Shape = self . _shape ;
for ( slices , 0 . . ) | s , a | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( s . step > 0 , " slice expects 'step' to be positive, got {} at index {} " , . { s . step , a } ) ;
2023-01-02 14:28:25 +00:00
const args : Slice = . {
. start = self . wrapIndex ( a , s . start ) ,
. end = if ( s . end ) | end | self . wrapIndex ( a , end ) else self . dim ( a ) ,
. step = s . step ,
} ;
start_indices [ a ] = args . start ;
limit_indices [ a ] = args . end . ? ;
strides [ a ] = args . step ;
res_shape = res_shape . setDim ( a , std . math . divCeil ( i64 , args . end . ? - args . start , args . step ) catch unreachable ) ;
}
2023-02-07 12:42:34 +00:00
const mlir_ctx = self . getContext ( ) . mlirCtx ( ) ;
const loc = mlir_ctx . location ( @src ( ) ) . namedFmt ( mlir_ctx , " slices={any} " , . { slices } ) ;
2024-07-15 12:32:24 +00:00
const result_type = mlir . ext . RankedTensorType . fromShape ( mlir_ctx , res_shape ) . as ( mlir . Type ) ;
2023-02-07 12:42:34 +00:00
const slice_op = dialect . stablehlo . slice (
mlir_ctx ,
self . value ( ) ,
start_indices [ 0 . . self . rank ( ) ] ,
limit_indices [ 0 . . self . rank ( ) ] ,
strides [ 0 . . self . rank ( ) ] ,
result_type ,
loc ,
) ;
2023-01-02 14:28:25 +00:00
return _result ( res_shape , slice_op . result ( 0 ) ) ;
}
2023-03-08 14:10:11 +00:00
test slice {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const x = try zml . Buffer . fromSlice ( platform , . { 2 , 5 } , & [ _ ] f32 { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 } ) ;
// Wrap slice1d to hide the anytype in the signature.
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _slice1dAxis ( input : Tensor , ax : i8 , slice_ : Tensor . Slice ) Tensor {
2023-03-08 14:10:11 +00:00
return input . slice1d ( ax , slice_ ) ;
}
} ;
{
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCallWithTensors ( platform , Local . _slice1dAxis , . { x . shape ( ) , 0 , . { . end = 1 } } , . { x , 0 , . { . end = 1 } } ) ;
2023-03-08 14:10:11 +00:00
try testing . expectEqual ( [ 5 ] f32 { 0 , 1 , 2 , 3 , 4 } , try res . getValue ( [ 5 ] f32 ) ) ;
}
{
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCallWithTensors ( platform , Local . _slice1dAxis , . { x . shape ( ) , 1 , . { . start = 1 , . step = 2 } } , . { x , 0 , . { . start = 1 , . step = 2 } } ) ;
2023-03-08 14:10:11 +00:00
try testing . expectEqual ( [ 4 ] f32 { 1 , 3 , 6 , 8 } , try res . getValue ( [ 4 ] f32 ) ) ;
}
{
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCallWithTensors ( platform , Local . _slice1dAxis , . { x . shape ( ) , - 1 , . { . start = - 2 } } , . { x , 0 , . { . start = - 2 } } ) ;
2023-03-08 14:10:11 +00:00
try testing . expectEqual ( [ 4 ] f32 { 3 , 4 , 8 , 9 } , try res . getValue ( [ 4 ] f32 ) ) ;
}
}
2023-01-02 14:28:25 +00:00
inline fn wrapIndex ( self : Tensor , axis_ : usize , idx : i64 ) i64 {
return if ( idx < 0 ) self . dim ( axis_ ) + idx else idx ;
}
2023-11-16 15:11:23 +00:00
pub fn choose1d ( self : Tensor , axis_ : anytype , i : i64 ) Tensor {
// TODO: this use case could be handled directly by slice if we added a .single field
2023-01-02 14:28:25 +00:00
return self . slice1d ( axis_ , . { . start = i , . end = i + 1 } ) . squeeze ( axis_ ) ;
}
/// Concatenates the input Tensors along the given axis.
2023-04-21 15:55:07 +00:00
pub fn concatenate ( tensors : [ ] const Tensor , axis_ : anytype ) Tensor {
2024-01-08 17:55:20 +00:00
if ( tensors . len = = 1 ) return tensors [ 0 ] ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( tensors . len < = 32 , " concatenate only supports up to 32 tensors, got {} " , . { tensors . len } ) ;
2023-01-02 14:28:25 +00:00
var buffer : [ 32 ] mlir . Value = undefined ;
std . debug . assert ( tensors . len < = buffer . len ) ;
std . debug . assert ( tensors . len > 0 ) ;
const a = tensors [ 0 ] . axis ( axis_ ) ;
// TODO(Corendos): Check that tensor axes match.
var concatenated_dim : i64 = 0 ;
for ( tensors , 0 . . ) | t , i | {
buffer [ i ] = t . value ( ) ;
concatenated_dim + = t . dim ( a ) ;
}
const res_shape = tensors [ 0 ] . _shape . set ( a , concatenated_dim ) ;
2024-01-01 15:31:41 +00:00
const ctx = tensors [ 0 ] . getContext ( ) ;
const loc = ctx . location ( @src ( ) , " axis={} " , . { axis_ } ) ;
const op = dialect . stablehlo . concatenate ( ctx . mlirCtx ( ) , buffer [ 0 . . tensors . len ] , a , loc ) ;
2023-01-02 14:28:25 +00:00
// log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape });
return _result ( res_shape , op . result ( 0 ) ) ;
}
/// Concatenates the input Tensors along a new axis. The Tensors must have the same shape.
/// For x, y, z of shape .{ .a = 10, .b = 11, .c = 12 }:
/// - Tensor.stack(&.{x, y, z}, .b, .layers) -> .{ .a, .layers, .b, .c }
/// - Tensor.stack(&.{x, y, z}, 1, .layers) -> .{ .a, .layers, .b, .c }
/// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers }
pub fn stack ( tensors : [ ] const Tensor , axis_ : anytype , tag : anytype ) Tensor {
// Note: we could ask the compilation context for some memory instead of stack allocating
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( tensors . len < = 32 , " stack only supports up to 32 tensors, got {} " , . { tensors . len } ) ;
2023-01-02 14:28:25 +00:00
const shape0 = tensors [ 0 ] . _shape ;
const res_shape = shape0 . insertTag ( axis_ , 1 , tag ) ;
for ( tensors [ 1 . . ] ) | tensor | {
2024-01-01 15:31:41 +00:00
stdx . debug . assert ( shape0 . eqlWithTags ( tensor . _shape ) , " stack expects tensor shapes to match, got {} and {} " , . { shape0 , tensor . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
var reshaped : [ 32 ] Tensor = undefined ;
for ( tensors , 0 . . ) | tensor , i | {
reshaped [ i ] = tensor . reshape ( res_shape ) ;
}
// Be careful here: we need to resolve ax before calling concatenate,
// because we added an axis, so all
const ax = if ( @TypeOf ( axis_ ) = = EnumLiteral and axis_ = = . last )
shape0 . rank ( )
else
shape0 . axis ( axis_ ) ;
return Tensor . concatenate ( reshaped [ 0 . . tensors . len ] , ax ) ;
}
/// Repeats a Tensor several times along the given axis.
///
2023-12-25 13:01:17 +00:00
/// * repeat1d(x, axis, 4) = concat(&.{x, x, x, x}, axis);
2023-01-02 14:28:25 +00:00
/// * repeat1d([0, 1, 2, 3], 0, 2) = [0, 1, 2, 3, 0, 1, 2, 3]
pub fn repeat1d ( self : Tensor , axis_ : anytype , n_rep : u63 ) Tensor {
if ( n_rep = = 1 ) {
return self ;
}
const a = self . axis ( axis_ ) ;
const broadshape = self . _shape . insert ( a + 1 , . { n_rep } ) ;
const repeat_dims = Shape . range ( self . rank ( ) + 1 , self . dtype ( ) ) . remove ( a + 1 ) ;
var res = self . broadcast ( broadshape , repeat_dims . dims ( ) ) . flatten ( a ) ;
// Restor the tag that has been lost by flatten.
res . _shape . _tags . set ( a , self . _shape . tag ( a ) ) ;
return res ;
}
/// Repeats a Tensor several times along the given axes.
pub fn repeat ( self : Tensor , n_reps : [ ] const u63 ) Tensor {
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( n_reps . len = = self . rank ( ) , " repeat expects tensor rank and 'n_reps' length to be equal, got {} and {} " , . { self . rank ( ) , n_reps . len } ) ;
2023-01-02 14:28:25 +00:00
var res = self ;
for ( n_reps , 0 . . ) | n_rep , a | {
if ( n_rep = = 1 ) continue ;
res = res . repeat1d ( a , n_rep ) ;
}
return res ;
}
/// Repeats in line each value along the given axis.
///
/// * stutter1d([0, 1, 2, 3], 0, 2) = [0, 0, 1, 1, 2, 2, 3, 3]
pub fn stutter1d ( self : Tensor , axis_ : i64 , n_rep : u63 ) Tensor {
const a = self . axis ( axis_ ) ;
const broadshape = self . _shape . insert ( a + 1 , . { n_rep } ) ;
const stutter_dims = Shape . range ( self . rank ( ) + 1 , self . dtype ( ) ) . remove ( a + 1 ) ;
return self . broadcast ( broadshape , stutter_dims . dims ( ) ) . flatten ( a ) ;
}
/// Repeats in line each value along the given axes.
pub fn stutter ( self : Tensor , n_reps : [ ] const u63 ) Tensor {
2023-11-16 15:11:23 +00:00
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( n_reps . len = = self . rank ( ) , " stutter expects tensor rank and 'n_reps' length to be equal, got {} and {} " , . { self . rank ( ) , n_reps . len } ) ;
2023-01-02 14:28:25 +00:00
var res = self ;
for ( n_reps , 0 . . ) | n_rep , a | {
if ( n_rep = = 1 ) continue ;
res = res . stutter1d ( @intCast ( a ) , n_rep ) ;
}
return res ;
}
/// Returns a Tensor containing the element-wise negation of the input Tensor.
pub fn negate ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const negate_op = dialect . stablehlo . negate ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , negate_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise cosine of the input Tensor.
pub fn cos ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const cosine_op = dialect . stablehlo . cosine ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , cosine_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise sine of the input Tensor.
pub fn sin ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const sine_op = dialect . stablehlo . sine ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , sine_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise exponential operation of the input Tensor.
pub fn exp ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . exponential ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise logarithm operation of the input Tensor.
pub fn log ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . log ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise square-root of the input Tensor.
pub fn sqrt ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const sqrt_op = dialect . stablehlo . sqrt ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , sqrt_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise reverse square-root of the input Tensor.
pub fn rsqrt ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const rsqrt_op = dialect . stablehlo . rsqrt ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , rsqrt_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise hyperbolic tangent of the input Tensor.
pub fn tanh ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const tanh_op = dialect . stablehlo . tanh ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , tanh_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise exponential minus one operation of the input Tensor.
pub fn exponentialMinusOne ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const expm1_op = dialect . stablehlo . exponential_minus_one ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , expm1_op . result ( 0 ) ) ;
}
2023-01-27 14:35:11 +00:00
pub const ArangeArgs = HostBuffer . ArangeArgs ;
2023-01-02 14:28:25 +00:00
/// Returns a Tensor containing evenly spaced values within a given interval.
pub fn arange ( args : ArangeArgs , dt : DataType ) Tensor {
2024-05-15 17:54:52 +00:00
stdx . debug . assert ( args . start < = args . end , " arange expects 'args.start' to be less than 'args.end', got {} and {} " , . { args . start , args . end } ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( args . step > 0 , " arange expects 'args.step' to be positive, got {} " , . { args . step } ) ;
2023-01-02 14:28:25 +00:00
const ctx = CompilationContext . current ( ) ;
2024-01-01 15:31:41 +00:00
const loc = ctx . location ( @src ( ) , " arange({}, dtype={}) " , . { args , dt } ) ;
2023-01-02 14:28:25 +00:00
const n_steps = std . math . divCeil ( i64 , args . end - args . start , args . step ) catch unreachable ;
const sh = Shape . init ( . { n_steps } , dt ) ;
var op = dialect . stablehlo . iota ( ctx . mlirCtx ( ) , 0 , mlir . ext . mlirType ( ctx . mlirCtx ( ) , sh ) , loc ) ;
var res = _result ( sh , op . result ( 0 ) ) ;
if ( args . step ! = 1 ) {
res = res . scale ( args . step ) ;
}
if ( args . start ! = 0 ) {
res = res . addConstant ( args . start ) ;
}
return res ;
}
/// Returns a Tensor containing values in increasing order starting from 0 along the given axis.
2023-06-15 12:45:52 +00:00
///
/// The output dtype will be `.i32`, unless the given axis has a too big dimension, in that case we use `.i64`.
/// In most program this shouldn't matter, because typically this will be used in a comparison,
/// or explicitly converted by the user to do floating point arithmetic.
pub fn iota ( sh : Shape , axis_ : anytype ) Tensor {
2023-01-02 14:28:25 +00:00
const a = sh . axis ( axis_ ) ;
2023-06-15 12:45:52 +00:00
const dt : DataType = if ( sh . dim ( a ) < = std . math . maxInt ( i32 ) ) . i32 else . i64 ;
2023-01-02 14:28:25 +00:00
const res_shape = sh . withDtype ( dt ) ;
2024-01-01 15:31:41 +00:00
const ctx = CompilationContext . current ( ) ;
const loc = ctx . location ( @src ( ) , " iota({_}, {}) " , . { res_shape , a } ) ;
2023-06-15 12:45:52 +00:00
2024-01-01 15:31:41 +00:00
const mlir_ctx = ctx . mlirCtx ( ) ;
2024-07-15 12:32:24 +00:00
var op = dialect . stablehlo . iota (
mlir_ctx ,
a ,
mlir . ext . RankedTensorType . fromShape ( mlir_ctx , res_shape ) . as ( mlir . Type ) ,
loc ,
) ;
2023-01-02 14:28:25 +00:00
return _result ( res_shape , op . result ( 0 ) ) ;
}
pub const LinspaceArgs = struct {
start : f64 ,
end : f64 ,
steps : i64 ,
} ;
/// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive.
pub fn linspace ( args : LinspaceArgs , dt : DataType ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( args . start < args . end , " linspace expects 'args.start' to be less than 'args.end', got {} and {} " , . { args . start , args . end } ) ;
stdx . debug . assert ( args . steps > 0 , " linspace expects 'args.steps' to be positive, got {} " , . { args . steps } ) ;
stdx . debug . assert ( dt . isFloat ( ) , " linspace expects type to be a float, got {} (hint: use arange instead) " , . { dt } ) ;
2023-01-02 14:28:25 +00:00
const ctx = CompilationContext . current ( ) ;
2024-01-01 15:31:41 +00:00
const loc = ctx . location ( @src ( ) , " linspace({}, dtype={}) " , . { args , dt } ) ;
2023-01-02 14:28:25 +00:00
const sh = Shape . init ( . { args . steps } , dt ) ;
var iota_op = dialect . stablehlo . iota ( ctx . mlirCtx ( ) , 0 , mlir . ext . mlirType ( ctx . mlirCtx ( ) , sh ) , loc ) ;
var res = _result ( sh , iota_op . result ( 0 ) ) ;
if ( args . steps ! = 1 ) {
res = res . scale ( args . steps ) ;
}
if ( args . start ! = 0 ) {
res = res . addConstant ( args . start ) ;
}
return res ;
}
2023-12-18 13:56:45 +00:00
/// Returns a 0-rank Tensor with the given value.
2023-01-02 14:28:25 +00:00
pub fn scalar ( val : anytype , dt : DataType ) Tensor {
return Tensor . constant ( . { } , Data . init ( dt , val ) ) ;
}
2023-12-18 13:56:45 +00:00
test scalar {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
pub fn _fwd ( ) [ 6 ] Tensor {
var res : [ 6 ] Tensor = undefined ;
const dtypes = . { . bool , . u8 , . i32 , . f32 , . bf16 , . u64 } ;
inline for ( 0 . . 6 ) | i | res [ i ] = scalar ( 0 , dtypes [ i ] ) ;
return res ;
}
} ;
_ = try zml . testing . compileAndCall ( platform , Local . _fwd , . { } ) ;
}
2023-01-02 14:28:25 +00:00
/// Returns a constant Tensor with the given value.
pub fn constant ( dimz : anytype , val : Data ) Tensor {
2023-02-14 13:52:49 +00:00
const sh = Shape . init ( dimz , val . dtype ( ) ) ;
const singleton_sh = Shape . init ( . { } , val . dtype ( ) ) ;
2023-01-02 14:28:25 +00:00
const ctx = CompilationContext . current ( ) . mlirCtx ( ) ;
2024-01-01 15:31:41 +00:00
const loc = CompilationContext . current ( ) . location ( @src ( ) , " dims={d}, value={} " , . { sh , val } ) ;
2023-09-21 11:15:50 +00:00
const res_type = mlir . ext . RankedTensorType . fromShape ( ctx , singleton_sh ) ;
var constant_op = if ( mlir . ext . denseElementAttrType ( val . dtype ( ) ) ) | elem_type |
dialect . stablehlo . constant ( ctx , res_type , elem_type , val . constSlice ( ) , loc )
else blk : {
// Not all dtype can be serialized in the IR. If that's not possible, use f32.
const val_f32 = val . as ( f32 ) ;
break : blk dialect . stablehlo . constant ( ctx , res_type , . f32 , std . mem . asBytes ( & val_f32 ) , loc ) ;
} ;
2023-01-02 14:28:25 +00:00
if ( sh . rank ( ) > 0 ) {
2024-07-15 12:32:24 +00:00
constant_op = dialect . stablehlo . broadcast_in_dim ( ctx , constant_op . result ( 0 ) , & . { } , mlir . ext . RankedTensorType . fromShape ( ctx , sh ) . as ( mlir . Type ) , loc ) ;
2023-01-02 14:28:25 +00:00
}
2023-09-21 11:15:50 +00:00
return _result ( sh , constant_op . result ( 0 ) ) . convert ( val . dtype ( ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Embeds a buffer with concrete values into an Mlir program.
pub fn constantTensor ( val : HostBuffer ) Tensor {
const ctx = CompilationContext . current ( ) . mlirCtx ( ) ;
const result_type = mlir . ext . RankedTensorType . fromShape ( ctx , val . shape ( ) ) ;
const loc = ctx . location ( @src ( ) ) ;
2023-09-21 11:15:50 +00:00
const elem_type = mlir . ext . denseElementAttrType ( val . dtype ( ) ) orelse std . debug . panic ( " constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {} " , . { val . shape ( ) } ) ;
2023-01-02 14:28:25 +00:00
const constant_op = dialect . stablehlo . constant ( ctx , result_type , elem_type , val . data , loc ) ;
return _result ( val . shape ( ) , constant_op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the result of the outer product between the input Tensors.
pub fn outer ( self : Tensor , other : Tensor ) Tensor {
if ( self . rank ( ) + other . rank ( ) = = 1 ) {
return self . mul ( other ) ;
}
2024-01-01 15:31:41 +00:00
const res_shape = self . shape ( ) . outer ( other . shape ( ) ) ;
return self . broad ( res_shape ) . mul ( other . broad ( res_shape ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Given a tensor and a shape of the same rank,
/// will "broadcast" the given axes, so that `self` has the given shape.
/// This happens by virtually repeating the data several time along each give axes.
/// Note: most of the time the optimizer will make it so that the broadcast doesn't trigger a copy.
/// Note: the tags of the return tensor will be from the `output_shape`.
/// This means if you use and un-tagged broadcast on a tagged tensor,
/// you will lose the tags.
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
pub fn broadcast ( self : Tensor , output_shape : Shape , axes_ : [ ] const i64 ) Tensor {
2023-09-21 11:15:50 +00:00
stdx . debug . assert ( axes_ . len = = self . rank ( ) , " broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d}) " , . { self , output_shape , axes_ } ) ;
for ( 0 . . , axes_ ) | self_ax , other_ax | {
const d = self . dim ( self_ax ) ;
stdx . debug . assert ( d = = 1 or d = = output_shape . dim ( other_ax ) , " broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {} " , . { self , output_shape , axes_ , self_ax , other_ax } ) ;
}
2024-01-08 17:55:20 +00:00
const res_shape = output_shape . withDtype ( self . dtype ( ) ) ;
if ( std . mem . eql ( i64 , self . dims ( ) , output_shape . dims ( ) ) ) {
// No broadcast needed. We don't emit a new stablehlo value
// but we propagate output_shape tags.
return _result ( res_shape , self . value ( ) ) ;
}
2024-01-01 15:31:41 +00:00
const ctx = self . getContext ( ) ;
2024-07-15 12:32:24 +00:00
const result_type = mlir . ext . RankedTensorType . fromShape ( ctx . mlirCtx ( ) , res_shape ) . as ( mlir . Type ) ;
2024-01-01 15:31:41 +00:00
const loc = ctx . location ( @src ( ) , " broadcast({_}, {_}, axes={d}) " , . { self , res_shape , axes_ } ) ;
const broadcast_op = dialect . stablehlo . broadcast_in_dim ( ctx . mlirCtx ( ) , self . value ( ) , axes_ , result_type , loc ) ;
2023-01-02 14:28:25 +00:00
return _result ( res_shape , broadcast_op . result ( 0 ) ) ;
}
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
pub fn broadcastLeft ( self : Tensor , output_shape : Shape ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) < = output_shape . rank ( ) , " broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {} " , . { self . rank ( ) , output_shape . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
const a = output_shape . rank ( ) - self . rank ( ) ;
if ( self . rank ( ) = = output_shape . rank ( ) and std . mem . eql ( i64 , self . dims ( ) , output_shape . dims ( ) ) ) {
return self ;
}
return self . broadcast ( output_shape , Shape . range ( output_shape . rank ( ) , output_shape . dtype ( ) ) . dims ( ) [ a . . ] ) ;
}
/// Broadcasts a Tensor to the given shape, adding axes at the end.
pub fn broadcastRight ( self : Tensor , output_shape : Shape ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) < = output_shape . rank ( ) , " broadcastRight expects tensor rank to be less than output tensor rank, got {} and {} " , . { self . rank ( ) , output_shape . rank ( ) } ) ;
2023-01-02 14:28:25 +00:00
if ( self . rank ( ) = = output_shape . rank ( ) and self . _shape . eql ( output_shape ) ) {
return self ;
}
return self . broadcast ( output_shape , Shape . range ( self . rank ( ) , output_shape . dtype ( ) ) . dims ( ) ) ;
}
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
pub fn broad ( self : Tensor , other : Shape ) Tensor {
2024-01-08 17:55:20 +00:00
// TODO: broad is too restrictive because sometime you only want to specify one specific axis
// Note: if you code below, make sure to update Shape.canBroadcastTo.
stdx . debug . assert ( self . _shape . canBroadcastTo ( other ) , " Can't broadcast {} to {} " , . { self , other } ) ;
// Already the right shape
if ( std . mem . eql ( i64 , self . dims ( ) , other . dims ( ) ) ) return self ;
2023-01-02 14:28:25 +00:00
// Non ambiguous broadcasting
2024-01-08 17:55:20 +00:00
// TODO: broad is error prone because of this:
// it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 }
2023-01-02 14:28:25 +00:00
if ( self . _shape . rank ( ) = = 0 or self . _shape . rank ( ) = = other . rank ( ) ) {
2024-01-08 17:55:20 +00:00
const all_axes = [ MAX_RANK ] i64 { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 } ;
return self . broadcast ( other , all_axes [ 0 . . self . rank ( ) ] ) ;
2023-01-02 14:28:25 +00:00
}
// check that each axis of self maps to an axis of other
var axes_ : std . BoundedArray ( i64 , MAX_RANK ) = . { } ;
for ( self . _shape . tags ( ) ) | t | {
2024-01-08 17:55:20 +00:00
axes_ . appendAssumeCapacity ( @intCast ( other . axis ( t ) ) ) ;
2023-01-02 14:28:25 +00:00
}
return self . broadcast ( other , axes_ . constSlice ( ) ) ;
}
/// Reshapes the input Tensor with the given shape.
pub fn reshape ( self : Tensor , output_shape_ : anytype ) Tensor {
const output_shape = self . _shape . reshape ( output_shape_ ) ;
const tensor_type = mlir . ext . RankedTensorType . fromShape ( self . getContext ( ) . mlirCtx ( ) , output_shape ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " reshape({any}) " , . { output_shape } ) ;
2023-01-02 14:28:25 +00:00
const reshape_value = dialect . stablehlo . reshape ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , tensor_type , loc ) ;
return _result ( output_shape , reshape_value . result ( 0 ) ) ;
}
2023-12-18 13:56:45 +00:00
/// Converts the given 1 element Tensor into a 0-rank Tensor.
pub fn asScalar ( self : Tensor ) Tensor {
stdx . debug . assert ( self . count ( ) = = 1 , " Tensor.asScalar expects an input with exactly 1-element got {} " , . { self } ) ;
return self . reshape ( . { } ) ;
}
2023-05-17 09:01:27 +00:00
pub const Pad = struct {
2023-06-19 15:29:29 +00:00
low : i64 = 0 ,
high : i64 = 0 ,
interior : i64 = 0 ,
2023-05-17 09:01:27 +00:00
} ;
2023-01-02 14:28:25 +00:00
/// Pads the input Tensor with the given values.
2023-05-17 09:01:27 +00:00
/// Usage: x.pad(0, .{ .a = .{ .low = 1, .high = 1 }});
pub fn pad ( self : Tensor , padding_value : anytype , paddings : anytype ) Tensor {
const _paddings = self . shape ( ) . parseAxesOptions ( Pad , paddings , . { } ) ;
2023-01-02 14:28:25 +00:00
const ZEROS = [ _ ] i64 { 0 } * * MAX_RANK ;
2023-05-17 09:01:27 +00:00
var low = ZEROS ;
var high = ZEROS ;
var interior = ZEROS ;
2023-01-02 14:28:25 +00:00
2023-05-17 09:01:27 +00:00
var res_shape = self . _shape ;
for ( _paddings . constSlice ( ) , 0 . . ) | padding , i | {
low [ i ] = padding . low ;
high [ i ] = padding . high ;
interior [ i ] = padding . interior ;
2023-01-02 14:28:25 +00:00
2023-05-17 09:01:27 +00:00
var d : i64 = self . dim ( i ) ;
d + = low [ i ] + ( @max ( d - 1 , 0 ) * interior [ i ] ) + high [ i ] ;
res_shape . _dims . set ( i , d ) ;
2023-01-02 14:28:25 +00:00
}
2023-05-17 09:01:27 +00:00
const rk = self . rank ( ) ;
const mlir_ctx = self . getContext ( ) . mlirCtx ( ) ;
const loc = mlir_ctx . location ( @src ( ) ) . namedFmt ( mlir_ctx , " pad({},{}) " , . { padding_value , _paddings } ) ;
const pad_op = dialect . stablehlo . pad (
mlir_ctx ,
self . value ( ) ,
Tensor . scalar ( padding_value , self . dtype ( ) ) . value ( ) ,
. { . low = low [ 0 . . rk ] , . high = high [ 0 . . rk ] , . interior = interior [ 0 . . rk ] } ,
loc ,
) ;
2023-01-02 14:28:25 +00:00
2023-05-17 09:01:27 +00:00
return _result ( res_shape , pad_op . result ( 0 ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Inserts 1-dim axes at the given position, with the given tags.
/// `.{.a = 5, .b = 4}.insert(.b, .{ .c, .d }) -> .{ .a = 5, .c = 1, .d = 1, .b = 4 }`
pub fn insertAxes ( self : Tensor , axis_ : anytype , tags : anytype ) Tensor {
const tags_ = Shape . parseTags ( tags ) ;
const ax = if ( @TypeOf ( axis_ ) = = EnumLiteral and axis_ = = . last )
self . rank ( )
else
self . axis ( axis_ ) ;
var res_shape = self . _shape ;
const ones = [ _ ] i64 { 1 } * * MAX_RANK ;
res_shape . _dims . insertSlice ( ax , ones [ 0 . . tags_ . len ] ) catch unreachable ;
res_shape . _tags . insertSlice ( ax , tags_ . constSlice ( ) ) catch unreachable ;
return self . reshape ( res_shape ) ;
}
/// Appends a 1-dim axis, with the given tag.
pub fn appendAxes ( self : Tensor , t : anytype ) Tensor {
2024-02-28 15:47:37 +00:00
// stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
2023-01-02 14:28:25 +00:00
return self . insertAxes ( . last , t ) ;
}
/// Drops a 1-dim axis at the given index
pub fn squeeze ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dim ( a ) = = 1 , " squeeze expects axis to be squeezed to have a dimension of 1, got {} " , . { self . dim ( a ) } ) ;
2023-01-02 14:28:25 +00:00
const new_shape = self . _shape . remove ( a ) ;
// log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape });
return _result ( new_shape , self . reshape ( new_shape ) . value ( ) ) ;
}
/// Returns a Tensor with the given axes reversed.
pub fn reverse ( self : Tensor , axes_ : anytype ) Tensor {
const actual_axes = self . _shape . axes ( axes_ ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " reverse({any}) " , . { axes_ } ) ;
2023-01-02 14:28:25 +00:00
const reverse_op = dialect . stablehlo . reverse ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , toI64 ( actual_axes . constSlice ( ) ) , loc ) ;
return _result ( self . _shape , reverse_op . result ( 0 ) ) ;
}
pub const GatherOpts = struct { indices_are_sorted : bool = false } ;
2023-01-18 12:03:48 +00:00
/// For each coordinate in `indices`,
/// `gatherValues` extracts a single value of the given tensor.
///
/// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c }
/// * indices is an integer tensor
/// * result is a tensor whose shape is similar to the input shape
/// where the gathered axes have been replaced by axes from 'indices'.
///
/// Some example input for the base case where we work on one axis:
/// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]]
/// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b]
/// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c]
///
/// If an axis in common between `self` and `indices`,
/// it is treated as a "batching" axis, meaning that semantically
/// the operator is doing a gatherValues one time per dimension of this axis:
/// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]]
///
/// It is an error to have an axis present in `self`, `axes_` and `indices`.
///
/// If several axes are passed, then the last axis of indices is treated as coordinates:
/// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]]
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
///
/// It is possible to use gatherValues without tags, but batching won't be available.
2023-02-14 13:52:49 +00:00
pub fn gatherValues ( self : Tensor , coord_axes : anytype , indices : Tensor , opts : GatherOpts ) Tensor {
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
const single_coord , const coord_axes_ = _parseGatherCoord ( self , coord_axes ) ;
2023-01-02 14:28:25 +00:00
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( coord_axes_ . len > 0 , " gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})` " , . { } ) ;
2023-02-14 13:52:49 +00:00
for ( coord_axes_ . constSlice ( ) , 0 . . ) | a , i | {
2023-01-18 12:03:48 +00:00
if ( i > 0 ) {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( a = = coord_axes_ . get ( i - 1 ) + 1 , " gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {} " , . { coord_axes , self } ) ;
2023-01-02 14:28:25 +00:00
}
}
2023-01-18 12:03:48 +00:00
const AxisKind = enum { batching , offset , collapsed , indices } ;
var self_kind : std . BoundedArray ( AxisKind , MAX_RANK ) = . { } ;
var indices_batch_axes : Shape . DimsArray = . { } ;
for ( self . _shape . tags ( ) , 0 . . self . rank ( ) ) | t , self_ax | {
2023-02-14 13:52:49 +00:00
const maybe_coord_ax = std . mem . indexOfScalar ( u3 , coord_axes_ . constSlice ( ) , @intCast ( self_ax ) ) ;
2023-01-18 12:03:48 +00:00
if ( indices . _shape . hasTag ( t ) ) | id_ax | {
// tag is both in self and indices -> it's a batching dim
// Note: tags are required for batching.
self_kind . appendAssumeCapacity ( . batching ) ;
indices_batch_axes . appendAssumeCapacity ( id_ax ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( maybe_coord_ax = = null , " gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}' " , . { self . _shape . _tags . get ( self_ax ) , self , coord_axes , indices } ) ;
2023-02-14 13:52:49 +00:00
} else if ( maybe_coord_ax ) | _ | {
2023-01-18 12:03:48 +00:00
// for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none)
self_kind . appendAssumeCapacity ( . collapsed ) ;
} else {
self_kind . appendAssumeCapacity ( . offset ) ;
}
}
2023-01-02 14:28:25 +00:00
2023-02-14 13:52:49 +00:00
// When we receive several coord_axes we need an extra dimension to store
2023-01-18 12:03:48 +00:00
// one index per axis, which makes the coordinates of one value.
// Otherwi se stablehlo uses the "indices.rank()" default value.
2023-02-14 13:52:49 +00:00
const index_coord_axis = if ( single_coord )
2023-01-18 12:03:48 +00:00
indices . rank ( )
else blk : {
const ax = indices . _shape . hasTag ( . coord ) orelse indices . _shape . axis ( - 1 ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( indices . dim ( ax ) = = coord_axes_ . len , " gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {} " , . { coord_axes , coord_axes_ . len , indices } ) ;
2023-01-18 12:03:48 +00:00
break : blk ax ;
} ;
2023-01-02 14:28:25 +00:00
2023-01-18 12:03:48 +00:00
// compute res shape
var res_shape = Shape . init ( . { } , self . dtype ( ) ) ;
var res_kind : std . BoundedArray ( AxisKind , MAX_RANK ) = . { } ;
for ( self_kind . constSlice ( ) , 0 . . ) | kind , ax_usize | {
const ax : u3 = @intCast ( ax_usize ) ;
2023-02-14 13:52:49 +00:00
if ( ax = = coord_axes_ . get ( 0 ) ) {
2023-01-18 12:03:48 +00:00
// The first val_ax is special cause this is the place where we insert indices axes.
for ( indices . _shape . tags ( ) , 0 . . indices . rank ( ) ) | t , id_ax | {
if ( id_ax = = index_coord_axis ) continue ;
if ( std . mem . indexOfScalar ( i64 , indices_batch_axes . constSlice ( ) , @intCast ( id_ax ) ) ) | _ | {
// batching dim are already in res
continue ;
}
res_shape = res_shape . appendDim ( indices . dim ( id_ax ) , t ) ;
res_kind . appendAssumeCapacity ( . indices ) ;
}
}
switch ( kind ) {
. collapsed = > continue ,
else = > {
res_shape = res_shape . appendDim ( self . dim ( ax ) , self . _shape . tag ( ax ) ) ;
res_kind . appendAssumeCapacity ( kind ) ;
} ,
}
}
// This is not a gather, but a dynamicSlice.
// Sometimes the backend recognize this pattern, but not always.
// So let us handle that.
if ( indices . count ( ) = = 1 ) {
2023-12-18 13:56:45 +00:00
return self . dynamicSlice1d ( coord_axes_ . get ( 0 ) , . { . start = indices . flattenAll ( ) . squeeze ( 0 ) , . len = 1 } ) . reshape ( res_shape ) ;
2023-01-18 12:03:48 +00:00
}
var slice_dims : Shape . DimsArray = . { } ;
for ( self_kind . constSlice ( ) , self . dims ( ) ) | k , d | {
slice_dims . appendAssumeCapacity ( switch ( k ) {
. batching , . collapsed = > 1 ,
. offset = > d ,
. indices = > unreachable ,
} ) ;
}
// scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() });
2023-01-02 14:28:25 +00:00
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const gather_op = dialect . stablehlo . gather (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
indices . value ( ) ,
2023-01-18 12:03:48 +00:00
slice_dims . constSlice ( ) ,
2023-01-02 14:28:25 +00:00
loc ,
. {
2023-01-18 12:03:48 +00:00
. offset_dims = _collectAxes ( AxisKind , res_kind , . offset ) . constSlice ( ) ,
. collapsed_slice_dims = _collectAxes ( AxisKind , self_kind , . collapsed ) . constSlice ( ) ,
. operand_batching_dims = _collectAxes ( AxisKind , self_kind , . batching ) . constSlice ( ) ,
. start_indices_batching_dims = indices_batch_axes . constSlice ( ) ,
. start_index_map = _collectAxes ( AxisKind , self_kind , . collapsed ) . constSlice ( ) ,
. index_vector_dim = index_coord_axis ,
2023-01-02 14:28:25 +00:00
. indices_are_sorted = opts . indices_are_sorted ,
} ,
) ;
2023-01-18 12:03:48 +00:00
const mlir_shape = fromMlirValue ( gather_op . result ( 0 ) ) . shape ( ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( mlir_shape . eql ( res_shape ) , " gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other. " , . { self , indices } ) ;
2023-01-02 14:28:25 +00:00
return _result ( res_shape , gather_op . result ( 0 ) ) ;
}
2023-01-18 12:03:48 +00:00
test gatherValues {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
{
// Only test shapes
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-01-18 12:03:48 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 } , . a , . { } , . { } } ,
. { . { . a = 10 } , . a , . { . n = 8 } , . { . n = 8 } } ,
. { . { . a = 10 , . b = 20 } , . a , . { } , . { . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , . a , . { . n = 8 } , . { . n = 8 , . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , 0 , . { . n = 8 } , . { . n = 8 , . b = 20 } } ,
// Favor val shape, instead of indices shape.
. { . { . a = 10 , . b = 20 } , . b , . { . n = 8 } , . { . a = 10 , . n = 8 } } ,
. { . { . a = 10 , . b = 20 , . c = 30 } , . b , . { . n = 8 } , . { . a = 10 , . n = 8 , . c = 30 } } ,
// batching axes are implicits.
2023-08-07 12:28:36 +00:00
. { . { . a = 10 , . b = 20 } , . b , . { . a = 10 } , . { . a = 10 } } ,
. { . { . a = 10 , . b = 20 } , . a , . { . b = 20 } , . { . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , . b , . { . a = 10 , . n = 8 } , . { . a = 10 , . n = 8 } } ,
2023-01-18 12:03:48 +00:00
// stablehlo.gather is biased toward indices shape (like gatherSlice).
// This make it awkward to use when you have both batching dimension and new indices dimensions.
// For now we reject those, and let user explicitly transpose self or indices if needed.
// .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } },
// Also handle tuples
. { . { . a = 10 , . b = 20 } , . { . a , . b } , . { . n = 8 , . _ = 2 } , . { . n = 8 } } ,
. { . { 10 , 20 } , . { - 2 , - 1 } , . { 8 , 2 } , . { 8 } } ,
// and 1-tuple
. { . { . a = 10 , . b = 20 } , . { . b } , . { . n = 8 , . _ = 1 } , . { . a = 10 , . n = 8 } } ,
} ) | testcase | {
const x_shape , const tag , const idx_shape , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const idx = Tensor . constant ( idx_shape , . { . i32 = 0 } ) ;
const y = gatherValues ( x , tag , idx , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
}
2023-01-02 14:28:25 +00:00
/// Gathers slices along the given axes with runtime indices.
/// * slice_shape represents the shape of the slices to extract,
/// it must be smaller than original shape.
/// It must use a subset of self axes.
/// If slice_shape is **not** tagged, then it must have the same rank than self.
/// * `indices` represents a set of coordinates.
/// The coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
/// The coordinate axis must have `slice_shape.rank()` dims.
/// The coordinates represent the "top-left" corner of the slice to extract.
/// * the output tensor starts with axes from `indices`.
/// * if the input tensor has tagged axes, matching `indices` axes,
/// they will be considered "batching" axes.
///
/// Sample input/output shapes:
/// * gatherSlices([A, B, C, D], .{.b=B', .c=C'}, [N, 2]) -> [N, A, B', C', D]
/// * gatherSlices(x(a,b,c,d), .{.b=B', .c=C'}, g(n,m)) = z(n, a, b', c', d) = x(a, g(n, 0) + b', g(n, 1) + c', d)
///
2023-01-18 12:03:48 +00:00
/// Note: the axis order of the result is different from gatherValues.
2023-01-02 14:28:25 +00:00
/// This is because gatherSlices, favorizes contiguous copy of the extracted slices,
2023-01-18 12:03:48 +00:00
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
2023-01-02 14:28:25 +00:00
/// In our example the contiguous dimension .d is not sliced
/// and gatherSlices can copy data by group of C'*D elements.
2024-01-01 15:31:41 +00:00
pub fn gatherSlices ( self : Tensor , slice_shape_ : anytype , indices : Tensor , opts : GatherOpts ) Tensor {
const slice_shape = if ( @TypeOf ( slice_shape_ ) = = Shape ) slice_shape_ else Shape . init ( slice_shape_ , . i32 ) ;
2023-01-02 14:28:25 +00:00
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
const tagged_api = slice_shape . isFullyTagged ( ) ;
if ( tagged_api ) {
for ( slice_shape . tags ( ) ) | t | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . _shape . hasTag ( t ) ! = null , " gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {} " , . { t , self } ) ;
2023-01-02 14:28:25 +00:00
}
} else {
// For untagged api, we require all slices to be specified.
// Note: we could relax this and right align the slice.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( slice_shape . rank ( ) = = self . rank ( ) , " gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags. " , . { self , slice_shape } ) ;
2023-01-02 14:28:25 +00:00
}
const index_coord_axis = indices . _shape . hasTag ( . coord ) orelse indices . _shape . axis ( - 1 ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( indices . dim ( index_coord_axis ) = = slice_shape . rank ( ) , " gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {} " , . { self , slice_shape , slice_shape . rank ( ) , indices } ) ;
2023-01-02 14:28:25 +00:00
// Compute result shape
var res_shape = indices . _shape . remove ( index_coord_axis ) . withDtype ( self . dtype ( ) ) ;
var slice_dims = self . _shape . _dims ;
var self_batch_axes : std . BoundedArray ( i64 , MAX_RANK ) = . { } ;
var indices_batch_axes : std . BoundedArray ( i64 , MAX_RANK ) = . { } ;
var start_index_map : std . BoundedArray ( i64 , MAX_RANK ) = . { } ;
var self_offset_axes : std . BoundedArray ( i64 , MAX_RANK ) = . { } ;
for ( self . _shape . tags ( ) , 0 . . self . rank ( ) ) | t , self_ax | {
const maybe_slice_ax : ? u3 = if ( tagged_api ) slice_shape . hasTag ( t ) else @intCast ( self_ax ) ;
if ( tagged_api and indices . _shape . hasTag ( t ) ! = null ) {
// tag is both in self and indices -> it's a batching dim
// Note: tags are required for batching.
self_batch_axes . appendAssumeCapacity ( @intCast ( self_ax ) ) ;
indices_batch_axes . appendAssumeCapacity ( indices . _shape . axis ( t ) ) ;
slice_dims . set ( self_ax , 1 ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( slice_shape . hasTag ( t ) = = null , " gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}` " , . { t , slice_shape , indices } ) ;
2023-01-02 14:28:25 +00:00
} else if ( maybe_slice_ax ) | slice_ax | {
// Specified axes contains the start offset of the slices,
// and are collected in `start_index_map`.
const slice_dim = slice_shape . dim ( slice_ax ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( slice_dim < = self . _shape . dim ( self_ax ) , " gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}. " , . { t , slice_shape , self . _shape } ) ;
2023-01-02 14:28:25 +00:00
slice_dims . set ( self_ax , slice_dim ) ;
res_shape = res_shape . appendDim ( slice_dim , t ) ;
start_index_map . appendAssumeCapacity ( @intCast ( self_ax ) ) ;
self_offset_axes . appendAssumeCapacity ( res_shape . rank ( ) - 1 ) ;
} else {
// non-batching, non-indexed axes
res_shape = res_shape . appendDim ( self . dim ( self_ax ) , t ) ;
self_offset_axes . appendAssumeCapacity ( res_shape . rank ( ) - 1 ) ;
}
}
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " gatherSlices({_}, slice_shape={_}, idx={_}) " , . { self , slice_shape , indices } ) ;
2023-01-02 14:28:25 +00:00
const gather_op = dialect . stablehlo . gather (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
indices . value ( ) ,
slice_dims . constSlice ( ) ,
loc ,
. {
. offset_dims = self_offset_axes . constSlice ( ) ,
. collapsed_slice_dims = & . { } ,
. operand_batching_dims = self_batch_axes . constSlice ( ) ,
. start_indices_batching_dims = indices_batch_axes . constSlice ( ) ,
. start_index_map = start_index_map . constSlice ( ) ,
. index_vector_dim = index_coord_axis ,
. indices_are_sorted = opts . indices_are_sorted ,
} ,
) ;
return _result ( res_shape , gather_op . result ( 0 ) ) ;
}
test gatherSlices {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
2024-01-01 15:31:41 +00:00
const Local = struct {
pub fn _gatherSlices ( self : Tensor , slice_shape : Shape , indices : Tensor , opts : GatherOpts ) Tensor {
return self . gatherSlices ( slice_shape , indices , opts ) ;
}
} ;
2023-01-02 14:28:25 +00:00
{
// Only test shapes
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-01-02 14:28:25 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 } , . { } , . { . _ = 0 } , . { . a = 10 } } ,
. { . { . a = 10 } , . { . a = 7 } , . { . _ = 1 } , . { . a = 7 } } ,
. { . { . a = 10 } , . { . a = 7 } , . { . n = 8 , . _ = 1 } , . { . n = 8 , . a = 7 } } ,
. { . { . a = 10 } , . { . a = 7 } , . { . coord = 1 , . n = 8 } , . { . n = 8 , . a = 7 } } ,
// tags aren't required.
. { . { 10 } , . { 7 } , . { . n = 8 , . _ = 1 } , . { . n = 8 , . _ = 7 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = 7 } , . { . _ = 1 } , . { . a = 7 , . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = 7 } , . { . n = 8 , . _ = 1 } , . { . n = 8 , . a = 7 , . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = 7 } , . { . n = 8 , . coord = 1 , . m = 9 } , . { . n = 8 , . m = 9 , . a = 7 , . b = 20 } } ,
. { . { . a = 10 , . b = 20 } , . { . b = 17 } , . { . n = 8 , . _ = 1 } , . { . n = 8 , . a = 10 , . b = 17 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = 7 , . b = 17 } , . { . n = 8 , . _ = 2 } , . { . n = 8 , . a = 7 , . b = 17 } } ,
// Note: currently the order of the axes in the slice is not used.
. { . { . a = 10 , . b = 20 } , . { . b = 17 , . a = 7 } , . { . n = 8 , . _ = 2 } , . { . n = 8 , . a = 7 , . b = 17 } } ,
. { . { . a = 10 , . b = 20 , . c = 20 } , . { . b = 17 } , . { . n = 8 , . _ = 1 } , . { . n = 8 , . a = 10 , . b = 17 , . c = 20 } } ,
// batching dims
2023-08-07 12:28:36 +00:00
. { . { . a = 10 , . b = 20 } , . { . b = 17 } , . { . a = 10 , . _ = 1 } , . { . a = 10 , . b = 17 } } ,
. { . { . b = 200 , . a = 100 , . c = 300 } , . { . c = 300 } , . { . a = 100 , . b = 200 , . _ = 1 } , . { . a = 100 , . b = 200 , . c = 300 } } ,
2023-01-02 14:28:25 +00:00
} ) | testcase | {
const x_shape , const slice_dims , const idx_shape , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const slice_shape = Shape . init ( slice_dims , . u16 ) ;
const idx = Tensor . constant ( idx_shape , . { . i32 = 0 } ) ;
const y = gatherSlices ( x , slice_shape , idx , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
2023-02-01 15:58:30 +00:00
const mod = try zml . compileFn (
std . testing . allocator ,
2024-01-01 15:31:41 +00:00
Local . _gatherSlices ,
2023-02-01 15:58:30 +00:00
. { x . shape ( ) , slice_shape , idx . shape ( ) , . { . indices_are_sorted = true } } ,
platform ,
) ;
defer mod . deinit ( ) ;
2023-01-02 14:28:25 +00:00
}
}
// Test with actual values.
const range = try zml . HostBuffer . arange ( std . testing . allocator , . { . end = 2 * 4 * 6 } , . u16 ) ;
defer range . deinit ( std . testing . allocator ) ;
const operand = try range . reshape ( . { . a = 2 , . b = 4 , . c = 6 } ) . toDevice ( platform ) ;
defer operand . deinit ( ) ;
const start_indices = ( try zml . Buffer . fromArray ( platform , [ 2 ] [ 2 ] i32 { . { 2 , 1 } , . { 0 , 3 } } ) ) . withTags ( . { . n , . _ } ) ;
defer start_indices . deinit ( ) ;
2024-01-01 15:31:41 +00:00
const result = try zml . testing . compileAndCall ( platform , Local . _gatherSlices , . { operand , Shape . init ( . { . b = 2 , . c = 3 } , . u16 ) , start_indices , . { } } ) ;
2023-01-02 14:28:25 +00:00
const expected = zml . HostBuffer . fromArray ( & [ 2 ] [ 2 ] [ 2 ] [ 3 ] u16 {
. {
. { . { 13 , 14 , 15 } , . { 19 , 20 , 21 } } ,
. { . { 37 , 38 , 39 } , . { 43 , 44 , 45 } } ,
} ,
. {
. { . { 3 , 4 , 5 } , . { 9 , 10 , 11 } } ,
. { . { 27 , 28 , 29 } , . { 33 , 34 , 35 } } ,
} ,
} ) ;
try zml . testing . expectClose ( expected , result , 0 ) ;
}
2023-02-14 13:52:49 +00:00
pub const ScatterOpts = struct {
2024-01-08 17:55:20 +00:00
/// Promise scatter that all coordinates in `indices` are sorted, wrt to the final offset in `self`
2023-02-14 13:52:49 +00:00
/// Result is undefined if the promise is violated.
indices_are_sorted : bool = false ,
/// Promise scatter that slices don't overlap.
/// Result is undefined if the promise is violated.
2024-01-08 17:55:20 +00:00
/// This allows for better code generation, because it means that updates can be applied in parallel.
2023-02-14 13:52:49 +00:00
indices_are_unique : bool = false ,
/// Function used to update previous value in `self` with values from `updates`.
/// If `update_fn` is not associative (ie the order of execution matters),
/// then you should make sure the slices don't overlap,
/// otherwise the result will depend on the runtime scheduling
/// of the operator which is backend specific.
2024-01-08 17:55:20 +00:00
update_fn : * const fn ( Tensor , Tensor ) Tensor = increment ,
2023-02-14 13:52:49 +00:00
2024-01-08 17:55:20 +00:00
pub fn increment ( old_value : Tensor , new_value : Tensor ) Tensor {
2024-05-02 17:10:11 +00:00
return old_value . add ( new_value ) ;
2023-02-14 13:52:49 +00:00
}
2024-01-08 17:55:20 +00:00
pub fn override ( old_value : Tensor , new_value : Tensor ) Tensor {
2024-05-02 17:10:11 +00:00
_ = old_value ;
return new_value ;
2023-02-14 13:52:49 +00:00
}
} ;
2024-01-08 17:55:20 +00:00
/// Update the given tensor, by copying `values` into slice by slice into `self`.
2023-02-14 13:52:49 +00:00
/// The slices are chosen at runtime by interpreting indices as coordinates into `self`.
2024-01-08 17:55:20 +00:00
/// This is a generalized version of `dynamicUpdateSlice` where more than one offset can be specified at a time.
2023-02-14 13:52:49 +00:00
///
2024-01-08 17:55:20 +00:00
/// ### Arguments
2023-02-14 13:52:49 +00:00
///
2024-01-08 17:55:20 +00:00
/// - Return a tensor with same shape than `self`, with updated content.
/// - `indices` is a set of Tensor (typically rank 1), representing coordinates into `self`.
/// all indices must have the same shape, but scalars are accepted.
/// - each `indices` entry contains offset along an axes into `self`.
/// Typically axes are identified by their tags, but in the absence of tags on `indices`,
/// The entry in indices will be assigned to axes of `self` from major to minor axis.
/// It is recommended to have indices referencing only major axes of `self` for better performance.
/// - `values` shape is obtained by concatenating the shape of `indices` with the shape of the slices to be extracted.
/// - `opts`: `zml.Tensor.ScatterOpts` des
///
/// ### Sample input/output shapes with corresponding pseudo-code.
///
/// Basic `scatterSlices` with the first two axes (.a, .b) being indexed, and full (.c, .d) slice copies:
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N]}, y[N, C, D]) [A, B, C, D] {
/// var z = x;
/// for (0..N) |n| {
/// for (0..C) |c| for (0..D) |d| {{
/// z[off_a[n],off_b[n],c,d] += y[n, c, d];
/// }}
/// }
/// return z;
2023-02-14 13:52:49 +00:00
/// }
2024-01-08 17:55:20 +00:00
/// ```
2023-02-14 13:52:49 +00:00
///
2024-01-08 17:55:20 +00:00
/// `scatterSlices` with the first three axes (.a, .b, .c) being indexed, and a partial copy of (.c, .d).
/// Note that .c axis is present both in the indices and updates, and `updates.dim(.c) < self.dim(.c)`.
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[N], .b=off_b[N], .c=off_c[N]}, y[N, C', D]) [A, B, C, D] {
/// var z = x;
/// for (0..N) |n| {
/// for (0..C') |c| for (0..D) |d| {{
/// z[off_a[n],off_b[n],off_c[n]+c,d] += y[n, c, d];
/// }}
/// }
/// return z;
/// }
/// ```
///
/// `scatterSlices` with the first axis .a being indexed, and where .b is used as a batching axis.
/// Note that here .b axis is present in `self`, `off_a`, and `updates`,
/// and is not mentionned in the axes of indices.
///
/// ```
/// fn scatterSlices(x[A, B, C, D], .{.a=off_a[B,N]}, y[N, B, C, D]) [A, B, C, D] {
/// var z = x;
/// for (0..B) |b| {
/// for (0..N) |n| {
/// for (0..C) |c| for (0..D) |d| {{
/// z[off_a[b,n],b,c,d] += y[n, b, c, d];
/// }}
/// }
/// }
/// return z;
/// }
/// ```
///
/// ### Warnings
///
/// - if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound.
2023-02-14 13:52:49 +00:00
/// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`,
/// then the result will depend on the execution order that you don't control.
2024-01-08 17:55:20 +00:00
/// - `scatterSlices` is a very expressive operator, and can lead to complicated code generation
/// that requires host<->device synchronization.
/// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR.
pub fn scatterSlices ( self : Tensor , indices : anytype , updates : Tensor , opts : ScatterOpts ) Tensor {
scoped_log . debug ( " scatterSlices({}, {any}, {}) " , . { self , indices , updates } ) ;
2023-02-14 13:52:49 +00:00
2024-01-08 17:55:20 +00:00
const UpdateType = @TypeOf ( ScatterOpts . increment ) ;
2023-02-14 13:52:49 +00:00
2024-01-08 17:55:20 +00:00
const Custom = struct {
pub fn inc ( custom : * const UpdateType , old_value : Tensor , new_value : Tensor ) Tensor {
return @call ( . auto , custom , . { old_value , new_value } ) ;
2023-02-14 13:52:49 +00:00
}
} ;
2024-01-08 17:55:20 +00:00
return ops . scatter ( Tensor , * const UpdateType , Custom . inc , self , opts . update_fn , indices , updates , opts ) ;
2023-02-14 13:52:49 +00:00
}
test scatterSlices {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _scatter ( self : Tensor , indices : [ ] const Tensor , updates : Tensor ) Tensor {
2023-02-14 13:52:49 +00:00
return self . scatterSlices (
indices ,
updates ,
. { . update_fn = ScatterOpts . increment } ,
) ;
}
2024-01-08 17:55:20 +00:00
pub fn _scatterCB ( self : Tensor , coords : Tensor , updates : Tensor ) Tensor {
return self . scatterSlices (
. { . c = coords . choose1d ( . coord , 0 ) , . b = coords . choose1d ( . coord , 1 ) } ,
updates ,
. { . update_fn = ScatterOpts . increment } ,
) ;
}
pub fn _idx ( idx_shape : anytype ) Tensor {
return Tensor . constant ( idx_shape , . { . i32 = 0 } ) ;
}
2023-02-14 13:52:49 +00:00
} ;
{
// Only test shapes
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-02-14 13:52:49 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
2024-01-08 17:55:20 +00:00
const idx = Local . _idx ;
2023-02-14 13:52:49 +00:00
inline for ( . {
2024-01-08 17:55:20 +00:00
// This is equivalent to a dynamic update slice, update 3 values at given offset of axis .a:
. { . { . a = 10 } , . { . a = idx ( . { } ) } , . { . a = 3 } } ,
// Use .a as a batching axis with .a=10 x .n=8 updates of 2 elements of .b
. { . { . a = 10 , . b = 20 } , . { . b = idx ( . { . a = 10 , . n = 8 } ) } , . { . a = 10 , . n = 8 , . b = 2 } } ,
// Same but with update transposed
. { . { . a = 10 , . b = 20 } , . { . b = idx ( . { . a = 10 , . n = 8 } ) } , . { . a = 10 , . b = 2 , . n = 8 } } ,
// similar, but use the normalized form where a is no longer an explicit batching axis.
. { . { . a = 10 , . b = 20 } , . { . a = idx ( . { . a2 = 10 , . n = 8 } ) , . b = idx ( . { . a2 = 10 , . n = 8 } ) } , . { . a2 = 10 , . n = 8 , . b = 2 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = idx ( . { . a = 10 , . n = 8 } ) , . b = idx ( . { . a = 10 , . n = 8 } ) } , . { . a = 10 , . n = 8 , . b = 2 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = idx ( . { . n = 8 } ) } , . { . n = 8 , . a = 2 } } ,
. { . { . a = 10 , . b = 20 } , . { . b = idx ( . { . n = 8 } ) , . a = idx ( . { . n = 8 } ) } , . { . n = 8 , . a = 3 , . b = 2 } } ,
. { . { . a = 10 , . b = 20 } , . { . a = idx ( . { . n = 8 } ) , . b = idx ( . { . n = 8 } ) } , . { . a = 3 , . n = 8 , . b = 2 } } ,
2023-02-14 13:52:49 +00:00
} ) | testcase | {
2024-01-08 17:55:20 +00:00
const x_shape , const indices , const updates_shapes = testcase ;
2023-02-14 13:52:49 +00:00
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const updates = Tensor . constant ( updates_shapes , . { . f16 = 0 } ) ;
2024-01-08 17:55:20 +00:00
const y = scatterSlices ( x , indices , updates , . { } ) ;
2023-02-14 13:52:49 +00:00
// Shape doesn't change with scatterSlices
try zml . testing . expectEqualShapes ( x . shape ( ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
// Test with actual values, no batching.
{
const a_host = try zml . HostBuffer . arange ( std . testing . allocator , . { . end = 9 } , . i32 ) ;
const a = ( try zml . Buffer . from ( platform , a_host . reshape ( . { 3 , 3 } ) ) ) . withTags ( . { . a , . b } ) ;
defer a . deinit ( ) ;
a_host . deinit ( std . testing . allocator ) ;
2024-01-08 17:55:20 +00:00
const scatter_indices = try zml . Buffer . fromArray ( platform , [ 2 ] i32 { 0 , 2 } ) ;
2023-02-14 13:52:49 +00:00
const updates = try zml . Buffer . fromArray ( platform , [ 2 ] [ 3 ] i32 { . { 10 , 20 , 30 } , . { 70 , 80 , 90 } } ) ;
const expected = [ 3 ] [ 3 ] i32 { . { 10 , 21 , 32 } , . { 3 , 4 , 5 } , . { 76 , 87 , 98 } } ;
2024-01-08 17:55:20 +00:00
const result = try zml . testing . compileAndCall ( platform , Local . _scatter , . {
2023-02-14 13:52:49 +00:00
a ,
2024-01-08 17:55:20 +00:00
& . { scatter_indices . withTags ( . { . n } ) } ,
2023-02-14 13:52:49 +00:00
updates . withTags ( . { . n , . b } ) ,
} ) ;
try std . testing . expect ( a . shape ( ) . eql ( result . shape ( ) ) ) ;
try std . testing . expectEqual ( expected , result . getValue ( @TypeOf ( expected ) ) ) ;
}
2023-06-19 15:29:29 +00:00
// Test with setting individual values (no batching)
{
const a_host = try zml . HostBuffer . arange ( std . testing . allocator , . { . end = 9 } , . i32 ) ;
const a = try zml . Buffer . from ( platform , a_host ) ;
defer a . deinit ( ) ;
a_host . deinit ( std . testing . allocator ) ;
2024-01-08 17:55:20 +00:00
const scatter_indices = try zml . Buffer . fromArray ( platform , [ 2 ] i32 { 2 , 7 } ) ;
2023-06-19 15:29:29 +00:00
const updates = try zml . Buffer . fromArray ( platform , [ 2 ] i32 { 20 , 70 } ) ;
const expected = [ 9 ] i32 { 0 , 1 , 22 , 3 , 4 , 5 , 6 , 77 , 8 } ;
2024-01-08 17:55:20 +00:00
const result = try zml . testing . compileAndCall ( platform , Local . _scatter , . {
2023-06-19 15:29:29 +00:00
a ,
2024-01-08 17:55:20 +00:00
& . { scatter_indices . withTags ( . { . n } ) } ,
2023-06-19 15:29:29 +00:00
updates . withTags ( . { . n } ) ,
} ) ;
try std . testing . expect ( a . shape ( ) . eql ( result . shape ( ) ) ) ;
try std . testing . expectEqual ( expected , result . getValue ( @TypeOf ( expected ) ) ) ;
}
2023-02-14 13:52:49 +00:00
{
// Test with actual values and batching along axis .a
const operand = try zml . Buffer . constant ( platform , Shape . init ( . { . a = 2 , . b = 3 , . c = 4 , . d = 2 } , . u16 ) , 0 ) ;
defer operand . deinit ( ) ;
const start_indices = ( try zml . Buffer . fromArray (
platform ,
[ 2 ] [ 2 ] [ 3 ] [ 2 ] i32 {
. {
. { . { 0 , 0 } , . { 1 , 0 } , . { 2 , 1 } } ,
. { . { 0 , 1 } , . { 1 , 1 } , . { 0 , 9 } } ,
} ,
. {
. { . { 0 , 0 } , . { 2 , 1 } , . { 2 , 2 } } ,
. { . { 1 , 2 } , . { 0 , 1 } , . { 1 , 0 } } ,
} ,
} ,
) ) . withTags ( . { . n , . a , . m , . coord } ) ;
defer start_indices . deinit ( ) ;
const values = try zml . Buffer . constant (
platform ,
Shape . init ( . { . n = 2 , . a = 2 , . m = 3 , . c = 2 , . d = 2 } , . u16 ) ,
1 ,
) ;
defer values . deinit ( ) ;
2024-01-08 17:55:20 +00:00
const result = try zml . testing . compileAndCall ( platform , Local . _scatterCB , . { operand , start_indices , values } ) ;
2023-02-14 13:52:49 +00:00
const expected = [ 2 ] [ 3 ] [ 4 ] [ 2 ] u16 {
. {
. { . { 2 , 2 } , . { 3 , 3 } , . { 1 , 1 } , . { 0 , 0 } } ,
. { . { 0 , 0 } , . { 0 , 0 } , . { 2 , 2 } , . { 2 , 2 } } ,
. { . { 0 , 0 } , . { 0 , 0 } , . { 1 , 1 } , . { 1 , 1 } } ,
} ,
. {
. { . { 0 , 0 } , . { 1 , 1 } , . { 1 , 1 } , . { 0 , 0 } } ,
. { . { 2 , 2 } , . { 3 , 3 } , . { 1 , 1 } , . { 0 , 0 } } ,
. { . { 0 , 0 } , . { 1 , 1 } , . { 1 , 1 } , . { 0 , 0 } } ,
} ,
} ;
try std . testing . expect ( operand . shape ( ) . eql ( result . shape ( ) ) ) ;
try std . testing . expectEqual ( expected , result . getValue ( @TypeOf ( expected ) ) ) ;
}
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor containing the maximum over a given axis.
pub fn max ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
return ops . reduce (
struct {
pub fn cmp ( x : Tensor , res : Tensor ) Tensor {
return res . maximum ( x . convert ( res . dtype ( ) ) ) ;
}
} . cmp ,
self ,
2023-02-10 12:28:41 +00:00
Tensor . constant ( & . { } , self . dtype ( ) . minValue ( ) ) ,
2023-01-02 14:28:25 +00:00
& . { a } ,
) ;
}
/// Returns a Tensor containing the minimum over a given axis.
pub fn min ( self : Tensor , axis_ : anytype ) Tensor {
const a = self . axis ( axis_ ) ;
return ops . reduce (
struct {
pub fn cmp ( x : Tensor , res : Tensor ) Tensor {
return res . minimum ( x . convert ( res . dtype ( ) ) ) ;
}
} . cmp ,
self ,
2023-02-10 12:28:41 +00:00
Tensor . constant ( & . { } , self . dtype ( ) . maxValue ( ) ) ,
2023-01-02 14:28:25 +00:00
& . { a } ,
) ;
}
pub const ArgMaxRes = struct {
values : Tensor ,
indices : Tensor ,
fn cmp ( left : ArgMaxRes , right : ArgMaxRes ) ArgMaxRes {
const left_val = left . values ;
const right_val = right . values ;
const left_idx = left . indices ;
const right_idx = right . indices ;
const left_gt_right = left_val . cmp ( . GT , right_val ) ;
const is_nan = left_val . cmp ( . NE , left_val ) ;
const left_gt_or_nan = left_gt_right . logical ( . OR , is_nan ) ;
// we are bubbling up Nan.
const max_val = left_gt_or_nan . select ( left_val , right_val ) ;
// If left_val == right_val: keep the smallest idx.
const is_same = left_val . cmp ( . EQ , right_val ) ;
const is_first = left_idx . cmp ( . LT , right_idx ) ;
const is_same_but_first = is_same . logical ( . AND , is_first ) ;
const keep_left_idx = left_gt_or_nan . logical ( . OR , is_same_but_first ) ;
const max_idx = keep_left_idx . select ( left_idx , right_idx ) ;
return . { . values = max_val , . indices = max_idx } ;
}
} ;
/// Returns two Tensors containing the maximum and the index of this maximum over a given axis.
///
/// Stable argmax:
/// * bubbles up Nan
/// * in case of equality the smallest index matching the maximum
2024-01-01 15:31:41 +00:00
pub fn argMax ( x : Tensor , axis_ : anytype ) ArgMaxRes {
2023-01-02 14:28:25 +00:00
const a = x . axis ( axis_ ) ;
2024-01-01 15:31:41 +00:00
const dt : DataType = if ( x . dim ( a ) < = std . math . maxInt ( i32 ) ) . i32 else . i64 ;
2023-01-02 14:28:25 +00:00
return ops . reduce (
ArgMaxRes . cmp ,
2024-01-01 15:31:41 +00:00
. { . values = x , . indices = Tensor . arange ( . { . end = x . dim ( a ) } , dt ) . broadcast ( x . shape ( ) , & . { a } ) } ,
. { . values = Tensor . constant ( & . { } , x . dtype ( ) . minValue ( ) ) , . indices = Tensor . scalar ( 0 , dt ) } ,
2023-01-02 14:28:25 +00:00
& . { a } ,
) ;
}
2023-03-08 14:10:11 +00:00
test argMax {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const ArgMaxTest = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x : Tensor ) Tensor . ArgMaxRes {
2024-01-01 15:31:41 +00:00
return x . argMax ( 1 ) ;
2023-03-08 14:10:11 +00:00
}
} ;
2024-01-08 17:55:20 +00:00
const argmax = try zml . compileFn ( allocator , ArgMaxTest . _fwd , . { Shape . init ( . { 1 , 5 } , . f32 ) } , platform ) ;
2023-03-08 14:10:11 +00:00
defer argmax . deinit ( ) ;
// Test with tie
{
const x = try zml . Buffer . fromArray ( platform , [ 1 ] [ 5 ] f32 { . { 5.0 , 4.1 , 7.9 , 0 , 7.9 } } ) ;
const res = argmax . call ( . { x } ) ;
const max_ = res . values . getValue ( f32 ) ;
const max_idx = res . indices . getValue ( i32 ) ;
try testing . expectEqual ( max_ , 7.9 ) ;
// We should always return the first max found.
try testing . expectEqual ( max_idx , 2 ) ;
}
// Test with Nan
{
const x = try zml . Buffer . fromArray ( platform , [ 1 ] [ 5 ] f32 { . { 5.0 , std . math . nan ( f32 ) , 7.9 , 0 , 7.9 } } ) ;
const res = argmax . call ( . { x } ) ;
const max_ = try res . values . getValue ( f32 ) ;
const max_idx = try res . indices . getValue ( i32 ) ;
try testing . expect ( std . math . isNan ( max_ ) ) ;
try testing . expectEqual ( max_idx , 1 ) ;
}
}
2023-01-02 14:28:25 +00:00
pub const SortRes = ArgMaxRes ;
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
2023-06-15 12:45:52 +00:00
pub fn sort ( self : Tensor , axis_ : anytype , opts : struct { descending : bool = false } ) SortRes {
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis_ ) ;
const indices = Tensor . arange ( . { . end = self . dim ( a ) } , . i32 ) . broadcast ( self . _shape , & . { a } ) ;
const res = ops . sort (
struct {
fn call ( direction : dialect . stablehlo . ComparisonDirection . Direction , lhs : Tensor , rhs : Tensor , _ : Tensor , _ : Tensor ) Tensor {
return lhs . cmp ( direction , rhs ) ;
}
} . call ,
if ( opts . descending ) . GT else . LT ,
. { self , indices } ,
self . axis ( axis_ ) ,
true ,
) ;
return . { . values = res [ 0 ] , . indices = res [ 1 ] } ;
}
2023-11-16 15:11:23 +00:00
pub const ArgSortOpts = struct { descending : bool = false } ;
2023-01-02 14:28:25 +00:00
/// Returns a Tensor containing the indices corresponding to the sorted values over the given axis.
2023-11-16 15:11:23 +00:00
pub fn argsort ( self : Tensor , axis_ : anytype , opts : ArgSortOpts ) Tensor {
2023-01-02 14:28:25 +00:00
return self . sort ( axis_ , . { . descending = opts . descending } ) . indices ;
}
2023-03-08 14:10:11 +00:00
test argsort {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
2023-11-16 15:11:23 +00:00
const Local = struct {
pub fn _argsort ( x : Tensor , axis_ : u3 , opts : ArgSortOpts ) Tensor {
return x . argsort ( axis_ , opts ) ;
}
} ;
2023-03-08 14:10:11 +00:00
var arena_state = std . heap . ArenaAllocator . init ( std . testing . allocator ) ;
defer arena_state . deinit ( ) ;
const allocator = arena_state . allocator ( ) ;
// 2D Tensor - dim = 1, ascending
{
const x = try zml . Buffer . fromSlice ( platform , . { 2 , 5 } , & [ _ ] f32 { - 0.9264 , 0.7156 , 1.0202 , 0.3992 , 1.2349 , 1.0003 , - 0.1932 , 1.3935 , 0.7316 , 0.0851 } ) ;
2023-11-16 15:11:23 +00:00
const res = try zml . testing . compileAndCall ( platform , Local . _argsort , . { x , 1 , . { } } ) ;
2023-03-08 14:10:11 +00:00
const res_cpu = try res . toHostAlloc ( allocator ) ;
try testing . expectEqualSlices ( i32 , & . { 0 , 3 , 1 , 2 , 4 , 1 , 4 , 3 , 0 , 2 } , res_cpu . items ( i32 ) ) ;
}
// 3D Tensor, dim = 1, descending
{
const x = try zml . Buffer . fromSlice ( platform , . { 1 , 5 , 10 } , & [ _ ] f16 {
- 0.2505 , 1.2520 , - 0.7041 , 0.1066 , 1.2773 , - 1.7246 , 0.8389 , 1.1094 , 0.0601 , 1.0684 ,
0.9619 , 1.3916 , 1.2246 , - 0.1406 , 0.3674 , - 1.2480 , - 1.7051 , - 0.0934 , 0.3435 , 0.4373 ,
1.3809 , 0.5444 , - 0.6079 , 1.2031 , - 0.6880 , 1.2979 , - 0.1869 , 0.2991 , 0.0156 , 0.1847 ,
0.6626 , - 0.3040 , - 0.8726 , - 1.4805 , - 1.6943 , 1.1055 , - 2.0078 , - 0.5288 , 0.8813 , 0.8008 ,
2.0527 , 1.1230 , 0.5430 , 0.2494 , - 0.9434 , 0.7876 , 0.1818 , 0.9258 , - 2.4902 , 1.5918 ,
} ) ;
2023-11-16 15:11:23 +00:00
const res_dev = try zml . testing . compileAndCall ( platform , Local . _argsort , . { x , 1 , . { . descending = true } } ) ;
2023-03-08 14:10:11 +00:00
const res = try res_dev . toHostAlloc ( allocator ) ;
try testing . expectEqualSlices ( i32 , & . {
4 , 1 , 1 , 2 , 0 , 2 , 0 , 0 , 3 , 4 ,
2 , 0 , 4 , 4 , 1 , 3 , 4 , 4 , 1 , 0 ,
1 , 4 , 2 , 0 , 2 , 4 , 2 , 2 , 0 , 3 ,
3 , 2 , 0 , 1 , 4 , 1 , 1 , 1 , 2 , 1 ,
0 , 3 , 3 , 3 , 3 , 0 , 3 , 3 , 4 , 2 ,
} , res . items ( i32 ) ) ;
}
// 4D Tensor, dim = 3, ascending
{
const x = try zml . Buffer . fromSlice ( platform , . { 4 , 2 , 1 , 4 } , & [ _ ] i32 {
89 , 31 , 22 , 42 ,
64 , 39 , 0 , 30 ,
64 , 71 , 46 , 31 ,
89 , 82 , 78 , 86 ,
55 , 32 , 43 , 19 ,
93 , 24 , 45 , 72 ,
64 , 86 , 62 , 88 ,
57 , 21 , 19 , 12 ,
} ) ;
2023-11-16 15:11:23 +00:00
const res_dev = try zml . testing . compileAndCall ( platform , Local . _argsort , . { x , 3 , . { } } ) ;
2023-03-08 14:10:11 +00:00
const res = try res_dev . toHostAlloc ( allocator ) ;
try testing . expectEqualSlices ( i32 , & . {
2 , 1 , 3 , 0 ,
2 , 3 , 1 , 0 ,
3 , 2 , 0 , 1 ,
2 , 1 , 3 , 0 ,
3 , 1 , 2 , 0 ,
1 , 2 , 3 , 0 ,
2 , 0 , 1 , 3 ,
3 , 2 , 1 , 0 ,
} , res . items ( i32 ) ) ;
}
}
2023-01-02 14:28:25 +00:00
/// Returns a Tensor representing the result of Top-K over the given axis.
pub fn topK ( self : Tensor , k : u32 , axis_ : anytype , opts : struct { descending : bool = true } ) SortRes {
const a = self . axis ( axis_ ) ;
const result = self . sort ( a , . { . descending = opts . descending } ) ;
return . {
. values = result . values . slice1d ( a , . { . end = k } ) ,
. indices = result . indices . slice1d ( a , . { . end = k } ) ,
} ;
}
pub const MaxPoolRes = ArgMaxRes ;
/// Computes the 1d maxPool operation on the input Tensor.
pub fn maxPool1d ( self : Tensor , opts : struct {
window_dimensions : i64 ,
window_strides : ? i64 ,
base_dilations : i64 = 1 ,
window_dilations : i64 = 1 ,
2023-05-17 09:01:27 +00:00
padding : [ 2 ] i64 = . { 0 , 0 } ,
2023-01-02 14:28:25 +00:00
} ) MaxPoolRes {
// TODO migrate to the following syntax.
// maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} },
// .b = .{ .stride = 8, .dilation = 2, .padding = .{0, 1} }),
// maxPool(.{
// .stride = .{ .a = 5, .b = 8 },
// .dilation = .{ .a = 2, .b = 2 },
// .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2}
// })
// TODO: support maxPool on non last axis
const a = self . axis ( - 1 ) ;
2023-05-17 09:01:27 +00:00
const ones = [ _ ] i64 { 1 } * * Tensor . MAX_RANK ;
var window_dimensions = ones ;
window_dimensions [ a ] = opts . window_dimensions ;
var window_strides = window_dimensions ;
if ( opts . window_strides ) | stride | window_strides [ a ] = stride ;
var base_dilations = ones ;
base_dilations [ a ] = opts . base_dilations ;
var window_dilations = ones ;
window_dilations [ a ] = opts . window_dilations ;
2023-01-02 14:28:25 +00:00
2023-05-17 09:01:27 +00:00
var padding = [ _ ] [ 2 ] i64 { . { 0 , 0 } } * * Tensor . MAX_RANK ;
padding [ a ] = opts . padding ;
2023-01-02 14:28:25 +00:00
return ops . reduceWindow (
MaxPoolRes . cmp ,
2023-06-15 12:45:52 +00:00
. { . values = self , . indices = iota ( self . _shape , a ) } ,
2023-01-02 14:28:25 +00:00
. { . values = Tensor . constant ( . { } , self . dtype ( ) . minValue ( ) ) , . indices = Tensor . scalar ( 0 , . i32 ) } ,
. {
. window_dimensions = window_dimensions [ 0 . . self . rank ( ) ] ,
. window_strides = window_strides [ 0 . . self . rank ( ) ] ,
2023-05-17 09:01:27 +00:00
. base_dilations = base_dilations [ 0 . . self . rank ( ) ] ,
2023-01-02 14:28:25 +00:00
. window_dilations = window_dilations [ 0 . . self . rank ( ) ] ,
2023-05-17 09:01:27 +00:00
. padding = padding [ 0 . . self . rank ( ) ] ,
2023-01-02 14:28:25 +00:00
} ,
) ;
}
/// Computes the 2d maxPool operation on the input Tensor.
pub fn maxPool2d ( self : Tensor , opts : struct {
2023-05-17 09:01:27 +00:00
window_dimensions : [ 2 ] i64 ,
window_strides : ? [ 2 ] i64 = null ,
base_dilations : [ 2 ] i64 = . { 1 , 1 } ,
window_dilations : [ 2 ] i64 = . { 1 , 1 } ,
padding : [ 2 ] [ 2 ] i64 = . { . { 0 , 0 } , . { 0 , 0 } } ,
2023-01-02 14:28:25 +00:00
} ) MaxPoolRes {
2023-05-17 09:01:27 +00:00
// TODO: rewrite using modern ZML
2023-06-21 14:45:14 +00:00
stdx . debug . guard ( self . rank ( ) = = 3 or self . rank ( ) = = 4 , @src ( ) ) ;
2023-01-02 14:28:25 +00:00
// TODO: support maxPool on non last axis
// Note: the problem is initPoolArg assuming last axis
const a = self . axis ( - 1 ) ;
2023-05-17 09:01:27 +00:00
const window_dimensions = initPoolArg ( self . rank ( ) , & opts . window_dimensions ) ;
const window_strides = if ( opts . window_strides ) | ws | initPoolArg ( self . rank ( ) , & ws ) else window_dimensions ;
const base_dilation = initPoolArg ( self . rank ( ) , & opts . base_dilations ) ;
const window_dilations = initPoolArg ( self . rank ( ) , & opts . window_dilations ) ;
var padding = [ _ ] [ 2 ] i64 { . { 0 , 0 } } * * Tensor . MAX_RANK ;
padding [ a - 1 ] = opts . padding [ 0 ] ;
padding [ a ] = opts . padding [ 1 ] ;
2023-01-02 14:28:25 +00:00
return ops . reduceWindow (
MaxPoolRes . cmp ,
2023-06-15 12:45:52 +00:00
. { . values = self , . indices = iota ( self . _shape , a ) } ,
2023-01-02 14:28:25 +00:00
. { . values = Tensor . constant ( . { } , self . dtype ( ) . minValue ( ) ) , . indices = Tensor . scalar ( 0 , . i32 ) } ,
. {
. window_dimensions = window_dimensions [ 0 . . self . rank ( ) ] ,
. window_strides = window_strides [ 0 . . self . rank ( ) ] ,
. base_dilations = base_dilation [ 0 . . self . rank ( ) ] ,
. window_dilations = window_dilations [ 0 . . self . rank ( ) ] ,
2023-05-17 09:01:27 +00:00
. padding = padding [ 0 . . self . rank ( ) ] ,
2023-01-02 14:28:25 +00:00
} ,
) ;
}
2023-02-07 12:42:34 +00:00
/// Chunk a given tensor into exactly n parts of equal shape.
/// `self.dim(axis_)` must be divisible by n_chunks.
pub fn chunkExact ( self : Tensor , axis_ : anytype , n_chunks : comptime_int ) [ n_chunks ] Tensor {
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis_ ) ;
2023-02-07 12:42:34 +00:00
const d = self . dim ( a ) ;
const chunk_size = @divExact ( d , n_chunks ) ;
var chunks : [ n_chunks ] Tensor = undefined ;
for ( 0 . . n_chunks ) | i | {
const start : i64 = @as ( i64 , @intCast ( i ) ) * chunk_size ;
chunks [ i ] = self . slice1d ( a , . { . start = start , . end = start + chunk_size } ) ;
}
return chunks ;
2023-01-02 14:28:25 +00:00
}
2023-02-07 12:42:34 +00:00
test chunkExact {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
// Only test shapes
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-02-07 12:42:34 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 12 } , . a , 3 , . { . a = 4 } } ,
. { . { . a = 12 , . b = 2 } , . a , 3 , . { . a = 4 , . b = 2 } } ,
. { . { 12 , 2 } , 0 , 3 , . { 4 , 2 } } ,
} ) | testcase | {
const x_shape , const ax , const n_chunks , const res = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const chunks = x . chunkExact ( ax , n_chunks ) ;
const res_shape = Shape . init ( res , . f16 ) ;
2023-11-16 15:11:23 +00:00
for ( chunks ) | chk | {
2023-02-07 12:42:34 +00:00
try zml . testing . expectEqualShapes ( res_shape , chk . shape ( ) ) ;
}
}
2023-01-02 14:28:25 +00:00
}
2023-02-07 12:42:34 +00:00
/// Chunk a given tensor into n parts of equal shape, and one part with the remaining items.
/// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`.
pub fn chunkAllowTrailing (
self : Tensor ,
axis_ : i64 ,
n_chunks : comptime_int ,
2023-11-16 15:11:23 +00:00
) [ ] Tensor {
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis_ ) ;
2023-02-07 12:42:34 +00:00
const d = self . dim ( a ) ;
const chunk_size : i64 = @divFloor ( d , n_chunks ) ;
const tail_chunk_size : i64 = @rem ( d , chunk_size ) ;
2023-11-16 15:11:23 +00:00
const allocator = self . getContext ( ) . allocator ( ) ;
var chunks = std . ArrayListUnmanaged ( Tensor ) . initCapacity ( allocator , n_chunks + 1 ) catch @panic ( " OOM " ) ;
2023-02-07 12:42:34 +00:00
for ( 0 . . n_chunks ) | i | {
2023-01-02 14:28:25 +00:00
const start : i64 = @as ( i64 , @intCast ( i ) ) * chunk_size ;
2023-02-07 12:42:34 +00:00
chunks . appendAssumeCapacity (
self . slice1d ( a , . { . start = start , . end = start + chunk_size } ) ,
) ;
2023-01-02 14:28:25 +00:00
}
if ( tail_chunk_size ! = 0 ) {
2023-02-07 12:42:34 +00:00
const start : i64 = n_chunks * chunk_size ;
chunks . appendAssumeCapacity ( self . slice1d ( a , . { . start = start } ) ) ;
}
2023-11-16 15:11:23 +00:00
return chunks . items ;
2023-02-07 12:42:34 +00:00
}
test chunkAllowTrailing {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
// Only test shapes
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-02-07 12:42:34 +00:00
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 } , . a , 3 , . { . a = 3 } , . { . a = 1 } } ,
. { . { . a = 10 , . b = 2 } , . a , 3 , . { . a = 3 , . b = 2 } , . { . a = 1 , . b = 2 } } ,
. { . { 10 , 2 } , 0 , 3 , . { 3 , 2 } , . { 1 , 2 } } ,
. { . { 12 , 2 } , 0 , 3 , . { 4 , 2 } , . { } } ,
} ) | testcase | {
const x_shape , const ax , const n_chunks , const res , const trailing = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const chunks = x . chunkAllowTrailing ( x . axis ( ax ) , n_chunks ) ;
const res_shape = Shape . init ( res , . f16 ) ;
2023-11-16 15:11:23 +00:00
for ( chunks [ 0 . . n_chunks ] ) | chk | {
2023-02-07 12:42:34 +00:00
try zml . testing . expectEqualShapes ( res_shape , chk . shape ( ) ) ;
}
const trailing_shape = Shape . init ( trailing , . f16 ) ;
if ( trailing_shape . rank ( ) > 0 ) {
try std . testing . expectEqual ( n_chunks + 1 , chunks . len ) ;
2023-11-16 15:11:23 +00:00
try zml . testing . expectEqualShapes ( trailing_shape , chunks [ n_chunks ] . shape ( ) ) ;
2023-02-07 12:42:34 +00:00
} else {
try std . testing . expectEqual ( n_chunks , chunks . len ) ;
}
2023-01-02 14:28:25 +00:00
}
}
2023-11-16 15:11:23 +00:00
pub fn split ( self : Tensor , axis_ : anytype , split_sizes : [ ] const i64 ) [ ] Tensor {
stdx . debug . assert ( split_sizes . len > 0 , " split expects at least one 'split_sizes', got 0 " , . { } ) ;
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis_ ) ;
2023-11-16 15:11:23 +00:00
const d = self . dim ( a ) ;
var split_sum : i64 = 0 ;
for ( split_sizes ) | n | split_sum + = n ;
stdx . debug . assert ( split_sum = = d , " split expects sum of 'split_sizes' values and axis dimension to be equal, got {} and {} " , . { split_sum , d } ) ;
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
const allocator = self . getContext ( ) . allocator ( ) ;
const res = allocator . alloc ( Tensor , split_sizes . len ) catch @panic ( " OOM " ) ;
2023-02-07 12:42:34 +00:00
errdefer allocator . dealloc ( res ) ;
var start : i64 = 0 ;
2023-11-16 15:11:23 +00:00
for ( split_sizes , 0 . . ) | n , i | {
2023-02-07 12:42:34 +00:00
res [ i ] = self . slice1d ( a , . { . start = start , . end = start + n } ) ;
start + = n ;
}
return res ;
2023-01-02 14:28:25 +00:00
}
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
2023-04-21 15:55:07 +00:00
/// Note: this doesn't support tagging, if you have tags,
/// you should use `dynamicSlice` directly.
2023-12-18 13:56:45 +00:00
pub fn dynamicSlice1d ( self : Tensor , axis_ : i8 , slice_ : DynSlice ) Tensor {
stdx . debug . assert ( slice_ . start . rank ( ) = = 0 , " dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {} " , . { slice_ . start } ) ;
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis_ ) ;
2023-12-18 13:56:45 +00:00
const new_shape = self . _shape . set ( a , slice_ . len ) ;
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " dynSlice({}, len={}) " , . { axis_ , slice_ . len } ) ;
2023-12-18 13:56:45 +00:00
var start_indices = [ _ ] mlir . Value { constant ( . { } , slice_ . start . dtype ( ) . zero ( ) ) . value ( ) } * * MAX_RANK ;
start_indices [ a ] = slice_ . start . value ( ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . dynamicSlice (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
new_shape . dims ( ) ,
2023-12-18 13:56:45 +00:00
start_indices [ 0 . . self . rank ( ) ] ,
2023-01-02 14:28:25 +00:00
loc ,
) ;
return _result ( new_shape , op . result ( 0 ) ) ;
}
2023-01-18 12:03:48 +00:00
pub const DynSlice = struct { start : Tensor , len : i64 } ;
2023-01-02 14:28:25 +00:00
/// Slices a Tensor across many axes, with runtime known offsets.
///
/// Due to the nature of stablehlo, the length of the slices need to be known when compiling the IR.
/// When using the tagged API it is allowed to not specify some axes.
/// But with the non-tagged API all slices need to be specified.
/// Examples:
/// ```
/// Tensor(.{.a=20,.b=30,.c=40 }).dynamicSlice(.{ .a = .{ .start = a_off, .len = 11});
/// Tensor(.{.a=20,.b=30,.c=40 }).dynamicSlice(.{
/// .a = .{ .start = a_off, .len = 11 },
/// .b = .{ .start = b_off, .len = 12 },
/// });
/// Tensor(.{ 20,30,40}).dynamicSlice(.{.{ .start = scalar(0, .i32), .len = 20 }, .{ .start = b_off, .len = 12 }, .{ .start = scalar(0, .i32), .len = 40 }});
/// ```
pub fn dynamicSlice ( self : Tensor , slices_ : anytype ) Tensor {
// TODO: the untagged api is a bit verbose. Should I allow: `Tensor(.{ 20,30,40}).dynamicSlice(.{.{}, .{ .start = b_off, .len = 12 }, .{}});` ??
//
const slices , const slices_tags = Shape . parseStruct ( DynSlice , slices_ ) ;
// TODO use slices and slices_tags for the format.
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
// which is kinda ugly.
2023-12-18 13:56:45 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " dynSlice({any}) " , . { slices_ } ) ;
2023-01-02 14:28:25 +00:00
const idx_dtype = if ( slices . len > 0 ) slices . get ( 0 ) . start . dtype ( ) else . i32 ;
const zero = Tensor . scalar ( 0 , idx_dtype ) . value ( ) ;
var offset_values = [ _ ] mlir . Value { zero } * * MAX_RANK ;
var res_shape = self . _shape ;
for ( slices . constSlice ( ) , 0 . . ) | slice_ , i | {
const offset = slice_ . start ;
const len = slice_ . len ;
if ( slices_tags . len = = 0 ) {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) = = slices . len , " dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {} " , . { self . rank ( ) , slices . len } ) ;
2023-01-02 14:28:25 +00:00
offset_values [ i ] = offset . value ( ) ;
res_shape . _dims . set ( i , len ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( len < = self . dim ( i ) , " dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {} " , . { len , self . dim ( i ) , i } ) ;
2023-01-02 14:28:25 +00:00
} else {
const t = slices_tags . get ( i ) ;
2023-06-21 14:45:14 +00:00
const a = res_shape . hasTag ( t ) orelse stdx . debug . panic ( " dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {}) " , . { t , self . _shape } ) ;
2023-01-02 14:28:25 +00:00
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( len < = self . dim ( a ) , " dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s} " , . { len , self . dim ( a ) , t } ) ;
2023-01-02 14:28:25 +00:00
offset_values [ a ] = offset . value ( ) ;
res_shape . _dims . set ( a , len ) ;
}
}
const op = dialect . stablehlo . dynamicSlice ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , res_shape . dims ( ) , offset_values [ 0 . . self . rank ( ) ] , loc ) ;
return _result ( res_shape , op . result ( 0 ) ) ;
}
2023-03-08 14:10:11 +00:00
test dynamicSlice {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const T = f32 ;
{
const x = try zml . Buffer . fromArray ( platform , [ 10 ] T { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 } ) ;
const z = try zml . Buffer . scalar ( platform , 4 , . i32 ) ;
2023-12-18 13:56:45 +00:00
const res = try zml . testing . compileAndCall ( platform , Tensor . dynamicSlice1d , . { x , 0 , . { . len = 2 , . start = z } } ) ;
2023-03-08 14:10:11 +00:00
try testing . expectEqual ( [ 2 ] T { 4 , 5 } , try res . getValue ( [ 2 ] T ) ) ;
}
{
// Strided
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 5 ] T { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 , 9 } } ) ;
const z = try zml . Buffer . scalar ( platform , 3 , . i32 ) ;
2023-12-18 13:56:45 +00:00
const res = try zml . testing . compileAndCall ( platform , Tensor . dynamicSlice1d , . { x , 1 , . { . len = 2 , . start = z } } ) ;
2023-03-08 14:10:11 +00:00
try testing . expectEqual ( [ 4 ] T { 3 , 4 , 8 , 9 } , res . getValue ( [ 4 ] T ) ) ;
}
}
2023-01-02 14:28:25 +00:00
/// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime.
2023-09-21 11:15:50 +00:00
/// Note this is the untagged api, if you have tags, you should use dynamicUpdateSlice directly.
2023-01-02 14:28:25 +00:00
pub fn dynamicUpdateSlice1d ( self : Tensor , update : Tensor , axis_ : i64 , offset : Tensor ) Tensor {
const placeholder = Tensor . scalar ( 0 , . i32 ) ;
var start_indices = [ _ ] Tensor { placeholder } * * MAX_RANK ;
start_indices [ self . axis ( axis_ ) ] = offset ;
return self . dynamicUpdateSlice ( start_indices [ 0 . . self . rank ( ) ] , update ) ;
}
/// Updates a part of the input Tensor using the given 'update' Tensor, with runtime known offsets.
///
/// The offsets are specified similarly to the dynamicSlice api.
/// It's semantically equivalent to:
/// self.dynamicSlice(offsets_) := update
/// Examples:
/// ```
/// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 }));
/// ```
pub fn dynamicUpdateSlice ( self : Tensor , offset_ : anytype , update_ : Tensor ) Tensor {
2023-05-17 09:01:27 +00:00
// TODO: add updateSlice for when the offset isn't dynamic
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) = = update_ . dtype ( ) , " dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {} " , . { self . dtype ( ) , update_ . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
const offset , const offset_tags = Shape . parseStruct ( Tensor , offset_ ) ;
// log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags });
for ( offset . constSlice ( ) , 0 . . ) | start_idx , i | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( start_idx . rank ( ) = = 0 , " dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {} " , . { start_idx . rank ( ) , i } ) ;
2023-01-02 14:28:25 +00:00
}
const tagged_api = update_ . _shape . isFullyTagged ( ) and self . _shape . isFullyTagged ( ) and offset_tags . len > 0 ;
// When using tags, we can safely insert axis with a 1-dim.
// the offset into the inserted axis will need to be specified through indices.
var update = update_ ;
if ( tagged_api ) {
// Check that all update tags are known.
for ( update . _shape . _tags . constSlice ( ) ) | t | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . _shape . hasTag ( t ) ! = null , " dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {}) " , . { t , self . _shape } ) ;
2023-01-02 14:28:25 +00:00
}
var update_shape = self . _shape ;
var prev_ax : i8 = - 1 ;
for ( self . _shape . tags ( ) , 0 . . ) | t , self_ax | {
if ( update . _shape . hasTag ( t ) ) | up_ax | {
2023-08-14 14:24:11 +00:00
stdx . debug . assert ( up_ax = = prev_ax + 1 , " dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_') " , . { update_ , self } ) ;
2023-01-02 14:28:25 +00:00
update_shape . _dims . set ( self_ax , update . dim ( up_ax ) ) ;
prev_ax = up_ax ;
} else {
update_shape . _dims . set ( self_ax , 1 ) ;
}
}
update = update . reshape ( update_shape ) ;
}
2023-08-14 14:24:11 +00:00
stdx . debug . assert ( self . rank ( ) = = update . rank ( ) , " dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} " , . { self , update } ) ;
2023-01-02 14:28:25 +00:00
for ( self . dims ( ) , update . dims ( ) , 0 . . ) | self_d , up_d , ax | {
const t = self . _shape . debugTag ( ax ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( up_d < = self_d , " dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} " , . { up_d , self_d , t } ) ;
2023-01-02 14:28:25 +00:00
if ( tagged_api and up_d < self_d ) {
const axis_has_offset = std . mem . indexOfScalar ( Shape . Tag , offset_tags . constSlice ( ) , self . _shape . _tags . get ( ax ) ) ! = null ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( axis_has_offset , " dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset) " , . { up_d , self_d , t } ) ;
2023-01-02 14:28:25 +00:00
}
}
const idx_dtype = if ( offset . len > 0 ) offset . get ( 0 ) . dtype ( ) else . i32 ;
const zero = Tensor . scalar ( 0 , idx_dtype ) . value ( ) ;
var offset_values : [ MAX_RANK ] mlir . Value = undefined ;
if ( offset_tags . len = = 0 ) {
// Without offset tags we need the same number of offset than rank.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) = = offset . len , " dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {} " , . { self . rank ( ) , offset . len } ) ;
2023-01-02 14:28:25 +00:00
for ( offset . constSlice ( ) , 0 . . ) | idx , i | {
offset_values [ i ] = idx . value ( ) ;
}
} else {
// If an axis isn't specified, update the full slice.
// This is only allowed when using tagged sliced.
offset_values = . { zero } * * MAX_RANK ;
for ( offset . constSlice ( ) , offset_tags . constSlice ( ) ) | start , t | {
2023-06-21 14:45:14 +00:00
const a = self . _shape . hasTag ( t ) orelse stdx . debug . panic ( " dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {}) " , . { t , self . _shape } ) ;
2023-01-02 14:28:25 +00:00
offset_values [ a ] = start . value ( ) ;
}
}
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . dynamic_update_slice (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
update . value ( ) ,
offset_values [ 0 . . self . rank ( ) ] ,
loc ,
) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
test dynamicUpdateSlice {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
{
const x = try zml . Buffer . fromArray ( platform , [ 10 ] f32 { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 } ) ;
const y = try zml . Buffer . fromArray ( platform , [ 2 ] f32 { - 1 , - 1 } ) ;
const idx = try zml . Buffer . scalar ( platform , 4 , . i32 ) ;
const res = try zml . testing . compileAndCall (
platform ,
struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x_ : Tensor , idx_ : struct { a : Tensor } , y_ : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return x_ . dynamicUpdateSlice ( idx_ , y_ ) ;
}
2024-01-08 17:55:20 +00:00
} . _fwd ,
2023-01-02 14:28:25 +00:00
. { x . withTags ( . { . a } ) , . { . a = idx } , y . withTags ( . { . a } ) } ,
) ;
try testing . expectEqual ( [ 10 ] f32 { 0 , 1 , 2 , 3 , - 1 , - 1 , 6 , 7 , 8 , 9 } , try res . getValue ( [ 10 ] f32 ) ) ;
}
{
// Updates 2D, tagged api
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 , 9 } } ) ;
const y = try zml . Buffer . fromArray ( platform , [ 2 ] f32 { - 1 , - 1 } ) ;
const idx = try zml . Buffer . scalar ( platform , 3 , . i32 ) ;
const res = try zml . testing . compileAndCall (
platform ,
struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x_ : Tensor , idx_ : Tensor , y_ : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return x_ . dynamicUpdateSlice ( . { . b = idx_ } , y_ ) ;
}
2024-01-08 17:55:20 +00:00
} . _fwd ,
2023-01-02 14:28:25 +00:00
. { x . withTags ( . { . a , . b } ) , idx , y . withTags ( . { . a } ) } ,
) ;
try testing . expectEqualDeep (
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , - 1 , 4 } , . { 5 , 6 , 7 , - 1 , 9 } } ,
try res . getValue ( [ 2 ] [ 5 ] f32 ) ,
) ;
}
{
// Updates 2D slice, un-tagged api. Note that `y` needs to have a 1 dimension axis.
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 , 9 } } ) ;
const y = try zml . Buffer . fromArray ( platform , [ 2 ] [ 1 ] f32 { . { - 1 } , . { - 1 } } ) ;
const idx = try zml . Buffer . scalar ( platform , 3 , . i32 ) ;
const res = try zml . testing . compileAndCall (
platform ,
struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x_ : Tensor , idx_ : Tensor , y_ : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return x_ . dynamicUpdateSlice ( . { zml . Tensor . scalar ( 0 , . i32 ) , idx_ } , y_ ) ;
}
2024-01-08 17:55:20 +00:00
} . _fwd ,
2023-01-02 14:28:25 +00:00
. { x , idx , y } ,
) ;
try testing . expectEqualDeep (
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , - 1 , 4 } , . { 5 , 6 , 7 , - 1 , 9 } } ,
res . getValue ( [ 2 ] [ 5 ] f32 ) ,
) ;
}
{
// Updates 2D, partial update
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 , 9 } } ) ;
const y = try zml . Buffer . fromArray ( platform , [ 1 ] f32 { - 1 } ) ;
const idx_a = try zml . Buffer . scalar ( platform , 1 , . i32 ) ;
const idx_b = try zml . Buffer . scalar ( platform , 3 , . i32 ) ;
const res = try zml . testing . compileAndCall (
platform ,
struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x_ : Tensor , idx_ : struct { a : Tensor , b : Tensor } , y_ : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return x_ . dynamicUpdateSlice ( idx_ , y_ ) ;
}
2024-01-08 17:55:20 +00:00
} . _fwd ,
2023-01-02 14:28:25 +00:00
. { x . withTags ( . { . a , . b } ) , . { . a = idx_a , . b = idx_b } , y . withTags ( . { . a } ) } ,
) ;
try testing . expectEqualDeep (
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , - 1 , 9 } } ,
res . getValue ( [ 2 ] [ 5 ] f32 ) ,
) ;
}
{
// Updates 2D, partial update, un-tagged api.
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 , 9 } } ) ;
const y = try zml . Buffer . fromArray ( platform , [ 1 ] [ 1 ] f32 { . { - 1 } } ) ;
const idx_a = try zml . Buffer . scalar ( platform , 1 , . i32 ) ;
const idx_b = try zml . Buffer . scalar ( platform , 3 , . i32 ) ;
const A = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x_ : Tensor , idx_ : [ 2 ] Tensor , y_ : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return x_ . dynamicUpdateSlice ( & idx_ , y_ ) ;
}
} ;
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCall ( platform , A . _fwd , . { x , . { idx_a , idx_b } , y } ) ;
2023-01-02 14:28:25 +00:00
try testing . expectEqualDeep (
[ 2 ] [ 5 ] f32 { . { 0 , 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , - 1 , 9 } } ,
res . getValue ( [ 2 ] [ 5 ] f32 ) ,
) ;
}
}
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
pub fn cmp ( self : Tensor , direction : dialect . stablehlo . ComparisonDirection . Direction , other : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) = = other . dtype ( ) , " cmp expects input tensors to be of the same type, got {} and {} " , . { self . dtype ( ) , other . dtype ( ) } ) ;
2023-01-02 14:28:25 +00:00
if ( self . rank ( ) = = 0 and other . rank ( ) ! = 0 ) return self . broadcast ( other . _shape , & . { } ) . cmp ( direction , other ) ;
if ( self . rank ( ) ! = 0 and other . rank ( ) = = 0 ) return self . cmp ( direction , other . broadcast ( self . _shape , & . { } ) ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . _shape . eql ( other . _shape ) , " cmp expects input tensor shapes to match, got {} and {} " , . { self . _shape , other . _shape } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const loc = self . getContext ( ) . location ( @src ( ) , " cmp(.{s}) " , . { @tagName ( direction ) } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . compare (
self . getContext ( ) . mlirCtx ( ) ,
self . value ( ) ,
other . value ( ) ,
dialect . stablehlo . ComparisonDirection . init ( self . getContext ( ) . mlirCtx ( ) , direction ) ,
getComparisonType ( self . getContext ( ) . mlirCtx ( ) , self . dtype ( ) ) ,
loc ,
) ;
return _result ( self . _shape . withDtype ( . bool ) , op . result ( 0 ) ) ;
}
2023-05-18 16:39:21 +00:00
/// For each vector in the input tensor,
/// creates a diagonal-matrix where diagonal values are set to the vector values.
pub fn toDiagonal ( self : Tensor , axis_ : anytype , new_tags : [ 2 ] EnumLiteral ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) < MAX_RANK - 1 , " toDiagonal expects input up to {} rank, got {} " , . { MAX_RANK - 1 , self } ) ;
2023-05-18 16:39:21 +00:00
const a = self . axis ( axis_ ) ;
const d = self . dim ( a ) ;
var res_shape = self . _shape ;
res_shape . _dims . replaceRange ( a , 1 , & . { d , d } ) catch unreachable ;
res_shape . _tags . replaceRange ( a , 1 , & . { @tagName ( new_tags [ 0 ] ) , @tagName ( new_tags [ 1 ] ) } ) catch unreachable ;
const values = self . insertAxes ( a + 1 , . { new_tags [ 1 ] } ) . broad ( res_shape ) ;
const zeros = Tensor . constant ( res_shape , self . dtype ( ) . zero ( ) ) ;
2023-06-15 12:45:52 +00:00
const x = Tensor . iota ( res_shape , a ) ;
const y = Tensor . iota ( res_shape , a + 1 ) ;
2023-05-18 16:39:21 +00:00
var res = x . cmp ( . EQ , y ) . select ( values , zeros ) ;
res . _shape = res_shape ;
return res ;
}
test toDiagonal {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
pub fn _toDiag ( input : Tensor ) Tensor {
const res = input . toDiagonal ( - 1 , . { . x , . y } ) ;
std . debug . assert ( res . dim ( . x ) = = input . dim ( - 1 ) ) ;
std . debug . assert ( res . dim ( . y ) = = input . dim ( - 1 ) ) ;
return res ;
}
} ;
const x = try zml . Buffer . fromArray ( platform , [ 2 ] [ 2 ] u8 { . { 1 , 2 } , . { 3 , 4 } } ) ;
{
const res = try zml . testing . compileAndCall ( platform , Local . _toDiag , . { x } ) ;
try testing . expectEqual (
[ 2 ] [ 2 ] [ 2 ] u8 { . {
. { 1 , 0 } ,
. { 0 , 2 } ,
} , . {
. { 3 , 0 } ,
. { 0 , 4 } ,
} } ,
try res . getValue ( [ 2 ] [ 2 ] [ 2 ] u8 ) ,
) ;
}
}
/// For each matrix specified by the two axes, returns the lower triangular part of it.
/// The other elements are set to 0.
/// Usage: `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .w, .h}, 0);`
///
/// * if `num_diagonals` is set to 0, the diagonal is not modified.
/// * if set to -1, the diagonal is set to 0
/// * if set to n, the n "quasi diagonal" above the diagonal are conserved.
///
/// To get the upper triangular part, swap the order of axes:
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
pub fn triangular ( self : Tensor , axes_ : anytype , num_diagonals : i32 ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assertComptime ( stdx . meta . isTuple ( @TypeOf ( axes_ ) ) and axes_ . len = = 2 , " triangular expects exactly two axes to work on. " , . { } ) ;
2023-05-18 16:39:21 +00:00
const _axes = self . axes ( axes_ ) ;
2023-06-15 12:45:52 +00:00
const x = Tensor . iota ( self . shape ( ) , _axes . get ( 0 ) ) ;
const y = Tensor . iota ( self . shape ( ) , _axes . get ( 1 ) ) ;
2023-05-18 16:39:21 +00:00
const zeros = Tensor . constant ( self . shape ( ) , self . dtype ( ) . zero ( ) ) ;
return x . addConstant ( num_diagonals ) . cmp ( . GE , y ) . select ( self , zeros ) ;
}
test triangular {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
pub fn _tri ( input : Tensor , num_diagonals : i32 ) Tensor {
return input . triangular ( . { - 2 , - 1 } , num_diagonals ) ;
}
} ;
const x = try zml . Buffer . fromArray ( platform , [ 3 ] [ 3 ] u8 {
. { 1 , 1 , 1 } ,
. { 1 , 1 , 1 } ,
. { 1 , 1 , 1 } ,
} ) ;
{
const res = try zml . testing . compileAndCall ( platform , Local . _tri , . { x , 0 } ) ;
try testing . expectEqual (
[ 3 ] [ 3 ] u8 {
. { 1 , 0 , 0 } ,
. { 1 , 1 , 0 } ,
. { 1 , 1 , 1 } ,
} ,
try res . getValue ( [ 3 ] [ 3 ] u8 ) ,
) ;
}
{
const res = try zml . testing . compileAndCall ( platform , Local . _tri , . { x , 1 } ) ;
try testing . expectEqual (
[ 3 ] [ 3 ] u8 {
. { 1 , 1 , 0 } ,
. { 1 , 1 , 1 } ,
. { 1 , 1 , 1 } ,
} ,
try res . getValue ( [ 3 ] [ 3 ] u8 ) ,
) ;
}
{
const res = try zml . testing . compileAndCall ( platform , Local . _tri , . { x , - 1 } ) ;
try testing . expectEqual (
[ 3 ] [ 3 ] u8 {
. { 0 , 0 , 0 } ,
. { 1 , 0 , 0 } ,
. { 1 , 1 , 0 } ,
} ,
try res . getValue ( [ 3 ] [ 3 ] u8 ) ,
) ;
}
}
2023-01-02 14:28:25 +00:00
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
pub fn select ( bool_tensor : Tensor , on_true : Tensor , on_false : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( bool_tensor . dtype ( ) = = . bool , " select expects input tensor type to be a boolean, got {} " , . { bool_tensor . dtype ( ) } ) ;
stdx . debug . assert ( on_true . dtype ( ) = = on_false . dtype ( ) , " select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {} " , . { on_true . dtype ( ) , on_false . dtype ( ) } ) ;
2023-06-15 12:45:52 +00:00
if ( bool_tensor . rank ( ) ! = 0 and on_true . rank ( ) = = 0 ) {
return bool_tensor . select ( on_true . broad ( bool_tensor . shape ( ) ) , on_false ) ;
}
if ( bool_tensor . rank ( ) ! = 0 and on_false . rank ( ) = = 0 ) {
return bool_tensor . select ( on_true , on_false . broad ( bool_tensor . shape ( ) ) ) ;
}
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( bool_tensor . _shape . eqlDims ( on_true . _shape ) , " select expects input tensor and 'on_true' tensor dimensions to match, got {} and {} " , . { bool_tensor . _shape , on_true . _shape } ) ;
stdx . debug . assert ( bool_tensor . _shape . eqlDims ( on_false . _shape ) , " select expects input tensor and 'on_false' tensor dimensions to match, got {} and {} " , . { bool_tensor . _shape , on_false . _shape } ) ;
2023-01-02 14:28:25 +00:00
const loc = bool_tensor . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . select (
bool_tensor . getContext ( ) . mlirCtx ( ) ,
bool_tensor . value ( ) ,
on_true . value ( ) ,
on_false . value ( ) ,
loc ,
) ;
return _result ( on_true . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing the element-wise not logical operation of the input Tensor.
pub fn not ( self : Tensor ) Tensor {
const loc = self . getContext ( ) . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . not ( self . getContext ( ) . mlirCtx ( ) , self . value ( ) , loc ) ;
return _result ( self . _shape , op . result ( 0 ) ) ;
}
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
2023-11-16 15:11:23 +00:00
pub fn any ( self : Tensor , axis_ : anytype ) Tensor {
2023-01-02 14:28:25 +00:00
const pred = self . cmp ( . NE , Tensor . constant ( self . dims ( ) , self . dtype ( ) . zero ( ) ) ) ;
2024-01-01 15:31:41 +00:00
return ops . reduce (
2023-01-02 14:28:25 +00:00
struct {
pub fn acc ( x : Tensor , res : Tensor ) Tensor {
return res . logical ( . OR , x ) ;
}
} . acc ,
pred ,
2024-01-01 15:31:41 +00:00
Tensor . scalar ( false , . bool ) ,
& . { self . axis ( axis_ ) } ,
) ;
}
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
pub fn all ( self : Tensor , axis_ : anytype ) Tensor {
const pred = if ( self . dtype ( ) = = . bool ) self else self . cmp ( . NE , Tensor . scalar ( 0 , self . dtype ( ) ) ) ;
return ops . reduce (
struct {
pub fn acc ( x : Tensor , res : Tensor ) Tensor {
return res . logical ( . AND , x ) ;
}
} . acc ,
pred ,
Tensor . scalar ( true , . bool ) ,
2023-01-02 14:28:25 +00:00
& . { self . axis ( axis_ ) } ,
) ;
}
/// Given a set of N vectors of lengths A, B, C, D,
/// returns N tensors of rank N, and shape (A, B, C, D).
2024-01-08 17:55:20 +00:00
/// For any coordinate (a, b, c, d), we have:
///
2023-01-02 14:28:25 +00:00
/// - res[0][a, b, c, d] == A[a]
/// - res[1][a, b, c, d] == B[b]
/// - res[2][a, b, c, d] == C[c]
/// - res[3][a, b, c, d] == D[d]
2024-01-08 17:55:20 +00:00
///
2023-01-02 14:28:25 +00:00
/// This is implemented with broadcasting, so typically it won't copy.
/// In Pytorch/Numpy this is know as `meshgrid` with "ij" mode.
2024-01-08 17:55:20 +00:00
/// See `zml.torch.meshgrid` for the "xy" mode.
2023-01-02 14:28:25 +00:00
pub fn cartesianProduct ( comptime N : u3 , vectors : [ N ] Tensor ) [ N ] Tensor {
var out : @TypeOf ( vectors ) = undefined ;
_cartesianProduct ( & vectors , & out ) ;
return out ;
}
fn _cartesianProduct ( vectors : [ ] const Tensor , out : [ ] Tensor ) void {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( vectors . len > = 1 , " cartesianProduct expects at least one input. " , . { } ) ;
stdx . debug . assert ( vectors . len < Tensor . MAX_RANK , " cartesianProduct expects at most {} input vectors, received {} ! " , . { Tensor . MAX_RANK - 1 , vectors . len } ) ;
2023-01-02 14:28:25 +00:00
for ( vectors ) | x | {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( x . rank ( ) < = 1 , " cartesianProduct expects 0 or 1 rank input vectors. Got: {any} " , . { vectors } ) ;
stdx . debug . assert ( vectors [ 0 ] . dtype ( ) = = x . dtype ( ) , " cartesianProduct expects input vectors to have all the same dtype. Got: {any} " , . { vectors } ) ;
2023-01-02 14:28:25 +00:00
}
var res_shape = Shape . init ( . { } , vectors [ 0 ] . dtype ( ) ) ;
for ( vectors ) | x | {
if ( x . rank ( ) = = 0 ) {
res_shape = res_shape . appendDim ( 1 , null ) ;
} else {
res_shape = res_shape . appendDim ( x . dim ( 0 ) , x . shape ( ) . tag ( 0 ) ) ;
}
}
for ( out , vectors , 0 . . ) | * o , x , i | {
o . * = x . broadcast ( res_shape , & [ 1 ] i64 { @intCast ( i ) } ) ;
}
}
test cartesianProduct {
const zml = @import ( " zml.zig " ) ;
const client = zml . testing . env ( ) ;
const x = try zml . Buffer . fromSlice ( client , . { 6 } , & [ _ ] i32 { 0 , 1 , 2 , 3 , 4 , 5 } ) ;
const y = try zml . Buffer . fromSlice ( client , . { 4 } , & [ _ ] i32 { 0 , 1 , 2 , 3 } ) ;
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _cartesianProduct2 ( a : Tensor , b : Tensor ) [ 2 ] Tensor {
2023-01-02 14:28:25 +00:00
return cartesianProduct ( 2 , . { a , b } ) ;
}
} ;
{
2024-01-08 17:55:20 +00:00
const xs , const ys = try zml . testing . compileAndCall ( client , Local . _cartesianProduct2 , . { x , y } ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqualSlices ( i64 , & . { 6 , 4 } , xs . shape ( ) . dims ( ) ) ;
try std . testing . expectEqualSlices ( i64 , & . { 6 , 4 } , ys . shape ( ) . dims ( ) ) ;
try std . testing . expectEqualDeep (
[ 6 ] [ 4 ] i32 {
. { 0 , 0 , 0 , 0 } ,
. { 1 , 1 , 1 , 1 } ,
. { 2 , 2 , 2 , 2 } ,
. { 3 , 3 , 3 , 3 } ,
. { 4 , 4 , 4 , 4 } ,
. { 5 , 5 , 5 , 5 } ,
} ,
try xs . getValue ( [ 6 ] [ 4 ] i32 ) ,
) ;
try std . testing . expectEqualDeep (
[ 6 ] [ 4 ] i32 {
. { 0 , 1 , 2 , 3 } ,
. { 0 , 1 , 2 , 3 } ,
. { 0 , 1 , 2 , 3 } ,
. { 0 , 1 , 2 , 3 } ,
. { 0 , 1 , 2 , 3 } ,
. { 0 , 1 , 2 , 3 } ,
} ,
try ys . getValue ( [ 6 ] [ 4 ] i32 ) ,
) ;
}
}
/// Given a set of N vectors of lengths A, B, C, D,
/// returns 1 tensors of rank N+1, and shape (A, B, C, D, N).
2024-01-08 17:55:20 +00:00
/// For any coordinate (a, b, c, d), we have:
///
2023-01-02 14:28:25 +00:00
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
pub fn cartesianProductStacked ( vectors : [ ] const Tensor ) Tensor {
var out = std . BoundedArray ( Tensor , Tensor . MAX_RANK ) . init ( vectors . len ) catch unreachable ;
_cartesianProduct ( vectors , out . slice ( ) ) ;
return Tensor . stack ( out . constSlice ( ) , . last , . coord ) ;
}
test cartesianProductStacked {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const x = try zml . Buffer . fromSlice ( platform , . { 6 } , & [ _ ] i32 { 0 , 1 , 2 , 3 , 4 , 5 } ) ;
const y = try zml . Buffer . fromSlice ( platform , . { 4 } , & [ _ ] i32 { 0 , 1 , 2 , 3 } ) ;
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( a : Tensor , b : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
return cartesianProductStacked ( & . { a , b } ) ;
}
} ;
2024-01-08 17:55:20 +00:00
const z = try zml . testing . compileAndCall ( platform , Local . _fwd , . { x , y } ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqualDeep (
[ 6 ] [ 4 ] [ 2 ] i32 {
. { . { 0 , 0 } , . { 0 , 1 } , . { 0 , 2 } , . { 0 , 3 } } ,
. { . { 1 , 0 } , . { 1 , 1 } , . { 1 , 2 } , . { 1 , 3 } } ,
. { . { 2 , 0 } , . { 2 , 1 } , . { 2 , 2 } , . { 2 , 3 } } ,
. { . { 3 , 0 } , . { 3 , 1 } , . { 3 , 2 } , . { 3 , 3 } } ,
. { . { 4 , 0 } , . { 4 , 1 } , . { 4 , 2 } , . { 4 , 3 } } ,
. { . { 5 , 0 } , . { 5 , 1 } , . { 5 , 2 } , . { 5 , 3 } } ,
} ,
try z . getValue ( [ 6 ] [ 4 ] [ 2 ] i32 ) ,
) ;
}
fn binaryOp (
2024-01-01 15:31:41 +00:00
src : std . builtin . SourceLocation ,
2023-01-02 14:28:25 +00:00
op_name : [ ] const u8 ,
op_fn : fn ( mlir . Context , mlir . Value , mlir . Value , mlir . Location ) mlir . Operation ,
) fn ( Tensor , Tensor ) Tensor {
return struct {
pub fn binaryOpHelper ( self : Tensor , other : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . dtype ( ) = = other . dtype ( ) , " {s} expects tensor to be of same type, got {} and {} " , . { op_name , self , other } ) ;
2023-01-02 14:28:25 +00:00
if ( self . rank ( ) = = 0 and other . rank ( ) ! = 0 ) {
return binaryOpHelper ( self . broad ( other . _shape ) , other ) ;
}
if ( self . rank ( ) ! = 0 and other . rank ( ) = = 0 ) {
return binaryOpHelper ( self , other . broad ( self . _shape ) ) ;
}
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . _shape . eql ( other . _shape ) , " {s} expects tensor shapes to match, got {} and {} " , . { op_name , self . _shape , other . _shape } ) ;
2023-01-02 14:28:25 +00:00
2024-01-01 15:31:41 +00:00
const ctx = self . getContext ( ) ;
const location = ctx . location ( src , " {s}({_}, {_}) " , . { op_name , self , other } ) ;
const ret = @call ( . auto , op_fn , . { ctx . mlirCtx ( ) , self . value ( ) , other . value ( ) , location } ) ;
2023-01-02 14:28:25 +00:00
return _result ( self . _shape , ret . result ( 0 ) ) ;
}
} . binaryOpHelper ;
}
2023-06-05 13:42:45 +00:00
/// Insert code that will print the content of the given buffer at runtime.
/// Only for debug purpose, it has the following limitations:
/// * only supported on Cuda atm
/// * only prints the first 1024 values
/// * pre allocates a buffer on the host to copy the content of the device buffer,
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
/// * does device to host synchronization so it will slow down the program execution.
pub fn print ( input : Tensor ) Tensor {
return ops . addHostCallback ( & printCallback , input ) ;
}
fn printCallback ( host_buffer : HostBuffer ) void {
2023-12-25 13:01:17 +00:00
std . debug . print ( " Device buffer: {}: {} " , . { host_buffer . shape ( ) , host_buffer . pretty ( ) } ) ;
2023-06-05 13:42:45 +00:00
}
2023-01-02 14:28:25 +00:00
} ;
fn initPoolArg ( rank : usize , data : [ ] const i64 ) [ Tensor . MAX_RANK ] i64 {
// TODO use shape
var result = [ _ ] i64 { 1 } * * Tensor . MAX_RANK ;
const start = rank - data . len ;
@memcpy ( result [ start . . start + data . len ] , data ) ;
return result ;
}
fn getPoolResDims ( dt : DataType , in_dims : [ ] const i64 , base_dilations : @Vector ( Tensor . MAX_RANK , i64 ) , padding : [ ] const i64 , window_dimensions : @Vector ( Tensor . MAX_RANK , i64 ) , window_dilations : @Vector ( Tensor . MAX_RANK , i64 ) , window_strides : @Vector ( Tensor . MAX_RANK , i64 ) ) Shape {
// TODO use shape
var input_dims = [ _ ] i64 { 1 } * * Tensor . MAX_RANK ;
@memcpy ( input_dims [ 0 . . in_dims . len ] , in_dims ) ;
const input_dims_ : @Vector ( Tensor . MAX_RANK , i64 ) = input_dims ;
const splat_one : @Vector ( Tensor . MAX_RANK , i64 ) = @splat ( 1 ) ;
const dilated_input_shape : @Vector ( Tensor . MAX_RANK , i64 ) = ( input_dims_ - splat_one ) * base_dilations + splat_one ;
var pad_slice0 : @Vector ( Tensor . MAX_RANK , i64 ) = @splat ( padding [ 0 ] ) ;
var pad_slice1 : @Vector ( Tensor . MAX_RANK , i64 ) = @splat ( padding [ 0 ] ) ;
if ( padding . len > 1 ) {
var idx : usize = 0 ;
while ( idx < in_dims . len * 2 ) : ( idx + = 2 ) {
pad_slice0 [ idx / 2 ] = padding [ idx ] ;
pad_slice1 [ idx / 2 ] = padding [ idx + 1 ] ;
}
}
const padded_input_shape : @Vector ( Tensor . MAX_RANK , i64 ) = pad_slice0 + dilated_input_shape + pad_slice1 ;
const dilated_window_shape = ( window_dimensions - splat_one ) * window_dilations + splat_one ;
const dims = @divFloor ( padded_input_shape - dilated_window_shape , window_strides ) + splat_one ;
const dims_arr : [ Tensor . MAX_RANK ] i64 = @bitCast ( dims ) ;
return Shape . init ( dims_arr [ 0 . . in_dims . len ] , dt ) ;
}
fn getComparisonType ( ctx : mlir . Context , dtype : DataType ) dialect . stablehlo . CompareType {
return dialect . stablehlo . CompareType . init ( ctx , switch ( dtype ) {
. i4 , . i8 , . i16 , . i32 , . i64 = > . SIGNED ,
. bool , . u4 , . u8 , . u16 , . u32 , . u64 = > . UNSIGNED ,
. f8e4m3b11fnuz , . f8e4m3fn , . f8e4m3fnuz , . f8e5m2 , . f8e5m2fnuz , . bf16 , . f16 , . f32 , . f64 = > . FLOAT ,
. c64 , . c128 = > @panic ( " Can't compare complex numbers " ) ,
} ) ;
}
test " Tensor.maxPool1d " {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const MaxPool = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x : zml . Tensor ) Tensor . ArgMaxRes {
2023-01-02 14:28:25 +00:00
return x . maxPool1d ( . {
. window_dimensions = 3 ,
. window_strides = 2 ,
} ) ;
}
} ;
var data : [ 20 ] f32 = undefined ;
for ( & data , 0 . . ) | * v , i | v . * = @floatFromInt ( i ) ;
const x = try zml . Buffer . fromSlice ( platform , . { 2 , 2 , 5 } , & data ) ;
2024-01-08 17:55:20 +00:00
const result = try zml . testing . compileAndCall ( platform , MaxPool . _fwd , . { x } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectEqualShapes ( Shape . init ( . { 2 , 2 , 2 } , . f32 ) , result . values . shape ( ) ) ;
try zml . testing . expectEqualShapes ( Shape . init ( . { 2 , 2 , 2 } , . i32 ) , result . indices . shape ( ) ) ;
const buffer = result . values . getValue ( [ 2 ] [ 2 ] [ 2 ] f32 ) ;
try std . testing . expectEqualDeep (
[ 2 ] [ 2 ] [ 2 ] f32 {
[ 2 ] [ 2 ] f32 {
[ 2 ] f32 { 2 , 4 } ,
[ 2 ] f32 { 7 , 9 } ,
} ,
[ 2 ] [ 2 ] f32 {
[ 2 ] f32 { 12 , 14 } ,
[ 2 ] f32 { 17 , 19 } ,
} ,
} ,
buffer ,
) ;
}
test " Tensor.maxPool2d " {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const MaxPool = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x : Tensor ) Tensor . ArgMaxRes {
2023-01-02 14:28:25 +00:00
return x . maxPool2d ( . {
2023-05-17 09:01:27 +00:00
. window_dimensions = . { 3 , 2 } ,
. window_strides = . { 2 , 1 } ,
2023-01-02 14:28:25 +00:00
} ) ;
}
} ;
var data : [ 100 ] f32 = undefined ;
for ( & data , 0 . . ) | * v , i | v . * = @floatFromInt ( i ) ;
const x = try zml . Buffer . fromSlice ( platform , . { 2 , 2 , 5 , 5 } , & data ) ;
2024-01-08 17:55:20 +00:00
const result = try zml . testing . compileAndCall ( platform , MaxPool . _fwd , . { x } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectEqualShapes ( Shape . init ( . { 2 , 2 , 2 , 4 } , . f32 ) , result . values . shape ( ) ) ;
try zml . testing . expectEqualShapes ( Shape . init ( . { 2 , 2 , 2 , 4 } , . i32 ) , result . indices . shape ( ) ) ;
var buffer : [ 2 ] [ 2 ] [ 2 ] [ 4 ] f32 = undefined ;
_ = try result . values . toHost ( std . mem . asBytes ( & buffer ) ) ;
try std . testing . expectEqualDeep (
[ 2 ] [ 2 ] [ 2 ] [ 4 ] f32 {
. {
. { . { 11 , 12 , 13 , 14 } , . { 21 , 22 , 23 , 24 } } ,
. { . { 36 , 37 , 38 , 39 } , . { 46 , 47 , 48 , 49 } } ,
} ,
. {
. { . { 61 , 62 , 63 , 64 } , . { 71 , 72 , 73 , 74 } } ,
. { . { 86 , 87 , 88 , 89 } , . { 96 , 97 , 98 , 99 } } ,
} ,
} ,
buffer ,
) ;
}
2023-03-08 14:10:11 +00:00
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized ( comptime T : type ) type {
return meta . MapType ( Tensor , Buffer ) . map ( T ) ;
2023-01-02 14:28:25 +00:00
}
/// Return a clone of a type with Tensors replaced by Shapes.
/// Recursively descends into the type.
/// See also: shapesOf() and its tests, and meta.MapType().
pub fn ShapeOf ( comptime T : type ) type {
const M = meta . MapType ( Tensor , Shape ) ;
return M . map ( T ) ;
}
/// Return a clone of the argument where each instance of a Tensor is replaced
/// by its Shape. This is similar to ShapeOf(), but with runtime values.
/// See also: meta.mapAlloc().
pub fn shapesOf ( model : anytype , allocator : std . mem . Allocator ) ! ShapeOf ( @TypeOf ( model ) ) {
var shapes : ShapeOf ( @TypeOf ( model ) ) = undefined ;
try meta . mapAlloc ( struct {
fn shapeFromTensorCallback ( _ : void , tensor : Tensor ) Shape {
return tensor . shape ( ) ;
}
} . shapeFromTensorCallback , allocator , { } , model , & shapes ) ;
return shapes ;
}
test shapesOf {
const alloc = std . testing . allocator ;
// Tensor in struct
{
const S = struct {
a : Tensor ,
} ;
const shape = Shape . init ( . { 28 , 28 } , . f32 ) ;
const s : S = . {
. a = Tensor { . _shape = shape , . _id = undefined } ,
} ;
const shapes = try shapesOf ( s , alloc ) ;
try std . testing . expectEqual ( shape , shapes . a ) ;
}
// single Tensor
{
const shape = Shape . init ( . { 28 , 28 } , . f32 ) ;
const tensor = Tensor { . _shape = shape , . _id = undefined } ;
const shapes = try shapesOf ( tensor , alloc ) ;
try std . testing . expectEqual ( shape , shapes ) ;
}
// nn linear layer, no bias
{
const nn = @import ( " nn.zig " ) ;
const shape = Shape . init ( . { 28 , 28 } , . f32 ) ;
const layer : nn . Linear = . {
. weight = Tensor { . _shape = shape , . _id = undefined } ,
. bias = null ,
} ;
const shapes = try shapesOf ( layer , alloc ) ;
try std . testing . expectEqual ( shape , shapes . weight ) ;
try std . testing . expectEqual ( null , shapes . bias ) ;
}
// model
{
const Mnist = struct {
fc1 : Layer ,
fc2 : Layer ,
const Layer = struct {
weight : Tensor ,
bias : Tensor ,
} ;
} ;
const fc1_weight_shape = Shape . init ( . { 500 , 784 } , . f32 ) ;
const fc1_bias_shape = Shape . init ( . { 500 } , . f32 ) ;
const fc2_weight_shape = Shape . init ( . { 10 , 500 } , . f32 ) ;
const fc2_bias_shape = Shape . init ( . { 10 } , . f32 ) ;
const mnist : Mnist = . {
. fc1 = . {
. weight = Tensor { . _shape = fc1_weight_shape , . _id = undefined } ,
. bias = Tensor { . _shape = fc1_bias_shape , . _id = undefined } ,
} ,
. fc2 = . {
. weight = Tensor { . _shape = fc2_weight_shape , . _id = undefined } ,
. bias = Tensor { . _shape = fc2_bias_shape , . _id = undefined } ,
} ,
} ;
const shapes = try shapesOf ( mnist , alloc ) ;
try std . testing . expectEqual ( fc1_weight_shape , shapes . fc1 . weight ) ;
try std . testing . expectEqual ( fc1_bias_shape , shapes . fc1 . bias ) ;
try std . testing . expectEqual ( fc2_weight_shape , shapes . fc2 . weight ) ;
try std . testing . expectEqual ( fc2_bias_shape , shapes . fc2 . bias ) ;
}
}
2023-01-18 12:03:48 +00:00
2024-01-08 17:55:20 +00:00
pub fn _collectAxes ( T : type , bounded_array : std . BoundedArray ( T , Tensor . MAX_RANK ) , value : T ) std . BoundedArray ( i64 , Tensor . MAX_RANK ) {
2023-01-18 12:03:48 +00:00
var res : std . BoundedArray ( i64 , Tensor . MAX_RANK ) = . { } ;
for ( bounded_array . constSlice ( ) , 0 . . ) | v , ax | {
if ( v = = value ) {
res . appendAssumeCapacity ( @intCast ( ax ) ) ;
}
}
return res ;
}
2023-01-27 14:35:11 +00:00
2023-02-14 13:52:49 +00:00
fn _parseGatherCoord ( self : Tensor , axes_ : anytype ) struct { bool , std . BoundedArray ( u3 , Tensor . MAX_RANK ) } {
const AxesT = @TypeOf ( axes_ ) ;
2024-07-02 14:19:04 +00:00
const axes_is_scalar = AxesT = = EnumLiteral or AxesT = = comptime_int or @typeInfo ( AxesT ) = = . int ;
2023-02-14 13:52:49 +00:00
const coord_axes = if ( axes_is_scalar )
std . BoundedArray ( u3 , Tensor . MAX_RANK ) . fromSlice ( & . { self . axis ( axes_ ) } ) catch unreachable
else
self . axes ( axes_ ) ;
return . { axes_is_scalar , coord_axes } ;
2023-01-27 14:35:11 +00:00
}
2023-03-08 14:10:11 +00:00
fn parseArrayInfo ( T : type ) Shape {
return switch ( @typeInfo ( T ) ) {
. Array = > | arr | {
const s = parseArrayInfo ( arr . child ) ;
return s . insert ( 0 , . { arr . len } ) ;
} ,
else = > . { . _dtype = DataType . fromZigType ( T ) } ,
} ;
}
inline fn toI64 ( values : anytype ) [ ] i64 {
var res : [ Tensor . MAX_RANK ] i64 = undefined ;
for ( values , 0 . . ) | val , i | res [ i ] = @intCast ( val ) ;
return res [ 0 . . values . len ] ;
}
2023-12-18 13:56:45 +00:00
fn transposeIsJustAReshape ( x : Shape , permutation : [ ] const i64 ) bool {
var perm : std . BoundedArray ( struct { u8 , bool } , Tensor . MAX_RANK ) = . { } ;
// Don't rewrite on invalid inputs.
if ( permutation . len > x . rank ( ) ) return false ;
for ( permutation ) | ax | {
const squeezable = x . dim ( ax ) = = 1 ;
perm . appendAssumeCapacity ( . { @intCast ( ax ) , squeezable } ) ;
}
var effective_ax : u8 = 0 ;
for ( 0 . . perm . len ) | i | {
const ax , const squeezable = perm . get ( i ) ;
if ( squeezable ) {
// Effectively squeeze this axis by decrementing axes coming after by 1.
for ( i . . perm . len ) | j | {
if ( perm . buffer [ j ] [ 0 ] > ax ) {
perm . buffer [ j ] [ 0 ] - = 1 ;
}
}
continue ;
}
if ( ax ! = effective_ax ) return false ;
effective_ax + = 1 ;
}
return true ;
}
test transposeIsJustAReshape {
try std . testing . expect ( transposeIsJustAReshape ( Shape . init ( . { 5 , 1 , 3 } , . i32 ) , & . { 0 , 1 , 2 } ) ) ;
try std . testing . expect ( transposeIsJustAReshape ( Shape . init ( . { 5 , 1 , 3 } , . i32 ) , & . { 1 , 0 , 2 } ) ) ;
try std . testing . expect ( ! transposeIsJustAReshape ( Shape . init ( . { 5 , 1 , 3 } , . i32 ) , & . { 2 , 1 , 0 } ) ) ;
try std . testing . expect ( transposeIsJustAReshape ( Shape . init ( . { 64 , 8 , 1 , 128 } , . bf16 ) , & . { 0 , 2 , 1 , 3 } ) ) ;
try std . testing . expect ( ! transposeIsJustAReshape ( Shape . init ( . { 64 , 8 , 155 , 128 } , . bf16 ) , & . { 0 , 2 , 1 , 3 } ) ) ;
try std . testing . expect ( transposeIsJustAReshape ( Shape . init ( . { 64 , 1 , 1 , 128 } , . bf16 ) , & . { 1 , 2 , 0 , 3 } ) ) ;
try std . testing . expect ( ! transposeIsJustAReshape ( Shape . init ( . { . b = 1 , . h = 10 , . q = 155 , . hd = 1 } , . f32 ) , & . { 0 , 2 , 1 , 3 } ) ) ;
try std . testing . expect ( ! transposeIsJustAReshape ( Shape . init ( . { 1 , 10 , 155 , 1 } , . f32 ) , & . { 0 , 2 , 3 , 1 } ) ) ;
try std . testing . expect ( transposeIsJustAReshape ( Shape . init ( . { 1 , 10 , 155 , 1 } , . f32 ) , & . { 0 , 1 , 3 , 2 } ) ) ;
}
2024-01-05 16:44:41 +00:00
test " unused tensor " {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _fwd ( x : Tensor ) Tensor {
2024-01-05 16:44:41 +00:00
const y = x . addConstant ( 1 ) ;
_ = y ;
return x ;
}
} ;
2024-01-08 17:55:20 +00:00
const mod = try zml . compileFn ( std . testing . allocator , Local . _fwd , . { Shape . init ( . { 10 } , . f32 ) } , platform ) ;
2024-01-05 16:44:41 +00:00
defer mod . deinit ( ) ;
}