2023-01-02 14:28:25 +00:00
//! Common layer definition and functions for Neural Networks (NN)
const std = @import ( " std " ) ;
2024-10-07 12:53:03 +00:00
const assert = std . debug . assert ;
const testing = std . testing ;
2023-06-21 14:45:14 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
2024-10-07 12:53:03 +00:00
const DataType = @import ( " dtype.zig " ) . DataType ;
2023-01-02 14:28:25 +00:00
const helpers = @import ( " helpers.zig " ) ;
2023-06-21 14:45:14 +00:00
const meta = @import ( " meta.zig " ) ;
2024-10-07 12:53:03 +00:00
const cuda = @import ( " nn/cuda.zig " ) ;
2023-01-02 14:28:25 +00:00
const ops = @import ( " ops.zig " ) ;
const Shape = @import ( " shape.zig " ) . Shape ;
const Tensor = @import ( " tensor.zig " ) . Tensor ;
2024-10-07 12:53:03 +00:00
const zml = @import ( " zml.zig " ) ;
2023-01-02 14:28:25 +00:00
2023-06-21 14:45:14 +00:00
const log = std . log . scoped ( . @ " zml/tensor " ) ;
2023-01-27 14:35:11 +00:00
test {
_ = cuda ;
std . testing . refAllDecls ( @This ( ) ) ;
}
2023-01-02 14:28:25 +00:00
pub const Linear = struct {
weight : Tensor ,
bias : ? Tensor = null ,
pub fn forward ( self : Linear , x : Tensor ) Tensor {
var y = x . dotGeneral ( self . weight . convert ( x . dtype ( ) ) , & . { . { - 1 , - 1 } } , & . { } ) ;
// If self.weight doesn't have tags, preserve tags from x.
if ( y . shape ( ) . tag ( - 1 ) = = Shape . TagUnknown ) {
y . _shape . _tags . set ( y . rank ( ) - 1 , x . shape ( ) . tag ( - 1 ) ) ;
}
// log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() });
return if ( self . bias ) | bias | y . add ( bias . broadcastLeft ( y . shape ( ) ) ) else y ;
}
} ;
pub const TokenEmbedding = struct {
weight : Tensor ,
pub fn forward ( self : TokenEmbedding , idx : Tensor ) Tensor {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( idx . dtype ( ) . isInteger ( ) , " TokenEmbedding expects an integer input, received: {f} " , . { idx } ) ;
stdx . debug . assert ( self . weight . rank ( ) = = 2 , " TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f} " , . { self . weight } ) ;
2023-01-18 12:03:48 +00:00
return self . weight . gatherValues ( 0 , idx , . { } ) ;
2023-01-02 14:28:25 +00:00
}
} ;
pub const Activation = union ( enum ) {
sigmoid ,
tanh ,
relu ,
leakyReLU : f32 ,
2023-05-18 16:39:21 +00:00
elu : f32 ,
2023-01-02 14:28:25 +00:00
silu ,
gelu ,
quick_gelu ,
pub fn forward ( self : Activation , x : Tensor ) Tensor {
return switch ( self ) {
. sigmoid = > x . sigmoid ( ) ,
. tanh = > x . tanh ( ) ,
. relu = > x . relu ( ) ,
. silu = > x . silu ( ) ,
. gelu = > x . gelu ( ) ,
2023-05-18 16:39:21 +00:00
. elu = > | alpha | elu ( x , alpha ) ,
2023-01-02 14:28:25 +00:00
. quick_gelu = > x . quickGelu ( ) ,
. leakyReLU = > | slope | x . leakyReLU ( slope ) ,
} ;
}
} ;
2024-06-14 15:27:06 +00:00
pub fn elu ( x : Tensor , alpha : f32 ) Tensor {
return x . cmp ( . GE , Tensor . scalar ( 0 , x . dtype ( ) ) ) . select (
x ,
x . exp ( ) . addConstant ( - 1 ) . scale ( alpha ) ,
) ;
}
2023-01-02 14:28:25 +00:00
pub fn chainModules ( module_list : anytype , input : Tensor ) Tensor {
const T = @TypeOf ( module_list ) ;
switch ( @typeInfo ( T ) ) {
. Struct = > | struct_info | {
var x = input ;
inline for ( struct_info . fields ) | field | {
x = @field ( module_list , field . name ) . forward ( x ) ;
}
return x ;
} ,
else = > @compileError ( " chainModules only works on a struct with only containing 'module' struct. " ) ,
}
}
/// Layer Normalization
pub const LayerNorm = struct {
weight : Tensor ,
bias : ? Tensor = null ,
eps : f32 = 1e-5 ,
pub fn forward ( self : LayerNorm , x : Tensor ) Tensor {
const normed = normalizeVariance ( x , self . eps ) ;
2023-04-21 15:55:07 +00:00
const ax = x . axis ( - 1 ) ;
var out = normed . mul ( self . weight . broadcast ( x . shape ( ) , & . { ax } ) ) ;
if ( self . bias ) | bias | out = out . add ( bias . broadcast ( x . shape ( ) , & . { ax } ) ) ;
2023-01-02 14:28:25 +00:00
return out ;
}
} ;
2024-10-07 12:53:03 +00:00
pub fn rmsNorm ( x : Tensor , axis : anytype , eps : f32 ) Tensor {
const ax = x . axis ( axis ) ;
// upcast to improve precision
const xf32 = x . convert ( . f32 ) ;
const mean = xf32 . mul ( xf32 ) . mean ( ax ) ;
const rsqrt = Tensor . rsqrt ( mean . addConstant ( eps ) ) . convert ( x . dtype ( ) ) ;
return x . mul ( rsqrt . broad ( x . shape ( ) ) ) ;
}
2023-01-02 14:28:25 +00:00
/// Center and scale by the variance.
/// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps)
/// Work on the last axis.
pub fn normalizeVariance ( x : Tensor , eps : f32 ) Tensor {
const N : f32 = @floatFromInt ( x . dim ( - 1 ) ) ;
// Upcast to improve precision
const xf32 = x . convert ( . f32 ) ;
const mean = xf32 . sum ( - 1 ) . scale ( 1.0 / N ) ;
const mean_dev = xf32 . sub ( mean . broadcastRight ( xf32 . shape ( ) ) ) ;
const variance = mean_dev . mul ( mean_dev ) . sum ( - 1 ) . scale ( 1.0 / N ) ;
const rsqrt = Tensor . rsqrt ( variance . addConstant ( eps ) ) ;
return mean_dev . mul ( rsqrt . broadcastRight ( mean_dev . shape ( ) ) ) . convert ( x . dtype ( ) ) ;
}
// ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// Implementation equivalent to `nn.functional.normalize(tensor, dim=-1)` call
pub fn normalizeL2 ( input : Tensor , eps : f32 ) Tensor {
const inv_norm = input . pow ( Tensor . scalar ( 2 , input . dtype ( ) ) ) . sum ( - 1 ) . addConstant ( eps ) . rsqrt ( ) ;
return input . mul ( inv_norm . broad ( input . shape ( ) ) ) ;
}
test normalizeL2 {
const platform = zml . testing . env ( ) ;
const input = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f32 { - 0.9686 , - 1.0058 , - 1.7808 , 0.6698 } ) ;
const res = try zml . testing . compileAndCall ( platform , zml . nn . normalizeL2 , . { input , 1e-12 } ) ;
const expectation = zml . HostBuffer . fromSlice ( . { 2 , 2 } , & [ _ ] f32 { - 0.6937 , - 0.7203 , - 0.9360 , 0.3520 } ) ;
try zml . testing . expectClose ( expectation , res , 1e-4 ) ;
}
pub const RopeOpts = struct {
2024-10-07 12:53:03 +00:00
layout : Layout = . sequential ,
freq_base : f32 = 10_000 ,
scaling : Scaling = . default ,
/// There are two layouts corresponding to how to split `x` in real/imag parts.
2023-01-02 14:28:25 +00:00
/// * Interleaved means that the real/imag of each scalar is contiguous.
/// * Sequential means that you first read all real values then all imag values.
2024-10-07 12:53:03 +00:00
/// Typically HF models use sequential, while GGUF use interleaved.
pub const Layout = enum { interleaved , sequential } ;
/// There are several ways to init the scaling aka "inv_freq"
pub const Scaling = union ( enum ) {
default : void ,
custom : [ ] const f32 ,
llama3 : Llama3 ,
pub const Llama3 = struct {
factor : f32 ,
high_freq_factor : f32 ,
low_freq_factor : f32 ,
original_max_position_embeddings : u32 ,
} ;
2023-01-02 14:28:25 +00:00
2024-10-07 12:53:03 +00:00
/// Read a Rope scaling config from HF config.json format.
pub fn jsonParse ( allocator : std . mem . Allocator , source : anytype , options : std . json . ParseOptions ) ! Scaling {
const content = try std . json . Value . jsonParse ( allocator , source , options ) ;
2024-10-28 11:21:46 +00:00
if ( content = = . null ) return . default ;
2024-10-07 12:53:03 +00:00
if ( content ! = . object ) return error . InvalidEnumTag ;
const obj = content . object ;
const impl = obj . get ( " rope_type " ) orelse return error . MissingField ;
if ( impl ! = . string ) return error . InvalidEnumTag ;
if ( std . mem . eql ( u8 , impl . string , " llama3 " ) ) {
// Note: leaky is fine here cause Llama3 struct don't need to allocate memory.
return . { . llama3 = try std . json . parseFromValueLeaky ( Llama3 , undefined , content , . { . ignore_unknown_fields = true } ) } ;
} else {
log . warn ( " Unsupported Rope implementation: {s}, will use the default one which will produce altered results " , . { impl . string } ) ;
return . { . default = { } } ;
}
}
} ;
2023-01-02 14:28:25 +00:00
} ;
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
/// This biases a token to look at token near him.
2023-09-27 11:45:33 +00:00
/// The nice thing with rope is that you can cache the modified queries and keys directly.
2023-01-02 14:28:25 +00:00
/// See: https://paperswithcode.com/method/rope
2023-09-27 11:45:33 +00:00
///
/// Expected shapes of tensor:
/// - x: .{ .s, .hd } where .s is the sequence length and .hd the head dimension
/// - pos_idx: optional tensor which indicates which positions are needed.
/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len.
pub fn rope ( x : Tensor , pos_idx : ? Tensor , opts : RopeOpts ) Tensor {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( @mod ( x . dim ( . hd ) , 2 ) = = 0 , " rope expects a even head dim (.hd), got {f} " , . { x } ) ;
2023-09-27 11:45:33 +00:00
const idx = if ( pos_idx ) | idx | blk : {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( x . shape ( ) . hasTags ( . { . hd } ) , " rope expects x argument to have .hd axes got: rope(x={f}, idx={f}) " , . { x , idx } ) ;
2023-09-27 11:45:33 +00:00
break : blk idx ;
} else blk : {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( x . shape ( ) . hasTags ( . { . s , . hd } ) , " rope expects x argument to have both .s and .hd axes got: rope(x={f}) " , . { x } ) ;
2023-09-27 11:45:33 +00:00
break : blk Tensor . arange ( . { . end = x . dim ( . s ) } , . f32 ) . withTags ( . { . s } ) ;
} ;
2024-10-07 12:53:03 +00:00
const x_real , const x_imag = splitRealImg ( x , opts . layout ) ;
2023-09-27 11:45:33 +00:00
// compute sin and cos in f32 before downcasting to x type.
2024-10-07 12:53:03 +00:00
const inv_freq = invFreq ( x . dim ( . hd ) , opts ) . withTags ( . { . hd } ) ;
2023-09-27 11:45:33 +00:00
const inv_freq_pos = Tensor . outer ( idx . convert ( . f32 ) , inv_freq ) ;
const cos = inv_freq_pos . cos ( ) . convert ( x . dtype ( ) ) . broad ( x_real . shape ( ) ) ;
const sin = inv_freq_pos . sin ( ) . convert ( x . dtype ( ) ) . broad ( x_real . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
// apply rotation
2023-09-27 11:45:33 +00:00
const y_real = x_real . mul ( cos ) . sub ( x_imag . mul ( sin ) ) ;
const y_imag = x_real . mul ( sin ) . add ( x_imag . mul ( cos ) ) ;
2023-01-02 14:28:25 +00:00
// flatten last dimensions
2024-10-07 12:53:03 +00:00
return mergeRealImg ( y_real , y_imag , opts . layout ) ;
2023-01-02 14:28:25 +00:00
}
2024-10-07 12:53:03 +00:00
pub fn splitRealImg ( x : Tensor , layout : RopeOpts . Layout ) [ 2 ] Tensor {
2023-01-02 14:28:25 +00:00
const n = x . dim ( - 1 ) ;
2024-10-07 12:53:03 +00:00
return switch ( layout ) {
2023-01-02 14:28:25 +00:00
. sequential = > . {
x . slice1d ( - 1 , . { . end = @divExact ( n , 2 ) } ) ,
x . slice1d ( - 1 , . { . start = @divExact ( n , 2 ) , . end = n } ) ,
} ,
. interleaved = > . {
x . slice1d ( - 1 , . { . start = 0 , . step = 2 } ) ,
x . slice1d ( - 1 , . { . start = 1 , . step = 2 } ) ,
} ,
} ;
}
2024-10-07 12:53:03 +00:00
pub fn mergeRealImg ( x_real : Tensor , x_imag : Tensor , layout : RopeOpts . Layout ) Tensor {
return switch ( layout ) {
2023-01-02 14:28:25 +00:00
. sequential = > Tensor . concatenate ( & . { x_real , x_imag } , - 1 ) ,
. interleaved = > Tensor . concatenate ( & . {
x_real . appendAxes ( . { . interleaved_real_img } ) ,
x_imag . appendAxes ( . { . interleaved_real_img } ) ,
} , - 1 ) . flatten ( - 2 ) ,
} ;
}
/// {exp( - n * ln(10_000) / N ) | n in [0..N] }
2024-10-07 12:53:03 +00:00
pub fn invFreq ( N : i64 , opts : RopeOpts ) Tensor {
const allocator = zml . module . CompilationContext . current ( ) . allocator ( ) ;
const N_half : usize = @intCast ( @divExact ( N , 2 ) ) ;
const inv_freq = allocator . alloc ( f32 , N_half ) catch @panic ( " OOM " ) ;
_invFreq ( opts , inv_freq ) ;
return zml . Tensor . constantTensor ( . fromSlice ( . { @divExact ( N , 2 ) } , inv_freq ) ) ;
}
fn _invFreq ( opts : RopeOpts , inv_freq : [ ] f32 ) void {
const N = inv_freq . len ;
// Default frequencies
for ( 0 . . , inv_freq ) | n , * f | {
f . * = @exp ( - @log ( opts . freq_base ) * stdx . math . divFloat ( f32 , n , N ) ) ;
}
switch ( opts . scaling ) {
. default = > { } ,
. custom = > {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( opts . scaling . custom . len = = N , " rope expected custom inv_freq to match half head dimension {d}, got {d} " , . { N , opts . scaling . custom . len } ) ;
2024-10-07 12:53:03 +00:00
@memcpy ( inv_freq , opts . scaling . custom ) ;
} ,
. llama3 = > | s | {
// https://arxiv.org/pdf/2309.16039
// After Llama2 they observed that the rope frequencies where too sharp and hurting long distance attention.
// In Llama3 they used a higher base freq and also downscaled low frequencies.
std . debug . assert ( s . low_freq_factor < s . high_freq_factor ) ;
const M : f64 = @floatFromInt ( s . original_max_position_embeddings ) ;
const f_high = s . high_freq_factor * ( 2 * std . math . pi ) / M ;
const f_low = s . low_freq_factor * ( 2 * std . math . pi ) / M ;
const downscaling = 1.0 / s . factor ;
for ( 0 . . N , inv_freq ) | n , f | {
if ( f > f_high ) {
// High freq match default implem
} else if ( f < f_low ) {
// Downscaling for low freq
inv_freq [ n ] * = downscaling ;
} else {
// Linear interpolation for middle freq
const lerp : f64 = ( inv_freq [ n ] - f_low ) / ( f_high - f_low ) ;
inv_freq [ n ] * = @floatCast ( lerp + ( 1 - lerp ) * downscaling ) ;
}
}
} ,
}
}
test invFreq {
// Llama 3.2-1B config
const llama_conf : RopeOpts = . {
. freq_base = 500_000 ,
. scaling = . { . llama3 = . {
. factor = 32 ,
. high_freq_factor = 4 ,
. low_freq_factor = 1 ,
. original_max_position_embeddings = 8192 ,
} } ,
} ;
const llama_freq = [ _ ] f32 { 1.000000000000e+00 , 6.636012792587e-01 , 4.403666257858e-01 , 2.922278344631e-01 , 1.939227581024e-01 , 1.286873817444e-01 , 8.539710193872e-02 , 5.666961893439e-02 , 3.760603070259e-02 , 2.495540864766e-02 , 1.656044088304e-02 , 1.098952908069e-02 , 7.292665075511e-03 , 4.839421249926e-03 , 3.211446106434e-03 , 1.290548010729e-03 , 4.295567050576e-04 , 9.708286233945e-05 , 1.946163865796e-05 , 1.291476746701e-05 , 8.570255886298e-06 , 5.687232260243e-06 , 3.774054448513e-06 , 2.504467147446e-06 , 1.661967417022e-06 , 1.102883629756e-06 , 7.318749339902e-07 , 4.856731266045e-07 , 3.222932889457e-07 , 2.138742303259e-07 , 1.419272024350e-07 , 9.418306490261e-08 } ;
var inv_freq : @TypeOf ( llama_freq ) = undefined ;
_invFreq ( llama_conf , & inv_freq ) ;
for ( llama_freq , inv_freq , 0 . . ) | expected , actual , i | {
2025-07-28 13:54:28 +00:00
errdefer log . err ( " Mismatch at position {d}. \n Expected: {any} \n Actual: {any} " , . { i , llama_freq , inv_freq } ) ;
2024-10-07 12:53:03 +00:00
try std . testing . expectApproxEqRel ( expected , actual , 1e-5 ) ;
}
2023-01-02 14:28:25 +00:00
}
test " real/img " {
const platform = zml . testing . env ( ) ;
const Fns = struct {
2024-10-07 12:53:03 +00:00
fn testSplitMergeIsId ( layout : RopeOpts . Layout ) Tensor {
2023-01-02 14:28:25 +00:00
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
2024-10-07 12:53:03 +00:00
const real , const imag = splitRealImg ( x , layout ) ;
const y = mergeRealImg ( real , imag , layout ) ;
const real2 , const imag2 = splitRealImg ( y , layout ) ;
2023-01-02 14:28:25 +00:00
return real . cmp ( . EQ , real2 ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , imag2 ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitSeqVoid ( _ : void ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . sequential ) ;
const x_real = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 0 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 1 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
const x_imag = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 2 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 3 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitSeq ( ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . sequential ) ;
const x_real = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 0 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 1 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
const x_imag = Tensor . concatenate ( & . {
Tensor . arange ( . { . start = 2 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
Tensor . arange ( . { . start = 3 , . end = 20 , . step = 4 } , . f32 ) . reshape ( . { 5 , 1 } ) ,
} , 1 ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
fn testSplitInterleaved ( ) Tensor {
const x = Tensor . arange ( . { . end = 20 } , . f32 ) . reshape ( . { 5 , 4 } ) ;
const real , const imag = splitRealImg ( x , . interleaved ) ;
const x_real = Tensor . arange ( . { . start = 0 , . end = 20 , . step = 2 } , . f32 ) . reshape ( . { 5 , 2 } ) ;
const x_imag = Tensor . arange ( . { . start = 1 , . end = 20 , . step = 2 } , . f32 ) . reshape ( . { 5 , 2 } ) ;
return real . cmp ( . EQ , x_real ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) . add (
imag . cmp ( . EQ , x_imag ) . flatten ( 0 ) . convert ( . i32 ) . sum ( - 1 ) ,
) ;
}
} ;
const d_interleaved = try zml . testing . compileAndCall ( platform , Fns . testSplitMergeIsId , . { . interleaved } ) ;
try testing . expectEqual ( 20 , d_interleaved . getValue ( i32 ) ) ;
const d_sequential = try zml . testing . compileAndCall ( platform , Fns . testSplitMergeIsId , . { . sequential } ) ;
try testing . expectEqual ( 20 , d_sequential . getValue ( i32 ) ) ;
// test the function that accepts 1 void argument
const d_split_seq_void = try zml . testing . compileAndCall ( platform , Fns . testSplitSeqVoid , . { { } } ) ;
try testing . expectEqual ( 20 , d_split_seq_void . getValue ( i32 ) ) ;
// test the function that takes NO arguments
const d_split_seq = try zml . testing . compileAndCall ( platform , Fns . testSplitSeq , . { } ) ;
try testing . expectEqual ( 20 , d_split_seq . getValue ( i32 ) ) ;
// now try compiling and calling ourselves
{
const mod = try zml . compileFn ( std . testing . allocator , Fns . testSplitSeq , . { } , platform ) ;
defer mod . deinit ( ) ;
2024-11-28 12:24:39 +00:00
const ret = mod . call ( { } ) ;
2023-01-02 14:28:25 +00:00
try testing . expectEqual ( 20 , ret . getValue ( i32 ) ) ;
}
const d_split_interleaved = try zml . testing . compileAndCall ( platform , Fns . testSplitInterleaved , . { } ) ;
try testing . expectEqual ( 20 , d_split_interleaved . getValue ( i32 ) ) ;
}
2023-09-27 11:45:33 +00:00
test rope {
2023-01-27 14:35:11 +00:00
const platform = zml . testing . env ( ) ;
2023-01-02 14:28:25 +00:00
2023-09-27 11:45:33 +00:00
const Local = struct {
fn _fwd ( x : Tensor , opts : RopeOpts ) Tensor {
2023-01-02 14:28:25 +00:00
var input = x ;
{
// Convert input to the requested format
const real , const imag = splitRealImg ( input , . sequential ) ;
2024-10-07 12:53:03 +00:00
input = mergeRealImg ( real , imag , opts . layout ) ;
2023-01-02 14:28:25 +00:00
}
2023-09-27 11:45:33 +00:00
var res = rope ( input , null , opts ) . squeeze ( 0 ) ;
2023-01-02 14:28:25 +00:00
{
// Convert back to sequential
2024-10-07 12:53:03 +00:00
const real , const imag = splitRealImg ( res , opts . layout ) ;
2023-01-02 14:28:25 +00:00
res = mergeRealImg ( real , imag , . sequential ) ;
}
return res ;
}
} ;
// x is made such as the interleaved and sequential reps are the same.
// So the two implementations should give the same results.
2023-09-27 11:45:33 +00:00
const x = try zml . Buffer . fromSlice ( platform , . { . b = 1 , . s = 5 , . hd = 4 } , & [ _ ] f32 { 1.0 , 0.1 , - 1.0 , - 0.5 } * * 5 ) ;
2024-10-07 12:53:03 +00:00
const res1 = try zml . testing . compileAndCall ( platform , Local . _fwd , . { x , RopeOpts { . layout = . interleaved } } ) ;
const res2 = try zml . testing . compileAndCall ( platform , Local . _fwd , . { x , RopeOpts { . layout = . sequential } } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectClose ( res1 , res2 , 1e-4 ) ;
}
/// In neural network we generally care about the relative precision,
/// but on a given dimension, if the output is close to 0, then the precision
/// don't matter as much.
fn approxEq ( comptime Float : type , l : Float , r : Float , tolerance : Float ) bool {
const closeRel = std . math . approxEqRel ( Float , l , r , @floatCast ( tolerance ) ) ;
const closeAbs = std . math . approxEqAbs ( Float , l , r , @floatCast ( tolerance / 2 ) ) ;
return closeRel or closeAbs ;
}
pub const UpsampleMode = enum {
nearest ,
// TODO: Linear,
// TODO: Bilinear,
// TODO: Bicubic,
// TODO: Trilinear,
} ;
/// Upsample
pub fn upsample (
input : Tensor ,
opts : struct { mode : UpsampleMode , scale_factor : [ ] const f64 } ,
) Tensor {
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( 3 < = input . rank ( ) and input . rank ( ) < = 5 , " upsample is only implemented for (3,4,5)-D tensors, received {f} " , . { input } ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( opts . scale_factor . len = = 1 or opts . scale_factor . len = = input . rank ( ) - 2 , " scale factors " , . { } ) ;
2023-01-02 14:28:25 +00:00
return switch ( opts . mode ) {
. nearest = > {
var scale_factors : [ 3 ] f64 = undefined ;
switch ( opts . scale_factor . len ) {
1 = > {
for ( 0 . . input . rank ( ) - 2 ) | i | scale_factors [ i ] = opts . scale_factor [ 0 ] ;
} ,
else = > @memcpy ( scale_factors [ 0 . . opts . scale_factor . len ] , opts . scale_factor ) ,
}
return nearest ( input , scale_factors [ 0 . . input . rank ( ) - 2 ] ) ;
} ,
} ;
}
pub fn nearest ( input : Tensor , scale_factor : [ ] const f64 ) Tensor {
var out_shape = input . shape ( ) ;
for ( scale_factor , 0 . . ) | sf , i | {
out_shape . _dims . set ( i + 2 , @intFromFloat ( @floor ( @as ( f64 , @floatFromInt ( out_shape . dim ( i + 2 ) ) ) * sf ) ) ) ;
}
// TODO(james): remove this implicit two batching dims
var sd : [ 3 ] usize = undefined ;
var len_sd : usize = 0 ;
for ( 2 . . input . rank ( ) ) | i | {
if ( input . dim ( i ) ! = out_shape . dim ( i ) ) {
sd [ len_sd ] = i ;
len_sd + = 1 ;
}
}
const spatial_dims = sd [ 0 . . len_sd ] ;
var res = input ;
for ( spatial_dims ) | d | {
const n = out_shape . dim ( d ) ;
2023-06-29 10:26:54 +00:00
const ratio = stdx . math . divFloat ( f32 , input . dim ( d ) , n ) ;
2023-01-02 14:28:25 +00:00
const offsets = Tensor . arange ( . { . end = n } , . f32 ) . addConstant ( 0.5 ) . scale ( ratio ) . floor ( ) . convert ( . i32 ) ;
2023-01-18 12:03:48 +00:00
res = res . gatherValues ( d , offsets , . { . indices_are_sorted = true } ) ;
2023-01-02 14:28:25 +00:00
}
return res ;
}
test nearest {
const platform = zml . testing . env ( ) ;
// 3D Tensor (basic)
{
const input_3d_basic = try zml . Buffer . fromArray ( platform , [ 1 ] [ 1 ] [ 2 ] i32 { . { . { 1 , 2 } } } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_3d_basic , . { . scale_factor = & . { 3 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 6 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 6 ] i32 = . { . { . { 1 , 1 , 1 , 2 , 2 , 2 } } } ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 3D Tensor (advanced)
{
const input_3d_advanced = try zml . Buffer . fromArray ( platform , [ 2 ] [ 3 ] [ 4 ] i32 {
. { . { 1 , 2 , 3 , 4 } , . { 5 , 6 , 7 , 8 } , . { 9 , 10 , 11 , 12 } } ,
. { . { 13 , 14 , 15 , 16 } , . { 17 , 18 , 19 , 20 } , . { 21 , 22 , 23 , 24 } } ,
} ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_3d_advanced , . { . scale_factor = & . { 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 2 , 3 , 8 } , result . dims ( ) ) ;
const expected : [ 2 ] [ 3 ] [ 8 ] i32 = . {
. {
. { 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 } ,
. { 5 , 5 , 6 , 6 , 7 , 7 , 8 , 8 } ,
. { 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 } ,
} ,
. {
. { 13 , 13 , 14 , 14 , 15 , 15 , 16 , 16 } ,
. { 17 , 17 , 18 , 18 , 19 , 19 , 20 , 20 } ,
. { 21 , 21 , 22 , 22 , 23 , 23 , 24 , 24 } ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 4D Tensor (basic)
{
const input_4d_basic = try zml . Buffer . fromSlice ( platform , . { 1 , 1 , 2 , 2 } , & [ _ ] i32 { 1 , 2 , 3 , 4 } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_4d_basic , . { . scale_factor = & . { 3 , 3 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 6 , 6 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 6 ] [ 6 ] i32 = . { . { . {
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 1 , 1 , 1 , 2 , 2 , 2 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
. { 3 , 3 , 3 , 4 , 4 , 4 } ,
} } } ;
try std . testing . expectEqual ( expected , result . getValue ( [ 1 ] [ 1 ] [ 6 ] [ 6 ] i32 ) ) ;
}
// 4D Tensor (advanced)
{
const input_4d_advanced = try zml . Buffer . fromArray ( platform , [ 2 ] [ 2 ] [ 2 ] [ 2 ] i32 { . {
. { . { 1 , 2 } , . { 3 , 4 } } ,
. { . { 5 , 6 } , . { 7 , 8 } } ,
} , . {
. { . { 9 , 10 } , . { 11 , 12 } } ,
. { . { 13 , 14 } , . { 15 , 16 } } ,
} } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_4d_advanced , . { . scale_factor = & . { 2 , 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 2 , 2 , 4 , 4 } , result . dims ( ) ) ;
const expected : [ 2 ] [ 2 ] [ 4 ] [ 4 ] i32 = . {
. {
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
. {
. { 5 , 5 , 6 , 6 } ,
. { 5 , 5 , 6 , 6 } ,
. { 7 , 7 , 8 , 8 } ,
. { 7 , 7 , 8 , 8 } ,
} ,
} ,
. {
. {
. { 9 , 9 , 10 , 10 } ,
. { 9 , 9 , 10 , 10 } ,
. { 11 , 11 , 12 , 12 } ,
. { 11 , 11 , 12 , 12 } ,
} ,
. {
. { 13 , 13 , 14 , 14 } ,
. { 13 , 13 , 14 , 14 } ,
. { 15 , 15 , 16 , 16 } ,
. { 15 , 15 , 16 , 16 } ,
} ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
// 5D Tensor (basic)
{
const input_5d = try zml . Buffer . fromSlice ( platform , . { 1 , 1 , 1 , 2 , 2 } , & [ _ ] i32 { 1 , 2 , 3 , 4 } ) ;
const result = try zml . testing . compileAndCall ( platform , upsample , . { input_5d , . { . scale_factor = & . { 2 } , . mode = . nearest } } ) ;
try std . testing . expectEqualSlices ( i64 , & . { 1 , 1 , 2 , 4 , 4 } , result . dims ( ) ) ;
const expected : [ 1 ] [ 1 ] [ 2 ] [ 4 ] [ 4 ] i32 = . {
. {
. {
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
. {
. { 1 , 1 , 2 , 2 } ,
. { 1 , 1 , 2 , 2 } ,
. { 3 , 3 , 4 , 4 } ,
. { 3 , 3 , 4 , 4 } ,
} ,
} ,
} ,
} ;
try zml . testing . expectClose ( zml . HostBuffer . fromArray ( & expected ) , result , 0 ) ;
}
}
2023-01-27 14:35:11 +00:00
pub const ResizeOpts = struct {
/// scalar tensor containing the original dimension of the image.
/// It can be different from the image shape,
/// if the image has been padded.
/// This allows to compile one module that handle different input image sizes.
original_len : ? Tensor = null ,
/// Internal precision to do the interpolation.
/// Result will always use the same dtype than the original.
/// If not set, will use the image dtype, unless it's an integer type, in which case f32 will be used.
precision : ? zml . DataType = null ,
} ;
2023-01-02 14:28:25 +00:00
2023-01-27 14:35:11 +00:00
pub fn resizeBilinear ( image : Tensor , resized_axes : anytype , opt : ResizeOpts ) Tensor {
const new_size , const tags_ = Shape . parseStruct ( u63 , resized_axes ) ;
2023-01-02 14:28:25 +00:00
var out = image ;
2023-01-27 14:35:11 +00:00
for ( new_size . constSlice ( ) , tags_ . constSlice ( ) ) | d , t | {
const ax = image . shape ( ) . axis ( t ) ;
2023-01-02 14:28:25 +00:00
const child_opt : ResizeOpts = . {
2023-01-27 14:35:11 +00:00
. original_len = if ( opt . original_len ) | o | o . choose1d ( 0 , ax ) else null ,
2023-01-02 14:28:25 +00:00
} ;
2023-01-27 14:35:11 +00:00
out = resizeLinear1d ( out , ax , d , child_opt ) ;
2023-01-02 14:28:25 +00:00
}
return out ;
}
2023-01-27 14:35:11 +00:00
test resizeBilinear {
const platform = zml . testing . env ( ) ;
// Only test shapes
var comp = try zml . module . CompilationContext . init ( std . heap . page_allocator , " test " , platform ) ;
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 , . b = 10 } , . { . a = 20 } , . { . a = 20 , . b = 10 } } ,
. { . { . a = 10 , . b = 10 } , . { . b = 5 } , . { . a = 10 , . b = 5 } } ,
. { . { . a = 10 , . b = 10 } , . { . a = 20 , . b = 5 } , . { . a = 20 , . b = 5 } } ,
} ) | testcase | {
const x_shape , const resizing , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const y = resizeBilinear ( x , resizing , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
2023-01-02 14:28:25 +00:00
pub fn resizeLinear1d ( image : Tensor , axis : i8 , new_len : u63 , opt : ResizeOpts ) Tensor {
2023-01-27 14:35:11 +00:00
const res_shape = image . shape ( ) . set ( axis , new_len ) ;
const dtype = opt . precision orelse if ( image . dtype ( ) . class ( ) = = . integer ) . f32 else image . dtype ( ) ;
const og_len = opt . original_len orelse Tensor . scalar ( image . dim ( axis ) , dtype ) ;
2023-06-29 10:26:54 +00:00
const ratio = og_len . convert ( dtype ) . scale ( stdx . math . divFloat ( f32 , 1 , new_len ) ) ;
2023-01-27 14:35:11 +00:00
const scaled = Tensor . arange ( . { . end = new_len } , dtype ) . mul ( ratio ) ;
2023-01-02 14:28:25 +00:00
const left = scaled . floor ( ) ;
const right = left . addConstant ( 1 ) ;
2023-01-27 14:35:11 +00:00
// TODO: check that two gather isn't too bad perf wise.
// Normally we should use gatherSlices to collect the values 2 by 2,
// but gatherSlices messes up with the order of axes.
const left_val = image . gatherValues ( axis , left . convert ( . i32 ) , . { . indices_are_sorted = true } ) . convert ( dtype ) ;
const right_val = image . gatherValues ( axis , right . convert ( . i32 ) , . { . indices_are_sorted = true } ) . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
2023-01-27 14:35:11 +00:00
const left_weight = right . sub ( scaled ) . broadcast ( res_shape , & . { axis } ) ;
const right_weight = scaled . sub ( left ) . broadcast ( res_shape , & . { axis } ) ;
const res = left_val . mul ( left_weight ) . add ( right_val . mul ( right_weight ) ) ;
return res . convert ( image . dtype ( ) ) . withTags ( image . shape ( ) . tags ( ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Bicubic interpolation of the given image.
/// Warning as of May 2024 the cpu backend don't optimize this very well
/// and is not able to merge the weighting with the gather,
/// leading to 20x slow down compared to STB implementation.
2023-01-27 14:35:11 +00:00
pub fn resizeBicubic ( image : Tensor , resized_axes : anytype , opt : ResizeOpts ) Tensor {
const new_size , const tags_ = Shape . parseStruct ( u63 , resized_axes ) ;
2023-01-02 14:28:25 +00:00
var out = image ;
2023-01-27 14:35:11 +00:00
for ( new_size . constSlice ( ) , tags_ . constSlice ( ) ) | d , t | {
const ax = image . shape ( ) . axis ( t ) ;
2023-01-02 14:28:25 +00:00
const child_opt : ResizeOpts = . {
2023-01-27 14:35:11 +00:00
. original_len = if ( opt . original_len ) | o | o . choose1d ( 0 , ax ) else null ,
2023-01-02 14:28:25 +00:00
} ;
2023-01-27 14:35:11 +00:00
out = resizeCubic1d ( out , ax , d , child_opt ) ;
2023-01-02 14:28:25 +00:00
}
return out ;
}
2023-01-27 14:35:11 +00:00
test resizeBicubic {
const platform = zml . testing . env ( ) ;
// Only test shapes
var comp = try zml . module . CompilationContext . init ( std . heap . page_allocator , " test " , platform ) ;
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
inline for ( . {
. { . { . a = 10 , . b = 10 } , . { . a = 20 } , . { . a = 20 , . b = 10 } } ,
. { . { . a = 10 , . b = 10 } , . { . b = 5 } , . { . a = 10 , . b = 5 } } ,
. { . { . a = 10 , . b = 10 } , . { . a = 20 , . b = 5 } , . { . a = 20 , . b = 5 } } ,
} ) | testcase | {
const x_shape , const resizing , const res_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f16 = 0 } ) ;
const y = resizeBicubic ( x , resizing , . { } ) ;
try zml . testing . expectEqualShapes ( Shape . init ( res_shape , . f16 ) , y . shape ( ) ) ;
try std . testing . expect ( y . value ( ) . owner ( ) . verify ( ) ) ;
}
}
2023-01-02 14:28:25 +00:00
pub fn resizeCubic1d ( image : Tensor , axis : i8 , new_len : u63 , opt : ResizeOpts ) Tensor {
// Extract neighboring pixels from the image.
2023-01-27 14:35:11 +00:00
const dtype = opt . precision orelse if ( image . dtype ( ) . class ( ) = = . integer ) . f32 else image . dtype ( ) ;
const og_len = opt . original_len orelse Tensor . scalar ( image . dim ( axis ) , dtype ) ;
2023-06-29 10:26:54 +00:00
const ratio = og_len . convert ( dtype ) . scale ( stdx . math . divFloat ( f32 , 1 , new_len ) ) ;
2023-01-27 14:35:11 +00:00
const scaled = Tensor . arange ( . { . end = new_len } , dtype ) . mul ( ratio ) ;
2023-01-02 14:28:25 +00:00
const t = scaled . sub ( scaled . floor ( ) ) ;
const pos = Tensor . stack ( & . {
2023-01-27 14:35:11 +00:00
Tensor . constant ( t . shape ( ) , dtype . one ( ) ) ,
2023-01-02 14:28:25 +00:00
t ,
t . mul ( t ) ,
2023-01-27 14:35:11 +00:00
t . pow ( Tensor . scalar ( 3 , dtype ) ) ,
} , . last , . _interpolated ) ;
2023-01-02 14:28:25 +00:00
std . debug . assert ( pos . dim ( 0 ) = = new_len ) ;
std . debug . assert ( pos . dim ( 1 ) = = 4 ) ;
2023-01-27 14:35:11 +00:00
const neighbors = scaled . floor ( ) . addConstant ( - 1 ) . convert ( . i32 ) . maximum ( Tensor . scalar ( 0 , . i32 ) ) ;
const values = image . renameAxis ( axis , . _neighbors ) . gatherSlices (
Shape . init ( . { . _neighbors = 4 } , image . dtype ( ) ) ,
neighbors . appendAxes ( . { . coord } ) ,
. { . indices_are_sorted = true } ,
) . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
const weights_ : [ 4 ] [ 4 ] f32 = . {
. { 0 , 1 , 0 , 0 } ,
. { - 0.5 , 0 , 0.5 , 0 } ,
. { 1 , - 2.5 , 2 , - 0.5 } ,
. { - 0.5 , 1.5 , - 1.5 , 0.5 } ,
} ;
2023-01-27 14:35:11 +00:00
const weights = zml . Tensor . constantTensor ( zml . HostBuffer . fromArray ( & weights_ ) ) . convert ( dtype ) . withTags ( . { . _interpolated , . _neighbors } ) ;
2023-01-02 14:28:25 +00:00
// actually do the interpolation.
// Note: ideally this matmul should be inlined with the gather, but that's currently not the case.
2023-01-27 14:35:11 +00:00
// TODO: not being able to use dot here is a bit annoying.
var res = values . dotGeneral ( weights , & . { . { values . axis ( . _neighbors ) , weights . axis ( . _neighbors ) } } , & . { } ) ;
res = pos . dotGeneral ( res , & . { . { pos . axis ( . _interpolated ) , res . axis ( . _interpolated ) } } , & . { . { 0 , 0 } } ) ;
2023-01-02 14:28:25 +00:00
// the current axis is outputted in first position because it's a batching dim, put it back in place.
2023-01-27 14:35:11 +00:00
if ( axis ! = 0 ) {
res = res . swapAxes ( 0 , axis ) ;
}
2023-01-02 14:28:25 +00:00
// verify the shape
const res_shape = image . shape ( ) . set ( axis , new_len ) ;
// log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res });
std . debug . assert ( std . mem . eql ( i64 , res_shape . dims ( ) , res . dims ( ) ) ) ;
2023-01-27 14:35:11 +00:00
return res . convert ( image . dtype ( ) ) . withTags ( image . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
}
/// Return causal attention masks for the given shape.
/// The last dimensions are
pub fn causalAttnMask (
attn_shape_ : anytype ,
dtype : DataType ,
attn_window_len : ? u32 ,
) Tensor {
const attn_shape = Shape . init ( attn_shape_ , dtype ) ;
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( attn_shape . rank ( ) = = 2 , " causalAttnMask({f}) shape need to be exactly 2 axes " , . { attn_shape } ) ;
2023-01-02 14:28:25 +00:00
const qlen = attn_shape . dim ( - 2 ) ;
2023-06-21 14:45:14 +00:00
const q_idx = Tensor . iota ( attn_shape , - 2 ) ;
2023-01-02 14:28:25 +00:00
const klen = attn_shape . dim ( - 1 ) ;
2023-06-21 14:45:14 +00:00
const k_idx = Tensor . iota ( attn_shape , - 1 ) ;
2023-01-02 14:28:25 +00:00
// all elements > main diagonal must be 0
// (q_idx - window_len < k_idx <= q_idx)
var mask = k_idx . cmp ( . LE , q_idx ) ;
if ( attn_window_len ) | window_len | {
if ( qlen > = window_len or klen > = window_len ) {
const window_mask = q_idx . cmp ( . LT , k_idx . addConstant ( window_len ) ) ;
mask = mask . logical ( . AND , window_mask ) ;
}
}
2023-02-16 10:36:23 +00:00
2023-01-02 14:28:25 +00:00
if ( dtype . isFloat ( ) ) {
2023-02-16 10:36:23 +00:00
const zeros = Tensor . constant ( mask . shape ( ) , dtype . zero ( ) ) ;
const minus_inf = Tensor . constant ( mask . shape ( ) , dtype . minValue ( ) ) ;
mask = Tensor . select ( mask , zeros , minus_inf ) ;
} else {
mask = mask . convert ( dtype ) ;
2023-01-02 14:28:25 +00:00
}
2023-02-16 10:36:23 +00:00
2023-01-02 14:28:25 +00:00
return mask ;
}
pub const SdpaOpts = struct {
attn_mask : ? Tensor = null ,
scale : ? Tensor = null ,
allow_cudnn : bool = true ,
// TODO: put a callback instead of all this field,
// so that
} ;
/// Scaled dot product attention.
///
/// **Shapes**:
/// - q, result: .{ .h, .q, .hd }
/// - k, v: .{ .h, .k, .hd }
///
/// Where:
/// - .h is the number of head
/// - .q is the number of queries
/// - .k is the number of keys
/// - .hd is the head dimension
///
/// .h is allowed to differ from queries and keys as long as the key heads
/// can be repeated to match query heads.
pub fn sdpa ( q_ : Tensor , k_ : Tensor , v_ : Tensor , opts : SdpaOpts ) Tensor {
var q , var k , var v = . { q_ , k_ , v_ } ;
2025-07-28 13:54:28 +00:00
const err_template = " sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! " ;
2023-01-02 14:28:25 +00:00
const err_args = . { q , k , v , opts . attn_mask } ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( q . shape ( ) . hasTags ( . { . h , . q , . hd } ) , err_template + + " q is missing tags {{.h, .q, .hd}} " , err_args ) ;
stdx . debug . assert ( k . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " k is missing tags {{.h, .k, .hd}} " , err_args ) ;
stdx . debug . assert ( v . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " v is missing tags {{.h, .k, .hd}} " , err_args ) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
if ( opts . allow_cudnn and cuda . canUseCudnnSdpa ( q . shape ( ) ) ) {
2023-01-02 14:28:25 +00:00
return cuda . sdpa ( q , k , v , opts ) ;
}
if ( q . dim ( . h ) ! = k . dim ( . h ) ) {
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( @mod ( q . dim ( . h ) , k . dim ( . h ) ) = = 0 , err_template + + " Different number of heads for keys and queries, but can't repeat keys. " , err_args ) ;
2023-01-02 14:28:25 +00:00
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep : u63 = @intCast ( @divExact ( q . dim ( . h ) , k . dim ( . h ) ) ) ;
k , v = . { k . repeat1d ( . h , num_rep ) , v . repeat1d ( . h , num_rep ) } ;
}
const attn_mask = if ( opts . attn_mask ) | m | m else null ;
const dims = helpers . collectDims ( . { . h , . q , . k , . hd } , & . { q , k , v , attn_mask } , . strict ) catch {
2023-06-21 14:45:14 +00:00
stdx . debug . panic ( err_template + + " Inputs have incompatible shapes. " , err_args ) ;
2023-01-02 14:28:25 +00:00
} ;
2023-04-21 15:55:07 +00:00
const sqrtHeadDim : f32 = 1.0 / std . math . sqrt ( @as ( f32 , @floatFromInt ( dims . hd ) ) ) ;
2023-08-14 14:24:11 +00:00
const head_scaling = if ( opts . scale ) | s | s else Tensor . scalar ( sqrtHeadDim , k . dtype ( ) ) ;
k = k . mul ( head_scaling . convert ( k . dtype ( ) ) ) ;
2023-01-02 14:28:25 +00:00
var attn_weights = q . dot ( k , . { . hd } ) ;
2024-06-14 15:27:06 +00:00
// log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
2023-09-21 11:15:50 +00:00
if ( attn_mask ) | mask | attn_weights = attn_weights . add ( mask . broad ( attn_weights . shape ( ) ) ) ;
2024-01-01 15:31:41 +00:00
attn_weights = attn_weights . convert ( . f32 ) . softmax ( . k ) . convert ( q . dtype ( ) ) ;
2023-01-02 14:28:25 +00:00
var attn = attn_weights . dot ( v , . { . k } ) ;
return attn . transpose ( q . shape ( ) ) ;
}
2023-08-14 14:24:11 +00:00
pub const SdpaChunks = struct { q_chunk_size : u32 , k_chunk_size : u32 } ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
pub fn sdpaMemEfficient (
q : Tensor ,
k : Tensor ,
v : Tensor ,
sdpa_opts : SdpaOpts ,
chunking : SdpaChunks ,
) Tensor {
const sdpa_mem_efficient : SdpaMemEfficient = . {
. q = q ,
. k = k ,
. v = v ,
. sdpa_opts = sdpa_opts ,
. chunking = . {
. q_chunk_size = @intCast ( @min ( q . dim ( . q ) , chunking . q_chunk_size ) ) ,
. k_chunk_size = @intCast ( @min ( k . dim ( . k ) , chunking . k_chunk_size ) ) ,
} ,
} ;
return sdpa_mem_efficient . forward ( ) ;
2023-01-02 14:28:25 +00:00
}
const SdpaMemEfficient = struct {
q : Tensor ,
k : Tensor ,
v : Tensor ,
2023-08-14 14:24:11 +00:00
sdpa_opts : SdpaOpts ,
chunking : SdpaChunks ,
2023-01-02 14:28:25 +00:00
fn forward ( self : SdpaMemEfficient ) Tensor {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( @mod ( self . q . dim ( . q ) , self . chunking . q_chunk_size ) = = 0 , " sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {}) " , . { self . q , self . chunking } ) ;
stdx . debug . assert ( @mod ( self . k . dim ( . k ) , self . chunking . k_chunk_size ) = = 0 , " sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {}) " , . { self . k , self . chunking } ) ;
2023-08-14 14:24:11 +00:00
const n_q_chunks : u32 = @intCast ( @divExact ( self . q . dim ( . q ) , self . chunking . q_chunk_size ) ) ;
const ctx = zml . module . CompilationContext . current ( ) ;
2023-11-16 15:11:23 +00:00
const q_chunks = ctx . allocator ( ) . alloc ( zml . Tensor , n_q_chunks ) catch unreachable ;
defer ctx . allocator ( ) . free ( q_chunks ) ;
2023-08-14 14:24:11 +00:00
for ( 0 . . n_q_chunks ) | i | {
const idx : u32 = @intCast ( i ) ;
const q_slice : zml . Tensor . DynSlice = . {
. start = Tensor . scalar ( idx * self . chunking . q_chunk_size , . i32 ) ,
. len = self . chunking . q_chunk_size ,
} ;
const q_chunk = self . q . dynamicSlice ( . { . q = q_slice } ) ;
const attn_chunk = if ( self . sdpa_opts . attn_mask ) | attn_mask | attn_mask . dynamicSlice ( . { . q = q_slice } ) else null ;
var chunk : SdpaMemEfficient = self ;
chunk . q = q_chunk ;
chunk . sdpa_opts . attn_mask = attn_chunk ;
q_chunks [ i ] = chunk . scanKeyVal ( ) ;
}
const res = zml . Tensor . concatenate ( q_chunks , . q ) ;
return res . transpose ( self . q . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
}
fn nextQueriesChunk ( self : SdpaMemEfficient , idx : Tensor ) Tensor {
2023-08-14 14:24:11 +00:00
const q_slice : zml . Tensor . DynSlice = . {
. start = idx . scale ( self . chunking . q_chunk_size ) ,
. len = self . chunking . q_chunk_size ,
} ;
const q_chunk = self . q . dynamicSlice ( . { . q = q_slice } ) ;
const attn_chunk = if ( self . sdpa_opts . attn_mask ) | attn_mask | attn_mask . dynamicSlice ( . { . q = q_slice } ) else null ;
2023-01-02 14:28:25 +00:00
var chunk : SdpaMemEfficient = self ;
chunk . q = q_chunk ;
2023-08-14 14:24:11 +00:00
chunk . sdpa_opts . attn_mask = attn_chunk ;
2023-01-02 14:28:25 +00:00
return chunk . scanKeyVal ( ) ;
}
fn scanKeyVal ( self : SdpaMemEfficient ) Tensor {
2023-08-14 14:24:11 +00:00
const n_chunks = @divExact ( self . k . dim ( . k ) , self . chunking . k_chunk_size ) ;
return if ( n_chunks < = 4 ) {
// Unrolled version
var partial_softmax : ? PartialSoftmax = null ;
for ( 0 . . @intCast ( n_chunks ) ) | idx | {
const next = self . nextKeyValChunk ( Tensor . scalar ( idx , . i32 ) ) ;
partial_softmax = if ( partial_softmax ) | prev | prev . merge ( next ) else next ;
}
return partial_softmax . ? . finalize ( ) ;
} else {
// stablehlo.while version
const partial_softmax , _ = zml . ops . while_ ( hasNextKeyValChunk , nextKeyValChunkMerge , self , . { PartialSoftmax . zeros ( self . q . shape ( ) , . f32 ) , Tensor . scalar ( 0 , . i32 ) } ) ;
return partial_softmax . finalize ( ) ;
} ;
}
fn nextKeyValChunkMerge ( self : SdpaMemEfficient , prev : PartialSoftmax , idx : Tensor ) struct { PartialSoftmax , Tensor } {
const next = self . nextKeyValChunk ( idx ) ;
return . { prev . merge ( next ) , idx . addConstant ( 1 ) } ;
2023-01-02 14:28:25 +00:00
}
2023-08-14 14:24:11 +00:00
fn nextKeyValChunk ( self : SdpaMemEfficient , idx : Tensor ) PartialSoftmax {
const k_slice : zml . Tensor . DynSlice = . {
. start = idx . scale ( self . chunking . k_chunk_size ) ,
. len = self . chunking . k_chunk_size ,
} ;
const k_chunk = self . k . dynamicSlice ( . { . k = k_slice } ) ;
const v_chunk = self . v . dynamicSlice ( . { . k = k_slice } ) ;
const attn_chunk = if ( self . sdpa_opts . attn_mask ) | mask | mask . dynamicSlice ( . { . k = k_slice } ) else null ;
2023-01-02 14:28:25 +00:00
return sdpaChunk ( self . q , k_chunk , v_chunk , . { . attn_mask = attn_chunk } ) ;
}
2023-08-14 14:24:11 +00:00
pub fn hasNextKeyValChunk ( self : SdpaMemEfficient , _ : PartialSoftmax , idx : Tensor ) zml . Tensor {
const n_chunks = @divExact ( self . k . dim ( . k ) , self . chunking . k_chunk_size ) ;
return idx . cmp ( . LT , Tensor . scalar ( n_chunks , idx . dtype ( ) ) ) ;
}
2023-01-02 14:28:25 +00:00
} ;
2023-08-14 14:24:11 +00:00
pub const PartialSoftmax = struct {
values : Tensor ,
2023-01-02 14:28:25 +00:00
exp_sum : Tensor ,
max_value : Tensor ,
2023-08-14 14:24:11 +00:00
pub fn zeros ( q_shape : Shape , exp_sum_precision : DataType ) PartialSoftmax {
return . {
. values = Tensor . constant ( q_shape , q_shape . dtype ( ) . zero ( ) ) ,
. exp_sum = Tensor . constant ( q_shape . setDim ( . hd , 1 ) , exp_sum_precision . zero ( ) ) ,
. max_value = Tensor . constant ( q_shape . setDim ( . hd , 1 ) , q_shape . dtype ( ) . minValue ( ) ) ,
} ;
}
pub fn merge ( self : PartialSoftmax , other : PartialSoftmax ) PartialSoftmax {
// Rescale self and other using the new global_max.
const global_max = self . max_value . maximum ( other . max_value ) ;
const new_self = self . rescale ( global_max ) ;
const new_other = other . rescale ( global_max ) ;
// Now that self and other are using the same scale, we can just add them:
return . {
. max_value = global_max ,
. values = new_self . values . add ( new_other . values ) ,
. exp_sum = new_self . exp_sum . add ( new_other . exp_sum ) ,
} ;
}
/// Update max_value and rescale attn and exp_sum accordingly.
pub fn rescale ( self : PartialSoftmax , max_value : Tensor ) PartialSoftmax {
const max_diff_exp = self . max_value . sub ( max_value ) . exp ( ) ;
const sum_dtype = self . exp_sum . dtype ( ) ;
return . {
. max_value = max_value ,
. values = self . values . mul ( max_diff_exp . broad ( self . values . shape ( ) ) ) ,
. exp_sum = self . exp_sum . mul ( max_diff_exp . convert ( sum_dtype ) ) ,
} ;
}
/// Divides the intermediary results by the exp_sum to get the proper attention values.
pub fn finalize ( self : PartialSoftmax ) Tensor {
return self . values . div ( self . exp_sum . broad ( self . values . shape ( ) ) . convert ( self . values . dtype ( ) ) ) ;
}
2023-01-02 14:28:25 +00:00
} ;
/// Compute softmax over a chunk.
/// Returns intermediary results to allow aggregating later.
2023-08-14 14:24:11 +00:00
pub fn partialSoftmax ( self : Tensor , axis : anytype ) PartialSoftmax {
2023-01-02 14:28:25 +00:00
const a = self . axis ( axis ) ;
2024-05-02 17:10:11 +00:00
const max_val = self . max ( a ) . maximum ( Tensor . scalar ( - 1e16 , self . dtype ( ) ) ) ;
2023-01-02 14:28:25 +00:00
const out = self . sub ( max_val . broad ( self . shape ( ) ) ) . exp ( ) ;
return . {
2023-08-14 14:24:11 +00:00
. values = out ,
. exp_sum = out . convert ( . f32 ) . sum ( a ) ,
. max_value = max_val ,
2023-01-02 14:28:25 +00:00
} ;
}
/// Compute sdpa on a chunk, and computes a partial softmax.
/// q: (B, H, Sq, H_dim) ⊙ k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk)
2023-08-14 14:24:11 +00:00
pub fn sdpaChunk ( q_ : Tensor , k_ : Tensor , v_ : Tensor , opts : SdpaOpts ) PartialSoftmax {
// this is a dupe of sdpa, but return the PartialSoftmax instead of true Attn.
// Consider implementing sdpa from sdpaChunk.
var q , var k , var v = . { q_ , k_ , v_ } ;
2023-01-02 14:28:25 +00:00
2025-07-28 13:54:28 +00:00
const err_template = " sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! " ;
2023-08-14 14:24:11 +00:00
const err_args = . { q , k , v , opts . attn_mask } ;
stdx . debug . assert ( q . shape ( ) . hasTags ( . { . h , . q , . hd } ) , err_template + + " q is missing tags {{.h, .q, .hd}} " , err_args ) ;
stdx . debug . assert ( k . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " k is missing tags {{.h, .k, .hd}} " , err_args ) ;
stdx . debug . assert ( v . shape ( ) . hasTags ( . { . h , . k , . hd } ) , err_template + + " v is missing tags {{.h, .k, .hd}} " , err_args ) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
if ( q . dim ( . h ) ! = k . dim ( . h ) ) {
stdx . debug . assert ( @mod ( q . dim ( . h ) , k . dim ( . h ) ) = = 0 , err_template + + " Different number of heads for keys and queries, but can't repeat keys. " , err_args ) ;
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep : u63 = @intCast ( @divExact ( q . dim ( . h ) , k . dim ( . h ) ) ) ;
k , v = . { k . repeat1d ( . h , num_rep ) , v . repeat1d ( . h , num_rep ) } ;
}
const attn_mask = if ( opts . attn_mask ) | m | m else null ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
const dims = helpers . collectDims ( . { . h , . q , . k , . hd } , & . { q , k , v , attn_mask } , . strict ) catch {
stdx . debug . panic ( err_template + + " Inputs have incompatible shapes. " , err_args ) ;
} ;
const sqrtHeadDim : f32 = 1.0 / std . math . sqrt ( @as ( f32 , @floatFromInt ( dims . hd ) ) ) ;
const head_scaling = if ( opts . scale ) | s | s else Tensor . scalar ( sqrtHeadDim , k . dtype ( ) ) ;
k = k . mul ( head_scaling . convert ( k . dtype ( ) ) ) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
var attn_weights = q . dot ( k , . { . hd } ) ;
// log.debug("attn_weights : {}", .{attn_weights});
// log.debug("attn_mask : {?}", .{attn_mask});
2023-09-21 11:15:50 +00:00
if ( attn_mask ) | mask | attn_weights = attn_weights . add ( mask . broad ( attn_weights . shape ( ) ) ) ;
2023-08-14 14:24:11 +00:00
const partial = partialSoftmax ( attn_weights , . k ) ;
const attn = partial . values . dot ( v , . { . k } ) . transpose ( q . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
return . {
2023-08-14 14:24:11 +00:00
. values = attn ,
// The renaming is because the above dot projected values.k into .hd,
// do the same thing on the other tensors.
// This work because dot is a linear operation, and commutes with `PartialSoftmax.finalize`
. exp_sum = partial . exp_sum . rename ( . { . k = . hd } ) . transpose ( attn . shape ( ) ) ,
. max_value = partial . max_value . rename ( . { . k = . hd } ) . transpose ( attn . shape ( ) ) ,
2023-01-02 14:28:25 +00:00
} ;
}
2023-06-16 14:34:18 +00:00
2023-08-14 14:24:11 +00:00
test sdpaMemEfficient {
2023-01-02 14:28:25 +00:00
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
2023-01-23 16:28:19 +00:00
const rng = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 1 , 10 , 512 , 64 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng . deinit ( ) ;
2023-08-14 14:24:11 +00:00
const rng_mask = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 512 , 512 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
defer rng_mask . deinit ( ) ;
2023-01-02 14:28:25 +00:00
2024-11-28 12:24:39 +00:00
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng . call ( { } ) . withTags ( . { . b , . h , . q , . hd } ) ;
const k = rng . call ( { } ) . withTags ( . { . b , . h , . k , . hd } ) ;
const v = rng . call ( { } ) . withTags ( . { . b , . h , . k , . hd } ) ;
const mask = rng_mask . call ( { } ) . withTags ( . { . q , . k } ) ;
2023-08-14 14:24:11 +00:00
const ref_res = try zml . testing . compileAndCall (
2023-01-02 14:28:25 +00:00
platform ,
2023-08-14 14:24:11 +00:00
sdpa ,
2024-01-01 15:31:41 +00:00
. { q , k , v , . { . attn_mask = mask , . scale = null } } ,
2023-01-02 14:28:25 +00:00
) ;
2023-08-14 14:24:11 +00:00
try std . testing . expectEqualSlices ( i64 , q . shape ( ) . dims ( ) , ref_res . shape ( ) . dims ( ) ) ;
{
// 4 k_chunks
const res = try zml . testing . compileAndCall (
platform ,
sdpaMemEfficient ,
. {
q ,
k ,
v ,
2024-01-01 15:31:41 +00:00
. { . attn_mask = mask , . scale = null } ,
2023-08-14 14:24:11 +00:00
. { . q_chunk_size = 256 , . k_chunk_size = @divExact ( 512 , 4 ) } ,
} ,
) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
try zml . testing . expectClose ( ref_res , res , 2e-3 ) ;
}
{
// 16 k_chunks
const res = try zml . testing . compileAndCall (
platform ,
sdpaMemEfficient ,
. {
q ,
k ,
v ,
2024-01-01 15:31:41 +00:00
. { . attn_mask = mask , . scale = null } ,
2023-08-14 14:24:11 +00:00
. { . q_chunk_size = 256 , . k_chunk_size = @divExact ( 512 , 16 ) } ,
} ,
) ;
try zml . testing . expectClose ( ref_res , res , 2e-3 ) ;
}
2023-01-02 14:28:25 +00:00
}
2023-08-14 14:24:11 +00:00
test " sdpaMemEfficient transposed " {
2023-01-02 14:28:25 +00:00
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
2023-08-14 14:24:11 +00:00
const rng = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 1 , 512 , 10 , 64 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng . deinit ( ) ;
2023-01-23 16:28:19 +00:00
const rng_mask = try zml . compileFn ( allocator , Tensor . Rng . normal , . { Shape . init ( . { 512 , 512 } , . f32 ) , . { . mean = 0 , . stddev = 1 } } , platform ) ;
2023-01-02 14:28:25 +00:00
defer rng_mask . deinit ( ) ;
2024-11-28 12:24:39 +00:00
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng . call ( { } ) . withTags ( . { . b , . q , . h , . hd } ) ;
const k = rng . call ( { } ) . withTags ( . { . b , . k , . h , . hd } ) ;
const v = rng . call ( { } ) . withTags ( . { . b , . k , . h , . hd } ) ;
const mask = rng_mask . call ( { } ) . withTags ( . { . q , . k } ) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
const ref_res = try zml . testing . compileAndCall (
platform ,
sdpa ,
2024-01-01 15:31:41 +00:00
. { q , k , v , . { . attn_mask = mask , . scale = null } } ,
2023-08-14 14:24:11 +00:00
) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqualSlices ( i64 , q . shape ( ) . dims ( ) , ref_res . shape ( ) . dims ( ) ) ;
2023-08-14 14:24:11 +00:00
{
const res = try zml . testing . compileAndCall (
platform ,
sdpaMemEfficient ,
. {
q ,
k ,
v ,
2024-01-01 15:31:41 +00:00
. { . attn_mask = mask , . scale = null } ,
2023-08-14 14:24:11 +00:00
. { . q_chunk_size = @divExact ( 512 , 2 ) , . k_chunk_size = @divExact ( 512 , 4 ) } ,
} ,
) ;
try zml . testing . expectClose ( ref_res , res , 1e-3 ) ;
}
{
const res = try zml . testing . compileAndCall (
platform ,
sdpaMemEfficient ,
. {
q ,
k ,
v ,
2024-01-01 15:31:41 +00:00
. { . attn_mask = mask , . scale = null } ,
2023-08-14 14:24:11 +00:00
. { . q_chunk_size = 512 , . k_chunk_size = @divExact ( 512 , 4 ) } ,
} ,
) ;
2023-01-02 14:28:25 +00:00
2023-08-14 14:24:11 +00:00
try zml . testing . expectClose ( ref_res , res , 1e-3 ) ;
}
2023-01-02 14:28:25 +00:00
}
/// Options controlling generation. The default values correspond to greedy decoding.
pub const SamplingStrategy = struct {
topk : u32 = 1 ,
temperature : f32 = 1.0 ,
} ;
/// Given the output of the last layer of a LM with a `.voc` axis,
/// Compute indices for the next tokens, following the given sampling strategy.
/// Returns an integer tensor with a shape similar to the input, but without the .voc axis.
pub fn sampleTokens ( activations : Tensor , opts : SamplingStrategy , rng : Tensor . Rng ) struct { Tensor , Tensor . Rng } {
if ( opts . topk < = 1 ) {
2024-01-01 15:31:41 +00:00
const next_tokens = activations . argMax ( . voc ) . indices . squeeze ( . voc ) ;
2023-01-02 14:28:25 +00:00
return . { next_tokens , rng } ;
}
const topk = activations . topK ( opts . topk , . voc , . { } ) ;
// After the topk, we don't have .voc values, anymore, only topk.
var x = topk . values . rename ( . { . voc = . topk } ) ;
if ( opts . temperature ! = 1.0 ) {
x = x . scale ( 1 / opts . temperature ) ;
}
// Gumbel reparametrization trick:
// Adding gumbel noise and taking the argmax is equivalent
// to sampling from the categorical distribution produced by the softmax.
// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
const next_rng , const gumbel_noise = rng . gumbel ( x . shape ( ) ) ;
x = x . add ( gumbel_noise ) ;
2024-01-01 15:31:41 +00:00
const topk_idx = x . argMax ( . topk ) . indices ;
2023-01-02 14:28:25 +00:00
// topk_idx is indices into topk.values ! so in the range [0, topk]
// Convert for the original indices from the full [0, voc] range.
2023-01-18 12:03:48 +00:00
const next_tokens = topk . indices . gatherValues ( . voc , topk_idx . squeeze ( . topk ) , . { } ) ;
2023-01-02 14:28:25 +00:00
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return . { next_tokens , next_rng } ;
}
2023-06-16 14:34:18 +00:00
test sampleTokens {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const inf = std . math . inf ( f32 ) ;
var rng_buff = try zml . Tensor . Rng . init ( platform , 0xdeadbeef ) ;
defer rng_buff . _state . deinit ( ) ;
const mod = try zml . compileFn ( allocator , sampleTokens , . { Shape . init ( . { . voc = 4 } , . f32 ) , . { . topk = 4 , . temperature = 2.0 } , zml . Tensor . Rng . shape ( ) } , platform ) ;
defer mod . deinit ( ) ;
inline for ( . {
. { [ _ ] f32 { inf , 3.0 , 2.0 , 1.0 } , 0 } ,
. { [ _ ] f32 { - inf , 3.0 , - inf , - inf } , 1 } ,
. { [ _ ] f32 { 3.0 , 2 , inf , inf } , 2 } ,
} ) | logits_expected | {
const logits , const expected : i32 = logits_expected ;
var logits_buff = try zml . Buffer . fromArray ( platform , logits ) ;
defer logits_buff . deinit ( ) ;
2024-11-28 12:24:39 +00:00
var sampled , rng_buff = mod . call ( . { logits_buff , rng_buff } ) ;
2023-06-16 14:34:18 +00:00
defer sampled . deinit ( ) ;
try zml . testing . expectEqual ( expected , try sampled . getValue ( i32 ) ) ;
}
}
pub const DynamicSamplingStrategy = struct {
max_top_k : u32 ,
top_k : Tensor ,
temperature : Tensor ,
top_p : Tensor ,
min_p : Tensor ,
2024-07-02 14:19:04 +00:00
pub const Opts = struct {
top_k : u32 ,
temperature : f32 = 1.0 ,
top_p : f32 = 1.0 ,
min_p : f32 = 0.0 ,
} ;
2023-06-16 14:34:18 +00:00
pub fn shapes ( dtype : DataType , max_top_k : u32 ) zml . ShapeOf ( DynamicSamplingStrategy ) {
const scalar_float = Shape . init ( . { } , dtype ) ;
const scalar_i32 = Shape . init ( . { } , . i32 ) ;
return . {
. max_top_k = max_top_k ,
. top_k = scalar_i32 ,
. temperature = scalar_float ,
. top_p = scalar_float ,
. min_p = scalar_float ,
} ;
}
pub fn makeBuffers (
platform : zml . Platform ,
dtype : zml . DataType ,
2024-07-02 14:19:04 +00:00
opts : Opts ,
2023-06-16 14:34:18 +00:00
) ! zml . Bufferized ( DynamicSamplingStrategy ) {
return . {
2024-07-02 14:19:04 +00:00
. top_k = try zml . Buffer . scalar ( platform , opts . top_k , . i32 ) ,
. temperature = try zml . Buffer . scalar ( platform , opts . temperature , dtype ) ,
. top_p = try zml . Buffer . scalar ( platform , opts . top_p , dtype ) ,
. min_p = try zml . Buffer . scalar ( platform , opts . min_p , dtype ) ,
2023-06-16 14:34:18 +00:00
} ;
}
} ;
/// Given the output of the last layer of a LM with a `.voc` axis,
/// Compute indices for the next tokens, following the given sampling strategy.
/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax.
///
/// Options are:
///
/// * top_k: only sample among the k top scoring tokens,
/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens.
/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p
/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token
pub fn sampleTokensDynamic ( logits : Tensor , opts : DynamicSamplingStrategy , rng : Tensor . Rng ) struct { Tensor , Tensor . Rng } {
var x , const topk_indices = fixupLogits ( logits , opts ) ;
// the rest is similar to sampleTokens
const next_rng , const gumbel_noise = rng . gumbel ( x . shape ( ) ) ;
x = x . add ( gumbel_noise ) ;
2024-01-01 15:31:41 +00:00
const topk_idx = x . argMax ( . topk ) . indices ;
2023-06-16 14:34:18 +00:00
const next_tokens = topk_indices . gatherValues ( . voc , topk_idx . squeeze ( . topk ) , . { } ) ;
return . { next_tokens , next_rng } ;
}
fn fixupLogits ( logits : Tensor , opts : DynamicSamplingStrategy ) [ 2 ] Tensor {
const min_inf = Tensor . constant ( . { } , logits . dtype ( ) . minValue ( ) ) ;
// First reduce the vocab size to a reasonable sub set of candidate.
const full_topk = if ( opts . max_top_k > 0 )
logits . topK ( opts . max_top_k , . voc , . { . descending = true } )
else
logits . sort ( . voc , . { . descending = true } ) ;
// After the topk, we don't have .voc indices, anymore, only topk.
var x = full_topk . values . rename ( . { . voc = . topk } ) ;
// mask values above the dynamic top_k
x = Tensor . iota ( x . shape ( ) , . topk ) . cmp ( . GE , opts . top_k ) . select ( min_inf , x ) ;
x = x . mul ( opts . temperature ) ;
// if there are high values in x, softmax can overflow and will create nans in full probs
// this propagate to probs_sum and probs_max.
const probs = x . softmax ( . topk ) ;
const probs_sum = probs . cumulativeSum ( . topk ) ;
const probs_max = probs . slice1d ( . topk , . { . start = 0 , . end = 1 } ) ;
const top_p = opts . top_p . broad ( x . shape ( ) ) ;
const min_p = probs_max . mul ( opts . min_p ) . broad ( x . shape ( ) ) ;
// * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false
// * if first candidate score is even bigger, the probs become Nan because of the softmax,
// then cmp is is full false, and candidate is full false too.
const candidate = probs_sum . cmp ( . LE , top_p ) . logical ( . AND , probs . cmp ( . GE , min_p ) ) ;
// * so we explicitly always accept first candidate.
const first_token = Tensor . iota ( x . shape ( ) , . topk ) . cmp ( . EQ , Tensor . scalar ( 0 , . i32 ) ) ;
x = candidate . logical ( . OR , first_token ) . select ( x , min_inf ) ;
return . { x , full_topk . indices } ;
}
test sampleTokensDynamic {
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const ___ = - std . math . inf ( f32 ) ;
const logits = [ _ ] f32 { @log ( 2.0 ) , @log ( 1.0 ) , @log ( 4.0 ) , @log ( 3.0 ) } ;
const top_k_indices = [ _ ] i32 { 2 , 3 , 0 , 1 } ;
const logits_buff = try zml . Buffer . fromArray ( platform , logits ) ;
const mod = try zml . compileFn ( allocator , fixupLogits , . { Shape . init ( . { . voc = logits . len } , . f32 ) , DynamicSamplingStrategy . shapes ( . f32 , 0 ) } , platform ) ;
defer mod . deinit ( ) ;
2024-07-02 14:19:04 +00:00
const Args = struct { DynamicSamplingStrategy . Opts , [ 4 ] f32 } ;
inline for ( [ _ ] Args {
2023-06-16 14:34:18 +00:00
// top_k == logits.len -> just sort the input
. { . { . top_k = 4 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , @log ( 2.0 ) , @log ( 1.0 ) } } ,
. { . { . top_k = 2 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
. { . { . top_k = 2 , . temperature = 0.1 } , [ _ ] f32 { @log ( 4.0 ) * 0.1 , @log ( 3.0 ) * 0.1 , ___ , ___ } } ,
// top_k == logits.len and small top_p -> make sure at least one is returned
. { . { . top_k = 4 , . top_p = 0.1 } , [ _ ] f32 { @log ( 4.0 ) , ___ , ___ , ___ } } ,
. { . { . top_k = 4 , . top_p = 0.701 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
. { . { . top_k = 4 , . top_p = 0.901 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , @log ( 2.0 ) , ___ } } ,
// Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0)
. { . { . top_k = 3 , . top_p = 0.701 } , [ _ ] f32 { @log ( 4.0 ) , ___ , ___ , ___ } } ,
// Here top_p allows the first 3 results, but min_p only accepts the first two.
. { . { . top_k = 4 , . top_p = 0.901 , . min_p = 0.6 } , [ _ ] f32 { @log ( 4.0 ) , @log ( 3.0 ) , ___ , ___ } } ,
} ) | args_expected | {
const args , const expected = args_expected ;
const new_logits , const indices = mod . call ( . { logits_buff , try DynamicSamplingStrategy . makeBuffers ( platform , . f32 , args ) } ) ;
try std . testing . expectEqual ( top_k_indices , try indices . getValue ( @TypeOf ( top_k_indices ) ) ) ;
try zml . testing . expectEqual ( expected , try new_logits . getValue ( @TypeOf ( expected ) ) ) ;
}
{
// Similar but use bf16, and uses infinity to trigger nans after the softmax.
const bf16 = zml . floats . BFloat16 ;
const mod_bf16 = try zml . compileFn ( allocator , fixupLogits , . { Shape . init ( . { . voc = logits . len } , . bf16 ) , DynamicSamplingStrategy . shapes ( . bf16 , 0 ) } , platform ) ;
defer mod_bf16 . deinit ( ) ;
2024-07-23 17:43:43 +00:00
const boost = bf16 . inf ;
const nerf = bf16 . minus_inf ;
2023-06-16 14:34:18 +00:00
const logits_buff_2 = try zml . Buffer . fromArray ( platform , [ 4 ] bf16 { boost , boost , bf16 . fromF32 ( 2 ) , nerf } ) ;
const new_logits , const indices = mod_bf16 . call ( . { logits_buff_2 , try DynamicSamplingStrategy . makeBuffers ( platform , . bf16 , . { . top_k = 4 , . top_p = 0.9 , . min_p = 0.1 } ) } ) ;
try std . testing . expectEqual ( [ _ ] i32 { 0 , 1 , 2 , 3 } , try indices . getValue ( [ 4 ] i32 ) ) ;
try zml . testing . expectEqual ( [ _ ] bf16 { boost , nerf , nerf , nerf } , try new_logits . getValue ( [ 4 ] bf16 ) ) ;
}
}