///! GptOss architecture, using huggingface transformers naming. ///! Dimensions of activations: {.b, .s, .d} const std = @import("std"); const stdx = @import("stdx"); const zml = @import("zml"); const GptOss = @This(); const log = std.log.scoped(.GptOss); pub const Config = struct { bos_token_id: u32 = 199998, eos_token_id: stdx.json.Union(union(enum) { int: u32, ints: []const u32, }), head_dim: u32, num_hidden_layers: u32, num_attention_heads: u32, num_key_value_heads: u32, experts_per_token: u32, rope_theta: f32, max_position_embeddings: u32, rms_norm_eps: f32, sliding_window: u32, hf_rope_impl: bool = true, rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} }, }; pub const Options = struct { sampling_strategy: zml.nn.SamplingStrategy, max_seq_len: u32, max_prompt_len: u32, tokens_per_expert_ratio: f32, }; pub const Mode = union(enum) { /// In prefill mode we pass the actual len of the prompt prefill: zml.Tensor, /// In gen mode we pass the position of the next token gen: zml.Tensor, }; lm_head: ?zml.nn.Linear, model: Model, config: Config, options: Options, pub fn init(allocator: std.mem.Allocator, store: zml.aio.BufferStore, config: Config, options: Options) !GptOss { var self: GptOss = .{ .config = config, .options = options, .model = .{ .max_seq_len = @intCast(options.max_seq_len), .num_heads = @intCast(config.num_attention_heads), .num_kv_heads = @intCast(config.num_key_value_heads), .rope_opts = .{ .layout = if (config.hf_rope_impl) .sequential else .interleaved, .freq_base = config.rope_theta, .scaling = config.rope_scaling, }, .embed_tokens = .{ .weight = store.getTensor("model.embed_tokens.weight").withSharding(.{1}), }, .layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers), .norm = .{ .weight = store.getTensor("model.norm.weight"), .eps = config.rms_norm_eps, }, }, .lm_head = .{ .weight = store.getTensor("lm_head.weight").withSharding(.{0}) }, }; var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 1024); try prefix.push(stdx.noalloc, "model.layers"); for (self.model.layers, 0..) |*layer, i| { try prefix.pushDigit(stdx.noalloc, i); defer prefix.pop(); var self_attn: SelfAttn = .{ .sinks = store.getTensor(prefix.concat("self_attn.sinks")), .q_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.q_proj")), .k_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.k_proj")), .v_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.v_proj")), .o_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.o_proj")), .sliding_window = if (i % 2 == 0) config.sliding_window else null, .num_heads = self.model.num_heads, .num_kv_heads = self.model.num_kv_heads, .rope_opts = self.model.rope_opts, }; self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0}); self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0}); self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0}); self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1}); const on_disk_moe = try zml.aio.populateModelWithPrefix(MoE.OnDisk, allocator, store, prefix.concat("mlp")); var moe = on_disk_moe.rewrite(config.experts_per_token, options); { moe.experts.gate_up_proj.blocks = moe.experts.gate_up_proj.blocks.withSharding(.{.expert}); moe.experts.down_proj.blocks = moe.experts.down_proj.blocks.withSharding(.{.expert}); } layer.* = .{ .input_layernorm = .{ .weight = store.getTensor(prefix.concat("input_layernorm.weight")), .eps = config.rms_norm_eps, }, .post_attention_layernorm = .{ .weight = store.getTensor(prefix.concat("post_attention_layernorm.weight")), .eps = config.rms_norm_eps, }, .self_attn = self_attn, .mlp = moe, }; } // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. // It currently crashes/compilation fails if (self.options.sampling_strategy.topk == 1 and self.lm_head != null) { self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0}); } return self; } /// Predicts the token at `token_index` position. /// Returns: /// - updated `tokens`, /// - updated KV cache /// - a Rng state to allow for probabilistic generation pub fn forward( self: GptOss, tokens_: zml.Tensor, mode: Mode, kv_cache: KvCache, rng: zml.Tensor.Rng, ) struct { zml.Tensor, KvCache, zml.Tensor.Rng } { const tokens = tokens_.withPartialTags(.{.s}); // token index is the position in the kv cache where to write results. const token_index: zml.Tensor = switch (mode) { .gen => |token_index| token_index, .prefill => .scalar(0, .u32), }; var out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache }); switch (mode) { // In prefill we only pass the last token to the lm head. .prefill => |prompt_len| out = out.gather(.{ .s = prompt_len.convert(.i32).addConstant(-1) }, .{ .indices_are_sorted = true }), .gen => {}, } var new_token, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.options.sampling_strategy); new_token = new_token.convert(.u32); new_token = switch (mode) { .gen => new_token.reuseBuffer(tokens), .prefill => new_token.appendAxes(.{.s}), }; return .{ new_token, updated_kv_cache, new_rng }; } fn sampleTokens( self: GptOss, lm_head_: ?zml.nn.Linear, out_: zml.Tensor, rng: zml.Tensor.Rng, opts: zml.nn.SamplingStrategy, ) struct { zml.Tensor, zml.Tensor.Rng } { const out = out_.withPartialTags(.{.d}); var logits = blk: { if (lm_head_) |lm_head| { break :blk zml.call(lm_head, .forward, .{out}); } else { break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d}); } }; if (logits.shape().hasTag(.voc) == null) logits = logits.rename(.{ .d = .voc }); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng); return .{ next_tokens, new_rng }; } pub fn loadBuffers(self: GptOss, allocator: std.mem.Allocator, store: zml.aio.BufferStore, platform: zml.Platform) !zml.Bufferized(GptOss) { var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 256); defer prefix.deinit(allocator); const noalloc = stdx.noalloc; const loaded: zml.Bufferized(GptOss) = .{ .model = .{ .embed_tokens = try store.loadModelById(zml.nn.TokenEmbedding, noalloc, self.model.embed_tokens, platform), .layers = try allocator.alloc(zml.Bufferized(TransformerLayer), self.model.layers.len), .norm = try store.loadModelById(RmsNorm, noalloc, self.model.norm, platform), }, .lm_head = try store.loadModelById(?zml.nn.Linear, noalloc, self.lm_head, platform), }; prefix.push(noalloc, "model.layers") catch unreachable; for (loaded.model.layers, self.model.layers, 0..) |*d_layer, layer, layer_id| { const ckpt = prefix.checkpoint(); defer prefix.restore(ckpt); prefix.pushDigit(noalloc, layer_id) catch unreachable; d_layer.* = .{ .input_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.input_layernorm, platform), .self_attn = try store.loadModelById(SelfAttn, noalloc, layer.self_attn, platform), .post_attention_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.post_attention_layernorm, platform), .mlp = try store.loadModelById(MoE, noalloc, layer.mlp, platform), }; } return loaded; } pub const Model = struct { embed_tokens: zml.nn.TokenEmbedding, norm: RmsNorm, layers: []TransformerLayer, max_seq_len: u32 = 0, num_heads: i64 = 32, num_kv_heads: i64 = 32, rope_opts: zml.nn.RopeOpts = .{ .layout = .interleaved, .freq_base = 10_000, }, /// Forward one token, using KV cache for previous tokens. /// Returns result and updated KV cache. pub fn forward(self: Model, tokens: zml.Tensor, token_index: zml.Tensor, kv_cache: KvCache) struct { zml.Tensor, KvCache } { const embeds = embed(self.embed_tokens, tokens); var hidden = embeds; var updated_kv_cache = kv_cache; for (self.layers, 0..) |layer, i| { hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) }); } const output = zml.call(self.norm, .forward, .{hidden}); return .{ output, updated_kv_cache.reuseBuffer(kv_cache) }; } pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: zml.Tensor) zml.Tensor { return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d}); } }; pub const TransformerLayer = struct { input_layernorm: RmsNorm, self_attn: SelfAttn, post_attention_layernorm: RmsNorm, mlp: MoE, pub fn forward( self: TransformerLayer, x0: zml.Tensor, token_index: zml.Tensor, kv_cache: KvCache, ) struct { zml.Tensor, KvCache } { // Self Attention //log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) }); stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0}); const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0}); const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache }); const x1 = x0.add(delta0); // Fully Connected const x1_normalized = zml.call(self.post_attention_layernorm, .forward, .{x1}); const x2 = zml.call(self.mlp, .forward, .{x1_normalized}).add(x1); return .{ x2.reuseBuffer(x0), updated_kv_cache }; } }; const RmsNorm = struct { weight: zml.Tensor, eps: f32 = 1e-6, /// L2 normalization of input tensor along `.d` axis. pub fn forward(self: RmsNorm, input: zml.Tensor) zml.Tensor { const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d}); // Note: contrary to Llama here the full layer is done in .f32, not just the variance computation. const normalized = zml.nn.rmsNorm(x.convert(.f32), .d, self.eps); return normalized.mul(self.weight.convert(.f32).withTags(.{.d}).broad(x.shape())).convert(input.dtype()); } }; const MoE = struct { experts: Mlp, router: zml.nn.Linear, moe_opts: MoeOpts, pub fn forward(self: MoE, input: zml.Tensor) zml.Tensor { log.warn("compiling moe with {f}", .{input}); // Note: GptOss applies softmax on the routing score. // We delay the softmax to mixtureOfExperts where the actual routing is done. // This allow to do re-routing without introducing nans. const gating = self.router.forward(input); return mixtureOfExperts(Mlp, self.experts, input, gating, self.moe_opts); } pub const OnDisk = struct { router: zml.nn.Linear, experts: struct { down_proj_bias: zml.Tensor, down_proj_blocks: zml.Tensor, down_proj_scales: zml.Tensor, gate_up_proj_bias: zml.Tensor, gate_up_proj_blocks: zml.Tensor, gate_up_proj_scales: zml.Tensor, }, pub fn rewrite(on_disk: OnDisk, experts_per_token: u32, options: Options) MoE { const e = on_disk.experts; return .{ .experts = .{ .gate_up_proj = .{ // We need to bitcast the scale cause safetensors doesn't encode f8 types correctly .scale = e.gate_up_proj_scales.withTags(.{ .expert, .out, .d }), // We don't bitcast here because PJRT doesn't handle packed host buffers .blocks = e.gate_up_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }), .blocks_dtype = .f4e2m1, .bias = e.gate_up_proj_bias.withTags(.{ .expert, .d }), }, .down_proj = .{ .blocks = e.down_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }), .blocks_dtype = .f4e2m1, .scale = e.down_proj_scales.withTags(.{ .expert, .out, .d }), .bias = e.down_proj_bias.withTags(.{ .expert, .d }), }, }, .router = .{ .weight = on_disk.router.weight.withTags(.{ .expert, .d }), .bias = on_disk.router.bias.?.withTags(.{.expert}), }, .moe_opts = .{ .experts_per_token = experts_per_token, .tokens_per_expert_ratio = options.tokens_per_expert_ratio, .normalization = .softmax, }, }; } }; }; pub const Mlp = struct { gate_up_proj: BlockScaledLinear, // {.out = intermediate_size * 2, .d = hidden_size / block_size, .d_block = block_size } down_proj: BlockScaledLinear, // {.out = hidden_size * 2, .d = intermediate_size / block_size, .d_block = block_size } pub fn forward(self: Mlp, x: zml.Tensor) zml.Tensor { const dt = x.dtype(); var gate, var up = zml.nn.splitRealImg(self.gate_up_proj.forward(x), .interleaved); gate = .minimum(gate, .scalar(7, dt)); up = .clamp(up, .scalar(-7, dt), .scalar(7, dt)); const out = gate.quickGelu().mul(up.addConstant(1)); return zml.call(self.down_proj, .forward, .{out}); } pub fn format(self: Mlp, writer: *std.Io.Writer) std.Io.Writer.Error!void { try writer.print("Mlp(gate_up_proj=.{f}, down_proj=.{f})", .{ self.gate_up_proj, self.down_proj }); } }; pub const SelfAttn = struct { q_proj: zml.nn.Linear, k_proj: zml.nn.Linear, v_proj: zml.nn.Linear, sinks: zml.Tensor, o_proj: zml.nn.Linear, sliding_window: ?u32, num_heads: i64, num_kv_heads: i64, rope_opts: zml.nn.RopeOpts, /// Self Attention. /// - If token_index is set, x is assumed to be the representation of one new token, /// and kv_cache will be read for the previous tokens. /// - If token_index is not set, x is assumed to be the representation of all tokens /// since the beginning of the sequence, and kv_cache won't be read. /// In both case, kv_cache will be updated with the computed key and value. /// x: {.b, .s, .d } -> .{.b, .s, .d} pub fn forward( self: SelfAttn, x: zml.Tensor, token_index: zml.Tensor, kv_cache: KvCache, ) struct { zml.Tensor, KvCache } { const num_kv_heads = self.num_kv_heads; var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h}); var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); // Generate the attention mask. const seq_len = kv_cache.k.dim(.k); var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), self.sliding_window); // Note: in Pytorch it would be very inefficient to generate the full attn_mask, // then slice into it, but XLA is able to optimize this correctly. attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{}); // In self-attention, .s axis is used both for keys and queries. const pos_index = b: { const temp = zml.Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype())); break :b temp.add(token_index.broad(temp.shape())); }; q = zml.nn.rope(q, pos_index, self.rope_opts); k = zml.nn.rope(k, pos_index, self.rope_opts); q = q.rename(.{ .s = .q }); k = k.rename(.{ .s = .k }); v = v.rename(.{ .s = .k }); const dtype = q.dtype(); const new_kv_cache = kv_cache.update(k, v, token_index); k = new_kv_cache.keys().convert(dtype); v = new_kv_cache.values().convert(dtype); // TODO ringbuffer kv cache. const softmax_bias = self.sinks.withTags(.{.h}); const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .softmax_bias = softmax_bias }); const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache }; } }; pub const KvCache = struct { k: zml.Tensor, v: zml.Tensor, layer_index: zml.Tensor, pub fn init(kv_shape: zml.Shape) KvCache { // The KV-cache is initialized with ones to detect reads of uninitialized memory. return .{ .k = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .v = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .layer_index = .scalar(-1, .u32), }; } pub fn initShape(kv_shape: zml.Shape) zml.ShapeOf(KvCache) { return .{ .k = kv_shape, .v = kv_shape, .layer_index = zml.Shape.init(.{}, .u32), }; } pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) { return .{ .k = try zml.Buffer.uninitialized(platform, kv_shape, .{}), .v = try zml.Buffer.uninitialized(platform, kv_shape, .{}), .layer_index = try zml.Buffer.uninitialized(platform, .scalar(.u32), .{}), }; } pub fn keys(self: KvCache) zml.Tensor { return self.k.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer); } pub fn values(self: KvCache) zml.Tensor { return self.v.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer); } pub fn update(self: KvCache, new_k: zml.Tensor, new_v: zml.Tensor, token_index: ?zml.Tensor) KvCache { const idx = if (token_index) |idx| idx else zml.Tensor.arange(.{ .end = new_k.dim(.k) }, .u32); return .{ .k = self.k.scatterSlices( .{ .k = idx, .layer = self.layer_index }, new_k.convert(self.k.dtype()), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.k), .v = self.v.scatterSlices( .{ .k = idx, .layer = self.layer_index }, new_v.convert(self.v.dtype()), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.v), .layer_index = self.layer_index, }; } pub fn atLayer(self: KvCache, layer_index: usize) KvCache { return .{ .k = self.k, .v = self.v, .layer_index = .scalar(layer_index, .u32), }; } pub fn reuseBuffer(self: KvCache, other: KvCache) KvCache { return .{ .k = self.k.reuseBuffer(other.k), .v = self.v.reuseBuffer(other.v), .layer_index = self.layer_index.reuseBuffer(other.layer_index), }; } }; pub const BlockScaledLinear = struct { blocks: zml.Tensor, scale: zml.Tensor, bias: ?zml.Tensor = null, blocks_dtype: zml.DataType, pub fn forward(self: BlockScaledLinear, x: zml.Tensor) zml.Tensor { const ctx = x.getContext(); const res_shape = x.shape().setDim(-1, self.blocks.dim(-3)); // Bitcast to our actual type. This allows to load weights in a packed layout. const blocks_0 = self.blocks.bitCast(self.blocks_dtype); const blocks = blocks_0.merge(.{ .d_block = .{ .d_block, .bitcast } }); const scale = self.scale.bitCast(.f8e8m0); // log.warn("BlockScaledLinear({}): {f} -> {f}", .{ self, x, res_shape }); const y = switch (ctx._platform.target) { else => y: { var dequantized_weight: zml.Tensor = .mul( blocks.convert(x.dtype()), scale.convert(x.dtype()).appendAxes(.{.d_block}), ); var y = x.dot(dequantized_weight.merge(.{ .d = .{ .d, .d_block } }), .{.d}); // std.log.warn("output shape: {f}", .{y}); std.debug.assert(y.shape().eql(res_shape)); y._shape = res_shape; break :y y; }, }; return if (self.bias) |bias| y.add(bias.broad(y.shape())) else y; } pub fn format(self: BlockScaledLinear, writer: *std.Io.Writer) !void { try writer.print("BlockScaledLinear(blocks={f}, scale={f}, bias={?f}, dt={t})", .{ self.blocks, self.scale, self.bias, self.blocks_dtype }); } }; const MoeOpts = struct { experts_per_token: u32, tokens_per_expert_ratio: ?f32 = 0.0, normalization: Normalization, pub const Normalization = enum { linear, softmax }; }; /// We have three algorithms, /// * one for single-stream inference (naive), /// * one for small batch sized with exact precision that sends all tokens to all experts. /// this isn't too costly as long as the batch size is small and the experts are IO bound. /// * one for big batch size that assign a fixed compute budget per expert and /// experts chose the tokens they want to handle. This introduces noise since it's possible /// a token doesn't get their requested expert. /// The parameter `tokens_per_expert_ratio` control how much compute budget is granted: /// expert_budget = ratio * (num_tokens * experts_per_token / num_experts). /// Bigger values of ratio will ensure it's rare a token doesn't get it's top 2 tokens. /// /// The preferred algorithm is the batched one, /// it is selected as soon there is enough tokens to guarantee that experts will be active most of the time. /// /// - input: .{ .s, .d } per-entry vector /// - gating: .{ .s, .expert } per-entry expert-affinity /// - experts: .{ .expert, .d_out, .d } expert layer (need to have a .forward method). /// -> output: .{ .s, .d_out } pub fn mixtureOfExperts(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor { log.warn("mixtureOfExperts({s}, {f}, {f}, {})", .{ @typeName(Expert), input, gating, opts }); const num_tokens: u32 = @intCast(input.dim(.s)); const num_experts = gating.dim(.expert); stdx.debug.assert(opts.experts_per_token > 0, "mixtureOfExperts expects opts.experts_per_token > 0, got {}", .{opts}); if (num_tokens == 1) { return moePerTokenRouting(Expert, experts, input, gating, opts); } const tokens_per_expert: u32 = if (opts.tokens_per_expert_ratio) |ratio| tpe: { const compute_budget = ratio * @as(f32, @floatFromInt(num_tokens * opts.experts_per_token)); var tpe: u32 = @intFromFloat(stdx.math.divFloat(f32, compute_budget, num_experts)); // Round to next multiple of 8 to avoid weird shapes. if (tpe % 8 != 0) tpe += 8 - (tpe % 8); break :tpe tpe; } else num_tokens; if (3 * tokens_per_expert <= 2 * num_tokens) { const routing, const tokens_ids_per_expert = dispatchTokens(gating, .{ .tokens_per_expert = tokens_per_expert, .experts_per_token = opts.experts_per_token, .normalization = opts.normalization, }); const scores_per_expert = routing.transpose(.{ .expert, .s }).gather(.{ .s = tokens_ids_per_expert }, .{}); const input_per_expert = input.gather(.{ .s = tokens_ids_per_expert }, .{}); var output_per_expert = experts.forward(input_per_expert); output_per_expert = output_per_expert.mul(scores_per_expert.convert(output_per_expert.dtype()).broad(output_per_expert.shape())); // Reverse engineer the normal output shape that one expert would have produced for all tokens. // If this fall short, we could use the "sliced_expert" strategy and call forward ourselves. const output_shape = output_per_expert.shape().drop(.expert).rename(.{ .top_token = .s }).setDim(.s, num_tokens); const output = zml.Tensor.scatterSlices( .constant(output_shape, output_shape.dtype().zero()), .{ .s = tokens_ids_per_expert }, output_per_expert, .{ .update_fn = zml.Tensor.ScatterOpts.increment }, ); log.warn("mixtureOfExperts({s}, {f}, {f}) -> fixed budget impl tpe: {d}, tokens: {d}", .{ @typeName(Expert), input, gating, tokens_per_expert, num_tokens }); return output; } else { return mixtureOfExpertsAllToAll(Expert, experts, input, gating, opts); } } /// Few tokens: most experts are unused, experts have at most one token. /// Select active experts and compute with that. pub fn moePerTokenRouting(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor { const num_tokens: u32 = @intCast(input.dim(.s)); stdx.debug.assert(num_tokens < 32, "Trying to unroll a lot of tokens !", .{}); const per_token_outputs = input.getContext().allocator().alloc(zml.Tensor, num_tokens) catch @panic("OOM"); const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{}); const per_token_score = switch (opts.normalization) { .linear => routing.values.div(routing.values.sum(.top_expert)), .softmax => routing.values.softmax(.top_expert), }; for (per_token_outputs, 0..num_tokens) |*output, tok_id| { for (0..opts.experts_per_token) |expert_rank| { const expert_id = routing.indices.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar(); const expert_score = per_token_score.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar(); var sliced_expert: Expert = undefined; zml.meta.mapAlloc(struct { pub fn cb(expert_id_: zml.Tensor, expert_weight: zml.Tensor) zml.Tensor { return expert_weight.gather(.{ .expert = expert_id_ }, .{}); } }.cb, stdx.noalloc, expert_id, experts, &sliced_expert) catch unreachable; // TODO how does this work when the two experts are on different gpus? // does the compute overlap ? var expert_output = sliced_expert.forward(input.choose(.{ .s = tok_id })); expert_output = .mul( expert_output, expert_score.convert(input.dtype()).broad(expert_output.shape()), ); output.* = if (expert_rank > 0) output.add(expert_output) else expert_output; } } log.warn("mixtureOfExperts({s}, {f}, {f}) -> single-stream impl", .{ @typeName(Expert), input, gating }); return .stack(per_token_outputs, 0, .s); } /// Send all tokens to all experts, and apply gating. pub fn mixtureOfExpertsAllToAll(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor { log.warn("mixtureOfExperts({s}, {f}, {f}) -> all to all impl", .{ @typeName(Expert), input, gating }); const num_experts = gating.dim(.expert); const hard_gating = hardGating(gating, opts).print(); // TODO: `input.insertAxes(0, .{.expert}).repeat1d(.expert, num_experts)` is too verbose for just broadcasting along a new axis` const output_per_expert = experts.forward(input.insertAxes(0, .{.expert}).repeat1d(.expert, @intCast(num_experts))); return output_per_expert.dot(hard_gating.convert(input.dtype()), .expert); } /// Given `(token, expert) -> scores`, /// keeps only the top-k expert per token, and normalize the scores accordingly. /// Non selected experts will have a 0 score. pub fn hardGating(gating: zml.Tensor, opts: MoeOpts) zml.Tensor { const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{}); const per_token_score = switch (opts.normalization) { .linear => routing.values.div(routing.values.sum(.top_expert)), .softmax => routing.values.softmax(.top_expert), }; return zml.Tensor.scatterSlices( .zeroes(gating.shape()), .{ .expert = routing.indices }, per_token_score, .{ .indices_are_unique = true }, ); } /// Lot of tokens, each experts chose their tokens. /// It means that some tokens may have only one expert assigned. /// Each token will get assigned to at least one expert IIF the input gating is sums up to 1 (typically softmax output). /// Returns the actual `(token, expert) -> scores` used. pub fn dispatchTokens( gating: zml.Tensor, opts: struct { tokens_per_expert: u32, experts_per_token: u32, normalization: MoeOpts.Normalization, }, ) [2]zml.Tensor { const num_experts = gating.dim(.expert); const token_pref = gating.argsort(.expert, .{ .descending = true }); var expert_rank: zml.Tensor = .scatterSlices( .zeroes(gating.shape().withDtype(.i32)), .{ .expert = token_pref }, .addConstant(.iota(gating.shape(), .expert), 1), .{ .indices_are_unique = true }, ); // The pow(expert_rank) here means that we strongly favor top 1 over top 2 and top 2 over top 3. // expert_routing: (expert, top_token) -> token const expert_routing = gating.pow(expert_rank.convert(gating.dtype())).topK(.{ .top_token = .s }, opts.tokens_per_expert, .{}); const scores_per_expert = gating.gather(.{ .s = expert_routing.indices }, .{}); // Update the gating coefficient to account for the expert routing. // Each (token, expert) which can't be computed within the given budget is left to 0. const gating_v2: zml.Tensor = .scatterSlices( .zeroes(gating.shape()), .{ .s = expert_routing.indices }, scores_per_expert, .{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override }, ); // Now set to zero the scores (token, expert) for tokens that have been assigned more than experts_per_token. const lowest_experts = gating_v2.topK(.{ .top_expert = .expert }, @intCast(num_experts - opts.experts_per_token), .{ .descending = false }); var gating_v3: zml.Tensor = .scatterSlices( gating_v2, .{ .expert = lowest_experts.indices }, .zeroes(lowest_experts.values.shape()), .{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override }, ); // Then normalize so the sum of experts scores for one token sums up to 1. gating_v3 = switch (opts.normalization) { .linear => gating_v3.div(gating_v3.sum(.expert)), .softmax => gating_v3.softmax(.expert), }; const tokens_ids_per_expert = expert_routing.indices.transpose(.{ .expert, .top_token }); return .{ gating_v3, tokens_ids_per_expert }; }