2023-01-02 14:28:25 +00:00
const std = @import ( " std " ) ;
2024-11-28 12:24:39 +00:00
2023-06-21 14:45:14 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
const zml = @import ( " zml.zig " ) ;
const Tensor = zml . Tensor ;
2023-06-21 14:45:14 +00:00
2023-01-02 14:28:25 +00:00
/// Multiplies a matrix or a vector with a tensor,
/// following the semantic of pytorch `@` operator.
/// When both sides are matrices, it's the textbook matrix multiplication :
/// `matmul(.{ 8, 9 }, .{ 9, 10 }) -> .{ 8, 10 }`
/// When one of the input is a tensor, it assumes the first dimensions are batches,
/// and the last two ones are used for the regular matmul.
/// * `matmul(.{10}, .{10}) -> .{}`
/// * `matmul(.{10}, .{10}) -> .{}`
pub fn matmul ( lhs : Tensor , rhs : Tensor ) Tensor {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( lhs . rank ( ) > = 1 and rhs . rank ( ) > = 1 , " Can't matmul({f}, {f}) ! The two tensors need to have at least rank 1. " , . { lhs , rhs } ) ;
2023-01-02 14:28:25 +00:00
const contracting = [ _ ] [ 2 ] i8 { . { - 1 , if ( rhs . rank ( ) > = 2 ) rhs . rank ( ) - 2 else 0 } } ;
if ( lhs . rank ( ) = = 1 or rhs . rank ( ) < = 2 ) {
// When lhs is a vector or rhs is small the torch semantics match the dot_general semantics and life is easy.
return lhs . dotGeneral ( rhs , & contracting , & . { } ) ;
}
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( lhs . rank ( ) = = 2 , " Can't matmul({f}, {f}) ! One of the two tensors need to have a rank less than 2. " , . { lhs , rhs } ) ;
2023-01-02 14:28:25 +00:00
// Pytorch treats the extra dimensions of rhs has batching dimensions,
// and implicitly broadcast lhs along those.
// We make this broadcasting explicit.
var left_shape = rhs . shape ( ) ;
left_shape . _dims . set ( left_shape . axis ( - 2 ) , lhs . dim ( - 2 ) ) ;
left_shape . _tags . set ( left_shape . axis ( - 2 ) , lhs . shape ( ) . tag ( - 2 ) ) ;
left_shape . _dims . set ( left_shape . axis ( - 1 ) , lhs . dim ( - 1 ) ) ;
left_shape . _tags . set ( left_shape . axis ( - 1 ) , lhs . shape ( ) . tag ( - 1 ) ) ;
const lhs_broad = lhs . broadcastLeft ( left_shape ) ;
const n_batching_axes = rhs . rank ( ) - lhs . rank ( ) ;
var batching : [ Tensor . MAX_RANK ] [ 2 ] i8 = undefined ;
for ( 0 . . n_batching_axes ) | i | {
batching [ i ] = . { @intCast ( i ) , @intCast ( i ) } ;
}
return lhs_broad . dotGeneral ( rhs , & contracting , batching [ 0 . . n_batching_axes ] ) ;
}
test matmul {
const platform = zml . testing . env ( ) ;
var comp = try zml . module . CompilationContext . init ( std . heap . page_allocator , " test " , platform ) ;
defer comp . deinit ( ) ;
comp . activate ( ) ;
defer comp . deactivate ( ) ;
// Generated with pytorch
inline for ( . {
. { . { 20 } , . { 20 } , . { } } ,
. { . { 20 } , . { 20 , 15 } , . { 15 } } ,
. { . { 20 } , . { 11 , 20 , 15 } , . { 11 , 15 } } ,
. { . { 20 } , . { 9 , 11 , 20 , 15 } , . { 9 , 11 , 15 } } ,
. { . { 20 } , . { 7 , 9 , 11 , 20 , 15 } , . { 7 , 9 , 11 , 15 } } ,
. { . { 20 } , . { 5 , 7 , 9 , 11 , 20 , 15 } , . { 5 , 7 , 9 , 11 , 15 } } ,
. { . { 12 , 20 } , . { 20 } , . { 12 } } ,
. { . { 12 , 20 } , . { 20 , 15 } , . { 12 , 15 } } ,
. { . { 12 , 20 } , . { 11 , 20 , 15 } , . { 11 , 12 , 15 } } ,
. { . { 12 , 20 } , . { 9 , 11 , 20 , 15 } , . { 9 , 11 , 12 , 15 } } ,
. { . { 12 , 20 } , . { 7 , 9 , 11 , 20 , 15 } , . { 7 , 9 , 11 , 12 , 15 } } ,
. { . { 12 , 20 } , . { 5 , 7 , 9 , 11 , 20 , 15 } , . { 5 , 7 , 9 , 11 , 12 , 15 } } ,
. { . { 10 , 12 , 20 } , . { 20 } , . { 10 , 12 } } ,
. { . { 10 , 12 , 20 } , . { 20 , 15 } , . { 10 , 12 , 15 } } ,
. { . { 8 , 10 , 12 , 20 } , . { 20 } , . { 8 , 10 , 12 } } ,
. { . { 8 , 10 , 12 , 20 } , . { 20 , 15 } , . { 8 , 10 , 12 , 15 } } ,
. { . { 6 , 8 , 10 , 12 , 20 } , . { 20 } , . { 6 , 8 , 10 , 12 } } ,
. { . { 6 , 8 , 10 , 12 , 20 } , . { 20 , 15 } , . { 6 , 8 , 10 , 12 , 15 } } ,
. { . { 4 , 6 , 8 , 10 , 12 , 20 } , . { 20 } , . { 4 , 6 , 8 , 10 , 12 } } ,
. { . { 4 , 6 , 8 , 10 , 12 , 20 } , . { 20 , 15 } , . { 4 , 6 , 8 , 10 , 12 , 15 } } ,
} ) | testcase | {
const x_shape , const y_shape , const z_shape = testcase ;
const x = Tensor . constant ( x_shape , . { . f32 = 0.0 } ) ;
const y = Tensor . constant ( y_shape , . { . f32 = 0.0 } ) ;
const z = matmul ( x , y ) ;
try std . testing . expectEqualSlices ( i64 , & z_shape , z . dims ( ) ) ;
}
}
/// Inserts a 1-dim axis at the given position.
/// Negative indexes are handled like pytorch, ie they are relative to the returned shaped:
/// - `.{5, 4}.unsqueeze(1)` returns `.{5, 1, 4}`
/// - `.{5, 4}.unsqueeze(-1)` returns `.{5, 4, 1}`
pub fn unsqueeze (
self : Tensor ,
axis_ : anytype ,
) Tensor {
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( self . rank ( ) < Tensor . MAX_RANK - 1 , " Can't unsqueeze {f}, it's already at max rank. " , . { self } ) ;
2023-01-02 14:28:25 +00:00
const a = switch ( @typeInfo ( @TypeOf ( axis_ ) ) ) {
2024-07-02 14:19:04 +00:00
. int , . comptime_int = > if ( axis_ < 0 )
2023-01-02 14:28:25 +00:00
@as ( i8 , self . rank ( ) ) + 1 + axis_
else
self . axis ( axis_ ) ,
else = > self . axis ( axis_ ) ,
} ;
return self . insertAxes ( a , . { . _ } ) ;
}
test unsqueeze {
2024-01-08 17:55:20 +00:00
const Local = struct {
pub fn _fwd ( x : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
var y = x ;
y = unsqueeze ( y , 0 ) ;
y = unsqueeze ( y , - 1 ) ;
y = unsqueeze ( y , - 1 ) ;
return y ;
}
} ;
const platform = zml . testing . env ( ) ;
const x = try zml . Buffer . fromArray ( platform , @as ( [ 8 ] f16 , undefined ) ) ;
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCall ( platform , Local . _fwd , . { x } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectEqualShapes ( zml . Shape . init ( . { 1 , 8 , 1 , 1 } , . f16 ) , res . shape ( ) ) ;
}
/// Given an input images with .{ .c, .w, .h } tags,
/// shuffle values between the channel (.c), width (.w) and height (.h) axes.
/// pixelShuffle(.{ .c, .w, .h }, u) -> .{ .c / u / u, .w * u, .h * u}
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
pub fn pixelShuffle ( tensor : Tensor , upscale_factor : u32 ) Tensor {
const shape = tensor . shape ( ) ;
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( shape . hasTags ( . { . c , . w , . h } ) , " pixelShuffle({f}) is invalide. Missing tags {{.c, .w, .h}} " , . { tensor } ) ;
2023-01-02 14:28:25 +00:00
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( @mod ( shape . dim ( . c ) , upscale_factor * upscale_factor ) = = 0 , " pixelShuffle({f}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2 " , . { tensor , shape . dim ( . c ) , upscale_factor } ) ;
2023-01-02 14:28:25 +00:00
const s = tensor . splitAxis ( . c , . { . c = - 1 , . upscale_h = upscale_factor , . upscale_w = upscale_factor } ) ;
const perm = s . shape ( ) . contiguousPerm ( . { . h , . upscale_h , . w , . upscale_w } ) ;
const cont = s . transpose ( perm . constSlice ( ) ) ;
return cont . merge ( . { . h = . { . h , . upscale_h } , . w = . { . w , . upscale_w } } ) . transpose ( tensor . shape ( ) ) ;
}
test pixelShuffle {
const platform = zml . testing . env ( ) ;
const upscale_factor = 3 ;
2024-11-28 12:24:39 +00:00
const shape = zml . Shape . init ( . { . b = 1 , . c = 9 , . h = 4 , . w = 4 } , . i32 ) ;
const input = input : {
var digits : [ 9 * 4 * 4 ] i32 = undefined ;
for ( & digits , 0 . . ) | * d , i | d . * = @intCast ( i ) ;
break : input try zml . Buffer . fromSlice ( platform , shape , & digits ) ;
} ;
const output = try zml . testing . compileAndCall ( platform , pixelShuffle , . { input , upscale_factor } ) ;
2023-01-02 14:28:25 +00:00
const exp = zml . HostBuffer . fromArray ( & [ 1 ] [ 1 ] [ 12 ] [ 12 ] i32 { . { . {
. { 0 , 16 , 32 , 1 , 17 , 33 , 2 , 18 , 34 , 3 , 19 , 35 } ,
. { 48 , 64 , 80 , 49 , 65 , 81 , 50 , 66 , 82 , 51 , 67 , 83 } ,
. { 96 , 112 , 128 , 97 , 113 , 129 , 98 , 114 , 130 , 99 , 115 , 131 } ,
. { 4 , 20 , 36 , 5 , 21 , 37 , 6 , 22 , 38 , 7 , 23 , 39 } ,
. { 52 , 68 , 84 , 53 , 69 , 85 , 54 , 70 , 86 , 55 , 71 , 87 } ,
. { 100 , 116 , 132 , 101 , 117 , 133 , 102 , 118 , 134 , 103 , 119 , 135 } ,
. { 8 , 24 , 40 , 9 , 25 , 41 , 10 , 26 , 42 , 11 , 27 , 43 } ,
. { 56 , 72 , 88 , 57 , 73 , 89 , 58 , 74 , 90 , 59 , 75 , 91 } ,
. { 104 , 120 , 136 , 105 , 121 , 137 , 106 , 122 , 138 , 107 , 123 , 139 } ,
. { 12 , 28 , 44 , 13 , 29 , 45 , 14 , 30 , 46 , 15 , 31 , 47 } ,
. { 60 , 76 , 92 , 61 , 77 , 93 , 62 , 78 , 94 , 63 , 79 , 95 } ,
. { 108 , 124 , 140 , 109 , 125 , 141 , 110 , 126 , 142 , 111 , 127 , 143 } ,
} } } ) ;
try zml . testing . expectClose ( exp , output , 0 ) ;
}
/// Implementation of `torch.roll`.
///
/// Note: at the difference of Pytorch, shifts need to be explicitly repeated, even if they are the same for all axes.
/// ref: https://pytorch.org/docs/stable/generated/torch.roll.html
Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).
2025-09-26 13:38:11 +00:00
pub fn roll ( self : Tensor , shifts : [ ] const i64 , axes_ : [ ] const i8 ) Tensor {
2023-01-02 14:28:25 +00:00
// TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 })
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( self . rank ( ) > 0 and shifts . len = = axes_ . len , " Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length. " , . { shifts . len , axes_ . len } ) ;
2023-01-02 14:28:25 +00:00
if ( shifts . len ! = 1 or axes_ . len ! = 1 ) {
const tail_shifts = shifts [ 1 . . shifts . len ] ;
const tail_dims = axes_ [ 1 . . axes_ . len ] ;
const first_dim_rolled = roll ( self , & . { shifts [ 0 ] } , & . { axes_ [ 0 ] } ) ;
return roll ( first_dim_rolled , tail_shifts , tail_dims ) ;
}
Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).
2025-09-26 13:38:11 +00:00
const a = self . axis ( axes_ [ 0 ] ) ;
2023-01-02 14:28:25 +00:00
const start = @mod ( self . dim ( a ) - shifts [ 0 ] , self . dim ( a ) ) ;
const idx = Tensor . arange ( . { . start = start , . end = start + self . dim ( a ) } , . f32 ) ;
const divisor : f32 = @floatFromInt ( self . dim ( a ) ) ;
Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).
2025-09-26 13:38:11 +00:00
return self . gather_ ( & . { a } , & . { idx . fmod ( divisor ) . convert ( . i32 ) } , . { } ) ;
2023-01-02 14:28:25 +00:00
}
test roll {
const platform = zml . testing . env ( ) ;
const input = try zml . Buffer . fromSlice ( platform , . { 4 , 2 } , & [ _ ] f32 { 2 , 2 , 3 , 4 , 5 , 6 , 7 , 8 } ) ;
const res = try zml . testing . compileAndCall (
platform ,
roll ,
Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).
2025-09-26 13:38:11 +00:00
. { input , & [ _ ] i64 { 2 , 1 } , & [ _ ] i8 { 0 , 1 } } ,
2023-01-02 14:28:25 +00:00
) ;
const expectation = zml . HostBuffer . fromSlice ( . { 4 , 2 } , & [ _ ] f32 { 6 , 5 , 8 , 7 , 2 , 1 , 4 , 3 } ) ;
try zml . testing . expectClose ( expectation , res , 1e0 ) ;
}
pub const MeshgridIndexing = enum { xy , ij } ;
/// Mimic Pytorch and Numpy api.
/// The .ij mode is just calling to `zml.nn.cartesianProduct`
/// and has simple semantics.
/// The .xy mode swap the role of the first two vectors, it's generally best
/// to rewrite the calling code to use .ij mode if possible.
/// See Numpy docs:
/// https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html#numpy.meshgrid
/// - In the 2-D case with vectors of length M and N:
/// * for ‘ ij’ indexing, outputs are of shape (M, N)
/// * for ‘ xy’ indexing, outputs are of shape (N, M)
/// - In the 3-D case with vectors of length M, N and P:
/// * for ‘ ij’ indexing, outputs are of shape (M, N, P)
/// * for ‘ xy’ indexing, outputs are of shape (N, M, P)
pub fn meshgrid ( comptime N : u3 , vectors : [ N ] Tensor , indexing : MeshgridIndexing ) [ N ] Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . assertComptime ( vectors . len > = 1 , " Invalid meshgrid. No input. " , . { } ) ;
stdx . debug . assertComptime ( vectors . len < = Tensor . MAX_RANK , " Invalid meshgrid(...). Too many inputs: {} " , . { vectors . len } ) ;
2023-01-02 14:28:25 +00:00
if ( vectors . len = = 1 ) return vectors ;
return switch ( indexing ) {
. ij = > zml . Tensor . cartesianProduct ( N , vectors ) ,
. xy = > {
const x , const y = vectors [ 0 . . 2 ] . * ;
var new_vectors = vectors ;
new_vectors [ 0 . . 2 ] . * = . { y , x } ;
var res = zml . Tensor . cartesianProduct ( N , new_vectors ) ;
const y_res , const x_res = res [ 0 . . 2 ] . * ;
res [ 0 . . 2 ] . * = . { x_res , y_res } ;
return res ;
} ,
} ;
}
test meshgrid {
const platform = zml . testing . env ( ) ;
const x = try zml . Buffer . fromSlice ( platform , . { 6 } , & [ _ ] i32 { 0 , 1 , 2 , 3 , 4 , 5 } ) ;
const y = try zml . Buffer . fromSlice ( platform , . { 4 } , & [ _ ] i32 { 0 , 1 , 2 , 3 } ) ;
const Local = struct {
2024-01-08 17:55:20 +00:00
pub fn _meshgrid2 ( a : Tensor , b : Tensor , indexing : MeshgridIndexing ) [ 2 ] Tensor {
2023-01-02 14:28:25 +00:00
return meshgrid ( 2 , . { a , b } , indexing ) ;
}
} ;
// Only test .xy mode, sinc .ij is just calling cartesianProduct which
// got its own tests.
{
2024-01-08 17:55:20 +00:00
const xs , const ys = try zml . testing . compileAndCall ( platform , Local . _meshgrid2 , . { x , y , . xy } ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqualSlices ( i64 , & . { 4 , 6 } , xs . dims ( ) ) ;
try std . testing . expectEqualSlices ( i64 , & . { 4 , 6 } , ys . dims ( ) ) ;
try std . testing . expectEqualDeep (
[ 4 ] [ 6 ] i32 {
. { 0 , 1 , 2 , 3 , 4 , 5 } ,
. { 0 , 1 , 2 , 3 , 4 , 5 } ,
. { 0 , 1 , 2 , 3 , 4 , 5 } ,
. { 0 , 1 , 2 , 3 , 4 , 5 } ,
} ,
try xs . getValue ( [ 4 ] [ 6 ] i32 ) ,
) ;
try std . testing . expectEqualDeep (
[ 4 ] [ 6 ] i32 {
. { 0 , 0 , 0 , 0 , 0 , 0 } ,
. { 1 , 1 , 1 , 1 , 1 , 1 } ,
. { 2 , 2 , 2 , 2 , 2 , 2 } ,
. { 3 , 3 , 3 , 3 , 3 , 3 } ,
} ,
try ys . getValue ( [ 4 ] [ 6 ] i32 ) ,
) ;
}
}
Revamp gather API with named indices (and add gather_ variant), improve topK handling, and add Yarn rope embedding support across core modules (buffer, nn, pjrtx, quantization, shape, tensor, testing, tokenizer, torch).
2025-09-26 13:38:11 +00:00
/// Flattens the given axis and the next one, into one new axis.
pub fn flatten ( self : Tensor , axis_ : anytype ) Tensor {
const old_shape = self . _shape ;
const a = self . axis ( axis_ ) ;
stdx . debug . assert ( a + 1 < self . rank ( ) , " Can't flatten {f} on the last axis {}. " , . { self , axis_ } ) ;
const new_shape = old_shape . mergeAxis ( a , . { a , a + 1 } ) ;
return self . reshape ( new_shape ) ;
}