406 lines
16 KiB
Zig
406 lines
16 KiB
Zig
const std = @import("std");
|
|
const testing = std.testing;
|
|
|
|
const stdx = @import("stdx");
|
|
const zml = @import("zml");
|
|
const Buffer = zml.Buffer;
|
|
const Tensor = zml.Tensor;
|
|
const ShapeOf = zml.ShapeOf;
|
|
|
|
const log = std.log.scoped(.llama);
|
|
|
|
/// Llama architecture, using huggingface transformers naming.
|
|
/// Dimensions of activations: {.b, .s, .d}
|
|
pub const LlamaLM = struct {
|
|
pub const Config = struct {
|
|
bos_token_id: u32,
|
|
eos_token_id: stdx.json.Union(union(enum) {
|
|
int: u32,
|
|
ints: []u32,
|
|
}),
|
|
head_dim: ?u32 = null,
|
|
hidden_size: u32,
|
|
num_hidden_layers: u32,
|
|
num_attention_heads: u32,
|
|
num_key_value_heads: u32,
|
|
rope_theta: f32,
|
|
max_position_embeddings: u32,
|
|
rms_norm_eps: f32,
|
|
hf_rope_impl: bool = true,
|
|
tie_word_embeddings: bool = false,
|
|
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
|
|
};
|
|
|
|
pub const Options = struct {
|
|
sampling_strategy: ?zml.nn.SamplingStrategy,
|
|
max_seq_len: u32,
|
|
};
|
|
|
|
lm_head: ?zml.nn.Linear,
|
|
model: Llama,
|
|
|
|
// Options controlling generation
|
|
gen_opts: zml.nn.SamplingStrategy = .{},
|
|
config: Config,
|
|
|
|
pub fn init(allocator: std.mem.Allocator, config: Config, options: Options, store: zml.aio.BufferStore) !LlamaLM {
|
|
const rope_opts: zml.nn.RopeOpts = .{
|
|
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
|
.freq_base = config.rope_theta,
|
|
.scaling = config.rope_scaling,
|
|
};
|
|
|
|
const layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers);
|
|
var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024);
|
|
try prefix.push(stdx.noalloc, "model.layers");
|
|
for (0.., layers) |i, *layer| {
|
|
try prefix.pushDigit(stdx.noalloc, i);
|
|
defer prefix.pop();
|
|
var self_attn = try zml.aio.populateModelWithPrefix(SelfAttn, allocator, store, prefix.concat("self_attn"));
|
|
self_attn.num_heads = config.num_attention_heads;
|
|
self_attn.num_kv_heads = config.num_key_value_heads;
|
|
self_attn.rope_opts = rope_opts;
|
|
self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0});
|
|
self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0});
|
|
self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0});
|
|
self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1});
|
|
|
|
var input_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("input_layernorm"));
|
|
input_layernorm.eps = config.rms_norm_eps;
|
|
|
|
var post_attention_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("post_attention_layernorm"));
|
|
post_attention_layernorm.eps = config.rms_norm_eps;
|
|
|
|
var mlp = try zml.aio.populateModelWithPrefix(Mlp, allocator, store, prefix.concat("mlp"));
|
|
mlp.up_proj.weight = mlp.up_proj.weight.withSharding(.{0});
|
|
mlp.gate_proj.weight = mlp.gate_proj.weight.withSharding(.{0});
|
|
mlp.down_proj.weight = mlp.down_proj.weight.withSharding(.{1});
|
|
|
|
layer.* = .{
|
|
.self_attn = self_attn,
|
|
.input_layernorm = input_layernorm,
|
|
.post_attention_layernorm = post_attention_layernorm,
|
|
.mlp = mlp,
|
|
};
|
|
}
|
|
|
|
var lm_head: ?zml.nn.Linear = null;
|
|
if (!config.tie_word_embeddings) {
|
|
lm_head = .{ .weight = store.getTensor("lm_head.weight") };
|
|
if (options.sampling_strategy) |gen_opts| {
|
|
if (gen_opts.topk == 1)
|
|
lm_head.?.weight = lm_head.?.weight.withSharding(.{0});
|
|
}
|
|
}
|
|
|
|
return .{
|
|
.config = config,
|
|
.gen_opts = options.sampling_strategy orelse .{},
|
|
.model = .{
|
|
// Weights
|
|
.layers = layers,
|
|
.embed_tokens = .{ .weight = store.getTensor("model.embed_tokens.weight") },
|
|
.norm = .{
|
|
.weight = store.getTensor("model.norm.weight"),
|
|
.eps = config.rms_norm_eps,
|
|
},
|
|
// Push down some configs
|
|
.max_seq_len = options.max_seq_len,
|
|
.num_heads = config.num_attention_heads,
|
|
.num_kv_heads = config.num_key_value_heads,
|
|
.rope_opts = .{
|
|
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
|
.freq_base = config.rope_theta,
|
|
.scaling = config.rope_scaling,
|
|
},
|
|
},
|
|
.lm_head = lm_head,
|
|
};
|
|
}
|
|
|
|
/// Predicts the token at `token_index` position.
|
|
/// Returns:
|
|
/// - updated `tokens`,
|
|
/// - updated KV cache
|
|
/// - a Rng state to allow for probabilistic generation
|
|
pub fn forward(
|
|
self: LlamaLM,
|
|
tokens_: Tensor,
|
|
token_index: Tensor,
|
|
kv_cache: KvCache,
|
|
rng: Tensor.Rng,
|
|
) struct { Tensor, KvCache, Tensor.Rng } {
|
|
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {f} and {f}", .{ tokens_, token_index });
|
|
const tokens = tokens_.withPartialTags(.{.s});
|
|
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
|
const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts);
|
|
return .{ new_tokens.convert(tokens.dtype()).reuseBuffer(tokens), updated_kv_cache, new_rng };
|
|
}
|
|
|
|
pub fn sampleTokens(
|
|
self: LlamaLM,
|
|
lm_head_: ?zml.nn.Linear,
|
|
out_: Tensor,
|
|
rng: Tensor.Rng,
|
|
opts: zml.nn.SamplingStrategy,
|
|
) struct { Tensor, Tensor.Rng } {
|
|
const out = out_.withPartialTags(.{ .s, .d });
|
|
|
|
var logits = blk: {
|
|
if (lm_head_) |lm_head| {
|
|
break :blk zml.call(lm_head, .forward, .{out});
|
|
} else {
|
|
break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
|
|
}
|
|
};
|
|
|
|
if (logits.shape().hasTag(.voc) == null)
|
|
logits = logits.rename(.{ .d = .voc });
|
|
|
|
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
|
return .{ next_tokens, new_rng };
|
|
}
|
|
|
|
pub fn increment(_: u8, token_index: Tensor) Tensor {
|
|
return token_index.addConstant(1).reuseBuffer(token_index);
|
|
}
|
|
};
|
|
|
|
pub const Llama = struct {
|
|
embed_tokens: zml.nn.TokenEmbedding,
|
|
norm: RmsNorm,
|
|
layers: []TransformerLayer,
|
|
|
|
max_seq_len: u32 = 0,
|
|
num_heads: u32 = 32,
|
|
num_kv_heads: u32 = 32,
|
|
rope_opts: zml.nn.RopeOpts = .{
|
|
.layout = .interleaved,
|
|
.freq_base = 10_000,
|
|
},
|
|
|
|
/// Forward one token, using KV cache for previous tokens.
|
|
/// Returns result and updated KV cache.
|
|
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
|
|
const embeds = embed(self.embed_tokens, tokens);
|
|
var hidden = embeds;
|
|
|
|
var updated_kv_cache = kv_cache;
|
|
for (self.layers, 0..) |layer, i| {
|
|
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
|
}
|
|
const output = zml.call(self.norm, .forward, .{hidden});
|
|
|
|
return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
|
|
}
|
|
|
|
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
|
|
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
|
}
|
|
};
|
|
|
|
pub const TransformerLayer = struct {
|
|
input_layernorm: RmsNorm,
|
|
self_attn: SelfAttn,
|
|
post_attention_layernorm: RmsNorm,
|
|
mlp: Mlp,
|
|
|
|
pub fn forward(
|
|
self: TransformerLayer,
|
|
x0: Tensor,
|
|
token_index: Tensor,
|
|
kv_cache: KvCache,
|
|
) struct { Tensor, KvCache } {
|
|
// Self Attention
|
|
//log.debug("TransformerLayer({f}) -> {f}", .{ x0, self.input_layernorm.forward(x0) });
|
|
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0});
|
|
|
|
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
|
|
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
|
|
const x1 = x0.add(delta0);
|
|
|
|
// Fully Connected
|
|
const x1_normalized = zml.call(self.post_attention_layernorm, .forward, .{x1});
|
|
const x2 = zml.call(self.mlp, .forward, .{x1_normalized}).add(x1);
|
|
|
|
return .{ x2.reuseBuffer(x0), updated_kv_cache };
|
|
}
|
|
};
|
|
|
|
const RmsNorm = struct {
|
|
weight: Tensor,
|
|
eps: f32 = 1e-5,
|
|
|
|
/// L2 normalization of input tensor along `.d` axis.
|
|
pub fn forward(self: RmsNorm, input: Tensor) Tensor {
|
|
const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d});
|
|
const normalized = zml.nn.rmsNorm(x, .d, self.eps);
|
|
return normalized.mul(self.weight.convert(x.dtype()).withTags(.{.d}).broad(x.shape()));
|
|
}
|
|
};
|
|
|
|
const Mlp = struct {
|
|
up_proj: zml.nn.Linear, // (dim -> hidden_dim)
|
|
gate_proj: zml.nn.Linear, // (dim -> hidden_dim)
|
|
down_proj: zml.nn.Linear, // (hidden_dim -> dim)
|
|
|
|
pub fn forward(self: Mlp, x: Tensor) Tensor {
|
|
const proj = zml.call(self.up_proj, .forward, .{x});
|
|
var output = zml.call(self.gate_proj, .forward, .{x});
|
|
output = output.silu().mul(proj);
|
|
return zml.call(self.down_proj, .forward, .{output});
|
|
}
|
|
};
|
|
|
|
pub const SelfAttn = struct {
|
|
q_proj: zml.nn.Linear,
|
|
k_proj: zml.nn.Linear,
|
|
v_proj: zml.nn.Linear,
|
|
|
|
q_norm: ?RmsNorm,
|
|
k_norm: ?RmsNorm,
|
|
|
|
o_proj: zml.nn.Linear,
|
|
num_heads: i64 = undefined,
|
|
num_kv_heads: i64 = 0,
|
|
rope_opts: zml.nn.RopeOpts = undefined,
|
|
|
|
/// Self Attention.
|
|
/// - If token_index is set, x is assumed to be the representation of one new token,
|
|
/// and kv_cache will be read for the previous tokens.
|
|
/// - If token_index is not set, x is assumed to be the representation of all tokens
|
|
/// since the beginning of the sequence, and kv_cache won't be read.
|
|
/// In both case, kv_cache will be updated with the computed key and value.
|
|
/// x: {.b, .s, .d } -> .{.b, .s, .d}
|
|
pub fn forward(
|
|
self: SelfAttn,
|
|
x: Tensor,
|
|
token_index: Tensor,
|
|
kv_cache: KvCache,
|
|
) struct { Tensor, KvCache } {
|
|
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
|
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
|
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
|
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
|
|
|
// Generate the attention mask.
|
|
const seq_len = kv_cache.k.dim(.k);
|
|
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null);
|
|
|
|
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
|
// then slice into it, but XLA is able to optimize this correctly.
|
|
attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{});
|
|
|
|
// In self-attention, .s axis is used both for keys and queries.
|
|
const pos_index = b: {
|
|
const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype()));
|
|
break :b temp.add(token_index.broad(temp.shape()));
|
|
};
|
|
|
|
if (self.q_norm) |norm| q = norm.forward(q.rename(.{ .hd = .d })).rename(.{ .d = .hd });
|
|
if (self.k_norm) |norm| k = norm.forward(k.rename(.{ .hd = .d })).rename(.{ .d = .hd });
|
|
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
|
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
|
q = q.rename(.{ .s = .q });
|
|
k = k.rename(.{ .s = .k });
|
|
v = v.rename(.{ .s = .k });
|
|
|
|
const dtype = q.dtype();
|
|
const new_kv_cache = kv_cache.update(k, v, token_index);
|
|
k = new_kv_cache.keys().convert(dtype);
|
|
v = new_kv_cache.values().convert(dtype);
|
|
|
|
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .allow_cudnn = true });
|
|
// const attn_output = zml.nn.sdpaMemEfficient(q, k, v, .{ .attn_mask = attn_mask }, .{ .q_chunk_size = 4096, .k_chunk_size = 1024 });
|
|
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
|
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
|
}
|
|
};
|
|
|
|
pub const KvCache = struct {
|
|
k: Tensor,
|
|
v: Tensor,
|
|
layer_index: Tensor,
|
|
|
|
pub fn init(kv_shape: zml.Shape) KvCache {
|
|
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
|
|
return .{
|
|
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
|
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
|
.layer_index = Tensor.scalar(-1, .u32),
|
|
};
|
|
}
|
|
|
|
pub fn initShape(kv_shape: zml.Shape) ShapeOf(KvCache) {
|
|
return .{
|
|
.k = kv_shape,
|
|
.v = kv_shape,
|
|
.layer_index = zml.Shape.init(.{}, .u32),
|
|
};
|
|
}
|
|
|
|
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
|
return .{
|
|
.k = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
|
.v = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
|
.layer_index = try zml.Buffer.scalar(platform, 0, .u32),
|
|
};
|
|
}
|
|
|
|
pub fn keys(self: KvCache) Tensor {
|
|
return self.k.dynamicSlice(.{ .layer = Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
|
}
|
|
|
|
pub fn values(self: KvCache) Tensor {
|
|
return self.v.dynamicSlice(.{ .layer = Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
|
}
|
|
|
|
pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: ?Tensor) KvCache {
|
|
const k_shape = self.k.shape().drop(.layer);
|
|
var layer = self.layer_index;
|
|
layer = if (token_index) |idx| layer.broad(idx.shape()) else layer;
|
|
|
|
return if (token_index) |idx| .{
|
|
.k = self.k.scatterSlices(
|
|
.{ .layer = layer, .k = idx },
|
|
new_k.convert(self.k.dtype()).transpose(k_shape),
|
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
|
).reuseBuffer(self.k),
|
|
.v = self.v.scatterSlices(
|
|
.{ .layer = layer, .k = idx },
|
|
new_v.convert(self.v.dtype()).transpose(k_shape),
|
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
|
).reuseBuffer(self.v),
|
|
.layer_index = self.layer_index,
|
|
} else .{
|
|
.k = self.k.scatterSlices(
|
|
.{ .layer = layer },
|
|
new_k.convert(self.k.dtype()).transpose(k_shape),
|
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
|
).reuseBuffer(self.k),
|
|
.v = self.v.scatterSlices(
|
|
.{ .layer = layer },
|
|
new_v.convert(self.v.dtype()).transpose(k_shape),
|
|
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
|
).reuseBuffer(self.v),
|
|
.layer_index = self.layer_index,
|
|
};
|
|
}
|
|
|
|
pub fn atLayer(self: KvCache, layer_index: usize) KvCache {
|
|
return .{
|
|
.k = self.k,
|
|
.v = self.v,
|
|
.layer_index = Tensor.scalar(layer_index, .u32),
|
|
};
|
|
}
|
|
|
|
pub fn reuseBuffer(self: KvCache, other: KvCache) KvCache {
|
|
return .{
|
|
.k = self.k.reuseBuffer(other.k),
|
|
.v = self.v.reuseBuffer(other.v),
|
|
.layer_index = self.layer_index.reuseBuffer(other.layer_index),
|
|
};
|
|
}
|
|
};
|