Add example implementation and Bazel build for OpenAI gpt-oss models (GptOss.zig, main.zig, and BUILD.bazel).

This commit is contained in:
Foke Singh 2025-09-30 16:17:00 +00:00
parent d45a667ee5
commit e1b7fc5781
3 changed files with 1133 additions and 0 deletions

View File

@ -0,0 +1,16 @@
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_test")
zig_binary(
name = "gpt_oss",
srcs = [
"GptOss.zig",
],
main = "main.zig",
deps = [
"@com_github_hejsil_clap//:clap",
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
visibility = ["//visibility:public"],
)

741
examples/gpt_oss/GptOss.zig Normal file
View File

@ -0,0 +1,741 @@
///! GptOss architecture, using huggingface transformers naming.
///! Dimensions of activations: {.b, .s, .d}
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const GptOss = @This();
const log = std.log.scoped(.GptOss);
pub const Config = struct {
bos_token_id: u32 = 199998,
eos_token_id: stdx.json.Union(union(enum) {
int: u32,
ints: []const u32,
}),
head_dim: u32,
num_hidden_layers: u32,
num_attention_heads: u32,
num_key_value_heads: u32,
experts_per_token: u32,
rope_theta: f32,
max_position_embeddings: u32,
rms_norm_eps: f32,
sliding_window: u32,
hf_rope_impl: bool = true,
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
};
pub const Options = struct {
sampling_strategy: zml.nn.SamplingStrategy,
max_seq_len: u32,
max_prompt_len: u32,
tokens_per_expert_ratio: f32,
};
pub const Mode = union(enum) {
/// In prefill mode we pass the actual len of the prompt
prefill: zml.Tensor,
/// In gen mode we pass the position of the next token
gen: zml.Tensor,
};
lm_head: ?zml.nn.Linear,
model: Model,
config: Config,
options: Options,
pub fn init(allocator: std.mem.Allocator, store: zml.aio.BufferStore, config: Config, options: Options) !GptOss {
var self: GptOss = .{
.config = config,
.options = options,
.model = .{
.max_seq_len = @intCast(options.max_seq_len),
.num_heads = @intCast(config.num_attention_heads),
.num_kv_heads = @intCast(config.num_key_value_heads),
.rope_opts = .{
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
.freq_base = config.rope_theta,
.scaling = config.rope_scaling,
},
.embed_tokens = .{
.weight = store.getTensor("model.embed_tokens.weight").withSharding(.{1}),
},
.layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers),
.norm = .{
.weight = store.getTensor("model.norm.weight"),
.eps = config.rms_norm_eps,
},
},
.lm_head = .{ .weight = store.getTensor("lm_head.weight").withSharding(.{0}) },
};
var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 1024);
try prefix.push(stdx.noalloc, "model.layers");
for (self.model.layers, 0..) |*layer, i| {
try prefix.pushDigit(stdx.noalloc, i);
defer prefix.pop();
var self_attn: SelfAttn = .{
.sinks = store.getTensor(prefix.concat("self_attn.sinks")),
.q_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.q_proj")),
.k_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.k_proj")),
.v_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.v_proj")),
.o_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.o_proj")),
.sliding_window = if (i % 2 == 0) config.sliding_window else null,
.num_heads = self.model.num_heads,
.num_kv_heads = self.model.num_kv_heads,
.rope_opts = self.model.rope_opts,
};
self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0});
self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0});
self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0});
self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1});
const on_disk_moe = try zml.aio.populateModelWithPrefix(MoE.OnDisk, allocator, store, prefix.concat("mlp"));
var moe = on_disk_moe.rewrite(config.experts_per_token, options);
{
moe.experts.gate_up_proj.blocks = moe.experts.gate_up_proj.blocks.withSharding(.{.expert});
moe.experts.down_proj.blocks = moe.experts.down_proj.blocks.withSharding(.{.expert});
}
layer.* = .{
.input_layernorm = .{
.weight = store.getTensor(prefix.concat("input_layernorm.weight")),
.eps = config.rms_norm_eps,
},
.post_attention_layernorm = .{
.weight = store.getTensor(prefix.concat("post_attention_layernorm.weight")),
.eps = config.rms_norm_eps,
},
.self_attn = self_attn,
.mlp = moe,
};
}
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails
if (self.options.sampling_strategy.topk == 1 and self.lm_head != null) {
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
}
return self;
}
/// Predicts the token at `token_index` position.
/// Returns:
/// - updated `tokens`,
/// - updated KV cache
/// - a Rng state to allow for probabilistic generation
pub fn forward(
self: GptOss,
tokens_: zml.Tensor,
mode: Mode,
kv_cache: KvCache,
rng: zml.Tensor.Rng,
) struct { zml.Tensor, KvCache, zml.Tensor.Rng } {
const tokens = tokens_.withPartialTags(.{.s});
// token index is the position in the kv cache where to write results.
const token_index: zml.Tensor = switch (mode) {
.gen => |token_index| token_index,
.prefill => .scalar(0, .u32),
};
var out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
switch (mode) {
// In prefill we only pass the last token to the lm head.
.prefill => |prompt_len| out = out.gather(.{ .s = prompt_len.convert(.i32).addConstant(-1) }, .{ .indices_are_sorted = true }),
.gen => {},
}
var new_token, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.options.sampling_strategy);
new_token = new_token.convert(.u32);
new_token = switch (mode) {
.gen => new_token.reuseBuffer(tokens),
.prefill => new_token.appendAxes(.{.s}),
};
return .{ new_token, updated_kv_cache, new_rng };
}
fn sampleTokens(
self: GptOss,
lm_head_: ?zml.nn.Linear,
out_: zml.Tensor,
rng: zml.Tensor.Rng,
opts: zml.nn.SamplingStrategy,
) struct { zml.Tensor, zml.Tensor.Rng } {
const out = out_.withPartialTags(.{.d});
var logits = blk: {
if (lm_head_) |lm_head| {
break :blk zml.call(lm_head, .forward, .{out});
} else {
break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
}
};
if (logits.shape().hasTag(.voc) == null)
logits = logits.rename(.{ .d = .voc });
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
return .{ next_tokens, new_rng };
}
pub fn loadBuffers(self: GptOss, allocator: std.mem.Allocator, store: zml.aio.BufferStore, platform: zml.Platform) !zml.Bufferized(GptOss) {
var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 256);
defer prefix.deinit(allocator);
const noalloc = stdx.noalloc;
const loaded: zml.Bufferized(GptOss) = .{
.model = .{
.embed_tokens = try store.loadModelById(zml.nn.TokenEmbedding, noalloc, self.model.embed_tokens, platform),
.layers = try allocator.alloc(zml.Bufferized(TransformerLayer), self.model.layers.len),
.norm = try store.loadModelById(RmsNorm, noalloc, self.model.norm, platform),
},
.lm_head = try store.loadModelById(?zml.nn.Linear, noalloc, self.lm_head, platform),
};
prefix.push(noalloc, "model.layers") catch unreachable;
for (loaded.model.layers, self.model.layers, 0..) |*d_layer, layer, layer_id| {
const ckpt = prefix.checkpoint();
defer prefix.restore(ckpt);
prefix.pushDigit(noalloc, layer_id) catch unreachable;
d_layer.* = .{
.input_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.input_layernorm, platform),
.self_attn = try store.loadModelById(SelfAttn, noalloc, layer.self_attn, platform),
.post_attention_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.post_attention_layernorm, platform),
.mlp = try store.loadModelById(MoE, noalloc, layer.mlp, platform),
};
}
return loaded;
}
pub const Model = struct {
embed_tokens: zml.nn.TokenEmbedding,
norm: RmsNorm,
layers: []TransformerLayer,
max_seq_len: u32 = 0,
num_heads: i64 = 32,
num_kv_heads: i64 = 32,
rope_opts: zml.nn.RopeOpts = .{
.layout = .interleaved,
.freq_base = 10_000,
},
/// Forward one token, using KV cache for previous tokens.
/// Returns result and updated KV cache.
pub fn forward(self: Model, tokens: zml.Tensor, token_index: zml.Tensor, kv_cache: KvCache) struct { zml.Tensor, KvCache } {
const embeds = embed(self.embed_tokens, tokens);
var hidden = embeds;
var updated_kv_cache = kv_cache;
for (self.layers, 0..) |layer, i| {
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
}
const output = zml.call(self.norm, .forward, .{hidden});
return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
}
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: zml.Tensor) zml.Tensor {
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
}
};
pub const TransformerLayer = struct {
input_layernorm: RmsNorm,
self_attn: SelfAttn,
post_attention_layernorm: RmsNorm,
mlp: MoE,
pub fn forward(
self: TransformerLayer,
x0: zml.Tensor,
token_index: zml.Tensor,
kv_cache: KvCache,
) struct { zml.Tensor, KvCache } {
// Self Attention
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0});
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
const x1 = x0.add(delta0);
// Fully Connected
const x1_normalized = zml.call(self.post_attention_layernorm, .forward, .{x1});
const x2 = zml.call(self.mlp, .forward, .{x1_normalized}).add(x1);
return .{ x2.reuseBuffer(x0), updated_kv_cache };
}
};
const RmsNorm = struct {
weight: zml.Tensor,
eps: f32 = 1e-6,
/// L2 normalization of input tensor along `.d` axis.
pub fn forward(self: RmsNorm, input: zml.Tensor) zml.Tensor {
const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d});
// Note: contrary to Llama here the full layer is done in .f32, not just the variance computation.
const normalized = zml.nn.rmsNorm(x.convert(.f32), .d, self.eps);
return normalized.mul(self.weight.convert(.f32).withTags(.{.d}).broad(x.shape())).convert(input.dtype());
}
};
const MoE = struct {
experts: Mlp,
router: zml.nn.Linear,
moe_opts: MoeOpts,
pub fn forward(self: MoE, input: zml.Tensor) zml.Tensor {
log.warn("compiling moe with {f}", .{input});
// Note: GptOss applies softmax on the routing score.
// We delay the softmax to mixtureOfExperts where the actual routing is done.
// This allow to do re-routing without introducing nans.
const gating = self.router.forward(input);
return mixtureOfExperts(Mlp, self.experts, input, gating, self.moe_opts);
}
pub const OnDisk = struct {
router: zml.nn.Linear,
experts: struct {
down_proj_bias: zml.Tensor,
down_proj_blocks: zml.Tensor,
down_proj_scales: zml.Tensor,
gate_up_proj_bias: zml.Tensor,
gate_up_proj_blocks: zml.Tensor,
gate_up_proj_scales: zml.Tensor,
},
pub fn rewrite(on_disk: OnDisk, experts_per_token: u32, options: Options) MoE {
const e = on_disk.experts;
return .{
.experts = .{
.gate_up_proj = .{
// We need to bitcast the scale cause safetensors doesn't encode f8 types correctly
.scale = e.gate_up_proj_scales.withTags(.{ .expert, .out, .d }),
// We don't bitcast here because PJRT doesn't handle packed host buffers
.blocks = e.gate_up_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }),
.blocks_dtype = .f4e2m1,
.bias = e.gate_up_proj_bias.withTags(.{ .expert, .d }),
},
.down_proj = .{
.blocks = e.down_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }),
.blocks_dtype = .f4e2m1,
.scale = e.down_proj_scales.withTags(.{ .expert, .out, .d }),
.bias = e.down_proj_bias.withTags(.{ .expert, .d }),
},
},
.router = .{
.weight = on_disk.router.weight.withTags(.{ .expert, .d }),
.bias = on_disk.router.bias.?.withTags(.{.expert}),
},
.moe_opts = .{
.experts_per_token = experts_per_token,
.tokens_per_expert_ratio = options.tokens_per_expert_ratio,
.normalization = .softmax,
},
};
}
};
};
pub const Mlp = struct {
gate_up_proj: BlockScaledLinear, // {.out = intermediate_size * 2, .d = hidden_size / block_size, .d_block = block_size }
down_proj: BlockScaledLinear, // {.out = hidden_size * 2, .d = intermediate_size / block_size, .d_block = block_size }
pub fn forward(self: Mlp, x: zml.Tensor) zml.Tensor {
const dt = x.dtype();
var gate, var up = zml.nn.splitRealImg(self.gate_up_proj.forward(x), .interleaved);
gate = .minimum(gate, .scalar(7, dt));
up = .clamp(up, .scalar(-7, dt), .scalar(7, dt));
const out = gate.quickGelu().mul(up.addConstant(1));
return zml.call(self.down_proj, .forward, .{out});
}
pub fn format(self: Mlp, writer: *std.Io.Writer) std.Io.Writer.Error!void {
try writer.print("Mlp(gate_up_proj=.{f}, down_proj=.{f})", .{ self.gate_up_proj, self.down_proj });
}
};
pub const SelfAttn = struct {
q_proj: zml.nn.Linear,
k_proj: zml.nn.Linear,
v_proj: zml.nn.Linear,
sinks: zml.Tensor,
o_proj: zml.nn.Linear,
sliding_window: ?u32,
num_heads: i64,
num_kv_heads: i64,
rope_opts: zml.nn.RopeOpts,
/// Self Attention.
/// - If token_index is set, x is assumed to be the representation of one new token,
/// and kv_cache will be read for the previous tokens.
/// - If token_index is not set, x is assumed to be the representation of all tokens
/// since the beginning of the sequence, and kv_cache won't be read.
/// In both case, kv_cache will be updated with the computed key and value.
/// x: {.b, .s, .d } -> .{.b, .s, .d}
pub fn forward(
self: SelfAttn,
x: zml.Tensor,
token_index: zml.Tensor,
kv_cache: KvCache,
) struct { zml.Tensor, KvCache } {
const num_kv_heads = self.num_kv_heads;
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
// Generate the attention mask.
const seq_len = kv_cache.k.dim(.k);
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), self.sliding_window);
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
// then slice into it, but XLA is able to optimize this correctly.
attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{});
// In self-attention, .s axis is used both for keys and queries.
const pos_index = b: {
const temp = zml.Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype()));
break :b temp.add(token_index.broad(temp.shape()));
};
q = zml.nn.rope(q, pos_index, self.rope_opts);
k = zml.nn.rope(k, pos_index, self.rope_opts);
q = q.rename(.{ .s = .q });
k = k.rename(.{ .s = .k });
v = v.rename(.{ .s = .k });
const dtype = q.dtype();
const new_kv_cache = kv_cache.update(k, v, token_index);
k = new_kv_cache.keys().convert(dtype);
v = new_kv_cache.values().convert(dtype);
// TODO ringbuffer kv cache.
const softmax_bias = self.sinks.withTags(.{.h});
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .softmax_bias = softmax_bias });
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
}
};
pub const KvCache = struct {
k: zml.Tensor,
v: zml.Tensor,
layer_index: zml.Tensor,
pub fn init(kv_shape: zml.Shape) KvCache {
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
return .{
.k = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.v = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.layer_index = .scalar(-1, .u32),
};
}
pub fn initShape(kv_shape: zml.Shape) zml.ShapeOf(KvCache) {
return .{
.k = kv_shape,
.v = kv_shape,
.layer_index = zml.Shape.init(.{}, .u32),
};
}
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
return .{
.k = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
.v = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
.layer_index = try zml.Buffer.uninitialized(platform, .scalar(.u32), .{}),
};
}
pub fn keys(self: KvCache) zml.Tensor {
return self.k.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
}
pub fn values(self: KvCache) zml.Tensor {
return self.v.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
}
pub fn update(self: KvCache, new_k: zml.Tensor, new_v: zml.Tensor, token_index: ?zml.Tensor) KvCache {
const idx = if (token_index) |idx| idx else zml.Tensor.arange(.{ .end = new_k.dim(.k) }, .u32);
return .{
.k = self.k.scatterSlices(
.{ .k = idx, .layer = self.layer_index },
new_k.convert(self.k.dtype()),
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.k),
.v = self.v.scatterSlices(
.{ .k = idx, .layer = self.layer_index },
new_v.convert(self.v.dtype()),
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
).reuseBuffer(self.v),
.layer_index = self.layer_index,
};
}
pub fn atLayer(self: KvCache, layer_index: usize) KvCache {
return .{
.k = self.k,
.v = self.v,
.layer_index = .scalar(layer_index, .u32),
};
}
pub fn reuseBuffer(self: KvCache, other: KvCache) KvCache {
return .{
.k = self.k.reuseBuffer(other.k),
.v = self.v.reuseBuffer(other.v),
.layer_index = self.layer_index.reuseBuffer(other.layer_index),
};
}
};
pub const BlockScaledLinear = struct {
blocks: zml.Tensor,
scale: zml.Tensor,
bias: ?zml.Tensor = null,
blocks_dtype: zml.DataType,
pub fn forward(self: BlockScaledLinear, x: zml.Tensor) zml.Tensor {
const ctx = x.getContext();
const res_shape = x.shape().setDim(-1, self.blocks.dim(-3));
// Bitcast to our actual type. This allows to load weights in a packed layout.
const blocks_0 = self.blocks.bitCast(self.blocks_dtype);
const blocks = blocks_0.merge(.{ .d_block = .{ .d_block, .bitcast } });
const scale = self.scale.bitCast(.f8e8m0);
// log.warn("BlockScaledLinear({}): {f} -> {f}", .{ self, x, res_shape });
const y = switch (ctx._platform.target) {
else => y: {
var dequantized_weight: zml.Tensor = .mul(
blocks.convert(x.dtype()),
scale.convert(x.dtype()).appendAxes(.{.d_block}),
);
var y = x.dot(dequantized_weight.merge(.{ .d = .{ .d, .d_block } }), .{.d});
// std.log.warn("output shape: {f}", .{y});
std.debug.assert(y.shape().eql(res_shape));
y._shape = res_shape;
break :y y;
},
};
return if (self.bias) |bias| y.add(bias.broad(y.shape())) else y;
}
pub fn format(self: BlockScaledLinear, writer: *std.Io.Writer) !void {
try writer.print("BlockScaledLinear(blocks={f}, scale={f}, bias={?f}, dt={t})", .{ self.blocks, self.scale, self.bias, self.blocks_dtype });
}
};
const MoeOpts = struct {
experts_per_token: u32,
tokens_per_expert_ratio: ?f32 = 0.0,
normalization: Normalization,
pub const Normalization = enum { linear, softmax };
};
/// We have three algorithms,
/// * one for single-stream inference (naive),
/// * one for small batch sized with exact precision that sends all tokens to all experts.
/// this isn't too costly as long as the batch size is small and the experts are IO bound.
/// * one for big batch size that assign a fixed compute budget per expert and
/// experts chose the tokens they want to handle. This introduces noise since it's possible
/// a token doesn't get their requested expert.
/// The parameter `tokens_per_expert_ratio` control how much compute budget is granted:
/// expert_budget = ratio * (num_tokens * experts_per_token / num_experts).
/// Bigger values of ratio will ensure it's rare a token doesn't get it's top 2 tokens.
///
/// The preferred algorithm is the batched one,
/// it is selected as soon there is enough tokens to guarantee that experts will be active most of the time.
///
/// - input: .{ .s, .d } per-entry vector
/// - gating: .{ .s, .expert } per-entry expert-affinity
/// - experts: .{ .expert, .d_out, .d } expert layer (need to have a .forward method).
/// -> output: .{ .s, .d_out }
pub fn mixtureOfExperts(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
log.warn("mixtureOfExperts({s}, {f}, {f}, {})", .{ @typeName(Expert), input, gating, opts });
const num_tokens: u32 = @intCast(input.dim(.s));
const num_experts = gating.dim(.expert);
stdx.debug.assert(opts.experts_per_token > 0, "mixtureOfExperts expects opts.experts_per_token > 0, got {}", .{opts});
if (num_tokens == 1) {
return moePerTokenRouting(Expert, experts, input, gating, opts);
}
const tokens_per_expert: u32 = if (opts.tokens_per_expert_ratio) |ratio| tpe: {
const compute_budget = ratio * @as(f32, @floatFromInt(num_tokens * opts.experts_per_token));
var tpe: u32 = @intFromFloat(stdx.math.divFloat(f32, compute_budget, num_experts));
// Round to next multiple of 8 to avoid weird shapes.
if (tpe % 8 != 0) tpe += 8 - (tpe % 8);
break :tpe tpe;
} else num_tokens;
if (3 * tokens_per_expert <= 2 * num_tokens) {
const routing, const tokens_ids_per_expert = dispatchTokens(gating, .{
.tokens_per_expert = tokens_per_expert,
.experts_per_token = opts.experts_per_token,
.normalization = opts.normalization,
});
const scores_per_expert = routing.transpose(.{ .expert, .s }).gather(.{ .s = tokens_ids_per_expert }, .{});
const input_per_expert = input.gather(.{ .s = tokens_ids_per_expert }, .{});
var output_per_expert = experts.forward(input_per_expert);
output_per_expert = output_per_expert.mul(scores_per_expert.convert(output_per_expert.dtype()).broad(output_per_expert.shape()));
// Reverse engineer the normal output shape that one expert would have produced for all tokens.
// If this fall short, we could use the "sliced_expert" strategy and call forward ourselves.
const output_shape = output_per_expert.shape().drop(.expert).rename(.{ .top_token = .s }).setDim(.s, num_tokens);
const output = zml.Tensor.scatterSlices(
.constant(output_shape, output_shape.dtype().zero()),
.{ .s = tokens_ids_per_expert },
output_per_expert,
.{ .update_fn = zml.Tensor.ScatterOpts.increment },
);
log.warn("mixtureOfExperts({s}, {f}, {f}) -> fixed budget impl tpe: {d}, tokens: {d}", .{ @typeName(Expert), input, gating, tokens_per_expert, num_tokens });
return output;
} else {
return mixtureOfExpertsAllToAll(Expert, experts, input, gating, opts);
}
}
/// Few tokens: most experts are unused, experts have at most one token.
/// Select active experts and compute with that.
pub fn moePerTokenRouting(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
const num_tokens: u32 = @intCast(input.dim(.s));
stdx.debug.assert(num_tokens < 32, "Trying to unroll a lot of tokens !", .{});
const per_token_outputs = input.getContext().allocator().alloc(zml.Tensor, num_tokens) catch @panic("OOM");
const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{});
const per_token_score = switch (opts.normalization) {
.linear => routing.values.div(routing.values.sum(.top_expert)),
.softmax => routing.values.softmax(.top_expert),
};
for (per_token_outputs, 0..num_tokens) |*output, tok_id| {
for (0..opts.experts_per_token) |expert_rank| {
const expert_id = routing.indices.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar();
const expert_score = per_token_score.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar();
var sliced_expert: Expert = undefined;
zml.meta.mapAlloc(struct {
pub fn cb(expert_id_: zml.Tensor, expert_weight: zml.Tensor) zml.Tensor {
return expert_weight.gather(.{ .expert = expert_id_ }, .{});
}
}.cb, stdx.noalloc, expert_id, experts, &sliced_expert) catch unreachable;
// TODO how does this work when the two experts are on different gpus?
// does the compute overlap ?
var expert_output = sliced_expert.forward(input.choose(.{ .s = tok_id }));
expert_output = .mul(
expert_output,
expert_score.convert(input.dtype()).broad(expert_output.shape()),
);
output.* = if (expert_rank > 0) output.add(expert_output) else expert_output;
}
}
log.warn("mixtureOfExperts({s}, {f}, {f}) -> single-stream impl", .{ @typeName(Expert), input, gating });
return .stack(per_token_outputs, 0, .s);
}
/// Send all tokens to all experts, and apply gating.
pub fn mixtureOfExpertsAllToAll(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
log.warn("mixtureOfExperts({s}, {f}, {f}) -> all to all impl", .{ @typeName(Expert), input, gating });
const num_experts = gating.dim(.expert);
const hard_gating = hardGating(gating, opts).print();
// TODO: `input.insertAxes(0, .{.expert}).repeat1d(.expert, num_experts)` is too verbose for just broadcasting along a new axis`
const output_per_expert = experts.forward(input.insertAxes(0, .{.expert}).repeat1d(.expert, @intCast(num_experts)));
return output_per_expert.dot(hard_gating.convert(input.dtype()), .expert);
}
/// Given `(token, expert) -> scores`,
/// keeps only the top-k expert per token, and normalize the scores accordingly.
/// Non selected experts will have a 0 score.
pub fn hardGating(gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{});
const per_token_score = switch (opts.normalization) {
.linear => routing.values.div(routing.values.sum(.top_expert)),
.softmax => routing.values.softmax(.top_expert),
};
return zml.Tensor.scatterSlices(
.zeroes(gating.shape()),
.{ .expert = routing.indices },
per_token_score,
.{ .indices_are_unique = true },
);
}
/// Lot of tokens, each experts chose their tokens.
/// It means that some tokens may have only one expert assigned.
/// Each token will get assigned to at least one expert IIF the input gating is sums up to 1 (typically softmax output).
/// Returns the actual `(token, expert) -> scores` used.
pub fn dispatchTokens(
gating: zml.Tensor,
opts: struct {
tokens_per_expert: u32,
experts_per_token: u32,
normalization: MoeOpts.Normalization,
},
) [2]zml.Tensor {
const num_experts = gating.dim(.expert);
const token_pref = gating.argsort(.expert, .{ .descending = true });
var expert_rank: zml.Tensor = .scatterSlices(
.zeroes(gating.shape().withDtype(.i32)),
.{ .expert = token_pref },
.addConstant(.iota(gating.shape(), .expert), 1),
.{ .indices_are_unique = true },
);
// The pow(expert_rank) here means that we strongly favor top 1 over top 2 and top 2 over top 3.
// expert_routing: (expert, top_token) -> token
const expert_routing = gating.pow(expert_rank.convert(gating.dtype())).topK(.{ .top_token = .s }, opts.tokens_per_expert, .{});
const scores_per_expert = gating.gather(.{ .s = expert_routing.indices }, .{});
// Update the gating coefficient to account for the expert routing.
// Each (token, expert) which can't be computed within the given budget is left to 0.
const gating_v2: zml.Tensor = .scatterSlices(
.zeroes(gating.shape()),
.{ .s = expert_routing.indices },
scores_per_expert,
.{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override },
);
// Now set to zero the scores (token, expert) for tokens that have been assigned more than experts_per_token.
const lowest_experts = gating_v2.topK(.{ .top_expert = .expert }, @intCast(num_experts - opts.experts_per_token), .{ .descending = false });
var gating_v3: zml.Tensor = .scatterSlices(
gating_v2,
.{ .expert = lowest_experts.indices },
.zeroes(lowest_experts.values.shape()),
.{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override },
);
// Then normalize so the sum of experts scores for one token sums up to 1.
gating_v3 = switch (opts.normalization) {
.linear => gating_v3.div(gating_v3.sum(.expert)),
.softmax => gating_v3.softmax(.expert),
};
const tokens_ids_per_expert = expert_routing.indices.transpose(.{ .expert, .top_token });
return .{ gating_v3, tokens_ids_per_expert };
}

