Add example implementation and Bazel build for OpenAI gpt-oss models (GptOss.zig, main.zig, and BUILD.bazel).
This commit is contained in:
parent
d45a667ee5
commit
e1b7fc5781
16
examples/gpt_oss/BUILD.bazel
Normal file
16
examples/gpt_oss/BUILD.bazel
Normal file
@ -0,0 +1,16 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_test")
|
||||
|
||||
zig_binary(
|
||||
name = "gpt_oss",
|
||||
srcs = [
|
||||
"GptOss.zig",
|
||||
],
|
||||
main = "main.zig",
|
||||
deps = [
|
||||
"@com_github_hejsil_clap//:clap",
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
741
examples/gpt_oss/GptOss.zig
Normal file
741
examples/gpt_oss/GptOss.zig
Normal file
@ -0,0 +1,741 @@
|
||||
///! GptOss architecture, using huggingface transformers naming.
|
||||
///! Dimensions of activations: {.b, .s, .d}
|
||||
const std = @import("std");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const GptOss = @This();
|
||||
const log = std.log.scoped(.GptOss);
|
||||
|
||||
pub const Config = struct {
|
||||
bos_token_id: u32 = 199998,
|
||||
eos_token_id: stdx.json.Union(union(enum) {
|
||||
int: u32,
|
||||
ints: []const u32,
|
||||
}),
|
||||
head_dim: u32,
|
||||
num_hidden_layers: u32,
|
||||
num_attention_heads: u32,
|
||||
num_key_value_heads: u32,
|
||||
experts_per_token: u32,
|
||||
rope_theta: f32,
|
||||
max_position_embeddings: u32,
|
||||
rms_norm_eps: f32,
|
||||
sliding_window: u32,
|
||||
hf_rope_impl: bool = true,
|
||||
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
|
||||
};
|
||||
|
||||
pub const Options = struct {
|
||||
sampling_strategy: zml.nn.SamplingStrategy,
|
||||
max_seq_len: u32,
|
||||
max_prompt_len: u32,
|
||||
tokens_per_expert_ratio: f32,
|
||||
};
|
||||
|
||||
pub const Mode = union(enum) {
|
||||
/// In prefill mode we pass the actual len of the prompt
|
||||
prefill: zml.Tensor,
|
||||
/// In gen mode we pass the position of the next token
|
||||
gen: zml.Tensor,
|
||||
};
|
||||
|
||||
lm_head: ?zml.nn.Linear,
|
||||
model: Model,
|
||||
|
||||
config: Config,
|
||||
options: Options,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, store: zml.aio.BufferStore, config: Config, options: Options) !GptOss {
|
||||
var self: GptOss = .{
|
||||
.config = config,
|
||||
.options = options,
|
||||
.model = .{
|
||||
.max_seq_len = @intCast(options.max_seq_len),
|
||||
.num_heads = @intCast(config.num_attention_heads),
|
||||
.num_kv_heads = @intCast(config.num_key_value_heads),
|
||||
.rope_opts = .{
|
||||
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
||||
.freq_base = config.rope_theta,
|
||||
.scaling = config.rope_scaling,
|
||||
},
|
||||
|
||||
.embed_tokens = .{
|
||||
.weight = store.getTensor("model.embed_tokens.weight").withSharding(.{1}),
|
||||
},
|
||||
.layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers),
|
||||
.norm = .{
|
||||
.weight = store.getTensor("model.norm.weight"),
|
||||
.eps = config.rms_norm_eps,
|
||||
},
|
||||
},
|
||||
.lm_head = .{ .weight = store.getTensor("lm_head.weight").withSharding(.{0}) },
|
||||
};
|
||||
|
||||
var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 1024);
|
||||
try prefix.push(stdx.noalloc, "model.layers");
|
||||
for (self.model.layers, 0..) |*layer, i| {
|
||||
try prefix.pushDigit(stdx.noalloc, i);
|
||||
defer prefix.pop();
|
||||
|
||||
var self_attn: SelfAttn = .{
|
||||
.sinks = store.getTensor(prefix.concat("self_attn.sinks")),
|
||||
.q_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.q_proj")),
|
||||
.k_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.k_proj")),
|
||||
.v_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.v_proj")),
|
||||
.o_proj = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("self_attn.o_proj")),
|
||||
|
||||
.sliding_window = if (i % 2 == 0) config.sliding_window else null,
|
||||
.num_heads = self.model.num_heads,
|
||||
.num_kv_heads = self.model.num_kv_heads,
|
||||
.rope_opts = self.model.rope_opts,
|
||||
};
|
||||
|
||||
self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0});
|
||||
self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0});
|
||||
self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0});
|
||||
self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1});
|
||||
|
||||
const on_disk_moe = try zml.aio.populateModelWithPrefix(MoE.OnDisk, allocator, store, prefix.concat("mlp"));
|
||||
var moe = on_disk_moe.rewrite(config.experts_per_token, options);
|
||||
{
|
||||
moe.experts.gate_up_proj.blocks = moe.experts.gate_up_proj.blocks.withSharding(.{.expert});
|
||||
moe.experts.down_proj.blocks = moe.experts.down_proj.blocks.withSharding(.{.expert});
|
||||
}
|
||||
|
||||
layer.* = .{
|
||||
.input_layernorm = .{
|
||||
.weight = store.getTensor(prefix.concat("input_layernorm.weight")),
|
||||
.eps = config.rms_norm_eps,
|
||||
},
|
||||
.post_attention_layernorm = .{
|
||||
.weight = store.getTensor(prefix.concat("post_attention_layernorm.weight")),
|
||||
.eps = config.rms_norm_eps,
|
||||
},
|
||||
.self_attn = self_attn,
|
||||
.mlp = moe,
|
||||
};
|
||||
}
|
||||
|
||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||
// It currently crashes/compilation fails
|
||||
if (self.options.sampling_strategy.topk == 1 and self.lm_head != null) {
|
||||
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Predicts the token at `token_index` position.
|
||||
/// Returns:
|
||||
/// - updated `tokens`,
|
||||
/// - updated KV cache
|
||||
/// - a Rng state to allow for probabilistic generation
|
||||
pub fn forward(
|
||||
self: GptOss,
|
||||
tokens_: zml.Tensor,
|
||||
mode: Mode,
|
||||
kv_cache: KvCache,
|
||||
rng: zml.Tensor.Rng,
|
||||
) struct { zml.Tensor, KvCache, zml.Tensor.Rng } {
|
||||
const tokens = tokens_.withPartialTags(.{.s});
|
||||
|
||||
// token index is the position in the kv cache where to write results.
|
||||
const token_index: zml.Tensor = switch (mode) {
|
||||
.gen => |token_index| token_index,
|
||||
.prefill => .scalar(0, .u32),
|
||||
};
|
||||
var out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
||||
|
||||
switch (mode) {
|
||||
// In prefill we only pass the last token to the lm head.
|
||||
.prefill => |prompt_len| out = out.gather(.{ .s = prompt_len.convert(.i32).addConstant(-1) }, .{ .indices_are_sorted = true }),
|
||||
.gen => {},
|
||||
}
|
||||
|
||||
var new_token, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.options.sampling_strategy);
|
||||
new_token = new_token.convert(.u32);
|
||||
new_token = switch (mode) {
|
||||
.gen => new_token.reuseBuffer(tokens),
|
||||
.prefill => new_token.appendAxes(.{.s}),
|
||||
};
|
||||
return .{ new_token, updated_kv_cache, new_rng };
|
||||
}
|
||||
|
||||
fn sampleTokens(
|
||||
self: GptOss,
|
||||
lm_head_: ?zml.nn.Linear,
|
||||
out_: zml.Tensor,
|
||||
rng: zml.Tensor.Rng,
|
||||
opts: zml.nn.SamplingStrategy,
|
||||
) struct { zml.Tensor, zml.Tensor.Rng } {
|
||||
const out = out_.withPartialTags(.{.d});
|
||||
|
||||
var logits = blk: {
|
||||
if (lm_head_) |lm_head| {
|
||||
break :blk zml.call(lm_head, .forward, .{out});
|
||||
} else {
|
||||
break :blk self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(out, .{.d});
|
||||
}
|
||||
};
|
||||
|
||||
if (logits.shape().hasTag(.voc) == null)
|
||||
logits = logits.rename(.{ .d = .voc });
|
||||
|
||||
const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng);
|
||||
return .{ next_tokens, new_rng };
|
||||
}
|
||||
|
||||
pub fn loadBuffers(self: GptOss, allocator: std.mem.Allocator, store: zml.aio.BufferStore, platform: zml.Platform) !zml.Bufferized(GptOss) {
|
||||
var prefix: zml.aio.PrefixBuilder = try .initCapacity(allocator, 256);
|
||||
defer prefix.deinit(allocator);
|
||||
|
||||
const noalloc = stdx.noalloc;
|
||||
const loaded: zml.Bufferized(GptOss) = .{
|
||||
.model = .{
|
||||
.embed_tokens = try store.loadModelById(zml.nn.TokenEmbedding, noalloc, self.model.embed_tokens, platform),
|
||||
.layers = try allocator.alloc(zml.Bufferized(TransformerLayer), self.model.layers.len),
|
||||
.norm = try store.loadModelById(RmsNorm, noalloc, self.model.norm, platform),
|
||||
},
|
||||
.lm_head = try store.loadModelById(?zml.nn.Linear, noalloc, self.lm_head, platform),
|
||||
};
|
||||
|
||||
prefix.push(noalloc, "model.layers") catch unreachable;
|
||||
for (loaded.model.layers, self.model.layers, 0..) |*d_layer, layer, layer_id| {
|
||||
const ckpt = prefix.checkpoint();
|
||||
defer prefix.restore(ckpt);
|
||||
|
||||
prefix.pushDigit(noalloc, layer_id) catch unreachable;
|
||||
d_layer.* = .{
|
||||
.input_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.input_layernorm, platform),
|
||||
.self_attn = try store.loadModelById(SelfAttn, noalloc, layer.self_attn, platform),
|
||||
.post_attention_layernorm = try store.loadModelById(RmsNorm, noalloc, layer.post_attention_layernorm, platform),
|
||||
.mlp = try store.loadModelById(MoE, noalloc, layer.mlp, platform),
|
||||
};
|
||||
}
|
||||
|
||||
return loaded;
|
||||
}
|
||||
|
||||
pub const Model = struct {
|
||||
embed_tokens: zml.nn.TokenEmbedding,
|
||||
norm: RmsNorm,
|
||||
layers: []TransformerLayer,
|
||||
|
||||
max_seq_len: u32 = 0,
|
||||
num_heads: i64 = 32,
|
||||
num_kv_heads: i64 = 32,
|
||||
rope_opts: zml.nn.RopeOpts = .{
|
||||
.layout = .interleaved,
|
||||
.freq_base = 10_000,
|
||||
},
|
||||
|
||||
/// Forward one token, using KV cache for previous tokens.
|
||||
/// Returns result and updated KV cache.
|
||||
pub fn forward(self: Model, tokens: zml.Tensor, token_index: zml.Tensor, kv_cache: KvCache) struct { zml.Tensor, KvCache } {
|
||||
const embeds = embed(self.embed_tokens, tokens);
|
||||
var hidden = embeds;
|
||||
|
||||
var updated_kv_cache = kv_cache;
|
||||
for (self.layers, 0..) |layer, i| {
|
||||
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
||||
}
|
||||
const output = zml.call(self.norm, .forward, .{hidden});
|
||||
|
||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache) };
|
||||
}
|
||||
|
||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: zml.Tensor) zml.Tensor {
|
||||
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
||||
}
|
||||
};
|
||||
|
||||
pub const TransformerLayer = struct {
|
||||
input_layernorm: RmsNorm,
|
||||
self_attn: SelfAttn,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
mlp: MoE,
|
||||
|
||||
pub fn forward(
|
||||
self: TransformerLayer,
|
||||
x0: zml.Tensor,
|
||||
token_index: zml.Tensor,
|
||||
kv_cache: KvCache,
|
||||
) struct { zml.Tensor, KvCache } {
|
||||
// Self Attention
|
||||
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
||||
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0});
|
||||
|
||||
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
|
||||
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
|
||||
const x1 = x0.add(delta0);
|
||||
|
||||
// Fully Connected
|
||||
const x1_normalized = zml.call(self.post_attention_layernorm, .forward, .{x1});
|
||||
const x2 = zml.call(self.mlp, .forward, .{x1_normalized}).add(x1);
|
||||
|
||||
return .{ x2.reuseBuffer(x0), updated_kv_cache };
|
||||
}
|
||||
};
|
||||
|
||||
const RmsNorm = struct {
|
||||
weight: zml.Tensor,
|
||||
eps: f32 = 1e-6,
|
||||
|
||||
/// L2 normalization of input tensor along `.d` axis.
|
||||
pub fn forward(self: RmsNorm, input: zml.Tensor) zml.Tensor {
|
||||
const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d});
|
||||
// Note: contrary to Llama here the full layer is done in .f32, not just the variance computation.
|
||||
const normalized = zml.nn.rmsNorm(x.convert(.f32), .d, self.eps);
|
||||
return normalized.mul(self.weight.convert(.f32).withTags(.{.d}).broad(x.shape())).convert(input.dtype());
|
||||
}
|
||||
};
|
||||
|
||||
const MoE = struct {
|
||||
experts: Mlp,
|
||||
router: zml.nn.Linear,
|
||||
moe_opts: MoeOpts,
|
||||
|
||||
pub fn forward(self: MoE, input: zml.Tensor) zml.Tensor {
|
||||
log.warn("compiling moe with {f}", .{input});
|
||||
// Note: GptOss applies softmax on the routing score.
|
||||
// We delay the softmax to mixtureOfExperts where the actual routing is done.
|
||||
// This allow to do re-routing without introducing nans.
|
||||
const gating = self.router.forward(input);
|
||||
return mixtureOfExperts(Mlp, self.experts, input, gating, self.moe_opts);
|
||||
}
|
||||
|
||||
pub const OnDisk = struct {
|
||||
router: zml.nn.Linear,
|
||||
experts: struct {
|
||||
down_proj_bias: zml.Tensor,
|
||||
down_proj_blocks: zml.Tensor,
|
||||
down_proj_scales: zml.Tensor,
|
||||
gate_up_proj_bias: zml.Tensor,
|
||||
gate_up_proj_blocks: zml.Tensor,
|
||||
gate_up_proj_scales: zml.Tensor,
|
||||
},
|
||||
|
||||
pub fn rewrite(on_disk: OnDisk, experts_per_token: u32, options: Options) MoE {
|
||||
const e = on_disk.experts;
|
||||
return .{
|
||||
.experts = .{
|
||||
.gate_up_proj = .{
|
||||
// We need to bitcast the scale cause safetensors doesn't encode f8 types correctly
|
||||
.scale = e.gate_up_proj_scales.withTags(.{ .expert, .out, .d }),
|
||||
// We don't bitcast here because PJRT doesn't handle packed host buffers
|
||||
.blocks = e.gate_up_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }),
|
||||
.blocks_dtype = .f4e2m1,
|
||||
.bias = e.gate_up_proj_bias.withTags(.{ .expert, .d }),
|
||||
},
|
||||
.down_proj = .{
|
||||
.blocks = e.down_proj_blocks.withTags(.{ .expert, .out, .d, .d_block }),
|
||||
.blocks_dtype = .f4e2m1,
|
||||
.scale = e.down_proj_scales.withTags(.{ .expert, .out, .d }),
|
||||
.bias = e.down_proj_bias.withTags(.{ .expert, .d }),
|
||||
},
|
||||
},
|
||||
|
||||
.router = .{
|
||||
.weight = on_disk.router.weight.withTags(.{ .expert, .d }),
|
||||
.bias = on_disk.router.bias.?.withTags(.{.expert}),
|
||||
},
|
||||
|
||||
.moe_opts = .{
|
||||
.experts_per_token = experts_per_token,
|
||||
.tokens_per_expert_ratio = options.tokens_per_expert_ratio,
|
||||
.normalization = .softmax,
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
pub const Mlp = struct {
|
||||
gate_up_proj: BlockScaledLinear, // {.out = intermediate_size * 2, .d = hidden_size / block_size, .d_block = block_size }
|
||||
down_proj: BlockScaledLinear, // {.out = hidden_size * 2, .d = intermediate_size / block_size, .d_block = block_size }
|
||||
|
||||
pub fn forward(self: Mlp, x: zml.Tensor) zml.Tensor {
|
||||
const dt = x.dtype();
|
||||
var gate, var up = zml.nn.splitRealImg(self.gate_up_proj.forward(x), .interleaved);
|
||||
gate = .minimum(gate, .scalar(7, dt));
|
||||
up = .clamp(up, .scalar(-7, dt), .scalar(7, dt));
|
||||
|
||||
const out = gate.quickGelu().mul(up.addConstant(1));
|
||||
return zml.call(self.down_proj, .forward, .{out});
|
||||
}
|
||||
|
||||
pub fn format(self: Mlp, writer: *std.Io.Writer) std.Io.Writer.Error!void {
|
||||
try writer.print("Mlp(gate_up_proj=.{f}, down_proj=.{f})", .{ self.gate_up_proj, self.down_proj });
|
||||
}
|
||||
};
|
||||
|
||||
pub const SelfAttn = struct {
|
||||
q_proj: zml.nn.Linear,
|
||||
k_proj: zml.nn.Linear,
|
||||
v_proj: zml.nn.Linear,
|
||||
sinks: zml.Tensor,
|
||||
|
||||
o_proj: zml.nn.Linear,
|
||||
|
||||
sliding_window: ?u32,
|
||||
num_heads: i64,
|
||||
num_kv_heads: i64,
|
||||
rope_opts: zml.nn.RopeOpts,
|
||||
|
||||
/// Self Attention.
|
||||
/// - If token_index is set, x is assumed to be the representation of one new token,
|
||||
/// and kv_cache will be read for the previous tokens.
|
||||
/// - If token_index is not set, x is assumed to be the representation of all tokens
|
||||
/// since the beginning of the sequence, and kv_cache won't be read.
|
||||
/// In both case, kv_cache will be updated with the computed key and value.
|
||||
/// x: {.b, .s, .d } -> .{.b, .s, .d}
|
||||
pub fn forward(
|
||||
self: SelfAttn,
|
||||
x: zml.Tensor,
|
||||
token_index: zml.Tensor,
|
||||
kv_cache: KvCache,
|
||||
) struct { zml.Tensor, KvCache } {
|
||||
const num_kv_heads = self.num_kv_heads;
|
||||
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
||||
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||
|
||||
// Generate the attention mask.
|
||||
const seq_len = kv_cache.k.dim(.k);
|
||||
var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), self.sliding_window);
|
||||
|
||||
// Note: in Pytorch it would be very inefficient to generate the full attn_mask,
|
||||
// then slice into it, but XLA is able to optimize this correctly.
|
||||
attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{});
|
||||
|
||||
// In self-attention, .s axis is used both for keys and queries.
|
||||
const pos_index = b: {
|
||||
const temp = zml.Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype()));
|
||||
break :b temp.add(token_index.broad(temp.shape()));
|
||||
};
|
||||
|
||||
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
||||
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
||||
q = q.rename(.{ .s = .q });
|
||||
k = k.rename(.{ .s = .k });
|
||||
v = v.rename(.{ .s = .k });
|
||||
|
||||
const dtype = q.dtype();
|
||||
const new_kv_cache = kv_cache.update(k, v, token_index);
|
||||
k = new_kv_cache.keys().convert(dtype);
|
||||
v = new_kv_cache.values().convert(dtype);
|
||||
|
||||
// TODO ringbuffer kv cache.
|
||||
|
||||
const softmax_bias = self.sinks.withTags(.{.h});
|
||||
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = attn_mask, .softmax_bias = softmax_bias });
|
||||
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
||||
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
||||
}
|
||||
};
|
||||
|
||||
pub const KvCache = struct {
|
||||
k: zml.Tensor,
|
||||
v: zml.Tensor,
|
||||
layer_index: zml.Tensor,
|
||||
|
||||
pub fn init(kv_shape: zml.Shape) KvCache {
|
||||
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
|
||||
return .{
|
||||
.k = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||
.v = .constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||
.layer_index = .scalar(-1, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn initShape(kv_shape: zml.Shape) zml.ShapeOf(KvCache) {
|
||||
return .{
|
||||
.k = kv_shape,
|
||||
.v = kv_shape,
|
||||
.layer_index = zml.Shape.init(.{}, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
||||
return .{
|
||||
.k = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
||||
.v = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
||||
.layer_index = try zml.Buffer.uninitialized(platform, .scalar(.u32), .{}),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn keys(self: KvCache) zml.Tensor {
|
||||
return self.k.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
||||
}
|
||||
|
||||
pub fn values(self: KvCache) zml.Tensor {
|
||||
return self.v.dynamicSlice(.{ .layer = zml.Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer);
|
||||
}
|
||||
|
||||
pub fn update(self: KvCache, new_k: zml.Tensor, new_v: zml.Tensor, token_index: ?zml.Tensor) KvCache {
|
||||
const idx = if (token_index) |idx| idx else zml.Tensor.arange(.{ .end = new_k.dim(.k) }, .u32);
|
||||
|
||||
return .{
|
||||
.k = self.k.scatterSlices(
|
||||
.{ .k = idx, .layer = self.layer_index },
|
||||
new_k.convert(self.k.dtype()),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.k),
|
||||
.v = self.v.scatterSlices(
|
||||
.{ .k = idx, .layer = self.layer_index },
|
||||
new_v.convert(self.v.dtype()),
|
||||
.{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
).reuseBuffer(self.v),
|
||||
.layer_index = self.layer_index,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn atLayer(self: KvCache, layer_index: usize) KvCache {
|
||||
return .{
|
||||
.k = self.k,
|
||||
.v = self.v,
|
||||
.layer_index = .scalar(layer_index, .u32),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn reuseBuffer(self: KvCache, other: KvCache) KvCache {
|
||||
return .{
|
||||
.k = self.k.reuseBuffer(other.k),
|
||||
.v = self.v.reuseBuffer(other.v),
|
||||
.layer_index = self.layer_index.reuseBuffer(other.layer_index),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const BlockScaledLinear = struct {
|
||||
blocks: zml.Tensor,
|
||||
scale: zml.Tensor,
|
||||
bias: ?zml.Tensor = null,
|
||||
blocks_dtype: zml.DataType,
|
||||
|
||||
pub fn forward(self: BlockScaledLinear, x: zml.Tensor) zml.Tensor {
|
||||
const ctx = x.getContext();
|
||||
const res_shape = x.shape().setDim(-1, self.blocks.dim(-3));
|
||||
|
||||
// Bitcast to our actual type. This allows to load weights in a packed layout.
|
||||
const blocks_0 = self.blocks.bitCast(self.blocks_dtype);
|
||||
const blocks = blocks_0.merge(.{ .d_block = .{ .d_block, .bitcast } });
|
||||
|
||||
const scale = self.scale.bitCast(.f8e8m0);
|
||||
|
||||
// log.warn("BlockScaledLinear({}): {f} -> {f}", .{ self, x, res_shape });
|
||||
const y = switch (ctx._platform.target) {
|
||||
else => y: {
|
||||
var dequantized_weight: zml.Tensor = .mul(
|
||||
blocks.convert(x.dtype()),
|
||||
scale.convert(x.dtype()).appendAxes(.{.d_block}),
|
||||
);
|
||||
var y = x.dot(dequantized_weight.merge(.{ .d = .{ .d, .d_block } }), .{.d});
|
||||
// std.log.warn("output shape: {f}", .{y});
|
||||
std.debug.assert(y.shape().eql(res_shape));
|
||||
y._shape = res_shape;
|
||||
break :y y;
|
||||
},
|
||||
};
|
||||
return if (self.bias) |bias| y.add(bias.broad(y.shape())) else y;
|
||||
}
|
||||
|
||||
pub fn format(self: BlockScaledLinear, writer: *std.Io.Writer) !void {
|
||||
try writer.print("BlockScaledLinear(blocks={f}, scale={f}, bias={?f}, dt={t})", .{ self.blocks, self.scale, self.bias, self.blocks_dtype });
|
||||
}
|
||||
};
|
||||
|
||||
const MoeOpts = struct {
|
||||
experts_per_token: u32,
|
||||
tokens_per_expert_ratio: ?f32 = 0.0,
|
||||
normalization: Normalization,
|
||||
|
||||
pub const Normalization = enum { linear, softmax };
|
||||
};
|
||||
|
||||
/// We have three algorithms,
|
||||
/// * one for single-stream inference (naive),
|
||||
/// * one for small batch sized with exact precision that sends all tokens to all experts.
|
||||
/// this isn't too costly as long as the batch size is small and the experts are IO bound.
|
||||
/// * one for big batch size that assign a fixed compute budget per expert and
|
||||
/// experts chose the tokens they want to handle. This introduces noise since it's possible
|
||||
/// a token doesn't get their requested expert.
|
||||
/// The parameter `tokens_per_expert_ratio` control how much compute budget is granted:
|
||||
/// expert_budget = ratio * (num_tokens * experts_per_token / num_experts).
|
||||
/// Bigger values of ratio will ensure it's rare a token doesn't get it's top 2 tokens.
|
||||
///
|
||||
/// The preferred algorithm is the batched one,
|
||||
/// it is selected as soon there is enough tokens to guarantee that experts will be active most of the time.
|
||||
///
|
||||
/// - input: .{ .s, .d } per-entry vector
|
||||
/// - gating: .{ .s, .expert } per-entry expert-affinity
|
||||
/// - experts: .{ .expert, .d_out, .d } expert layer (need to have a .forward method).
|
||||
/// -> output: .{ .s, .d_out }
|
||||
pub fn mixtureOfExperts(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
|
||||
log.warn("mixtureOfExperts({s}, {f}, {f}, {})", .{ @typeName(Expert), input, gating, opts });
|
||||
const num_tokens: u32 = @intCast(input.dim(.s));
|
||||
const num_experts = gating.dim(.expert);
|
||||
stdx.debug.assert(opts.experts_per_token > 0, "mixtureOfExperts expects opts.experts_per_token > 0, got {}", .{opts});
|
||||
|
||||
if (num_tokens == 1) {
|
||||
return moePerTokenRouting(Expert, experts, input, gating, opts);
|
||||
}
|
||||
|
||||
const tokens_per_expert: u32 = if (opts.tokens_per_expert_ratio) |ratio| tpe: {
|
||||
const compute_budget = ratio * @as(f32, @floatFromInt(num_tokens * opts.experts_per_token));
|
||||
var tpe: u32 = @intFromFloat(stdx.math.divFloat(f32, compute_budget, num_experts));
|
||||
// Round to next multiple of 8 to avoid weird shapes.
|
||||
if (tpe % 8 != 0) tpe += 8 - (tpe % 8);
|
||||
break :tpe tpe;
|
||||
} else num_tokens;
|
||||
|
||||
if (3 * tokens_per_expert <= 2 * num_tokens) {
|
||||
const routing, const tokens_ids_per_expert = dispatchTokens(gating, .{
|
||||
.tokens_per_expert = tokens_per_expert,
|
||||
.experts_per_token = opts.experts_per_token,
|
||||
.normalization = opts.normalization,
|
||||
});
|
||||
const scores_per_expert = routing.transpose(.{ .expert, .s }).gather(.{ .s = tokens_ids_per_expert }, .{});
|
||||
const input_per_expert = input.gather(.{ .s = tokens_ids_per_expert }, .{});
|
||||
var output_per_expert = experts.forward(input_per_expert);
|
||||
output_per_expert = output_per_expert.mul(scores_per_expert.convert(output_per_expert.dtype()).broad(output_per_expert.shape()));
|
||||
|
||||
// Reverse engineer the normal output shape that one expert would have produced for all tokens.
|
||||
// If this fall short, we could use the "sliced_expert" strategy and call forward ourselves.
|
||||
const output_shape = output_per_expert.shape().drop(.expert).rename(.{ .top_token = .s }).setDim(.s, num_tokens);
|
||||
const output = zml.Tensor.scatterSlices(
|
||||
.constant(output_shape, output_shape.dtype().zero()),
|
||||
.{ .s = tokens_ids_per_expert },
|
||||
output_per_expert,
|
||||
.{ .update_fn = zml.Tensor.ScatterOpts.increment },
|
||||
);
|
||||
|
||||
log.warn("mixtureOfExperts({s}, {f}, {f}) -> fixed budget impl tpe: {d}, tokens: {d}", .{ @typeName(Expert), input, gating, tokens_per_expert, num_tokens });
|
||||
return output;
|
||||
} else {
|
||||
return mixtureOfExpertsAllToAll(Expert, experts, input, gating, opts);
|
||||
}
|
||||
}
|
||||
|
||||
/// Few tokens: most experts are unused, experts have at most one token.
|
||||
/// Select active experts and compute with that.
|
||||
pub fn moePerTokenRouting(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
|
||||
const num_tokens: u32 = @intCast(input.dim(.s));
|
||||
stdx.debug.assert(num_tokens < 32, "Trying to unroll a lot of tokens !", .{});
|
||||
const per_token_outputs = input.getContext().allocator().alloc(zml.Tensor, num_tokens) catch @panic("OOM");
|
||||
|
||||
const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{});
|
||||
const per_token_score = switch (opts.normalization) {
|
||||
.linear => routing.values.div(routing.values.sum(.top_expert)),
|
||||
.softmax => routing.values.softmax(.top_expert),
|
||||
};
|
||||
|
||||
for (per_token_outputs, 0..num_tokens) |*output, tok_id| {
|
||||
for (0..opts.experts_per_token) |expert_rank| {
|
||||
const expert_id = routing.indices.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar();
|
||||
const expert_score = per_token_score.choose(.{ .s = tok_id, .top_expert = expert_rank }).asScalar();
|
||||
|
||||
var sliced_expert: Expert = undefined;
|
||||
zml.meta.mapAlloc(struct {
|
||||
pub fn cb(expert_id_: zml.Tensor, expert_weight: zml.Tensor) zml.Tensor {
|
||||
return expert_weight.gather(.{ .expert = expert_id_ }, .{});
|
||||
}
|
||||
}.cb, stdx.noalloc, expert_id, experts, &sliced_expert) catch unreachable;
|
||||
|
||||
// TODO how does this work when the two experts are on different gpus?
|
||||
// does the compute overlap ?
|
||||
var expert_output = sliced_expert.forward(input.choose(.{ .s = tok_id }));
|
||||
expert_output = .mul(
|
||||
expert_output,
|
||||
expert_score.convert(input.dtype()).broad(expert_output.shape()),
|
||||
);
|
||||
output.* = if (expert_rank > 0) output.add(expert_output) else expert_output;
|
||||
}
|
||||
}
|
||||
|
||||
log.warn("mixtureOfExperts({s}, {f}, {f}) -> single-stream impl", .{ @typeName(Expert), input, gating });
|
||||
return .stack(per_token_outputs, 0, .s);
|
||||
}
|
||||
|
||||
/// Send all tokens to all experts, and apply gating.
|
||||
pub fn mixtureOfExpertsAllToAll(Expert: type, experts: Expert, input: zml.Tensor, gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
|
||||
log.warn("mixtureOfExperts({s}, {f}, {f}) -> all to all impl", .{ @typeName(Expert), input, gating });
|
||||
const num_experts = gating.dim(.expert);
|
||||
const hard_gating = hardGating(gating, opts).print();
|
||||
// TODO: `input.insertAxes(0, .{.expert}).repeat1d(.expert, num_experts)` is too verbose for just broadcasting along a new axis`
|
||||
const output_per_expert = experts.forward(input.insertAxes(0, .{.expert}).repeat1d(.expert, @intCast(num_experts)));
|
||||
return output_per_expert.dot(hard_gating.convert(input.dtype()), .expert);
|
||||
}
|
||||
|
||||
/// Given `(token, expert) -> scores`,
|
||||
/// keeps only the top-k expert per token, and normalize the scores accordingly.
|
||||
/// Non selected experts will have a 0 score.
|
||||
pub fn hardGating(gating: zml.Tensor, opts: MoeOpts) zml.Tensor {
|
||||
const routing = gating.topK(.{ .top_expert = .expert }, opts.experts_per_token, .{});
|
||||
|
||||
const per_token_score = switch (opts.normalization) {
|
||||
.linear => routing.values.div(routing.values.sum(.top_expert)),
|
||||
.softmax => routing.values.softmax(.top_expert),
|
||||
};
|
||||
|
||||
return zml.Tensor.scatterSlices(
|
||||
.zeroes(gating.shape()),
|
||||
.{ .expert = routing.indices },
|
||||
per_token_score,
|
||||
.{ .indices_are_unique = true },
|
||||
);
|
||||
}
|
||||
|
||||
/// Lot of tokens, each experts chose their tokens.
|
||||
/// It means that some tokens may have only one expert assigned.
|
||||
/// Each token will get assigned to at least one expert IIF the input gating is sums up to 1 (typically softmax output).
|
||||
/// Returns the actual `(token, expert) -> scores` used.
|
||||
pub fn dispatchTokens(
|
||||
gating: zml.Tensor,
|
||||
opts: struct {
|
||||
tokens_per_expert: u32,
|
||||
experts_per_token: u32,
|
||||
normalization: MoeOpts.Normalization,
|
||||
},
|
||||
) [2]zml.Tensor {
|
||||
const num_experts = gating.dim(.expert);
|
||||
|
||||
const token_pref = gating.argsort(.expert, .{ .descending = true });
|
||||
var expert_rank: zml.Tensor = .scatterSlices(
|
||||
.zeroes(gating.shape().withDtype(.i32)),
|
||||
.{ .expert = token_pref },
|
||||
.addConstant(.iota(gating.shape(), .expert), 1),
|
||||
.{ .indices_are_unique = true },
|
||||
);
|
||||
// The pow(expert_rank) here means that we strongly favor top 1 over top 2 and top 2 over top 3.
|
||||
// expert_routing: (expert, top_token) -> token
|
||||
const expert_routing = gating.pow(expert_rank.convert(gating.dtype())).topK(.{ .top_token = .s }, opts.tokens_per_expert, .{});
|
||||
const scores_per_expert = gating.gather(.{ .s = expert_routing.indices }, .{});
|
||||
|
||||
// Update the gating coefficient to account for the expert routing.
|
||||
// Each (token, expert) which can't be computed within the given budget is left to 0.
|
||||
const gating_v2: zml.Tensor = .scatterSlices(
|
||||
.zeroes(gating.shape()),
|
||||
.{ .s = expert_routing.indices },
|
||||
scores_per_expert,
|
||||
.{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
);
|
||||
// Now set to zero the scores (token, expert) for tokens that have been assigned more than experts_per_token.
|
||||
const lowest_experts = gating_v2.topK(.{ .top_expert = .expert }, @intCast(num_experts - opts.experts_per_token), .{ .descending = false });
|
||||
var gating_v3: zml.Tensor = .scatterSlices(
|
||||
gating_v2,
|
||||
.{ .expert = lowest_experts.indices },
|
||||
.zeroes(lowest_experts.values.shape()),
|
||||
.{ .indices_are_unique = true, .update_fn = zml.Tensor.ScatterOpts.override },
|
||||
);
|
||||
// Then normalize so the sum of experts scores for one token sums up to 1.
|
||||
gating_v3 = switch (opts.normalization) {
|
||||
.linear => gating_v3.div(gating_v3.sum(.expert)),
|
||||
.softmax => gating_v3.softmax(.expert),
|
||||
};
|
||||
const tokens_ids_per_expert = expert_routing.indices.transpose(.{ .expert, .top_token });
|
||||
|
||||
return .{ gating_v3, tokens_ids_per_expert };
|
||||
}
|
||||
376
examples/gpt_oss/main.zig
Normal file
376
examples/gpt_oss/main.zig
Normal file
@ -0,0 +1,376 @@
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const async = @import("async");
|
||||
const clap = @import("clap");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
const Buffer = zml.Buffer;
|
||||
const Tensor = zml.Tensor;
|
||||
const ShapeOf = zml.ShapeOf;
|
||||
|
||||
const GptOss = @import("GptOss.zig");
|
||||
|
||||
const log = std.log.scoped(.GptOss);
|
||||
|
||||
pub const std_options: std.Options = .{
|
||||
.log_level = .info,
|
||||
.logFn = async.logFn(std.log.defaultLog),
|
||||
};
|
||||
|
||||
const cli_params = clap.parseParamsComptime(
|
||||
\\--help print this help
|
||||
\\--prompt <STRING> the prompt
|
||||
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
|
||||
\\--seed <UINT> random seed (optional)
|
||||
\\--seq-len <UINT> max sequence length
|
||||
\\--prompt-len <UINT> max prompt length
|
||||
\\--temperature <FLOAT> temperature (default 1.0)
|
||||
\\--topk <UINT> topk (default 10)
|
||||
\\--expert-budget <FLOAT> token budget per expert
|
||||
\\--platform-options <STRING> platform options, using Zon syntax, eg '.{.cuda=.{.allocator=.{.async=.{.memory_fraction=0.95}}}}'
|
||||
\\--nochat <BOOL> skip prompt template
|
||||
\\--sharding <BOOL> default: true: sharding on or off
|
||||
);
|
||||
|
||||
pub fn tokenizePrompt(tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, no_chat: bool, out: []u32) ![]u32 {
|
||||
var encoder = try tokenizer.encoder();
|
||||
defer encoder.deinit();
|
||||
|
||||
if (no_chat) {
|
||||
const tokens = try encoder.encode(prompt);
|
||||
if (tokens.len > out.len) return error.PromptTooLong;
|
||||
@memcpy(out[0..tokens.len], tokens);
|
||||
return out[0..tokens.len];
|
||||
}
|
||||
|
||||
const start_header = tokenizer.tokenToId("<|start|>") orelse return error.NoSuchToken;
|
||||
const end_header_start_message = tokenizer.tokenToId("<|message|>") orelse return error.NoSuchToken;
|
||||
const end_message = tokenizer.tokenToId("<|end|>") orelse return error.NoSuchToken;
|
||||
|
||||
var tokens: std.ArrayList(u32) = .initBuffer(out);
|
||||
|
||||
const system_prompt = try encoder.encode("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
if (system_prompt.len + 4 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
|
||||
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("system").?, end_header_start_message });
|
||||
tokens.appendSliceAssumeCapacity(system_prompt);
|
||||
tokens.appendAssumeCapacity(end_message);
|
||||
|
||||
const user_prompt = try encoder.encode(prompt);
|
||||
if (user_prompt.len + 9 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
|
||||
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("user").?, end_header_start_message });
|
||||
tokens.appendSliceAssumeCapacity(user_prompt);
|
||||
tokens.appendSliceAssumeCapacity(&.{
|
||||
end_message,
|
||||
start_header,
|
||||
tokenizer.tokenToId("assistant").?,
|
||||
tokenizer.tokenToId("<|channel|>") orelse return error.NoSuchToken,
|
||||
tokenizer.tokenToId("analysis") orelse return error.NoSuchToken,
|
||||
end_header_start_message,
|
||||
});
|
||||
|
||||
return tokens.items;
|
||||
}
|
||||
|
||||
pub fn generateText(
|
||||
config: GptOss.Config,
|
||||
options: GptOss.Options,
|
||||
mod_prefill: zml.ModuleExe(GptOss.forward),
|
||||
mod_generate: zml.ModuleExe(GptOss.forward),
|
||||
kv_cache_: zml.Bufferized(GptOss.KvCache),
|
||||
tokenizer: zml.tokenizer.Tokenizer,
|
||||
allocator: std.mem.Allocator,
|
||||
seed: u128,
|
||||
prompt_tok: []const u32,
|
||||
output: *std.Io.Writer,
|
||||
) !void {
|
||||
var tokenizer_decoder = try tokenizer.decoder();
|
||||
defer tokenizer_decoder.deinit();
|
||||
|
||||
const platform = mod_generate.platform();
|
||||
|
||||
// init RNG and buffers
|
||||
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||||
var generated_token_buffer = [_]u32{undefined};
|
||||
|
||||
var current_token, var kv_cache = prefill: {
|
||||
// prepare device buffers for the prefill tokens and their positions
|
||||
const prefill_buffer = try allocator.alloc(u32, options.max_prompt_len);
|
||||
@memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok);
|
||||
|
||||
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{options.max_prompt_len}, prefill_buffer);
|
||||
defer prefill_tokens.deinit();
|
||||
var prefill_token_pos = try zml.Buffer.scalar(platform, prompt_tok.len, .u32);
|
||||
defer prefill_token_pos.deinit();
|
||||
|
||||
const first_token, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, .{ .prefill = prefill_token_pos }, kv_cache_, rng });
|
||||
|
||||
// extract the first generated token
|
||||
_ = try first_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||||
log.warn("first_token: {d}", .{generated_token_buffer[0]});
|
||||
break :prefill .{ first_token, kv_cache };
|
||||
};
|
||||
defer zml.aio.unloadBuffers(&kv_cache);
|
||||
defer current_token.deinit();
|
||||
|
||||
const output_tokens_len = options.max_seq_len - prompt_tok.len - 1;
|
||||
const start = std.time.microTimestamp();
|
||||
|
||||
// One token has already been generated by the prefill.
|
||||
var num_tokens_generated: usize = 1;
|
||||
|
||||
generation: for (0..output_tokens_len + 1) |i| {
|
||||
// collect and print generated sequence
|
||||
num_tokens_generated += 1;
|
||||
const generated_token = generated_token_buffer[0];
|
||||
if (try tokenizer_decoder.next(generated_token)) |chunk| {
|
||||
try output.writeAll(chunk);
|
||||
}
|
||||
|
||||
// check for eos
|
||||
if (i == output_tokens_len) break :generation;
|
||||
switch (config.eos_token_id.value) {
|
||||
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
||||
.ints => |eos_list| {
|
||||
for (eos_list) |eos| {
|
||||
if (generated_token == @as(u32, @intCast(eos))) break :generation;
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// current token pos needs to go into a zml.Buffer
|
||||
const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
||||
const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer);
|
||||
defer token_pos.deinit();
|
||||
|
||||
// call to generate the next token
|
||||
current_token, kv_cache, rng = mod_generate.call(.{ current_token, .{ .gen = token_pos }, kv_cache, rng });
|
||||
|
||||
// extract the generated token from the buffer
|
||||
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||||
}
|
||||
const end = std.time.microTimestamp();
|
||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||||
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
try async.AsyncThread.main(std.heap.smp_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
log.info(" GptOss was compiled with {}", .{@import("builtin").mode});
|
||||
|
||||
var allocator: std.mem.Allocator = alloc: {
|
||||
if (builtin.mode == .Debug) {
|
||||
var dbg_alloc: std.heap.DebugAllocator(.{
|
||||
.never_unmap = true,
|
||||
.retain_metadata = true,
|
||||
}) = .init;
|
||||
break :alloc dbg_alloc.allocator();
|
||||
}
|
||||
break :alloc std.heap.smp_allocator;
|
||||
};
|
||||
|
||||
const cli = ClapBoilerplate.parseCli(allocator);
|
||||
defer cli.deinit();
|
||||
|
||||
const hf_model_path = cli.args.@"hf-model-path" orelse {
|
||||
log.err("Missing --hf-model-path", .{});
|
||||
return;
|
||||
};
|
||||
|
||||
const config = config: {
|
||||
var arena: std.heap.ArenaAllocator = .init(allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
const model_config_path = try std.fs.path.join(arena.allocator(), &.{ hf_model_path, "config.json" });
|
||||
|
||||
var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only });
|
||||
defer config_json_file.close() catch unreachable;
|
||||
|
||||
var config_reader = config_json_file.reader(try arena.allocator().alloc(u8, 256));
|
||||
var reader = std.json.Reader.init(allocator, &config_reader.interface);
|
||||
defer reader.deinit();
|
||||
var config = try std.json.parseFromTokenSourceLeaky(GptOss.Config, arena.allocator(), &reader, .{ .ignore_unknown_fields = true });
|
||||
|
||||
// From generation_config.json
|
||||
config.eos_token_id = .{ .value = .{ .ints = &.{ 200002, 199999, 200012 } } };
|
||||
break :config config;
|
||||
};
|
||||
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
// initialize ZML platform
|
||||
const platform: zml.Platform = platform: {
|
||||
const arena: std.heap.ArenaAllocator = .init(allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
// eg: --platform-options='.{.cuda=.{.allocator=.{.bfc=.{.memory_fraction=0.99}}}}'
|
||||
// eg: --platform-options='.{.cpu=.{.device_count=8}}'
|
||||
const platform_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(cli.args.@"platform-options" orelse ".{}"), null, .{ .free_on_error = false }) catch |err| {
|
||||
log.err("Failed to parse --platform-options as json ({}): {s}", .{ err, cli.args.@"platform-options".? });
|
||||
return err;
|
||||
};
|
||||
|
||||
const compilation_options = zml.CompilationOptions{
|
||||
.xla_dump_to = "/tmp/zml/gpt_oss",
|
||||
.sharding_enabled = cli.args.sharding orelse true,
|
||||
};
|
||||
|
||||
const platform = context
|
||||
.autoPlatform(platform_opts)
|
||||
.withCompilationOptions(compilation_options);
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
break :platform platform;
|
||||
};
|
||||
|
||||
const options: GptOss.Options = .{
|
||||
.max_seq_len = cli.args.@"seq-len" orelse 8192,
|
||||
.max_prompt_len = cli.args.@"prompt-len" orelse 256,
|
||||
.tokens_per_expert_ratio = cli.args.@"expert-budget" orelse 4.0,
|
||||
.sampling_strategy = .{
|
||||
.topk = cli.args.topk orelse 10,
|
||||
.temperature = 1.0,
|
||||
},
|
||||
};
|
||||
|
||||
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
|
||||
defer compiler_arena.deinit();
|
||||
|
||||
const model_weights_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" });
|
||||
defer allocator.free(model_weights_path);
|
||||
|
||||
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
|
||||
defer store.deinit();
|
||||
|
||||
const model: GptOss = try GptOss.init(compiler_arena.allocator(), store, config, options);
|
||||
|
||||
const tokens_shape_prefill = zml.Shape.init(.{ .s = options.max_prompt_len }, .u32);
|
||||
const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
|
||||
|
||||
const dtype = model.model.embed_tokens.weight.dtype();
|
||||
|
||||
const kv_shape = zml.Shape.init(.{
|
||||
.layer = model.model.layers.len,
|
||||
.k = options.max_seq_len,
|
||||
.h = config.num_key_value_heads,
|
||||
.hd = config.head_dim,
|
||||
}, dtype).withSharding(.{.h});
|
||||
|
||||
const kv_cache_shape: zml.ShapeOf(GptOss.KvCache) = GptOss.KvCache.initShape(kv_shape);
|
||||
const rng_shape = zml.Tensor.Rng.shape();
|
||||
|
||||
var start = try std.time.Timer.start();
|
||||
var fut_mod_prefill = try async.async(zml.compileModel, .{
|
||||
allocator, GptOss.forward, model,
|
||||
.{
|
||||
tokens_shape_prefill,
|
||||
zml.ShapeOf(GptOss.Mode){ .prefill = .scalar(.u32) },
|
||||
kv_cache_shape,
|
||||
rng_shape,
|
||||
},
|
||||
platform,
|
||||
});
|
||||
|
||||
var fut_mod = try async.async(zml.compileModel, .{
|
||||
allocator, GptOss.forward, model,
|
||||
.{
|
||||
tokens_shape,
|
||||
zml.ShapeOf(GptOss.Mode){ .gen = .scalar(.u32) },
|
||||
kv_cache_shape,
|
||||
rng_shape,
|
||||
},
|
||||
platform,
|
||||
});
|
||||
|
||||
log.info("\tLoading GptOss weights from {s}...", .{model_weights_path});
|
||||
var gpt_oss_weights = try model.loadBuffers(compiler_arena.allocator(), store, platform);
|
||||
defer zml.aio.unloadBuffers(&gpt_oss_weights);
|
||||
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||||
|
||||
var module_prefill = (try fut_mod_prefill.await()).prepare(gpt_oss_weights);
|
||||
defer module_prefill.deinit();
|
||||
var module_gen = (try fut_mod.await()).prepare(gpt_oss_weights);
|
||||
defer module_gen.deinit();
|
||||
log.info("✅\tCompiled model in {D}", .{start.read()});
|
||||
|
||||
log.info("Creating KvCache", .{});
|
||||
const kv_cache = try GptOss.KvCache.initBuffer(kv_shape, platform);
|
||||
|
||||
var tokenizer = blk: {
|
||||
const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" });
|
||||
defer allocator.free(model_tokenizer_path);
|
||||
|
||||
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
|
||||
var timer = try stdx.time.Timer.start();
|
||||
defer log.info("Loaded tokenizer from {s} [{f}]", .{ model_tokenizer_path, timer.read() });
|
||||
|
||||
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
|
||||
};
|
||||
errdefer tokenizer.deinit();
|
||||
|
||||
const prompt = cli.args.prompt orelse "What are some fun facts about animals?";
|
||||
log.info("✅\tPrompt: {s}", .{prompt});
|
||||
|
||||
const no_chat = cli.args.nochat orelse false;
|
||||
const prompt_tok_buf = try allocator.alloc(u32, options.max_prompt_len);
|
||||
defer allocator.free(prompt_tok_buf);
|
||||
|
||||
const prompt_tok = tokenizePrompt(tokenizer, prompt, no_chat, prompt_tok_buf) catch |err| switch (err) {
|
||||
error.PromptTooLong => std.debug.panic("Prompt too long, expected at most {d} tokens. Consider increasing --max-prompt-len", .{prompt_tok_buf.len}),
|
||||
else => |e| return e,
|
||||
};
|
||||
log.info("\t Tokenized prompt: {any} ({d} tokens)", .{ prompt_tok, prompt_tok.len });
|
||||
|
||||
const seed = cli.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||||
|
||||
// Unbuffered writing of the tokens to stdout.
|
||||
// generated text will be printed token by token.
|
||||
var output = std.fs.File.stdout().writer(&.{});
|
||||
|
||||
try generateText(config, options, module_prefill, module_gen, kv_cache, tokenizer, allocator, seed, prompt_tok, &output.interface);
|
||||
}
|
||||
|
||||
const ClapBoilerplate = struct {
|
||||
pub const Cli = clap.Result(clap.Help, &cli_params, parsers);
|
||||
|
||||
fn bool_parser(in: []const u8) error{}!bool {
|
||||
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||
}
|
||||
|
||||
const parsers = .{
|
||||
.BOOL = bool_parser,
|
||||
.UINT = clap.parsers.int(u32, 0),
|
||||
.FLOAT = clap.parsers.float(f32),
|
||||
.STRING = clap.parsers.string,
|
||||
.PATH = clap.parsers.string,
|
||||
};
|
||||
|
||||
pub fn parseCli(allocator: std.mem.Allocator) Cli {
|
||||
var diag: clap.Diagnostic = .{};
|
||||
var stderr_buffer: [1024]u8 = undefined;
|
||||
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
||||
const cli = clap.parse(clap.Help, &cli_params, parsers, .{
|
||||
.diagnostic = &diag,
|
||||
.allocator = allocator,
|
||||
}) catch |err| {
|
||||
diag.report(&stderr.interface, err) catch {};
|
||||
stderr.interface.print("usage: ", .{}) catch {};
|
||||
clap.usage(&stderr.interface, clap.Help, &cli_params) catch {};
|
||||
stderr.interface.print("\n", .{}) catch {};
|
||||
stderr.interface.flush() catch {};
|
||||
std.process.exit(1);
|
||||
};
|
||||
if (cli.args.help != 0) {
|
||||
clap.help(&stderr.interface, clap.Help, &cli_params, .{}) catch {};
|
||||
stderr.interface.flush() catch {};
|
||||
std.process.exit(0);
|
||||
}
|
||||
return cli;
|
||||
}
|
||||
};
|
||||
Loading…
Reference in New Issue
Block a user