2023-01-02 14:28:25 +00:00
//! Common layer definition and functions for Neural Networks (NN)
const std = @import ( " std " ) ;
const assert = std . debug . assert ;
const testing = std . testing ;
const zml = @import ( " zml.zig " ) ;
const meta = @import ( " meta.zig " ) ;
const helpers = @import ( " helpers.zig " ) ;
const ops = @import ( " ops.zig " ) ;
const DataType = @import ( " dtype.zig " ) . DataType ;
const Shape = @import ( " shape.zig " ) . Shape ;
const Tensor = @import ( " tensor.zig " ) . Tensor ;
const log = std . log . scoped ( . zml_tensor ) ;
const cuda = @import ( " nn/cuda.zig " ) ;
2023-01-27 14:35:11 +00:00
test {
_ = cuda ;
std . testing . refAllDecls ( @This ( ) ) ;
}
2023-01-02 14:28:25 +00:00
pub const Linear = struct {
weight : Tensor ,
bias : ? Tensor = null ,
pub fn forward ( self : Linear , x : Tensor ) Tensor {
var y = x . dotGeneral ( self . weight . convert ( x . dtype ( ) ) , & . { . { - 1 , - 1 } } , & . { } ) ;
// If self.weight doesn't have tags, preserve tags from x.
if ( y . shape ( ) . tag ( - 1 ) = = Shape . TagUnknown ) {
y . _shape . _tags . set ( y . rank ( ) - 1 , x . shape ( ) . tag ( - 1 ) ) ;
}
// log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() });
return if ( self . bias ) | bias | y . add ( bias . broadcastLeft ( y . shape ( ) ) ) else y ;
}
} ;
pub const TokenEmbedding = struct {
weight : Tensor ,
pub fn forward ( self : TokenEmbedding , idx : Tensor ) Tensor {
meta . assert ( idx . dtype ( ) . isInteger ( ) , " TokenEmbedding expects an integer input, received: {} " , . { idx } ) ;
meta . assert ( self . weight . rank ( ) = = 2 , " TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {} " , . { self . weight } ) ;
2023-01-18 12:03:48 +00:00
return self . weight . gatherValues ( 0 , idx , . { } ) ;
2023-01-02 14:28:25 +00:00
}
} ;
pub const Activation = union ( enum ) {
sigmoid ,
tanh ,
relu ,
leakyReLU : f32 ,
2023-05-18 16:39:21 +00:00
elu : f32 ,
2023-01-02 14:28:25 +00:00
silu ,
gelu ,
quick_gelu ,
pub fn forward ( self : Activation , x : Tensor ) Tensor {
return switch ( self ) {
. sigmoid = > x . sigmoid ( ) ,
. tanh = > x . tanh ( ) ,
. relu = > x . relu ( ) ,
. silu = > x . silu ( ) ,
. gelu = > x . gelu ( ) ,
2023-05-18 16:39:21 +00:00
. elu = > | alpha | elu ( x , alpha ) ,
2023-01-02 14:28:25 +00:00
. quick_gelu = > x . quickGelu ( ) ,
. leakyReLU = > | slope | x . leakyReLU ( slope ) ,
} ;
}
2023-05-18 16:39:21 +00:00
pub fn elu ( x : Tensor , alpha : f32 ) Tensor {
return x . cmp ( . GE , Tensor . scalar ( 0 , x . dtype ( ) ) ) . select (
x ,
x . exp ( ) . addConstant ( - 1 ) . scale ( alpha ) ,
) ;
}
2023-01-02 14:28:25 +00:00
} ;
pub fn chainModules ( module_list : anytype , input : Tensor ) Tensor {
const T = @TypeOf ( module_list ) ;
switch ( @typeInfo ( T ) ) {
. Struct = > | struct_info | {
var x = input ;
inline for ( struct_info . fields ) | field | {
x = @field ( module_list , field . name ) . forward ( x ) ;
}
return x ;
} ,
else = > @compileError ( " chainModules only works on a struct with only containing 'module' struct. " ) ,
}
}
/// Layer Normalization
pub const LayerNorm = struct {
weight : Tensor ,
bias : ? Tensor = null ,
eps : f32 = 1e-5 ,
pub fn forward ( self : LayerNorm , x : Tensor ) Tensor {
const normed = normalizeVariance ( x , self . eps ) ;
2023-04-21 15:55:07 +00:00
const ax = x . axis ( - 1 ) ;
var out = normed . mul ( self . weight . broadcast ( x . shape ( ) , & . { ax } ) ) ;
if ( self . bias ) | bias | out = out . add ( bias . broadcast ( x . shape ( ) , & . { ax } ) ) ;
2023-01-02 14:28:25 +00:00
return out ;
}
} ;
/// Center and scale by the variance.
/// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps)
/// Work on the last axis.
pub fn normalizeVariance ( x : Tensor , eps : f32 ) Tensor {
const N : f32 = @floatFromInt ( x . dim ( - 1 ) ) ;
// Upcast to improve precision
const xf32 = x . convert ( . f32 ) ;
const mean = xf32 . sum ( - 1 ) . scale ( 1.0 / N ) ;
const mean_dev = xf32 . sub ( mean . broadcastRight ( xf32 . shape ( ) ) ) ;
const variance = mean_dev . mul ( mean_dev ) . sum ( - 1 ) . scale ( 1.0 / N ) ;
const rsqrt = Tensor . rsqrt ( variance . addConstant ( eps ) ) ;
return mean_dev . mul ( rsqrt . broadcastRight ( mean_dev . shape ( ) ) ) . convert ( x . dtype ( ) ) ;
}
// ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// Implementation equivalent to `nn.functional.normalize(tensor, dim=-1)` call
pub fn normalizeL2 ( input : Tensor , eps : f32 ) Tensor {
const inv_norm = input . pow ( Tensor . scalar ( 2 , input . dtype ( ) ) ) . sum ( - 1 ) . addConstant ( eps ) . rsqrt ( ) ;
return input . mul ( inv_norm . broad ( input . shape ( ) ) ) ;
}
test normalizeL2 {
const platform = zml . testing . env ( ) ;
const input = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f32 { - 0.9686 , - 1.0058 , - 1.7808 , 0.6698 } ) ;
const res = try zml . testing . compileAndCall ( platform , zml . nn . normalizeL2 , . { input , 1e-12 } ) ;
const expectation = zml . HostBuffer . fromSlice ( . { 2 , 2 } , & [ _ ] f32 { - 0.6937 , - 0.7203 , - 0.9360 , 0.3520 } ) ;
try zml . testing . expectClose ( expectation , res , 1e-4 ) ;
}
pub const RopeOpts = struct {
/// There are two implementations corresponding to how to split `x` in real/imag parts.
/// * Interleaved means that the real/imag of each scalar is contiguous.
/// * Sequential means that you first read all real values then all imag values.
pub const Implementation = enum { interleaved , sequential } ;
impl : Implementation ,
freq_base : f32 = 10_000 ,
} ;
pub const CosSin = [ 2 ] Tensor ;
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
/// This biases a token to look at token near him.
/// The nice thing with this solution is that you can cache the modified queries and keys directly.
/// See: https://paperswithcode.com/method/rope
pub fn rope ( x : Tensor , cos_sin_cache : CosSin , opts : RopeOpts ) Tensor {
const cos , const sin = cos_sin_cache ;
meta . assert ( x . dim ( - 1 ) = = 2 * cos . dim ( - 1 ) , " Couldn't compute rope({}, {}, {}) " , . { x , cos , sin } ) ;
// broadcast cos / sin to .{ batch, .seq, .half_dim }
const x_real , const x_imag = splitRealImg ( x , opts . impl ) ;
const has_tags = cos . shape ( ) . tag ( 0 ) ! = Shape . TagUnknown ;
const b_cos = if ( has_tags ) cos . broad ( x_real . shape ( ) ) else cos . broadcastLeft ( x_real . shape ( ) ) ;
const b_sin = if ( has_tags ) sin . broad ( x_real . shape ( ) ) else sin . broadcastLeft ( x_real . shape ( ) ) ;
// apply rotation
const y_real = x_real . mul ( b_cos ) . sub ( x_imag . mul ( b_sin ) ) ;
const y_imag = x_real . mul ( b_sin ) . add ( x_imag . mul ( b_cos ) ) ;
// flatten last dimensions
const y = mergeRealImg ( y_real , y_imag , opts . impl ) ;
return y ;
}
pub fn ropeCosSin ( sh : anytype , dtype : DataType , opts : RopeOpts ) CosSin {
const shape = Shape . init ( sh , dtype ) ;
meta . assert ( shape . rank ( ) = = 2 , " ropeCosSin({}) shape need to exactly have 2 axes " , . { shape } ) ;
const seq_len , const head_dim = . { shape . dim ( 0 ) , shape . dim ( 1 ) } ;
meta . assert ( @mod ( head_dim , 2 ) = = 0 , " ropeCosSin requires an even head_dim, got {} " , . { head_dim } ) ;
// compute sin and cos in f32 before downcasting to x type.
const inv_freq = invFreq ( head_dim , opts . freq_base , . f32 ) ;
var inv_freq_pos = Tensor . outer ( Tensor . arange ( . { . end = seq_len } , . f32 ) , inv_freq ) . convert ( shape . dtype ( ) ) ;
inv_freq_pos . _shape . _tags = shape . _tags ;
const cos = inv_freq_pos . cos ( ) ;
const sin = inv_freq_pos . sin ( ) ;
return . { cos , sin } ;
}
pub fn splitRealImg ( x : Tensor , impl : RopeOpts . Implementation ) [ 2 ] Tensor {
const n = x . dim ( - 1 ) ;
return switch ( impl ) {
. sequential = > . {
x . slice1d ( - 1 , . { . end = @divExact ( n , 2 ) } ) ,
x . slice1d ( - 1 , . { . start = @divExact ( n , 2 ) , . end = n } ) ,
} ,
. interleaved = > . {
x . slice1d ( - 1 , . { . start = 0 , . step = 2 } ) ,
x . slice1d ( - 1 , . { . start = 1 , . step = 2 } ) ,
} ,
} ;
}
pub fn mergeRealImg ( x_real : Tensor , x_imag : Tensor , impl : RopeOpts . Implementation ) Tensor {
return switch ( impl ) {
. sequential = > Tensor . concatenate ( & . { x_real , x_imag } , - 1 ) ,
. interleaved = > Tensor . concatenate ( & . {
x_real . appendAxes ( . { . interleaved_real_img } ) ,
x_imag . appendAxes ( . { . interleaved_real_img } ) ,
} , - 1 ) . flatten ( - 2 ) ,
} ;
}
/// {exp( - n * ln(10_000) / N ) | n in [0..N] }
pub fn invFreq ( N : i64 , theta : f32 , dtype : DataType ) Tensor {
const freq = - @log ( theta ) / @as ( f32 , @floatFromInt ( N ) ) ;
const range = Tensor . arange ( . { . start = 0 , . end = N , . step = 2 } , dtype ) . scale ( freq ) ;
return range . exp ( ) ;
}
test " real/img " {
const platform = zml . testing . env ( ) ;
const Fns = struct {
fn testSplitMergeIsId ( impl : RopeOpts . Implementation ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , impl ) ;
const y = mergeRealImg ( real , imag , impl ) ;
const real2 , const imag2 = splitRealImg ( y , impl ) ;
return real . cmp ( . EQ , real2 ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , imag2 ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitSeqVoid ( _ : void ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . sequential ) ;
const x_real = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 0 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 1 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
const x_imag = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 2 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 3 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitSeq ( ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . sequential ) ;
const x_real = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 0 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 1 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
const x_imag = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 2 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 3 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitInterleaved ( ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . interleaved ) ;
const x_real = Tensor . arange ( . { . start = 0 , . end = 20 , . step = 2 } , . f32 ) . reshape ( . { 5 , 2 } ) ;
const x_imag = Tensor . arange ( . { . start = 1 , . end = 20 , . step = 2 } , . f32 ) . reshape ( . { 5 , 2 } ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
} ;
const d_interleaved = try zml . testing . compileAndCall ( platform , Fns . testSplitMergeIsId , . { . interleaved } ) ;
try testing . expectEqual ( 20 , d_interleaved . getValue ( i32 ) ) ;
const d_sequential = try zml . testing . compileAndCall ( platform , Fns . testSplitMergeIsId , . { . sequential } ) ;
try testing . expectEqual ( 20 , d_sequential . getValue ( i32 ) ) ;
// test the function that accepts 1 void argument
const d_split_seq_void = try zml . testing . compileAndCall ( platform , Fns . testSplitSeqVoid , . { { } } ) ;
try testing . expectEqual ( 20 , d_split_seq_void . getValue ( i32 ) ) ;
// test the function that takes NO arguments
const d_split_seq = try zml . testing . compileAndCall ( platform , Fns . testSplitSeq , . { } ) ;
try testing . expectEqual ( 20 , d_split_seq . getValue ( i32 ) ) ;
// now try compiling and calling ourselves
{
const mod = try zml . compileFn ( std . testing . allocator , Fns . testSplitSeq , . { } , platform ) ;
defer mod . deinit ( ) ;
const ret = mod . call ( . { } ) ;
try testing . expectEqual ( 20 , ret . getValue ( i32 ) ) ;
}
const d_split_interleaved = try zml . testing . compileAndCall ( platform , Fns . testSplitInterleaved , . { } ) ;
try testing . expectEqual ( 20 , d_split_interleaved . getValue ( i32 ) ) ;
}
test " rope " {
2023-01-27 14:35:11 +00:00
const platform = zml . testing . env ( ) ;
2023-01-02 14:28:25 +00:00
const TestRope = struct {
fn forward ( x : Tensor , opts : RopeOpts ) Tensor {
var input = x ;
{
// Convert input to the requested format
const real , const imag = splitRealImg ( input , . sequential ) ;
input = mergeRealImg ( real , imag , opts . impl ) ;
}
const cos_sin = ropeCosSin ( . { input . dim ( - 2 ) , input . dim ( - 1 ) } , input . dtype ( ) , opts ) ;
var res = rope ( input , cos_sin , opts ) . squeeze ( 0 ) ;
{
// Convert back to sequential
const real , const imag = splitRealImg ( res , opts . impl ) ;
res = mergeRealImg ( real , imag , . sequential ) ;
}
return res ;
}
} ;
// x is made such as the interleaved and sequential reps are the same.
// So the two implementations should give the same results.
2023-01-27 14:35:11 +00:00
const x = try zml . Buffer . fromSlice ( platform , . { 1 , 5 , 4 } , & [ _ ] f32 { 1.0 , 0.1 , - 1.0 , - 0.5 } * * 5 ) ;
const res1 = try zml . testing . compileAndCall ( platform , TestRope . forward , . { x , RopeOpts { . impl = . interleaved } } ) ;
const res2 = try zml . testing . compileAndCall ( platform , TestRope . forward , . { x , RopeOpts { . impl = . sequential } } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectClose ( res1 , res2 , 1e-4 ) ;
}
/// In neural network we generally care about the relative precision,
/// but on a given dimension, if the output is close to 0, then the precision
/// don't matter as much.
fn approxEq ( comptime Float : type , l : Float , r : Float , tolerance : Float ) bool {
const closeRel = std . math . approxEqRel ( Float , l , r , @floatCast ( tolerance ) ) ;
const closeAbs = std . math . approxEqAbs ( Float , l , r , @floatCast ( tolerance / 2 ) ) ;
return closeRel or closeAbs ;
}
pub const UpsampleMode = enum {
nearest ,
// TODO: Linear,
// TODO: Bilinear,
// TODO: Bicubic,
// TODO: Trilinear,
} ;
/// Upsample
pub fn upsample (
input : Tensor ,
opts : struct { mode : UpsampleMode , scale_factor : [ ] const f64 } ,
) Tensor {
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
meta . assert ( 3 < = input . rank ( ) and input . rank ( ) < = 5 , " upsample is only implemented for (3,4,5)-D tensors, received {} " , . { input } ) ;
meta . assert ( opts . scale_factor . len = = 1 or opts . scale_factor . len = = input . rank ( ) - 2 , " scale factors " , . { } ) ;
return switch ( opts . mode ) {
. nearest = > {
var scale_factors : [ 3 ] f64 = undefined ;
switch ( opts . scale_factor . len ) {
1 = > {
for ( 0 . . input . rank ( ) - 2 ) | i | scale_factors [ i ] = opts . scale_factor [ 0 ] ;
} ,
else = > @memcpy ( scale_factors [ 0 . . opts . scale_factor . len ] , opts . scale_factor ) ,
}
return nearest ( input , scale_factors [ 0 . . input . rank ( ) - 2 ] ) ;
} ,
} ;
}
pub fn nearest ( input : Tensor , scale_factor : [ ] const f64 ) Tensor {
var out_shape = input . shape ( ) ;
for ( scale_factor , 0 . . ) | sf , i | {
out_shape . _dims . set ( i + 2 , @intFromFloat ( @floor ( @as ( f64 , @floatFromInt ( out_shape . dim ( i + 2 ) ) ) * sf ) ) ) ;
}
// TODO(james): remove this implicit two batching dims
var sd : [ 3 ] usize = undefined ;
var len_sd : usize = 0 ;
for ( 2 . . input . rank ( ) ) | i | {
if ( input . dim ( i ) ! = out_shape . dim ( i ) ) {
sd [ len_sd ] = i ;
len_sd + = 1 ;
}
}
const spatial_dims = sd [ 0 . . len_sd ] ;
var res = input ;
for ( spatial_dims ) | d | {
const n = out_shape . dim ( d ) ;
const ratio = meta . divFloat ( f32 , input . dim ( d ) , n ) ;
const offsets = Tensor . arange ( . { . end = n } , . f32 ) . addConstant ( 0.5 ) . scale ( ratio ) . floor ( ) . convert ( . i32 ) ;
2023-01-18 12:03:48 +00:00
res = res . gatherValues ( d , offsets , . { . indices_are_sorted = true } ) ;
2023-01-02 14:28:25 +00:00
}
return res ;
}
test nearest {
const platform = zml . testing . env ( ) ;
// 3D Tensor (basic)
{
const input_3d_basic = try zml . Buffer . fromArray ( platform , [ 1 ] [ 1 ] [ 2 ] i32 { . { . { 1 , 2 } } } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_3d_basic , . { . scale_factor = & . { 3 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 6 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 6 ] i32 = . { . { . { 1 , 1 , 1 , 2 , 2 , 2 } } } ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 3D Tensor (advanced)
{
const input_3d_advanced = try zml . Buffer . fromArray ( platform , [ 2 ] [ 3 ] [ 4 ] i32 {
. { . { 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 } , . { 9 , 10 , 11 , 12 } } ,
. { . { 13 , 14 , 15 , 16 } , . { 17 , 18 , 19 , 20 } , . { 21 , 22 , 23 , 24 } } ,
} ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_3d_advanced , . { . scale_factor = & . { 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 2 , 3 , 8 } , result . dims ( ) ) ;
const expected : [ 2 ] [ 3 ] [ 8 ] i32 = . {
. {
. { 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 } ,
. { 5 , 5 , 6 , 6 , 7 , 7 , 8 , 8 } ,
. { 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 } ,
} ,
. {
. { 13 , 13 , 14 , 14 , 15 , 15 , 16 , 16 } ,
. { 17 , 17 , 18 , 18 , 19 , 19 , 20 , 20 } ,
. { 21 , 21 , 22 , 22 , 23 , 23 , 24 , 24 } ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 4D Tensor (basic)
{
const input_4d_basic = try zml . Buffer . fromSlice ( platform , . { 1 , 1 , 2 , 2 } , & [ _ ] i32 { 1 , 2 , 3 , 4 } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_4d_basic , . { . scale_factor = & . { 3 , 3 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 6 , 6 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 6 ] [ 6 ] i32 = . { . { . {
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
} } } ;
try std . testing . expectEqual ( expected , result . getValue ( [ 1 ] [ 1 ] [ 6 ] [ 6 ] i32 ) ) ;
}
// 4D Tensor (advanced)
{
const input_4d_advanced = try zml . Buffer . fromArray ( platform , [ 2 ] [ 2 ] [ 2 ] [ 2 ] i32 { . {
. { . { 1 , 2 } , . { 3 , 4 } } ,
. { . { 5 , 6 } , . { 7 , 8 } } ,
} , . {
. { . { 9 , 10 } , . { 11 , 12 } } ,
. { . { 13 , 14 } , . { 15 , 16 } } ,
} } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_4d_advanced , . { . scale_factor = & . { 2 , 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 2 , 2 , 4 , 4 } , result . dims ( ) ) ;
const expected : [ 2 ] [ 2 ] [ 4 ] [ 4 ] i32 = . {
. {
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
. {
. { 5 , 5 , 6 , 6 } ,
. { 5 , 5 , 6 , 6 } ,
. { 7 , 7 , 8 , 8 } ,
. { 7 , 7 , 8 , 8 } ,
} ,
} ,
. {
. {
. { 9 , 9 , 10 , 10 } ,
. { 9 , 9 , 10 , 10 } ,
. { 11 , 11 , 12 , 12 } ,
. { 11 , 11 , 12 , 12 } ,
} ,
. {
. { 13 , 13 , 14 , 14 } ,
. { 13 , 13 , 14 , 14 } ,
. { 15 , 15 , 16 , 16 } ,
. { 15 , 15 , 16 , 16 } ,
} ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 5D Tensor (basic)
{
const input_5d = try zml . Buffer . fromSlice ( platform , . { 1 , 1 , 1 , 2 , 2 } , & [ _ ] i32 { 1 , 2 , 3 , 4 } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_5d , . { . scale_factor = & . { 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 2 , 4 , 4 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 2 ] [ 4 ] [ 4 ] i32 = . {
. {
. {
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
} ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
}
2023-01-27 14:35:11 +00:00
pub const ResizeOpts = struct {
/// scalar tensor containing the original dimension of the image.
/// It can be different from the image shape,
/// if the image has been padded.
/// This allows to compile one module that handle different input image sizes.
original_len : ? Tensor = null ,
/// Internal precision to do the interpolation.
/// Result will always use the same dtype than the original.
/// If not set, will use the image dtype, unless it's an integer type, in which case f32 will be used.
precision : ? zml . DataType = null ,
} ;
2023-01-02 14:28:25 +00:00
2023-01-27 14:35:11 +00:00
pub fn resizeBilinear ( image : Tensor , resized_axes : anytype , opt : ResizeOpts ) Tensor {
const new_size , const tags_ = Shape . parseStruct ( u63 , resized_axes ) ;
2023-01-02 14:28:25 +00:00
var out = image ;
2023-01-27 14:35:11 +00:00
for ( new_size . constSlice ( ) , tags_ . constSlice ( ) ) | d , t | {
const ax = image . shape ( ) . axis ( t ) ;
2023-01-02 14:28:25 +00:00
const child_opt : ResizeOpts = . {
2023-01-27 14:35:11 +00:00
. original_len = if ( opt . original_len ) | o | o . choose1d ( 0 , ax ) else null ,
2023-01-02 14:28:25 +00:00
} ;
2023-01-27 14:35:11 +00:00
out = resizeLinear1d ( out , ax , d , child_opt ) ;
2023-01-02 14:28:25 +00:00
}
return out ;
}
2023-01-27 14:35:11 +00:00
test resizeBilinear {
const platform = zml . testing . env ( ) ;
// Only test shapes
var comp = try zml . module . CompilationContext . init ( std . heap . page_allocator , " test " , platform ) ;
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 , . b = 10 } , . { . a = 20 } , . { . a = 20 , . b = 10 } } ,
. { . { . a = 10 , . b = 10 } , . { . b = 5 } , . { . a = 10 , . b = 5 } } ,
. { . { . a = 10 , . b = 10 } , . { . a = 20 , . b = 5 } , . { . a = 20 , . b = 5 } } ,
} ) | testcase | {
const x_shape , const resizing , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const y = resizeBilinear ( x , resizing , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
2023-01-02 14:28:25 +00:00
pub fn resizeLinear1d ( image : Tensor , axis : i8 , new_len : u63 , opt : ResizeOpts ) Tensor {
2023-01-27 14:35:11 +00:00
const res_shape = image . shape ( ) . set ( axis , new_len ) ;
const dtype = opt . precision orelse if ( image . dtype ( ) . class ( ) = = . integer ) . f32 else image . dtype ( ) ;
const og_len = opt . original_len orelse Tensor . scalar ( image . dim ( axis ) , dtype ) ;
const ratio = og_len . convert ( dtype ) . scale ( meta . divFloat ( f32 , 1 , new_len ) ) ;
const scaled = Tensor . arange ( . { . end = new_len } , dtype ) . mul ( ratio ) ;
2023-01-02 14:28:25 +00:00
const left = scaled . floor ( ) ;
const right = left . addConstant ( 1 ) ;
2023-01-27 14:35:11 +00:00
// TODO: check that two gather isn't too bad perf wise.
// Normally we should use gatherSlices to collect the values 2 by 2,
// but gatherSlices messes up with the order of axes.
const left_val = image . gatherValues ( axis , left . convert ( . i32 ) , . { . indices_are_sorted = true } ) . convert ( dtype ) ;
const right_val = image . gatherValues ( axis , right . convert ( . i32 ) , . { . indices_are_sorted = true } ) . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
2023-01-27 14:35:11 +00:00
const left_weight = right . sub ( scaled ) . broadcast ( res_shape , & . { axis } ) ;
const right_weight = scaled . sub ( left ) . broadcast ( res_shape , & . { axis } ) ;
const res = left_val . mul ( left_weight ) . add ( right_val . mul ( right_weight ) ) ;
return res . convert ( image . dtype ( ) ) . withTags ( image . shape ( ) . tags ( ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Bicubic interpolation of the given image.
/// Warning as of May 2024 the cpu backend don't optimize this very well
/// and is not able to merge the weighting with the gather,
/// leading to 20x slow down compared to STB implementation.
2023-01-27 14:35:11 +00:00
pub fn resizeBicubic ( image : Tensor , resized_axes : anytype , opt : ResizeOpts ) Tensor {
const new_size , const tags_ = Shape . parseStruct ( u63 , resized_axes ) ;
2023-01-02 14:28:25 +00:00
var out = image ;
2023-01-27 14:35:11 +00:00
for ( new_size . constSlice ( ) , tags_ . constSlice ( ) ) | d , t | {
const ax = image . shape ( ) . axis ( t ) ;
2023-01-02 14:28:25 +00:00
const child_opt : ResizeOpts = . {
2023-01-27 14:35:11 +00:00
. original_len = if ( opt . original_len ) | o | o . choose1d ( 0 , ax ) else null ,
2023-01-02 14:28:25 +00:00
} ;
2023-01-27 14:35:11 +00:00
out = resizeCubic1d ( out , ax , d , child_opt ) ;
2023-01-02 14:28:25 +00:00
}
return out ;
}
2023-01-27 14:35:11 +00:00
test resizeBicubic {
const platform = zml . testing . env ( ) ;
// Only test shapes
var comp = try zml . module . CompilationContext . init ( std . heap . page_allocator , " test " , platform ) ;
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 , . b = 10 } , . { . a = 20 } , . { . a = 20 , . b = 10 } } ,
. { . { . a = 10 , . b = 10 } , . { . b = 5 } , . { . a = 10 , . b = 5 } } ,
. { . { . a = 10 , . b = 10 } , . { . a = 20 , . b = 5 } , . { . a = 20 , . b = 5 } } ,
} ) | testcase | {
const x_shape , const resizing , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const y = resizeBicubic ( x , resizing , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
2023-01-02 14:28:25 +00:00
pub fn resizeCubic1d ( image : Tensor , axis : i8 , new_len : u63 , opt : ResizeOpts ) Tensor {
// Extract neighboring pixels from the image.
2023-01-27 14:35:11 +00:00
const dtype = opt . precision orelse if ( image . dtype ( ) . class ( ) = = . integer ) . f32 else image . dtype ( ) ;
const og_len = opt . original_len orelse Tensor . scalar ( image . dim ( axis ) , dtype ) ;
const ratio = og_len . convert ( dtype ) . scale ( meta . divFloat ( f32 , 1 , new_len ) ) ;
const scaled = Tensor . arange ( . { . end = new_len } , dtype ) . mul ( ratio ) ;
2023-01-02 14:28:25 +00:00
const t = scaled . sub ( scaled . floor ( ) ) ;
const pos = Tensor . stack ( & . {
2023-01-27 14:35:11 +00:00
Tensor . constant ( t . shape ( ) , dtype . one ( ) ) ,
2023-01-02 14:28:25 +00:00
t ,
t . mul ( t ) ,
2023-01-27 14:35:11 +00:00
t . pow ( Tensor . scalar ( 3 , dtype ) ) ,
} , . last , . _interpolated ) ;
2023-01-02 14:28:25 +00:00
std . debug . assert ( pos . dim ( 0 ) = = new_len ) ;
std . debug . assert ( pos . dim ( 1 ) = = 4 ) ;
2023-01-27 14:35:11 +00:00
const neighbors = scaled . floor ( ) . addConstant ( - 1 ) . convert ( . i32 ) . maximum ( Tensor . scalar ( 0 , . i32 ) ) ;
const values = image . renameAxis ( axis , . _neighbors ) . gatherSlices (
Shape . init ( . { . _neighbors = 4 } , image . dtype ( ) ) ,
neighbors . appendAxes ( . { . coord } ) ,
. { . indices_are_sorted = true } ,
) . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
const weights_ : [ 4 ] [ 4 ] f32 = . {
. { 0 , 1 , 0 , 0 } ,
. { - 0.5 , 0 , 0.5 , 0 } ,
. { 1 , - 2.5 , 2 , - 0.5 } ,
. { - 0.5 , 1.5 , - 1.5 , 0.5 } ,
} ;
2023-01-27 14:35:11 +00:00
const weights = zml . Tensor . constantTensor ( zml . HostBuffer . fromArray ( & weights_ ) ) . convert ( dtype ) . withTags ( . { . _interpolated , . _neighbors } ) ;
2023-01-02 14:28:25 +00:00
// actually do the interpolation.
// Note: ideally this matmul should be inlined with the gather, but that's currently not the case.
2023-01-27 14:35:11 +00:00
// TODO: not being able to use dot here is a bit annoying.
var res = values . dotGeneral ( weights , & . { . { values . axis ( . _neighbors ) , weights . axis ( . _neighbors ) } } , & . { } ) ;
res = pos . dotGeneral ( res , & . { . { pos . axis ( . _interpolated ) , res . axis ( . _interpolated ) } } , & . { . { 0 , 0 } } ) ;
2023-01-02 14:28:25 +00:00
// the current axis is outputted in first position because it's a batching dim, put it back in place.
2023-01-27 14:35:11 +00:00
if ( axis ! = 0 ) {
res = res . swapAxes ( 0 , axis ) ;
}
2023-01-02 14:28:25 +00:00
// verify the shape
const res_shape = image . shape ( ) . set ( axis , new_len ) ;
// log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res });
std . debug . assert ( std . mem . eql ( i64 , res_shape . dims ( ) , res . dims ( ) ) ) ;
2023-01-27 14:35:11 +00:00
return res . convert ( image . dtype ( ) ) . withTags ( image . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Return causal attention masks for the given shape.
/// The last dimensions are
pub fn causalAttnMask (
attn_shape_ : anytype ,
dtype : DataType ,
attn_window_len : ? u32 ,
) Tensor {
const attn_shape = Shape . init ( attn_shape_ , dtype ) ;
meta . assert ( attn_shape . rank ( ) = = 2 , " causalAttnMask({}) shape need to be exactly 2 axes " , . { attn_shape } ) ;
const qlen = attn_shape . dim ( - 2 ) ;
const q_idx = Tensor . iota ( attn_shape , . i32 , - 2 ) ;
const klen = attn_shape . dim ( - 1 ) ;
const k_idx = Tensor . iota ( attn_shape , . i32 , - 1 ) ;
// all elements > main diagonal must be 0
// (q_idx - window_len < k_idx <= q_idx)
var mask = k_idx . cmp ( . LE , q_idx ) ;
if ( attn_window_len ) | window_len | {
if ( qlen > = window_len or klen > = window_len ) {
const window_mask = q_idx . cmp ( . LT , k_idx . addConstant ( window_len ) ) ;
mask = mask . logical ( . AND , window_mask ) ;
}
}
2023-02-16 10:36:23 +00:00
2023-01-02 14:28:25 +00:00
if ( dtype . isFloat ( ) ) {
2023-02-16 10:36:23 +00:00
const zeros = Tensor . constant ( mask . shape ( ) , dtype . zero ( ) ) ;
const minus_inf = Tensor . constant ( mask . shape ( ) , dtype . minValue ( ) ) ;
mask = Tensor . select ( mask , zeros , minus_inf ) ;
} else {
mask = mask . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
}
2023-02-16 10:36:23 +00:00
2023-01-02 14:28:25 +00:00
return mask ;
}
pub const SdpaOpts = struct {
attn_mask : ? Tensor = null ,
scale : ? Tensor = null ,
bias : ? Tensor = null ,
allow_cudnn : bool = true ,
// TODO: put a callback instead of all this field,
// so that
} ;
/// Scaled dot product attention.
///
/// **Shapes**:
/// - q, result: .{ .h, .q, .hd }
/// - k, v: .{ .h, .k, .hd }
///
/// Where:
/// - .h is the number of head
/// - .q is the number of queries
/// - .k is the number of keys
/// - .hd is the head dimension
///
/// .h is allowed to differ from queries and keys as long as the key heads
/// can be repeated to match query heads.
pub fn sdpa ( q_ : Tensor , k_ : Tensor , v_ : Tensor , opts : SdpaOpts ) Tensor {
var q , var k , var v = . { q_ , k_ , v_ } ;
const err_template = " sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! " ;
const err_args = . { q , k , v , opts . attn_mask } ;
meta . assert ( q . shape ( ) . hasTags ( . { . h , . q , . hd } ) , err_template + + " q is missing tags {{.h, .q, .hd}} " , err_args ) ;
meta . assert ( k . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " k is missing tags {{.h, .k, .hd}} " , err_args ) ;
meta . assert ( v . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " v is missing tags {{.h, .k, .hd}} " , err_args ) ;
if ( opts . allow_cudnn and cuda . canUseCudnnSdpa ( q . dim ( . hd ) , q . dtype ( ) ) ) {
return cuda . sdpa ( q , k , v , opts ) ;
}
if ( q . dim ( . h ) ! = k . dim ( . h ) ) {
meta . assert ( @mod ( q . dim ( . h ) , k . dim ( . h ) ) = = 0 , err_template + + " Different number of heads for keys and queries, but can't repeat keys. " , err_args ) ;
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep : u63 = @intCast ( @divExact ( q . dim ( . h ) , k . dim ( . h ) ) ) ;
k , v = . { k . repeat1d ( . h , num_rep ) , v . repeat1d ( . h , num_rep ) } ;
}
const attn_mask = if ( opts . attn_mask ) | m | m else null ;
const dims = helpers . collectDims ( . { . h , . q , . k , . hd } , & . { q , k , v , attn_mask } , . strict ) catch {
meta . panic ( err_template + + " Inputs have incompatible shapes. " , err_args ) ;
} ;
2023-04-21 15:55:07 +00:00
const sqrtHeadDim : f32 = 1.0 / std . math . sqrt ( @as ( f32 , @floatFromInt ( dims . hd ) ) ) ;
const scale_logit = if ( opts . scale ) | s | s else Tensor . scalar ( sqrtHeadDim , k . dtype ( ) ) ;
2023-01-02 14:28:25 +00:00
k = k . mul ( scale_logit . convert ( k . dtype ( ) ) ) ;
var attn_weights = q . dot ( k , . { . hd } ) ;
// log.debug("attn_weights : {}", .{attn_weights});
// log.debug("attn_mask : {?}", .{attn_mask});
if ( attn_mask ) | mask | attn_weights = attn_weights . add ( mask . broadcastLeft ( attn_weights . shape ( ) ) ) ;
attn_weights = attn_weights . convert ( . f32 ) ;
if ( opts . bias ) | bias | {
attn_weights = attn_weights . add ( bias ) ;
}
attn_weights = attn_weights . softmax ( . k ) . convert ( q . dtype ( ) ) ;
var attn = attn_weights . dot ( v , . { . k } ) ;
return attn . transpose ( q . shape ( ) ) ;
}
pub const MemEfficientOps = struct {
scale : ? f32 = null ,
query_chunk_size : u32 ,
key_chunk_size : u32 ,
opts : SdpaOpts = . { } ,
} ;
pub fn sdpaMemEfficient ( q_ : Tensor , k_ : Tensor , v_ : Tensor , opts : MemEfficientOps ) Tensor {
const q = q_ . withTags ( . { . b , . hq , . sq , . hd } ) ;
const k = k_ . withTags ( . { . b , . hk , . sk , . hd } ) ;
const v = v_ . withTags ( . { . b , . hk , . sk , . hd } ) ;
var sdpa_opts = opts . opts ;
if ( sdpa_opts . attn_mask ) | * attn_mask | attn_mask . * = attn_mask . withTags ( . { . sq , . sk } ) ;
const sdpa_mem_efficient : SdpaMemEfficient = . { . q = q , . k = k , . v = v , . opt = . {
. query_chunk_size = @intCast ( @min ( q . dim ( . sq ) , opts . query_chunk_size ) ) ,
. key_chunk_size = @intCast ( @min ( k . dim ( . sk ) , opts . key_chunk_size ) ) ,
. scale = opts . scale ,
. opts = sdpa_opts ,
} } ;
// TODO(Corentin): Maybe `withTags` could take a Shape to copy from.
var result = sdpa_mem_efficient . forward ( ) ;
result . _shape = q_ . shape ( ) ;
return result ;
}
const SdpaMemEfficient = struct {
q : Tensor ,
k : Tensor ,
v : Tensor ,
opt : MemEfficientOps ,
fn forward ( self : SdpaMemEfficient ) Tensor {
const n_q_chunks = @divExact ( self . q . dim ( . sq ) , self . opt . query_chunk_size ) ;
const res = ops . for_ ( SdpaMemEfficient . nextQueriesChunk , self , . { . nq = n_q_chunks } ) ;
// TODO: should "for_" operate on an axis ?
// res: (nq, b, nh, qlen / nq, dim) -> (b, nh, qlen, dim)
return res . transpose ( . { 1 , 2 , 0 , 3 , 4 } ) . flatten ( 2 ) ;
// return res.transpose(.{ .b, .hq, .nq, .sq, .hd }).merge(.{ .nq, .sq }, .sq);
}
fn nextQueriesChunk ( self : SdpaMemEfficient , idx : Tensor ) Tensor {
const offset = idx . scale ( self . opt . query_chunk_size ) ;
const q_chunk = self . q . dynamicSlice ( . { . sq = . { . start = offset , . len = self . opt . query_chunk_size } } ) ;
const attn_chunk = if ( self . opt . opts . attn_mask ) | attn_mask | attn_mask . dynamicSlice1d ( 0 , self . opt . query_chunk_size , offset ) else null ;
var chunk : SdpaMemEfficient = self ;
chunk . q = q_chunk ;
chunk . opt . opts . attn_mask = attn_chunk ;
return chunk . scanKeyVal ( ) ;
}
fn scanKeyVal ( self : SdpaMemEfficient ) Tensor {
const n_chunks = @divExact ( self . k . dim ( . sk ) , self . opt . key_chunk_size ) ;
const res = ops . for_ ( SdpaMemEfficient . nextKeyValChunk , self , . { . k_chunk = n_chunks } ) ;
const global_max = res . max_value . max ( . k_chunk ) . broad ( res . max_value . shape ( ) ) ;
const max_diffs = res . max_value . sub ( global_max ) . exp ( ) ;
const attn = res . attn . mul ( max_diffs . broad ( res . attn . shape ( ) ) ) . sum ( . k_chunk ) . squeeze ( . k_chunk ) ;
const exp_sum = res . exp_sum . mul ( max_diffs . convert ( . f32 ) ) . sum ( . k_chunk ) . squeeze ( . k_chunk ) . convert ( attn . dtype ( ) ) ;
return attn . div ( exp_sum . broad ( self . q . shape ( ) ) ) ;
}
fn nextKeyValChunk ( self : SdpaMemEfficient , idx : Tensor ) PartialAttn {
const offset = idx . scale ( self . opt . key_chunk_size ) ;
const k_chunk = self . k . dynamicSlice ( . { . sk = . { . start = offset , . len = self . opt . key_chunk_size } } ) ;
const v_chunk = self . v . dynamicSlice ( . { . sk = . { . start = offset , . len = self . opt . key_chunk_size } } ) ;
const attn_chunk = if ( self . opt . opts . attn_mask ) | mask | mask . dynamicSlice1d ( 1 , self . opt . key_chunk_size , offset ) else null ;
return sdpaChunk ( self . q , k_chunk , v_chunk , . { . attn_mask = attn_chunk } ) ;
}
} ;
pub const PartialAttn = struct {
attn : Tensor ,
exp_sum : Tensor ,
max_value : Tensor ,
} ;
/// Compute softmax over a chunk.
/// Returns intermediary results to allow aggregating later.
pub fn partialSoftmax ( self : Tensor , axis : anytype ) PartialAttn {
const a = self . axis ( axis ) ;
const max_val = self . max ( a ) ;
const out = self . sub ( max_val . broad ( self . shape ( ) ) ) . exp ( ) ;
return . {
. attn = out ,
. exp_sum = out . convert ( . f32 ) . sum ( a ) . squeeze ( a ) ,
. max_value = max_val . squeeze ( a ) ,
} ;
}
/// Compute sdpa on a chunk, and computes a partial softmax.
/// q: (B, H, Sq, H_dim) ⊙ k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk)
fn sdpaChunk ( q : Tensor , k : Tensor , v : Tensor , opts : SdpaOpts ) PartialAttn {
// const bs, const num_head, const sk, const h_dim = q.dims[0..4];
// TODO: rewrite using modern ZML
// If we have more query heads (hq) than key heads (hk), repeat keys.
const k_rep , const v_rep = if ( q . dim ( . hq ) ! = k . dim ( . hk ) ) blk : {
const num_rep : u63 = @intCast ( @divExact ( q . dim ( . hq ) , k . dim ( . hk ) ) ) ;
break : blk . { k . repeat1d ( 0 , num_rep ) . rename ( . { . hk = . hq } ) , v . repeat1d ( 0 , num_rep ) . rename ( . { . hk = . hq } ) } ;
} else . { k . rename ( . { . hk = . hq } ) , v . rename ( . { . hk = . hq } ) } ;
var qk = q . dot ( k_rep , . { . hd } ) ;
const sqrtHeadDim : f32 = 1.0 / std . math . sqrt ( @as ( f32 , @floatFromInt ( q . dim ( . hd ) ) ) ) ;
qk = qk . scale ( sqrtHeadDim ) ;
std . debug . assert ( qk . rank ( ) = = q . rank ( ) ) ;
if ( opts . attn_mask ) | mask | {
qk = qk . add ( mask . broad ( qk . shape ( ) ) ) ;
}
const partial = partialSoftmax ( qk , - 1 ) ;
const attn = partial . attn . dot ( v_rep , . { . sk } ) ;
return . {
. attn = attn ,
. exp_sum = partial . exp_sum ,
. max_value = partial . max_value ,
} ;
}
2023-06-16 14:34:18 +00:00
2023-01-02 14:28:25 +00:00
test " sdpaMemEfficient without mask " {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
2023-01-23 16:28:19 +00:00
const rng = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 1 , 10 , 512 , 64 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng . deinit ( ) ;
// Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable.
const q = rng . call ( undefined ) ;
const k = rng . call ( undefined ) ;
const v = rng . call ( undefined ) ;
const ref_res = try zml . testing . compileAndCallWithTensors ( platform , sdpa , . {
q . shape ( ) . withTags ( . { . b , . h , . q , . hd } ) ,
k . shape ( ) . withTags ( . { . b , . h , . k , . hd } ) ,
v . shape ( ) . withTags ( . { . b , . h , . k , . hd } ) ,
. { . attn_mask = null , . scale = null , . bias = null } ,
} , . { q , k , v , undefined } ) ;
try std . testing . expectEqualSlices ( i64 , q . shape ( ) . dims ( ) , ref_res . shape ( ) . dims ( ) ) ;
const opts : zml . ShapeOf ( MemEfficientOps ) = . { . query_chunk_size = 256 , . key_chunk_size = 128 , . opts = . { . attn_mask = null , . scale = null , . bias = null } } ;
const res = try zml . testing . compileAndCallWithTensors (
platform ,
sdpaMemEfficient ,
. { q . shape ( ) , k . shape ( ) , v . shape ( ) , opts } ,
. { q , k , v , undefined } ,
) ;
try zml . testing . expectClose ( ref_res , res , 2e-3 ) ;
}
test " sdpaMemEfficient with mask " {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
2023-01-23 16:28:19 +00:00
const rng = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 1 , 10 , 512 , 64 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng . deinit ( ) ;
2023-01-23 16:28:19 +00:00
const rng_mask = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 512 , 512 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng_mask . deinit ( ) ;
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
const q = rng . call ( undefined ) ;
const k = rng . call ( undefined ) ;
const v = rng . call ( undefined ) ;
const mask = rng_mask . call ( undefined ) ;
const ref_res = try zml . testing . compileAndCall ( platform , sdpa , . { q . withTags ( . { . b , . h , . q , . hd } ) , k . withTags ( . { . b , . h , . k , . hd } ) , v . withTags ( . { . b , . h , . k , . hd } ) , . { . attn_mask = mask . withTags ( . { . q , . k } ) , . scale = null , . bias = null } } ) ;
try std . testing . expectEqualSlices ( i64 , q . shape ( ) . dims ( ) , ref_res . shape ( ) . dims ( ) ) ;
const res = try zml . testing . compileAndCall ( platform , sdpaMemEfficient , . { q , k , v , . {
. query_chunk_size = 256 ,
. key_chunk_size = 128 ,
. scale = null ,
. opts = . { . attn_mask = mask , . scale = null , . bias = null , . allow_cudnn = false } ,
} } ) ;
try zml . testing . expectClose ( ref_res , res , 2e-3 ) ;
}
/// Options controlling generation. The default values correspond to greedy decoding.
pub const SamplingStrategy = struct {
topk : u32 = 1 ,
temperature : f32 = 1.0 ,
} ;
/// Given the output of the last layer of a LM with a `.voc` axis,
/// Compute indices for the next tokens, following the given sampling strategy.
/// Returns an integer tensor with a shape similar to the input, but without the .voc axis.
pub fn sampleTokens ( activations : Tensor , opts : SamplingStrategy , rng : Tensor . Rng ) struct { Tensor , Tensor . Rng } {
if ( opts . topk < = 1 ) {
const next_tokens = activations . argMax ( . voc , . i32 ) . indices . squeeze ( . voc ) ;
return . { next_tokens , rng } ;
}
const topk = activations . topK ( opts . topk , . voc , . { } ) ;
// After the topk, we don't have .voc values, anymore, only topk.
var x = topk . values . rename ( . { . voc = . topk } ) ;
if ( opts . temperature ! = 1.0 ) {
x = x . scale ( 1 / opts . temperature ) ;
}
// Gumbel reparametrization trick:
// Adding gumbel noise and taking the argmax is equivalent
// to sampling from the categorical distribution produced by the softmax.
// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
const next_rng , const gumbel_noise = rng . gumbel ( x . shape ( ) ) ;
x = x . add ( gumbel_noise ) ;
const topk_idx = x . argMax ( . topk , . i32 ) . indices ;
// topk_idx is indices into topk.values ! so in the range [0, topk]
// Convert for the original indices from the full [0, voc] range.
2023-01-18 12:03:48 +00:00
const next_tokens = topk . indices . gatherValues ( . voc , topk_idx . squeeze ( . topk ) , . { } ) ;
2023-01-02 14:28:25 +00:00
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return . { next_tokens , next_rng } ;
}
2023-06-16 14:34:18 +00:00
test sampleTokens {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const inf = std . math . inf ( f32 ) ;
var rng_buff = try zml . Tensor . Rng . init ( platform , 0xdeadbeef ) ;
defer rng_buff . _state . deinit ( ) ;
const mod = try zml . compileFn ( allocator , sampleTokens , . { Shape . init ( . { . voc = 4 } , . f32 ) , . { . topk = 4 , . temperature = 2.0 } , zml . Tensor . Rng . shape ( ) } , platform ) ;
defer mod . deinit ( ) ;
inline for ( . {
. { [ _ ] f32 { inf , 3.0 , 2.0 , 1.0 } , 0 } ,
. { [ _ ] f32 { - inf , 3.0 , - inf , - inf } , 1 } ,
. { [ _ ] f32 { 3.0 , 2 , inf , inf } , 2 } ,
} ) | logits_expected | {
const logits , const expected : i32 = logits_expected ;
var logits_buff = try zml . Buffer . fromArray ( platform , logits ) ;
defer logits_buff . deinit ( ) ;
var sampled , rng_buff = mod . call ( . { logits_buff , undefined , rng_buff } ) ;
defer sampled . deinit ( ) ;
try zml . testing . expectEqual ( expected , try sampled . getValue ( i32 ) ) ;
}
}
pub const DynamicSamplingStrategy = struct {
max_top_k : u32 ,
top_k : Tensor ,
temperature : Tensor ,
top_p : Tensor ,
min_p : Tensor ,
pub fn shapes ( dtype : DataType , max_top_k : u32 ) zml . ShapeOf ( DynamicSamplingStrategy ) {
const scalar_float = Shape . init ( . { } , dtype ) ;
const scalar_i32 = Shape . init ( . { } , . i32 ) ;
return . {
. max_top_k = max_top_k ,
. top_k = scalar_i32 ,
. temperature = scalar_float ,
. top_p = scalar_float ,
. min_p = scalar_float ,
} ;
}
pub fn makeBuffers (
platform : zml . Platform ,
dtype : zml . DataType ,
args : struct {
top_k : u32 ,
temperature : f32 = 1.0 ,
top_p : f32 = 1.0 ,
min_p : f32 = 0.0 ,
} ,
) ! zml . Bufferized ( DynamicSamplingStrategy ) {
return . {
. max_top_k = 0 ,
. top_k = try zml . Buffer . scalar ( platform , args . top_k , . i32 ) ,
. temperature = try zml . Buffer . scalar ( platform , args . temperature , dtype ) ,
. top_p = try zml . Buffer . scalar ( platform , args . top_p , dtype ) ,
. min_p = try zml . Buffer . scalar ( platform , args . min_p , dtype ) ,
} ;
}
} ;
/// Given the output of the last layer of a LM with a `.voc` axis,
/// Compute indices for the next tokens, following the given sampling strategy.
/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax.
///
/// Options are:
///
/// * top_k: only sample among the k top scoring tokens,
/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens.
/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p
/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token
pub fn sampleTokensDynamic ( logits : Tensor , opts : DynamicSamplingStrategy , rng : Tensor . Rng ) struct { Tensor , Tensor . Rng } {
var x , const topk_indices = fixupLogits ( logits , opts ) ;
// the rest is similar to sampleTokens
const next_rng , const gumbel_noise = rng . gumbel ( x . shape ( ) ) ;
x = x . add ( gumbel_noise ) ;
const topk_idx = x . argMax ( . topk , . i32 ) . indices ;
const next_tokens = topk_indices . gatherValues ( . voc , topk_idx . squeeze ( . topk ) , . { } ) ;
return . { next_tokens , next_rng } ;
}
fn fixupLogits ( logits : Tensor , opts : DynamicSamplingStrategy ) [ 2 ] Tensor {
const min_inf = Tensor . constant ( . { } , logits . dtype ( ) . minValue ( ) ) ;
// First reduce the vocab size to a reasonable sub set of candidate.
const full_topk = if ( opts . max_top_k > 0 )
logits . topK ( opts . max_top_k , . voc , . { . descending = true } )
else
logits . sort ( . voc , . { . descending = true } ) ;
// After the topk, we don't have .voc indices, anymore, only topk.
var x = full_topk . values . rename ( . { . voc = . topk } ) ;
// mask values above the dynamic top_k
x = Tensor . iota ( x . shape ( ) , . topk ) . cmp ( . GE , opts . top_k ) . select ( min_inf , x ) ;
x = x . mul ( opts . temperature ) ;
// if there are high values in x, softmax can overflow and will create nans in full probs
// this propagate to probs_sum and probs_max.
const probs = x . softmax ( . topk ) ;
const probs_sum = probs . cumulativeSum ( . topk ) ;
const probs_max = probs . slice1d ( . topk , . { . start = 0 , . end = 1 } ) ;
const top_p = opts . top_p . broad ( x . shape ( ) ) ;
const min_p = probs_max . mul ( opts . min_p ) . broad ( x . shape ( ) ) ;
// * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false
// * if first candidate score is even bigger, the probs become Nan because of the softmax,
// then cmp is is full false, and candidate is full false too.
const candidate = probs_sum . cmp ( . LE , top_p ) . logical ( . AND , probs . cmp ( . GE , min_p ) ) ;
// * so we explicitly always accept first candidate.
const first_token = Tensor . iota ( x . shape ( ) , . topk ) . cmp ( . EQ , Tensor . scalar ( 0 , . i32 ) ) ;
x = candidate . logical ( . OR , first_token ) . select ( x , min_inf ) ;
return . { x , full_topk . indices } ;
}
test sampleTokensDynamic {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const ___ = - std . math . inf ( f32 ) ;
const logits = [ _ ] f32 { @log ( 2.0 ) , @log ( 1.0 ) , @log ( 4.0 ) , @log ( 3.0 ) } ;
const top_k_indices = [ _ ] i32 { 2 , 3 , 0 , 1 } ;
const logits_buff = try zml . Buffer . fromArray ( platform , logits ) ;
const mod = try zml . compileFn ( allocator , fixupLogits , . { Shape . init ( . { . voc = logits . len } , . f32 ) , DynamicSamplingStrategy . shapes ( . f32 , 0 ) } , platform ) ;
defer mod . deinit ( ) ;
inline for ( . {
// top_k == logits.len -> just sort the input
. { . { . top_k = 4 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , @log ( 2.0 ) , @log ( 1.0 ) } } ,
. { . { . top_k = 2 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
. { . { . top_k = 2 , . temperature = 0.1 } , [ _ ] f32 { @log ( 4.0 ) * 0.1 , @log ( 3.0 ) * 0.1 , ___ , ___ } } ,
// top_k == logits.len and small top_p -> make sure at least one is returned
. { . { . top_k = 4 , . top_p = 0.1 } , [ _ ] f32 { @log ( 4.0 ) , ___ , ___ , ___ } } ,
. { . { . top_k = 4 , . top_p = 0.701 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
. { . { . top_k = 4 , . top_p = 0.901 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , @log ( 2.0 ) , ___ } } ,
// Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0)
. { . { . top_k = 3 , . top_p = 0.701 } , [ _ ] f32 { @log ( 4.0 ) , ___ , ___ , ___ } } ,
// Here top_p allows the first 3 results, but min_p only accepts the first two.
. { . { . top_k = 4 , . top_p = 0.901 , . min_p = 0.6 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
} ) | args_expected | {
const args , const expected = args_expected ;
const new_logits , const indices = mod . call ( . { logits_buff , try DynamicSamplingStrategy . makeBuffers ( platform , . f32 , args ) } ) ;
try std . testing . expectEqual ( top_k_indices , try indices . getValue ( @TypeOf ( top_k_indices ) ) ) ;
try zml . testing . expectEqual ( expected , try new_logits . getValue ( @TypeOf ( expected ) ) ) ;
}
{
// Similar but use bf16, and uses infinity to trigger nans after the softmax.
const bf16 = zml . floats . BFloat16 ;
const mod_bf16 = try zml . compileFn ( allocator , fixupLogits , . { Shape . init ( . { . voc = logits . len } , . bf16 ) , DynamicSamplingStrategy . shapes ( . bf16 , 0 ) } , platform ) ;
defer mod_bf16 . deinit ( ) ;
const boost = bf16 . inf ( ) ;
const nerf = bf16 . minusInf ( ) ;
const logits_buff_2 = try zml . Buffer . fromArray ( platform , [ 4 ] bf16 { boost , boost , bf16 . fromF32 ( 2 ) , nerf } ) ;
const new_logits , const indices = mod_bf16 . call ( . { logits_buff_2 , try DynamicSamplingStrategy . makeBuffers ( platform , . bf16 , . { . top_k = 4 , . top_p = 0.9 , . min_p = 0.1 } ) } ) ;
try std . testing . expectEqual ( [ _ ] i32 { 0 , 1 , 2 , 3 } , try indices . getValue ( [ 4 ] i32 ) ) ;
try zml . testing . expectEqual ( [ _ ] bf16 { boost , nerf , nerf , nerf } , try new_logits . getValue ( [ 4 ] bf16 ) ) ;
}
}