2023-01-03 10:21:07 +00:00
const std = @import ( " std " ) ;
2023-06-27 14:23:22 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-03 10:21:07 +00:00
const zml = @import ( " zml " ) ;
2023-06-27 14:23:22 +00:00
const testing = std . testing ;
2023-01-03 10:21:07 +00:00
const Buffer = zml . Buffer ;
const Tensor = zml . Tensor ;
const ShapeOf = zml . ShapeOf ;
2023-06-27 14:23:22 +00:00
const gguf = zml . io . gguf ;
2023-01-03 10:21:07 +00:00
const expectClose = zml . testing . expectClose ;
2023-06-27 14:23:22 +00:00
const log = std . log . scoped ( . llama ) ;
2023-01-03 10:21:07 +00:00
/// Llama architecture, using huggingface transformers naming.
/// Dimensions of activations: {.b, .s, .d}
pub const LlamaLM = struct {
2024-03-04 12:11:13 +00:00
pub const Config = struct {
bos_token_id : u32 ,
eos_token_id : stdx . json . Union ( union ( enum ) {
int : u32 ,
ints : [ ] u32 ,
} ) ,
num_hidden_layers : usize ,
num_attention_heads : usize ,
num_key_value_heads : usize ,
rope_theta : f32 ,
max_position_embeddings : usize ,
rms_norm_eps : f32 ,
2024-04-01 17:40:18 +00:00
hf_rope_impl : bool = true ,
2024-03-04 12:11:13 +00:00
} ;
pub const Options = struct {
sampling_strategy : ? zml . nn . SamplingStrategy ,
max_seq_len : usize ,
} ;
lm_head : ? zml . nn . Linear ,
2023-01-03 10:21:07 +00:00
model : Llama ,
// Options controlling generation
gen_opts : zml . nn . SamplingStrategy = . { } ,
2024-03-04 12:11:13 +00:00
config : Config ,
pub fn init ( self : * LlamaLM , config : Config , options : Options ) void {
self . config = config ;
self . gen_opts = options . sampling_strategy orelse . { } ;
self . model . max_seq_len = @intCast ( options . max_seq_len ) ;
self . model . num_heads = @intCast ( config . num_attention_heads ) ;
self . model . num_kv_heads = @intCast ( config . num_key_value_heads ) ;
self . model . rope_opts = . {
2024-04-01 17:40:18 +00:00
. impl = if ( config . hf_rope_impl ) . sequential else . interleaved ,
2024-03-04 12:11:13 +00:00
. freq_base = config . rope_theta ,
} ;
2023-01-03 10:21:07 +00:00
for ( self . model . layers ) | * layer | {
2024-03-04 12:11:13 +00:00
layer . self_attn . num_heads = self . model . num_heads ;
layer . self_attn . num_kv_heads = self . model . num_kv_heads ;
layer . self_attn . rope_opts = self . model . rope_opts ;
layer . input_layernorm . eps = config . rms_norm_eps ;
layer . post_attention_layernorm . eps = config . rms_norm_eps ;
2023-04-13 12:35:27 +00:00
layer . mlp . up_proj . weight = layer . mlp . up_proj . weight . withSharding ( . { 0 } ) ;
layer . mlp . gate_proj . weight = layer . mlp . gate_proj . weight . withSharding ( . { 0 } ) ;
layer . mlp . down_proj . weight = layer . mlp . down_proj . weight . withSharding ( . { 1 } ) ;
layer . self_attn . q_proj . weight = layer . self_attn . q_proj . weight . withSharding ( . { 0 } ) ;
layer . self_attn . k_proj . weight = layer . self_attn . k_proj . weight . withSharding ( . { 0 } ) ;
layer . self_attn . v_proj . weight = layer . self_attn . v_proj . weight . withSharding ( . { 0 } ) ;
layer . self_attn . o_proj . weight = layer . self_attn . o_proj . weight . withSharding ( . { 1 } ) ;
}
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails
2024-03-04 12:11:13 +00:00
if ( self . gen_opts . topk = = 1 and self . lm_head ! = null ) {
self . lm_head . ? . weight = self . lm_head . ? . weight . withSharding ( . { 0 } ) ;
2023-01-03 10:21:07 +00:00
}
}
/// Predicts the token at `token_index` position.
/// Returns:
/// - updated `tokens`,
/// - updated KV cache
/// - a Rng state to allow for probabilistic generation
pub fn forward (
self : LlamaLM ,
tokens_ : Tensor ,
token_index : Tensor ,
2024-03-04 12:11:13 +00:00
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
rng : Tensor . Rng ,
2024-03-04 12:11:13 +00:00
) struct { Tensor , KvCache , Tensor . Rng } {
stdx . debug . assert ( tokens_ . dtype ( ) = = . u32 and tokens_ . rank ( ) > = 1 and token_index . dtype ( ) = = . u32 and token_index . rank ( ) < = 1 , " Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {} " , . { tokens_ , token_index } ) ;
2024-04-24 16:44:25 +00:00
const tokens = tokens_ . withPartialTags ( . { . s } ) ;
2024-03-04 12:11:13 +00:00
const out , const updated_kv_cache = zml . call ( self . model , . forward , . { tokens , token_index , kv_cache } ) ;
2024-04-24 16:44:25 +00:00
const new_tokens , const new_rng = self . sampleTokens ( self . lm_head , out , rng , self . gen_opts ) ;
return . { new_tokens . convert ( tokens . dtype ( ) ) . reuseBuffer ( tokens ) , updated_kv_cache , new_rng } ;
2023-01-03 10:21:07 +00:00
}
2024-03-04 12:11:13 +00:00
pub fn sampleTokens (
2023-11-01 10:16:48 +00:00
self : LlamaLM ,
2024-03-04 12:11:13 +00:00
lm_head_ : ? zml . nn . Linear ,
2023-01-03 10:21:07 +00:00
out_ : Tensor ,
rng : Tensor . Rng ,
opts : zml . nn . SamplingStrategy ,
) struct { Tensor , Tensor . Rng } {
const out = out_ . withPartialTags ( . { . s , . d } ) ;
2024-03-04 12:11:13 +00:00
var logits = blk : {
if ( lm_head_ ) | lm_head | {
break : blk zml . call ( lm_head , . forward , . { out } ) ;
} else {
break : blk self . model . embed_tokens . weight . withTags ( . { . voc , . d } ) . dot ( out , . { . d } ) ;
}
} ;
2023-11-01 10:16:48 +00:00
2023-01-03 10:21:07 +00:00
if ( logits . shape ( ) . hasTag ( . voc ) = = null )
logits = logits . rename ( . { . d = . voc } ) ;
2024-03-04 12:11:13 +00:00
const next_tokens , const new_rng = zml . nn . sampleTokens ( logits , opts , rng ) ;
2024-04-24 16:44:25 +00:00
return . { next_tokens , new_rng } ;
2023-01-03 10:21:07 +00:00
}
pub fn increment ( _ : u8 , token_index : Tensor ) Tensor {
2024-03-04 12:11:13 +00:00
return token_index . addConstant ( 1 ) . reuseBuffer ( token_index ) ;
2023-01-03 10:21:07 +00:00
}
} ;
pub const Llama = struct {
embed_tokens : zml . nn . TokenEmbedding ,
norm : RmsNorm ,
layers : [ ] TransformerLayer ,
max_seq_len : u32 = 0 ,
num_heads : i64 = 32 ,
num_kv_heads : i64 = 32 ,
rope_opts : zml . nn . RopeOpts = . {
. impl = . interleaved ,
. freq_base = 10_000 ,
} ,
const Shape = struct {
s : u32 ,
layer : u16 ,
hd : u16 ,
nh : u16 ,
nkvh : u16 ,
dtype : zml . DataType ,
} ;
pub fn shape ( self : Llama ) Shape {
const key_dim = self . layers [ 0 ] . self_attn . k_proj . weight . dim ( 0 ) ;
const num_kv_heads = if ( self . num_kv_heads > 0 ) self . num_kv_heads else self . num_heads ;
return . {
. s = self . max_seq_len ,
. layer = @intCast ( self . layers . len ) ,
. hd = @intCast ( @divExact ( key_dim , num_kv_heads ) ) ,
. nh = @intCast ( self . num_heads ) ,
. nkvh = @intCast ( num_kv_heads ) ,
. dtype = self . embed_tokens . weight . dtype ( ) ,
} ;
}
/// Forward one token, using KV cache for previous tokens.
/// Returns result and updated KV cache.
2024-03-04 12:11:13 +00:00
pub fn forward ( self : Llama , tokens : Tensor , token_index : Tensor , kv_cache : KvCache ) struct { Tensor , KvCache } {
const embeds = embed ( self . embed_tokens , tokens ) ;
2023-01-03 10:21:07 +00:00
var hidden = embeds ;
2024-03-04 12:11:13 +00:00
var updated_kv_cache = kv_cache ;
2023-01-03 10:21:07 +00:00
for ( self . layers , 0 . . ) | layer , i | {
hidden , updated_kv_cache = zml . call ( layer , . forward , . { hidden , token_index , updated_kv_cache . atLayer ( i ) } ) ;
}
2023-10-17 11:00:37 +00:00
const output = zml . call ( self . norm , . forward , . { hidden } ) ;
2023-01-03 10:21:07 +00:00
2024-03-04 12:11:13 +00:00
return . { output , updated_kv_cache . reuseBuffer ( kv_cache ) } ;
2023-01-03 10:21:07 +00:00
}
2024-03-04 12:11:13 +00:00
pub fn embed ( embed_tokens_ : zml . nn . TokenEmbedding , tokens_ : Tensor ) Tensor {
return zml . call ( embed_tokens_ , . forward , . { tokens_ } ) . withPartialTags ( . { . d } ) ;
2023-01-03 10:21:07 +00:00
}
fn initKvCache ( self : Llama , embed_shape : zml . Shape ) KvCache {
const dims = self . shape ( ) ;
var kv_shape = embed_shape . insert ( 0 , . { . layer = dims . layer } ) . rename ( . { . s = . k } ) . splitAxes ( . { . d = . { . h = dims . nkvh , . hd = dims . hd } } ) ;
2024-03-04 12:11:13 +00:00
const perm = kv_shape . contiguousPerm ( . { . k , . h , . hd } ) ;
2023-01-03 10:21:07 +00:00
kv_shape = kv_shape . transpose ( perm . constSlice ( ) ) ;
return KvCache . init ( kv_shape ) ;
}
} ;
pub const TransformerLayer = struct {
input_layernorm : RmsNorm ,
self_attn : SelfAttn ,
post_attention_layernorm : RmsNorm ,
mlp : Mlp ,
pub fn forward (
self : TransformerLayer ,
x0 : Tensor ,
2024-03-04 12:11:13 +00:00
token_index : Tensor ,
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
) struct { Tensor , KvCache } {
// Self Attention
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
2023-06-27 14:23:22 +00:00
stdx . debug . assert ( x0 . rank ( ) > = 2 and x0 . shape ( ) . hasTags ( . { . s , . d } ) , " TransformerLayer expected input shape: {{..., .s, .d}}, received: {} " , . { x0 } ) ;
2023-01-03 10:21:07 +00:00
const x0_normalized = zml . call ( self . input_layernorm , . forward , . { x0 } ) ;
const delta0 , const updated_kv_cache = zml . call ( self . self_attn , . forward , . { x0_normalized , token_index , kv_cache } ) ;
const x1 = x0 . add ( delta0 ) ;
// Fully Connected
const x1_normalized = zml . call ( self . post_attention_layernorm , . forward , . { x1 } ) ;
const x2 = zml . call ( self . mlp , . forward , . { x1_normalized } ) . add ( x1 ) ;
return . { x2 . reuseBuffer ( x0 ) , updated_kv_cache } ;
}
} ;
const RmsNorm = struct {
weight : Tensor ,
eps : f32 = 1e-5 ,
/// L2 normalization of input tensor along `.d` axis.
pub fn forward ( self : RmsNorm , input : Tensor ) Tensor {
const x = if ( input . shape ( ) . isFullyTagged ( ) ) input else input . withPartialTags ( . { . d } ) ;
// upcast to improve precision
const xf32 = x . convert ( . f32 ) ;
const mean = xf32 . mul ( xf32 ) . mean ( . d ) ;
const rsqrt = Tensor . rsqrt ( mean . addConstant ( self . eps ) ) . convert ( x . dtype ( ) ) ;
const normalized = x . mul ( rsqrt . broad ( x . shape ( ) ) ) ;
return normalized . mul ( self . weight . convert ( x . dtype ( ) ) . withTags ( . { . d } ) . broad ( x . shape ( ) ) ) ;
}
} ;
const Mlp = struct {
up_proj : zml . nn . Linear , // (dim -> hidden_dim)
gate_proj : zml . nn . Linear , // (dim -> hidden_dim)
down_proj : zml . nn . Linear , // (hidden_dim -> dim)
pub fn forward ( self : Mlp , x : Tensor ) Tensor {
const proj = zml . call ( self . up_proj , . forward , . { x } ) ;
var output = zml . call ( self . gate_proj , . forward , . { x } ) ;
output = output . silu ( ) . mul ( proj ) ;
return zml . call ( self . down_proj , . forward , . { output } ) ;
}
} ;
pub const SelfAttn = struct {
q_proj : zml . nn . Linear ,
k_proj : zml . nn . Linear ,
v_proj : zml . nn . Linear ,
o_proj : zml . nn . Linear ,
num_heads : i64 = undefined ,
num_kv_heads : i64 = 0 ,
rope_opts : zml . nn . RopeOpts = undefined ,
/// Self Attention.
/// - If token_index is set, x is assumed to be the representation of one new token,
/// and kv_cache will be read for the previous tokens.
/// - If token_index is not set, x is assumed to be the representation of all tokens
/// since the beginning of the sequence, and kv_cache won't be read.
/// In both case, kv_cache will be updated with the computed key and value.
/// x: {.b, .s, .d } -> .{.b, .s, .d}
pub fn forward (
self : SelfAttn ,
x : Tensor ,
2024-03-04 12:11:13 +00:00
token_index : Tensor ,
kv_cache : KvCache ,
2023-01-03 10:21:07 +00:00
) struct { Tensor , KvCache } {
const num_kv_heads = if ( self . num_kv_heads > 0 ) self . num_kv_heads else self . num_heads ;
2023-04-13 12:35:27 +00:00
var q = zml . call ( self . q_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = self . num_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
var k = zml . call ( self . k_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = num_kv_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
var v = zml . call ( self . v_proj , . forward , . { x } ) . splitAxis ( - 1 , . { . h = num_kv_heads , . hd = . auto } ) . withSharding ( . { . h } ) ;
2024-03-04 12:11:13 +00:00
2023-01-03 10:21:07 +00:00
// Generate the attention mask.
const seq_len = kv_cache . k . dim ( . k ) ;
var attn_mask = zml . nn . causalAttnMask ( . { . q = seq_len , . k = seq_len } , x . dtype ( ) , null ) ;
2024-03-04 12:11:13 +00:00
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
// then slice into it, but XLA is able to optimize this correctly.
2024-03-20 13:37:19 +00:00
attn_mask = attn_mask . gatherSlices ( zml . Shape . init ( . { . q = x . dim ( . s ) } , attn_mask . dtype ( ) ) , token_index . reshape ( . { . coord = 1 } ) , . { } ) ;
2023-01-03 10:21:07 +00:00
// In self-attention, .s axis is used both for keys and queries.
2024-03-04 12:11:13 +00:00
const pos_index = b : {
2024-03-20 13:37:19 +00:00
const temp = Tensor . arange ( . { . end = x . dim ( . s ) } , token_index . dtype ( ) ) . withTags ( . { . s } ) . broad ( zml . Shape . init ( . { . s = x . dim ( . s ) } , token_index . dtype ( ) ) ) ;
break : b temp . add ( token_index . broad ( temp . shape ( ) ) ) ;
2024-03-04 12:11:13 +00:00
} ;
q = zml . nn . rope ( q , pos_index , self . rope_opts ) ;
k = zml . nn . rope ( k , pos_index , self . rope_opts ) ;
2023-01-03 10:21:07 +00:00
q = q . rename ( . { . s = . q } ) ;
k = k . rename ( . { . s = . k } ) ;
v = v . rename ( . { . s = . k } ) ;
2024-03-04 12:11:13 +00:00
const dtype = q . dtype ( ) ;
const new_kv_cache = kv_cache . update ( k , v , token_index ) ;
k = new_kv_cache . keys ( ) . convert ( dtype ) ;
v = new_kv_cache . values ( ) . convert ( dtype ) ;
2023-01-03 10:21:07 +00:00
2024-03-04 12:11:13 +00:00
const attn_output = zml . nn . sdpa ( q , k , v , . { . attn_mask = attn_mask , . allow_cudnn = true } ) ;
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
2023-01-03 10:21:07 +00:00
const attn = attn_output . merge ( . { . d = . { . h , . hd } } ) . rename ( . { . q = . s } ) ;
return . { zml . call ( self . o_proj , . forward , . { attn } ) , new_kv_cache } ;
}
fn initKvCache ( key_shape : zml . Shape ) KvCache {
// When we call initKvCache, we haven't renamed .s to .k yet.
var kv_shape = key_shape . insert ( 0 , . { . layer = 1 } ) . rename ( . { . s = . k } ) ;
const perm = kv_shape . contiguousPerm ( . { . h , . k , . hd } ) ;
kv_shape = kv_shape . transpose ( perm . constSlice ( ) ) ;
var res = KvCache . init ( kv_shape ) ;
2024-03-04 12:11:13 +00:00
res . layer_index = Tensor . scalar ( 0 , . u32 ) ;
2023-01-03 10:21:07 +00:00
return res ;
}
} ;
pub const KvCache = struct {
k : Tensor ,
v : Tensor ,
layer_index : Tensor ,
pub fn init ( kv_shape : zml . Shape ) KvCache {
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
return . {
2023-04-13 12:35:27 +00:00
. k = Tensor . constant ( kv_shape , kv_shape . dtype ( ) . one ( ) ) . withSharding ( . { . h } ) ,
. v = Tensor . constant ( kv_shape , kv_shape . dtype ( ) . one ( ) ) . withSharding ( . { . h } ) ,
2024-03-04 12:11:13 +00:00
. layer_index = Tensor . scalar ( - 1 , . u32 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn initShape ( kv_shape : zml . Shape ) ShapeOf ( KvCache ) {
return . {
. k = kv_shape ,
. v = kv_shape ,
2024-03-04 12:11:13 +00:00
. layer_index = zml . Shape . init ( . { } , . u32 ) ,
} ;
}
pub fn initBuffer ( kv_shape : zml . Shape , platform : zml . Platform ) ! zml . Bufferized ( KvCache ) {
return . {
. k = try zml . Buffer . constant ( platform , kv_shape , 1 ) ,
. v = try zml . Buffer . constant ( platform , kv_shape , 1 ) ,
. layer_index = try zml . Buffer . constant ( platform , zml . Shape . init ( . { } , . u32 ) , 0 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn keys ( self : KvCache ) Tensor {
return self . k . dynamicSlice ( . { . layer = . { . start = self . layer_index , . len = 1 } } ) . squeeze ( . layer ) ;
}
pub fn values ( self : KvCache ) Tensor {
return self . v . dynamicSlice ( . { . layer = . { . start = self . layer_index , . len = 1 } } ) . squeeze ( . layer ) ;
}
2024-03-04 12:11:13 +00:00
pub fn update ( self : KvCache , new_k : Tensor , new_v : Tensor , token_index : ? Tensor ) KvCache {
const k_shape = self . k . shape ( ) . drop ( . layer ) ;
var layer = self . layer_index ;
layer = if ( token_index ) | idx | layer . broad ( idx . shape ( ) ) else layer ;
return if ( token_index ) | idx | . {
. k = self . k . scatterSlices (
. { . layer = layer , . k = idx } ,
new_k . convert ( self . k . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
) . reuseBuffer ( self . k ) ,
. v = self . v . scatterSlices (
. { . layer = layer , . k = idx } ,
new_v . convert ( self . v . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
) . reuseBuffer ( self . v ) ,
. layer_index = self . layer_index ,
} else . {
. k = self . k . scatterSlices (
. { . layer = layer } ,
new_k . convert ( self . k . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
2023-01-03 10:21:07 +00:00
) . reuseBuffer ( self . k ) ,
2024-03-04 12:11:13 +00:00
. v = self . v . scatterSlices (
. { . layer = layer } ,
new_v . convert ( self . v . dtype ( ) ) . transpose ( k_shape ) ,
. { . indices_are_sorted = true , . update_fn = zml . Tensor . ScatterOpts . override } ,
2023-01-03 10:21:07 +00:00
) . reuseBuffer ( self . v ) ,
. layer_index = self . layer_index ,
} ;
}
pub fn atLayer ( self : KvCache , layer_index : usize ) KvCache {
return . {
. k = self . k ,
. v = self . v ,
2024-03-04 12:11:13 +00:00
. layer_index = Tensor . scalar ( layer_index , . u32 ) ,
2023-01-03 10:21:07 +00:00
} ;
}
pub fn reuseBuffer ( self : KvCache , other : KvCache ) KvCache {
return . {
. k = self . k . reuseBuffer ( other . k ) ,
. v = self . v . reuseBuffer ( other . v ) ,
. layer_index = self . layer_index . reuseBuffer ( other . layer_index ) ,
} ;
}
} ;