2023-01-03 10:21:07 +00:00
const std = @import ( " std " ) ;
2024-10-04 17:49:07 +00:00
const testing = std . testing ;
2023-06-27 14:23:22 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-03 10:21:07 +00:00
const zml = @import ( " zml " ) ;
const Buffer = zml . Buffer ;
const Tensor = zml . Tensor ;
const ShapeOf = zml . ShapeOf ;
2023-06-27 14:23:22 +00:00
const log = std . log . scoped ( . llama ) ;
2023-01-03 10:21:07 +00:00
/// Llama architecture, using huggingface transformers naming.
/// Dimensions of activations: {.b, .s, .d}
pub const LlamaLM = struct {
2024-03-04 12:11:13 +00:00
pub const Config = struct {
bos_token_id : u32 ,
eos_token_id : stdx . json . Union ( union ( enum ) {
int : u32 ,
ints : [ ] u32 ,
} ) ,
2025-08-22 17:55:03 +00:00
head_dim : ? u32 ,
hidden_size : u32 ,
num_hidden_layers : u32 ,
num_attention_heads : u32 ,
num_key_value_heads : u32 ,
2024-03-04 12:11:13 +00:00
rope_theta : f32 ,
2025-08-22 17:55:03 +00:00
max_position_embeddings : u32 ,
2024-03-04 12:11:13 +00:00
rms_norm_eps : f32 ,
2024-04-01 17:40:18 +00:00
hf_rope_impl : bool = true ,
2025-08-22 17:55:03 +00:00
tie_word_embeddings : bool = false ,
2024-10-04 17:49:07 +00:00
rope_scaling : zml . nn . RopeOpts . Scaling = . { . default = { } } ,
2024-03-04 12:11:13 +00:00
} ;
pub const Options = struct {
sampling_strategy : ? zml . nn . SamplingStrategy ,
2025-08-22 17:55:03 +00:00
max_seq_len : u32 ,
2024-03-04 12:11:13 +00:00
} ;
lm_head : ? zml . nn . Linear ,
2023-01-03 10:21:07 +00:00
model : Llama ,
// Options controlling generation
gen_opts : zml . nn . SamplingStrategy = . { } ,
2024-03-04 12:11:13 +00:00
config : Config ,
2025-08-22 17:55:03 +00:00
pub fn init ( allocator : std . mem . Allocator , config : Config , options : Options , store : zml . aio . BufferStore ) ! LlamaLM {
const rope_opts : zml . nn . RopeOpts = . {
2024-10-04 17:49:07 +00:00
. layout = if ( config . hf_rope_impl ) . sequential else . interleaved ,
2024-03-04 12:11:13 +00:00
. freq_base = config . rope_theta ,
2024-10-04 17:49:07 +00:00
. scaling = config . rope_scaling ,
2024-03-04 12:11:13 +00:00
} ;
2025-08-22 17:55:03 +00:00
const layers = try allocator . alloc ( TransformerLayer , config . num_hidden_layers ) ;
var prefix = try zml . aio . PrefixBuilder . initCapacity ( allocator , 1024 ) ;
try prefix . push ( stdx . noalloc , " model.layers " ) ;
for ( 0 . . , layers ) | i , * layer | {
try prefix . pushDigit ( stdx . noalloc , i ) ;
defer prefix . pop ( ) ;
var self_attn = try zml . aio . populateModelWithPrefix ( SelfAttn , allocator , store , prefix . concat ( " self_attn " ) ) ;
self_attn . num_heads = config . num_attention_heads ;
self_attn . num_kv_heads = config . num_key_value_heads ;
self_attn . rope_opts = rope_opts ;
self_attn . q_proj . weight = self_attn . q_proj . weight . withSharding ( . { 0 } ) ;
self_attn . k_proj . weight = self_attn . k_proj . weight . withSharding ( . { 0 } ) ;
self_attn . v_proj . weight = self_attn . v_proj . weight . withSharding ( . { 0 } ) ;
self_attn . o_proj . weight = self_attn . o_proj . weight . withSharding ( . { 1 } ) ;
var input_layernorm = try zml . aio . populateModelWithPrefix ( RmsNorm , allocator , store , prefix . concat ( " input_layernorm " ) ) ;
input_layernorm . eps = config . rms_norm_eps ;
var post_attention_layernorm = try zml . aio . populateModelWithPrefix ( RmsNorm , allocator , store , prefix . concat ( " post_attention_layernorm " ) ) ;
post_attention_layernorm . eps = config . rms_norm_eps ;
var mlp = try zml . aio . populateModelWithPrefix ( Mlp , allocator , store , prefix . concat ( " mlp " ) ) ;
mlp . up_proj . weight = mlp . up_proj . weight . withSharding ( . { 0 } ) ;
mlp . gate_proj . weight = mlp . gate_proj . weight . withSharding ( . { 0 } ) ;
mlp . down_proj . weight = mlp . down_proj . weight . withSharding ( . { 1 } ) ;
layer . * = . {
. self_attn = self_attn ,
. input_layernorm = input_layernorm ,
. post_attention_layernorm = post_attention_layernorm ,
. mlp = mlp ,
} ;
2023-04-13 12:35:27 +00:00
}
2025-08-22 17:55:03 +00:00
var lm_head : ? zml . nn . Linear = null ;
if ( ! config . tie_word_embeddings ) {
lm_head = . { . weight = store . getTensor ( " lm_head.weight " ) } ;
if ( options . sampling_strategy ) | gen_opts | {
if ( gen_opts . topk = = 1 )
lm_head . ? . weight = lm_head . ? . weight . withSharding ( . { 0 } ) ;
}
2023-01-03 10:21:07 +00:00
}
2025-08-22 17:55:03 +00:00
return . {
. config = config ,
. gen_opts = options . sampling_strategy orelse . { } ,
. model = . {
// Weights
. layers = layers ,
. embed_tokens = . { . weight = store . getTensor ( " model.embed_tokens.weight " ) } ,
. norm = . {
. weight = store . getTensor ( " model.norm.weight " ) ,
. eps = config . rms_norm_eps ,
} ,
// Push down some configs
. max_seq_len = options . max_seq_len ,
. num_heads = config . num_attention_heads ,
. num_kv_heads = config . num_key_value_heads ,
. rope_opts = . {
. layout = if ( config . hf_rope_impl ) . sequential else . interleaved ,
. freq_base = config . rope_theta ,
. scaling = config . rope_scaling ,
} ,
} ,
. lm_head = lm_head ,
} ;
2023-01-03 10:21:07 +00:00
}
/// Predicts the token at `token_index` position.
/// Returns:
/// - updated `tokens`,
/// - updated KV cache
/// - a Rng state to allow for probabilistic generation
pub fn forward (
self : LlamaLM ,
tokens_ : Tensor ,
token_index : Tensor ,
2024-03-04 12:11:13 +00:00
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
rng : Tensor . Rng ,
2024-03-04 12:11:13 +00:00
) struct { Tensor , KvCache , Tensor . Rng } {
2025-07-29 16:07:11 +00:00
stdx . debug . assert ( tokens_ . dtype ( ) = = . u32 and tokens_ . rank ( ) > = 1 and token_index . dtype ( ) = = . u32 and token_index . rank ( ) < = 1 , " Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {f} and {f} " , . { tokens_ , token_index } ) ;
2024-04-24 16:44:25 +00:00
const tokens = tokens_ . withPartialTags ( . { . s } ) ;
2024-03-04 12:11:13 +00:00
const out , const updated_kv_cache = zml . call ( self . model , . forward , . { tokens , token_index , kv_cache } ) ;
2024-04-24 16:44:25 +00:00
const new_tokens , const new_rng = self . sampleTokens ( self . lm_head , out , rng , self . gen_opts ) ;
return . { new_tokens . convert ( tokens . dtype ( ) ) . reuseBuffer ( tokens ) , updated_kv_cache , new_rng } ;
2023-01-03 10:21:07 +00:00
}
2024-03-04 12:11:13 +00:00
pub fn sampleTokens (
2023-11-01 10:16:48 +00:00
self : LlamaLM ,
2024-03-04 12:11:13 +00:00
lm_head_ : ? zml . nn . Linear ,
2023-01-03 10:21:07 +00:00
out_ : Tensor ,
rng : Tensor . Rng ,
opts : zml . nn . SamplingStrategy ,
) struct { Tensor , Tensor . Rng } {
const out = out_ . withPartialTags ( . { . s , . d } ) ;
2024-03-04 12:11:13 +00:00
var logits = blk : {
if ( lm_head_ ) | lm_head | {
break : blk zml . call ( lm_head , . forward , . { out } ) ;
} else {
break : blk self . model . embed_tokens . weight . withTags ( . { . voc , . d } ) . dot ( out , . { . d } ) ;
}
} ;
2023-11-01 10:16:48 +00:00
2023-01-03 10:21:07 +00:00
if ( logits . shape ( ) . hasTag ( . voc ) = = null )
logits = logits . rename ( . { . d = . voc } ) ;
2024-03-04 12:11:13 +00:00
const next_tokens , const new_rng = zml . nn . sampleTokens ( logits , opts , rng ) ;
2024-04-24 16:44:25 +00:00
return . { next_tokens , new_rng } ;
2023-01-03 10:21:07 +00:00
}
pub fn increment ( _ : u8 , token_index : Tensor ) Tensor {
2024-03-04 12:11:13 +00:00
return token_index . addConstant ( 1 ) . reuseBuffer ( token_index ) ;
2023-01-03 10:21:07 +00:00
}
} ;
pub const Llama = struct {
embed_tokens : zml . nn . TokenEmbedding ,
norm : RmsNorm ,
layers : [ ] TransformerLayer ,
max_seq_len : u32 = 0 ,
2025-08-22 17:55:03 +00:00
num_heads : u32 = 32 ,
num_kv_heads : u32 = 32 ,
2023-01-03 10:21:07 +00:00
rope_opts : zml . nn . RopeOpts = . {
2024-10-04 17:49:07 +00:00
. layout = . interleaved ,
2023-01-03 10:21:07 +00:00
. freq_base = 10_000 ,
} ,
/// Forward one token, using KV cache for previous tokens.
/// Returns result and updated KV cache.
2024-03-04 12:11:13 +00:00
pub fn forward ( self : Llama , tokens : Tensor , token_index : Tensor , kv_cache : KvCache ) struct { Tensor , KvCache } {
const embeds = embed ( self . embed_tokens , tokens ) ;
2023-01-03 10:21:07 +00:00
var hidden = embeds ;
2024-03-04 12:11:13 +00:00
var updated_kv_cache = kv_cache ;
2023-01-03 10:21:07 +00:00
for ( self . layers , 0 . . ) | layer , i | {
hidden , updated_kv_cache = zml . call ( layer , . forward , . { hidden , token_index , updated_kv_cache . atLayer ( i ) } ) ;
}
2023-10-17 11:00:37 +00:00
const output = zml . call ( self . norm , . forward , . { hidden } ) ;
2023-01-03 10:21:07 +00:00
2024-03-04 12:11:13 +00:00
return . { output , updated_kv_cache . reuseBuffer ( kv_cache ) } ;
2023-01-03 10:21:07 +00:00
}
2024-03-04 12:11:13 +00:00
pub fn embed ( embed_tokens_ : zml . nn . TokenEmbedding , tokens_ : Tensor ) Tensor {
return zml . call ( embed_tokens_ , . forward , . { tokens_ } ) . withPartialTags ( . { . d } ) ;
2023-01-03 10:21:07 +00:00
}
} ;
pub const TransformerLayer = struct {
input_layernorm : RmsNorm ,
self_attn : SelfAttn ,
post_attention_layernorm : RmsNorm ,
mlp : Mlp ,
pub fn forward (
self : TransformerLayer ,
x0 : Tensor ,
2024-03-04 12:11:13 +00:00
token_index : Tensor ,
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
) struct { Tensor , KvCache } {
// Self Attention
2025-07-29 16:07:11 +00:00
//log.debug("TransformerLayer({f}) -> {f}", .{ x0, self.input_layernorm.forward(x0) });
stdx . debug . assert ( x0 . rank ( ) > = 2 and x0 . shape ( ) . hasTags ( . { . s , . d } ) , " TransformerLayer expected input shape: {{..., .s, .d}}, received: {f} " , . { x0 } ) ;
2023-01-03 10:21:07 +00:00
const x0_normalized = zml . call ( self . input_layernorm , . forward , . { x0 } ) ;
const delta0 , const updated_kv_cache = zml . call ( self . self_attn , . forward , . { x0_normalized , token_index , kv_cache } ) ;
const x1 = x0 . add ( delta0 ) ;
// Fully Connected
const x1_normalized = zml . call ( self . post_attention_layernorm , . forward , . { x1 } ) ;
const x2 = zml . call ( self . mlp , . forward , . { x1_normalized } ) . add ( x1 ) ;
return . { x2 . reuseBuffer ( x0 ) , updated_kv_cache } ;
}
} ;
const RmsNorm = struct {
weight : Tensor ,
eps : f32 = 1e-5 ,
/// L2 normalization of input tensor along `.d` axis.
pub fn forward ( self : RmsNorm , input : Tensor ) Tensor {
const x = if ( input . shape ( ) . isFullyTagged ( ) ) input else input . withPartialTags ( . { . d } ) ;
2024-10-04 17:49:07 +00:00
const normalized = zml . nn . rmsNorm ( x , . d , self . eps ) ;
2023-01-03 10:21:07 +00:00
return normalized . mul ( self . weight . convert ( x . dtype ( ) ) . withTags ( . { . d } ) . broad ( x . shape ( ) ) ) ;
}
} ;
const Mlp = struct {
up_proj : zml . nn . Linear , // (dim -> hidden_dim)
gate_proj : zml . nn . Linear , // (dim -> hidden_dim)
down_proj : zml . nn . Linear , // (hidden_dim -> dim)
pub fn forward ( self : Mlp , x : Tensor ) Tensor {
const proj = zml . call ( self . up_proj , . forward , . { x } ) ;
var output = zml . call ( self . gate_proj , . forward , . { x } ) ;
output = output . silu ( ) . mul ( proj ) ;
return zml . call ( self . down_proj , . forward , . { output } ) ;
}
} ;
pub const SelfAttn = struct {
q_proj : zml . nn . Linear ,
k_proj : zml . nn . Linear ,
v_proj : zml . nn . Linear ,
2025-08-22 17:55:03 +00:00
q_norm : ? RmsNorm ,
k_norm : ? RmsNorm ,
2023-01-03 10:21:07 +00:00
o_proj : zml . nn . Linear ,
num_heads : i64 = undefined ,
num_kv_heads : i64 = 0 ,
rope_opts : zml . nn . RopeOpts = undefined ,
/// Self Attention.
/// - If token_index is set, x is assumed to be the representation of one new token,
/// and kv_cache will be read for the previous tokens.
/// - If token_index is not set, x is assumed to be the representation of all tokens
/// since the beginning of the sequence, and kv_cache won't be read.
/// In both case, kv_cache will be updated with the computed key and value.
/// x: {.b, .s, .d } -> .{.b, .s, .d}
pub fn forward (
self : SelfAttn ,
x : Tensor ,
2024-03-04 12:11:13 +00:00
token_index : Tensor ,
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
) struct { Tensor , KvCache } {
const num_kv_heads = if ( self . num_kv_heads > 0 ) self . num_kv_heads else self . num_heads ;
2023-04-13 12:35:27 +00:00
var q = zml . call ( self . q_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = self . num_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
var k = zml . call ( self . k_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = num_kv_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
var v = zml . call ( self . v_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = num_kv_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
2024-03-04 12:11:13 +00:00
2023-01-03 10:21:07 +00:00
// Generate the attention mask.
const seq_len = kv_cache . k . dim ( . k ) ;
var attn_mask = zml . nn . causalAttnMask ( . { . q = seq_len , . k = seq_len } , x . dtype ( ) , null ) ;
2024-03-04 12:11:13 +00:00
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
// then slice into it, but XLA is able to optimize this correctly.
2024-03-20 13:37:19 +00:00
attn_mask = attn_mask . gatherSlices ( zml . Shape . init ( . { . q = x . dim ( . s ) } , attn_mask . dtype ( ) ) , token_index . reshape ( . { . coord = 1 } ) , . { } ) ;
2023-01-03 10:21:07 +00:00
// In self-attention, .s axis is used both for keys and queries.
2024-03-04 12:11:13 +00:00
const pos_index = b : {
2024-03-20 13:37:19 +00:00
const temp = Tensor . arange ( . { . end = x . dim ( . s ) } , token_index . dtype ( ) ) . withTags ( . { . s } ) . broad ( zml . Shape . init ( . { . s = x . dim ( . s ) } , token_index . dtype ( ) ) ) ;
break : b temp . add ( token_index . broad ( temp . shape ( ) ) ) ;
2024-03-04 12:11:13 +00:00
} ;
2025-08-22 17:55:03 +00:00
if ( self . q_norm ) | norm | q = norm . forward ( q . rename ( . { . hd = . d } ) ) . rename ( . { . d = . hd } ) ;
if ( self . k_norm ) | norm | k = norm . forward ( k . rename ( . { . hd = . d } ) ) . rename ( . { . d = . hd } ) ;
2024-03-04 12:11:13 +00:00
q = zml . nn . rope ( q , pos_index , self . rope_opts ) ;
k = zml . nn . rope ( k , pos_index , self . rope_opts ) ;
2023-01-03 10:21:07 +00:00
q = q . rename ( . { . s = . q } ) ;
k = k . rename ( . { . s = . k } ) ;
v = v . rename ( . { . s = . k } ) ;
2024-03-04 12:11:13 +00:00
const dtype = q . dtype ( ) ;
const new_kv_cache = kv_cache . update ( k , v , token_index ) ;
k = new_kv_cache . keys ( ) . convert ( dtype ) ;
v = new_kv_cache . values ( ) . convert ( dtype ) ;
2023-01-03 10:21:07 +00:00
2024-03-04 12:11:13 +00:00
const attn_output = zml . nn . sdpa ( q , k , v , . { . attn_mask = attn_mask , . allow_cudnn = true } ) ;
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
2023-01-03 10:21:07 +00:00
const attn = attn_output . merge ( . { . d = . { . h , . hd } } ) . rename ( . { . q = . s } ) ;
return . { zml . call ( self . o_proj , . forward , . { attn } ) , new_kv_cache } ;
}
} ;
pub const KvCache = struct {
k : Tensor ,
v : Tensor ,
layer_index : Tensor ,
pub fn init ( kv_shape : zml . Shape ) KvCache {
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
return . {
2023-04-13 12:35:27 +00:00
. k = Tensor . constant ( kv_shape , kv_shape . dtype ( ) . one ( ) ) . withSharding ( . { . h } ) ,
. v = Tensor . constant ( kv_shape , kv_shape . dtype ( ) . one ( ) ) . withSharding ( . { . h } ) ,
2024-03-04 12:11:13 +00:00
. layer_index = Tensor . scalar ( - 1 , . u32 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn initShape ( kv_shape : zml . Shape ) ShapeOf ( KvCache ) {
return . {
. k = kv_shape ,
. v = kv_shape ,
2024-03-04 12:11:13 +00:00
. layer_index = zml . Shape . init ( . { } , . u32 ) ,
} ;
}
pub fn initBuffer ( kv_shape : zml . Shape , platform : zml . Platform ) ! zml . Bufferized ( KvCache ) {
return . {
2025-08-22 17:55:03 +00:00
. k = try zml . Buffer . uninitialized ( platform , kv_shape , . { } ) ,
. v = try zml . Buffer . uninitialized ( platform , kv_shape , . { } ) ,
2025-07-29 16:07:11 +00:00
. layer_index = try zml . Buffer . scalar ( platform , 0 , . u32 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn keys ( self : KvCache ) Tensor {
2024-07-03 11:30:49 +00:00
return self . k . dynamicSlice ( . { . layer = Tensor . DynSlice { . start = self . layer_index , . len = 1 } } ) . squeeze ( . layer ) ;
2023-01-03 10:21:07 +00:00
}
pub fn values ( self : KvCache ) Tensor {
2024-07-03 11:30:49 +00:00
return self . v . dynamicSlice ( . { . layer = Tensor . DynSlice { . start = self . layer_index , . len = 1 } } ) . squeeze ( . layer ) ;
2023-01-03 10:21:07 +00:00
}
2024-03-04 12:11:13 +00:00
pub fn update ( self : KvCache , new_k : Tensor , new_v : Tensor , token_index : ? Tensor ) KvCache {
const k_shape = self . k . shape ( ) . drop ( . layer ) ;
var layer = self . layer_index ;
layer = if ( token_index ) | idx | layer . broad ( idx . shape ( ) ) else layer ;
return if ( token_index ) | idx | . {
. k = self . k . scatterSlices (
. { . layer = layer , . k = idx } ,
new_k . convert ( self . k . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
) . reuseBuffer ( self . k ) ,
. v = self . v . scatterSlices (
. { . layer = layer , . k = idx } ,
new_v . convert ( self . v . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
) . reuseBuffer ( self . v ) ,
. layer_index = self . layer_index ,
} else . {
. k = self . k . scatterSlices (
. { . layer = layer } ,
new_k . convert ( self . k . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
2023-01-03 10:21:07 +00:00
) . reuseBuffer ( self . k ) ,
2024-03-04 12:11:13 +00:00
. v = self . v . scatterSlices (
. { . layer = layer } ,
new_v . convert ( self . v . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
2023-01-03 10:21:07 +00:00
) . reuseBuffer ( self . v ) ,
. layer_index = self . layer_index ,
} ;
}
pub fn atLayer ( self : KvCache , layer_index : usize ) KvCache {
return . {
. k = self . k ,
. v = self . v ,
2024-03-04 12:11:13 +00:00
. layer_index = Tensor . scalar ( layer_index , . u32 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn reuseBuffer ( self : KvCache , other : KvCache ) KvCache {
return . {
. k = self . k . reuseBuffer ( other . k ) ,
. v = self . v . reuseBuffer ( other . v ) ,
. layer_index = self . layer_index . reuseBuffer ( other . layer_index ) ,
} ;
}
} ;