376
examples/gpt_oss/main.zig Normal file
View File

@ -0,0 +1,376 @@
const std = @import("std");
const builtin = @import("builtin");
const async = @import("async");
const clap = @import("clap");
const stdx = @import("stdx");
const zml = @import("zml");
const Buffer = zml.Buffer;
const Tensor = zml.Tensor;
const ShapeOf = zml.ShapeOf;
const GptOss = @import("GptOss.zig");
const log = std.log.scoped(.GptOss);
pub const std_options: std.Options = .{
.log_level = .info,
.logFn = async.logFn(std.log.defaultLog),
};
const cli_params = clap.parseParamsComptime(
\\--help print this help
\\--prompt <STRING> the prompt
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
\\--seed <UINT> random seed (optional)
\\--seq-len <UINT> max sequence length
\\--prompt-len <UINT> max prompt length
\\--temperature <FLOAT> temperature (default 1.0)
\\--topk <UINT> topk (default 10)
\\--expert-budget <FLOAT> token budget per expert
\\--platform-options <STRING> platform options, using Zon syntax, eg '.{.cuda=.{.allocator=.{.async=.{.memory_fraction=0.95}}}}'
\\--nochat <BOOL> skip prompt template
\\--sharding <BOOL> default: true: sharding on or off
);
pub fn tokenizePrompt(tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, no_chat: bool, out: []u32) ![]u32 {
var encoder = try tokenizer.encoder();
defer encoder.deinit();
if (no_chat) {
const tokens = try encoder.encode(prompt);
if (tokens.len > out.len) return error.PromptTooLong;
@memcpy(out[0..tokens.len], tokens);
return out[0..tokens.len];
}
const start_header = tokenizer.tokenToId("<|start|>") orelse return error.NoSuchToken;
const end_header_start_message = tokenizer.tokenToId("<|message|>") orelse return error.NoSuchToken;
const end_message = tokenizer.tokenToId("<|end|>") orelse return error.NoSuchToken;
var tokens: std.ArrayList(u32) = .initBuffer(out);
const system_prompt = try encoder.encode("You are ChatGPT, a large language model trained by OpenAI.\n");
if (system_prompt.len + 4 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("system").?, end_header_start_message });
tokens.appendSliceAssumeCapacity(system_prompt);
tokens.appendAssumeCapacity(end_message);
const user_prompt = try encoder.encode(prompt);
if (user_prompt.len + 9 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("user").?, end_header_start_message });
tokens.appendSliceAssumeCapacity(user_prompt);
tokens.appendSliceAssumeCapacity(&.{
end_message,
start_header,
tokenizer.tokenToId("assistant").?,
tokenizer.tokenToId("<|channel|>") orelse return error.NoSuchToken,
tokenizer.tokenToId("analysis") orelse return error.NoSuchToken,
end_header_start_message,
});
return tokens.items;
}
pub fn generateText(
config: GptOss.Config,
options: GptOss.Options,
mod_prefill: zml.ModuleExe(GptOss.forward),
mod_generate: zml.ModuleExe(GptOss.forward),
kv_cache_: zml.Bufferized(GptOss.KvCache),
tokenizer: zml.tokenizer.Tokenizer,
allocator: std.mem.Allocator,
seed: u128,
prompt_tok: []const u32,
output: *std.Io.Writer,
) !void {
var tokenizer_decoder = try tokenizer.decoder();
defer tokenizer_decoder.deinit();
const platform = mod_generate.platform();
// init RNG and buffers
var rng = try zml.Tensor.Rng.init(platform, seed);
var generated_token_buffer = [_]u32{undefined};
var current_token, var kv_cache = prefill: {
// prepare device buffers for the prefill tokens and their positions
const prefill_buffer = try allocator.alloc(u32, options.max_prompt_len);
@memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok);
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{options.max_prompt_len}, prefill_buffer);
defer prefill_tokens.deinit();
var prefill_token_pos = try zml.Buffer.scalar(platform, prompt_tok.len, .u32);
defer prefill_token_pos.deinit();
const first_token, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, .{ .prefill = prefill_token_pos }, kv_cache_, rng });
// extract the first generated token
_ = try first_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
log.warn("first_token: {d}", .{generated_token_buffer[0]});
break :prefill .{ first_token, kv_cache };
};
defer zml.aio.unloadBuffers(&kv_cache);
defer current_token.deinit();
const output_tokens_len = options.max_seq_len - prompt_tok.len - 1;
const start = std.time.microTimestamp();
// One token has already been generated by the prefill.
var num_tokens_generated: usize = 1;
generation: for (0..output_tokens_len + 1) |i| {
// collect and print generated sequence
num_tokens_generated += 1;
const generated_token = generated_token_buffer[0];
if (try tokenizer_decoder.next(generated_token)) |chunk| {
try output.writeAll(chunk);
}
// check for eos
if (i == output_tokens_len) break :generation;
switch (config.eos_token_id.value) {
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
.ints => |eos_list| {
for (eos_list) |eos| {
if (generated_token == @as(u32, @intCast(eos))) break :generation;
}
},
}
// current token pos needs to go into a zml.Buffer
const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer);
defer token_pos.deinit();
// call to generate the next token
current_token, kv_cache, rng = mod_generate.call(.{ current_token, .{ .gen = token_pos }, kv_cache, rng });
// extract the generated token from the buffer
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
}
const end = std.time.microTimestamp();
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
}
pub fn main() !void {
try async.AsyncThread.main(std.heap.smp_allocator, asyncMain);
}
pub fn asyncMain() !void {
log.info(" GptOss was compiled with {}", .{@import("builtin").mode});
var allocator: std.mem.Allocator = alloc: {
if (builtin.mode == .Debug) {
var dbg_alloc: std.heap.DebugAllocator(.{
.never_unmap = true,
.retain_metadata = true,
}) = .init;
break :alloc dbg_alloc.allocator();
}
break :alloc std.heap.smp_allocator;
};
const cli = ClapBoilerplate.parseCli(allocator);
defer cli.deinit();
const hf_model_path = cli.args.@"hf-model-path" orelse {
log.err("Missing --hf-model-path", .{});
return;
};
const config = config: {
var arena: std.heap.ArenaAllocator = .init(allocator);
defer arena.deinit();
const model_config_path = try std.fs.path.join(arena.allocator(), &.{ hf_model_path, "config.json" });
var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only });
defer config_json_file.close() catch unreachable;
var config_reader = config_json_file.reader(try arena.allocator().alloc(u8, 256));
var reader = std.json.Reader.init(allocator, &config_reader.interface);
defer reader.deinit();
var config = try std.json.parseFromTokenSourceLeaky(GptOss.Config, arena.allocator(), &reader, .{ .ignore_unknown_fields = true });
// From generation_config.json
config.eos_token_id = .{ .value = .{ .ints = &.{ 200002, 199999, 200012 } } };
break :config config;
};
var context = try zml.Context.init();
defer context.deinit();
// initialize ZML platform
const platform: zml.Platform = platform: {
const arena: std.heap.ArenaAllocator = .init(allocator);
defer arena.deinit();
// eg: --platform-options='.{.cuda=.{.allocator=.{.bfc=.{.memory_fraction=0.99}}}}'
// eg: --platform-options='.{.cpu=.{.device_count=8}}'
const platform_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(cli.args.@"platform-options" orelse ".{}"), null, .{ .free_on_error = false }) catch |err| {
log.err("Failed to parse --platform-options as json ({}): {s}", .{ err, cli.args.@"platform-options".? });
return err;
};
const compilation_options = zml.CompilationOptions{
.xla_dump_to = "/tmp/zml/gpt_oss",
.sharding_enabled = cli.args.sharding orelse true,
};
const platform = context
.autoPlatform(platform_opts)
.withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
break :platform platform;
};
const options: GptOss.Options = .{
.max_seq_len = cli.args.@"seq-len" orelse 8192,
.max_prompt_len = cli.args.@"prompt-len" orelse 256,
.tokens_per_expert_ratio = cli.args.@"expert-budget" orelse 4.0,
.sampling_strategy = .{
.topk = cli.args.topk orelse 10,
.temperature = 1.0,
},
};
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
defer compiler_arena.deinit();
const model_weights_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" });
defer allocator.free(model_weights_path);
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
defer store.deinit();
const model: GptOss = try GptOss.init(compiler_arena.allocator(), store, config, options);
const tokens_shape_prefill = zml.Shape.init(.{ .s = options.max_prompt_len }, .u32);
const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
const dtype = model.model.embed_tokens.weight.dtype();
const kv_shape = zml.Shape.init(.{
.layer = model.model.layers.len,
.k = options.max_seq_len,
.h = config.num_key_value_heads,
.hd = config.head_dim,
}, dtype).withSharding(.{.h});
const kv_cache_shape: zml.ShapeOf(GptOss.KvCache) = GptOss.KvCache.initShape(kv_shape);
const rng_shape = zml.Tensor.Rng.shape();
var start = try std.time.Timer.start();
var fut_mod_prefill = try async.async(zml.compileModel, .{
allocator, GptOss.forward, model,
.{
tokens_shape_prefill,
zml.ShapeOf(GptOss.Mode){ .prefill = .scalar(.u32) },
kv_cache_shape,
rng_shape,
},
platform,
});
var fut_mod = try async.async(zml.compileModel, .{
allocator, GptOss.forward, model,
.{
tokens_shape,
zml.ShapeOf(GptOss.Mode){ .gen = .scalar(.u32) },
kv_cache_shape,
rng_shape,
},
platform,
});
log.info("\tLoading GptOss weights from {s}...", .{model_weights_path});
var gpt_oss_weights = try model.loadBuffers(compiler_arena.allocator(), store, platform);
defer zml.aio.unloadBuffers(&gpt_oss_weights);
log.info("\tLoaded weights in {D}", .{start.read()});
var module_prefill = (try fut_mod_prefill.await()).prepare(gpt_oss_weights);
defer module_prefill.deinit();
var module_gen = (try fut_mod.await()).prepare(gpt_oss_weights);
defer module_gen.deinit();
log.info("\tCompiled model in {D}", .{start.read()});
log.info("Creating KvCache", .{});
const kv_cache = try GptOss.KvCache.initBuffer(kv_shape, platform);
var tokenizer = blk: {
const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" });
defer allocator.free(model_tokenizer_path);
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
var timer = try stdx.time.Timer.start();
defer log.info("Loaded tokenizer from {s} [{f}]", .{ model_tokenizer_path, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
};
errdefer tokenizer.deinit();
const prompt = cli.args.prompt orelse "What are some fun facts about animals?";
log.info("\tPrompt: {s}", .{prompt});
const no_chat = cli.args.nochat orelse false;
const prompt_tok_buf = try allocator.alloc(u32, options.max_prompt_len);
defer allocator.free(prompt_tok_buf);
const prompt_tok = tokenizePrompt(tokenizer, prompt, no_chat, prompt_tok_buf) catch |err| switch (err) {
error.PromptTooLong => std.debug.panic("Prompt too long, expected at most {d} tokens. Consider increasing --max-prompt-len", .{prompt_tok_buf.len}),
else => |e| return e,
};
log.info("\t Tokenized prompt: {any} ({d} tokens)", .{ prompt_tok, prompt_tok.len });
const seed = cli.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
// Unbuffered writing of the tokens to stdout.
// generated text will be printed token by token.
var output = std.fs.File.stdout().writer(&.{});
try generateText(config, options, module_prefill, module_gen, kv_cache, tokenizer, allocator, seed, prompt_tok, &output.interface);
}
const ClapBoilerplate = struct {
pub const Cli = clap.Result(clap.Help, &cli_params, parsers);
fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}
const parsers = .{
.BOOL = bool_parser,
.UINT = clap.parsers.int(u32, 0),
.FLOAT = clap.parsers.float(f32),
.STRING = clap.parsers.string,
.PATH = clap.parsers.string,
};
pub fn parseCli(allocator: std.mem.Allocator) Cli {
var diag: clap.Diagnostic = .{};
var stderr_buffer: [1024]u8 = undefined;
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
const cli = clap.parse(clap.Help, &cli_params, parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
diag.report(&stderr.interface, err) catch {};
stderr.interface.print("usage: ", .{}) catch {};
clap.usage(&stderr.interface, clap.Help, &cli_params) catch {};
stderr.interface.print("\n", .{}) catch {};
stderr.interface.flush() catch {};
std.process.exit(1);
};
if (cli.args.help != 0) {
clap.help(&stderr.interface, clap.Help, &cli_params, .{}) catch {};
stderr.interface.flush() catch {};
std.process.exit(0);
}
return cli;
}
};