Update llama and simple_layer examples to use BufferStore tensor IDs, new CPU device count API, and fix zml.call tag hashing.
This commit is contained in:
parent
cc969bd532
commit
6e7617918d
@ -18,19 +18,22 @@ pub const LlamaLM = struct {
|
|||||||
int: u32,
|
int: u32,
|
||||||
ints: []u32,
|
ints: []u32,
|
||||||
}),
|
}),
|
||||||
num_hidden_layers: usize,
|
head_dim: ?u32,
|
||||||
num_attention_heads: usize,
|
hidden_size: u32,
|
||||||
num_key_value_heads: usize,
|
num_hidden_layers: u32,
|
||||||
|
num_attention_heads: u32,
|
||||||
|
num_key_value_heads: u32,
|
||||||
rope_theta: f32,
|
rope_theta: f32,
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: u32,
|
||||||
rms_norm_eps: f32,
|
rms_norm_eps: f32,
|
||||||
hf_rope_impl: bool = true,
|
hf_rope_impl: bool = true,
|
||||||
|
tie_word_embeddings: bool = false,
|
||||||
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
|
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Options = struct {
|
pub const Options = struct {
|
||||||
sampling_strategy: ?zml.nn.SamplingStrategy,
|
sampling_strategy: ?zml.nn.SamplingStrategy,
|
||||||
max_seq_len: usize,
|
max_seq_len: u32,
|
||||||
};
|
};
|
||||||
|
|
||||||
lm_head: ?zml.nn.Linear,
|
lm_head: ?zml.nn.Linear,
|
||||||
@ -40,39 +43,79 @@ pub const LlamaLM = struct {
|
|||||||
gen_opts: zml.nn.SamplingStrategy = .{},
|
gen_opts: zml.nn.SamplingStrategy = .{},
|
||||||
config: Config,
|
config: Config,
|
||||||
|
|
||||||
pub fn init(self: *LlamaLM, config: Config, options: Options) void {
|
pub fn init(allocator: std.mem.Allocator, config: Config, options: Options, store: zml.aio.BufferStore) !LlamaLM {
|
||||||
self.config = config;
|
const rope_opts: zml.nn.RopeOpts = .{
|
||||||
self.gen_opts = options.sampling_strategy orelse .{};
|
|
||||||
self.model.max_seq_len = @intCast(options.max_seq_len);
|
|
||||||
self.model.num_heads = @intCast(config.num_attention_heads);
|
|
||||||
self.model.num_kv_heads = @intCast(config.num_key_value_heads);
|
|
||||||
self.model.rope_opts = .{
|
|
||||||
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
||||||
.freq_base = config.rope_theta,
|
.freq_base = config.rope_theta,
|
||||||
.scaling = config.rope_scaling,
|
.scaling = config.rope_scaling,
|
||||||
};
|
};
|
||||||
self.model.norm.eps = config.rms_norm_eps;
|
|
||||||
for (self.model.layers) |*layer| {
|
|
||||||
layer.self_attn.num_heads = self.model.num_heads;
|
|
||||||
layer.self_attn.num_kv_heads = self.model.num_kv_heads;
|
|
||||||
layer.self_attn.rope_opts = self.model.rope_opts;
|
|
||||||
layer.input_layernorm.eps = config.rms_norm_eps;
|
|
||||||
layer.post_attention_layernorm.eps = config.rms_norm_eps;
|
|
||||||
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
|
|
||||||
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
|
|
||||||
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
|
|
||||||
|
|
||||||
layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0});
|
const layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers);
|
||||||
layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0});
|
var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024);
|
||||||
layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0});
|
try prefix.push(stdx.noalloc, "model.layers");
|
||||||
layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1});
|
for (0.., layers) |i, *layer| {
|
||||||
|
try prefix.pushDigit(stdx.noalloc, i);
|
||||||
|
defer prefix.pop();
|
||||||
|
var self_attn = try zml.aio.populateModelWithPrefix(SelfAttn, allocator, store, prefix.concat("self_attn"));
|
||||||
|
self_attn.num_heads = config.num_attention_heads;
|
||||||
|
self_attn.num_kv_heads = config.num_key_value_heads;
|
||||||
|
self_attn.rope_opts = rope_opts;
|
||||||
|
self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0});
|
||||||
|
self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0});
|
||||||
|
self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0});
|
||||||
|
self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1});
|
||||||
|
|
||||||
|
var input_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("input_layernorm"));
|
||||||
|
input_layernorm.eps = config.rms_norm_eps;
|
||||||
|
|
||||||
|
var post_attention_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("post_attention_layernorm"));
|
||||||
|
post_attention_layernorm.eps = config.rms_norm_eps;
|
||||||
|
|
||||||
|
var mlp = try zml.aio.populateModelWithPrefix(Mlp, allocator, store, prefix.concat("mlp"));
|
||||||
|
mlp.up_proj.weight = mlp.up_proj.weight.withSharding(.{0});
|
||||||
|
mlp.gate_proj.weight = mlp.gate_proj.weight.withSharding(.{0});
|
||||||
|
mlp.down_proj.weight = mlp.down_proj.weight.withSharding(.{1});
|
||||||
|
|
||||||
|
layer.* = .{
|
||||||
|
.self_attn = self_attn,
|
||||||
|
.input_layernorm = input_layernorm,
|
||||||
|
.post_attention_layernorm = post_attention_layernorm,
|
||||||
|
.mlp = mlp,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
var lm_head: ?zml.nn.Linear = null;
|
||||||
// It currently crashes/compilation fails
|
if (!config.tie_word_embeddings) {
|
||||||
if (self.gen_opts.topk == 1 and self.lm_head != null) {
|
lm_head = .{ .weight = store.getTensor("lm_head.weight") };
|
||||||
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
|
if (options.sampling_strategy) |gen_opts| {
|
||||||
|
if (gen_opts.topk == 1)
|
||||||
|
lm_head.?.weight = lm_head.?.weight.withSharding(.{0});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.config = config,
|
||||||
|
.gen_opts = options.sampling_strategy orelse .{},
|
||||||
|
.model = .{
|
||||||
|
// Weights
|
||||||
|
.layers = layers,
|
||||||
|
.embed_tokens = .{ .weight = store.getTensor("model.embed_tokens.weight") },
|
||||||
|
.norm = .{
|
||||||
|
.weight = store.getTensor("model.norm.weight"),
|
||||||
|
.eps = config.rms_norm_eps,
|
||||||
|
},
|
||||||
|
// Push down some configs
|
||||||
|
.max_seq_len = options.max_seq_len,
|
||||||
|
.num_heads = config.num_attention_heads,
|
||||||
|
.num_kv_heads = config.num_key_value_heads,
|
||||||
|
.rope_opts = .{
|
||||||
|
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
||||||
|
.freq_base = config.rope_theta,
|
||||||
|
.scaling = config.rope_scaling,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
.lm_head = lm_head,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predicts the token at `token_index` position.
|
/// Predicts the token at `token_index` position.
|
||||||
@ -129,36 +172,13 @@ pub const Llama = struct {
|
|||||||
layers: []TransformerLayer,
|
layers: []TransformerLayer,
|
||||||
|
|
||||||
max_seq_len: u32 = 0,
|
max_seq_len: u32 = 0,
|
||||||
num_heads: i64 = 32,
|
num_heads: u32 = 32,
|
||||||
num_kv_heads: i64 = 32,
|
num_kv_heads: u32 = 32,
|
||||||
rope_opts: zml.nn.RopeOpts = .{
|
rope_opts: zml.nn.RopeOpts = .{
|
||||||
.layout = .interleaved,
|
.layout = .interleaved,
|
||||||
.freq_base = 10_000,
|
.freq_base = 10_000,
|
||||||
},
|
},
|
||||||
|
|
||||||
const Shape = struct {
|
|
||||||
s: u32,
|
|
||||||
layer: u16,
|
|
||||||
hd: u16,
|
|
||||||
nh: u16,
|
|
||||||
nkvh: u16,
|
|
||||||
dtype: zml.DataType,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn shape(self: Llama) Shape {
|
|
||||||
const key_dim = self.layers[0].self_attn.k_proj.weight.dim(0);
|
|
||||||
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
|
||||||
|
|
||||||
return .{
|
|
||||||
.s = self.max_seq_len,
|
|
||||||
.layer = @intCast(self.layers.len),
|
|
||||||
.hd = @intCast(@divExact(key_dim, num_kv_heads)),
|
|
||||||
.nh = @intCast(self.num_heads),
|
|
||||||
.nkvh = @intCast(num_kv_heads),
|
|
||||||
.dtype = self.embed_tokens.weight.dtype(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Forward one token, using KV cache for previous tokens.
|
/// Forward one token, using KV cache for previous tokens.
|
||||||
/// Returns result and updated KV cache.
|
/// Returns result and updated KV cache.
|
||||||
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
|
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
|
||||||
@ -177,14 +197,6 @@ pub const Llama = struct {
|
|||||||
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
|
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
|
||||||
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
|
|
||||||
const dims = self.shape();
|
|
||||||
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
|
|
||||||
const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd });
|
|
||||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
|
||||||
return KvCache.init(kv_shape);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const TransformerLayer = struct {
|
pub const TransformerLayer = struct {
|
||||||
@ -245,6 +257,9 @@ pub const SelfAttn = struct {
|
|||||||
k_proj: zml.nn.Linear,
|
k_proj: zml.nn.Linear,
|
||||||
v_proj: zml.nn.Linear,
|
v_proj: zml.nn.Linear,
|
||||||
|
|
||||||
|
q_norm: ?RmsNorm,
|
||||||
|
k_norm: ?RmsNorm,
|
||||||
|
|
||||||
o_proj: zml.nn.Linear,
|
o_proj: zml.nn.Linear,
|
||||||
num_heads: i64 = undefined,
|
num_heads: i64 = undefined,
|
||||||
num_kv_heads: i64 = 0,
|
num_kv_heads: i64 = 0,
|
||||||
@ -282,6 +297,8 @@ pub const SelfAttn = struct {
|
|||||||
break :b temp.add(token_index.broad(temp.shape()));
|
break :b temp.add(token_index.broad(temp.shape()));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (self.q_norm) |norm| q = norm.forward(q.rename(.{ .hd = .d })).rename(.{ .d = .hd });
|
||||||
|
if (self.k_norm) |norm| k = norm.forward(k.rename(.{ .hd = .d })).rename(.{ .d = .hd });
|
||||||
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
q = zml.nn.rope(q, pos_index, self.rope_opts);
|
||||||
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
k = zml.nn.rope(k, pos_index, self.rope_opts);
|
||||||
q = q.rename(.{ .s = .q });
|
q = q.rename(.{ .s = .q });
|
||||||
@ -298,16 +315,6 @@ pub const SelfAttn = struct {
|
|||||||
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
||||||
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initKvCache(key_shape: zml.Shape) KvCache {
|
|
||||||
// When we call initKvCache, we haven't renamed .s to .k yet.
|
|
||||||
var kv_shape = key_shape.insert(0, .{ .layer = 1 }).rename(.{ .s = .k });
|
|
||||||
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
|
|
||||||
kv_shape = kv_shape.transpose(perm.constSlice());
|
|
||||||
var res = KvCache.init(kv_shape);
|
|
||||||
res.layer_index = Tensor.scalar(0, .u32);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const KvCache = struct {
|
pub const KvCache = struct {
|
||||||
@ -334,8 +341,8 @@ pub const KvCache = struct {
|
|||||||
|
|
||||||
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
|
||||||
return .{
|
return .{
|
||||||
.k = try zml.Buffer.constant(platform, kv_shape, 1),
|
.k = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
||||||
.v = try zml.Buffer.constant(platform, kv_shape, 1),
|
.v = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
|
||||||
.layer_index = try zml.Buffer.scalar(platform, 0, .u32),
|
.layer_index = try zml.Buffer.scalar(platform, 0, .u32),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,35 +22,42 @@ pub const std_options: std.Options = .{
|
|||||||
.logFn = asynk.logFn(std.log.defaultLog),
|
.logFn = asynk.logFn(std.log.defaultLog),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const params = clap.parseParamsComptime(
|
||||||
|
\\--help print this help
|
||||||
|
\\--prompt <STRING> the prompt
|
||||||
|
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
|
||||||
|
\\--seed <UINT> random seed (optional)
|
||||||
|
\\--seq-len <UINT> sequence length
|
||||||
|
\\--create-options <STRING> platform creation options JSON, defaults to {}
|
||||||
|
\\--no-llama3 <BOOL> skip prompt template
|
||||||
|
\\--sharding <BOOL> default: true: sharding on or off
|
||||||
|
);
|
||||||
|
|
||||||
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
|
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
|
||||||
var tokens = std.array_list.Managed(u32).init(allocator);
|
|
||||||
var encoder = try tokenizer.encoder();
|
var encoder = try tokenizer.encoder();
|
||||||
defer encoder.deinit();
|
defer encoder.deinit();
|
||||||
|
|
||||||
if (skip_llama3_encoding) {
|
if (skip_llama3_encoding) {
|
||||||
// Copy to the arraylist so the ownership is the same in both branches.
|
// Copy so the ownership is the same in both branches.
|
||||||
try tokens.appendSlice(try encoder.encode(prompt));
|
return try allocator.dupe(u32, try encoder.encode(prompt));
|
||||||
return tokens.toOwnedSlice();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
|
const start_header = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
|
||||||
const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
|
const end_header = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
|
||||||
const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
|
const user = tokenizer.tokenToId("user") orelse return error.NoSuchToken;
|
||||||
const newline_id = (try encoder.encode("\n"))[0];
|
const assistant = tokenizer.tokenToId("assistant") orelse return error.NoSuchToken;
|
||||||
|
const eot = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
|
||||||
|
const newline = (try encoder.encode("\n"))[0];
|
||||||
|
|
||||||
try tokens.append(config.bos_token_id);
|
var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len);
|
||||||
|
try tokens.appendSlice(allocator, &.{ config.bos_token_id, start_header, user, end_header, newline });
|
||||||
|
|
||||||
try tokens.append(start_header_id);
|
try tokens.appendSlice(allocator, try encoder.encode(prompt));
|
||||||
try tokens.appendSlice(try encoder.encode("user"));
|
try tokens.appendSlice(allocator, &.{ eot, newline });
|
||||||
try tokens.appendSlice(&.{ end_header_id, newline_id });
|
|
||||||
|
|
||||||
try tokens.appendSlice(try encoder.encode(prompt));
|
try tokens.appendSlice(allocator, &.{ start_header, assistant, end_header, newline });
|
||||||
try tokens.appendSlice(&.{ eot_id, newline_id });
|
|
||||||
try tokens.append(start_header_id);
|
|
||||||
try tokens.appendSlice(try encoder.encode("assistant"));
|
|
||||||
try tokens.appendSlice(&.{ end_header_id, newline_id });
|
|
||||||
|
|
||||||
return tokens.toOwnedSlice();
|
return tokens.toOwnedSlice(allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generateText(
|
pub fn generateText(
|
||||||
@ -64,7 +71,8 @@ pub fn generateText(
|
|||||||
seed: u128,
|
seed: u128,
|
||||||
prompt: []const u8,
|
prompt: []const u8,
|
||||||
skip_llama3_encoding: bool,
|
skip_llama3_encoding: bool,
|
||||||
) ![]const u8 {
|
writer: *std.Io.Writer,
|
||||||
|
) !void {
|
||||||
const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding);
|
const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding);
|
||||||
defer allocator.free(prompt_tok);
|
defer allocator.free(prompt_tok);
|
||||||
|
|
||||||
@ -72,7 +80,7 @@ pub fn generateText(
|
|||||||
defer tokenizer_decoder.deinit();
|
defer tokenizer_decoder.deinit();
|
||||||
|
|
||||||
const platform = mod_generate.platform();
|
const platform = mod_generate.platform();
|
||||||
const max_seq_len = llama_.model.shape().s;
|
const max_seq_len = llama_.model.max_seq_len;
|
||||||
|
|
||||||
// init RNG and buffers
|
// init RNG and buffers
|
||||||
var rng = try zml.Tensor.Rng.init(platform, seed);
|
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||||||
@ -100,10 +108,6 @@ pub fn generateText(
|
|||||||
var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer);
|
var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer);
|
||||||
defer current_token.deinit();
|
defer current_token.deinit();
|
||||||
|
|
||||||
// Here we collect the generated text
|
|
||||||
var output = std.array_list.Managed(u8).init(allocator);
|
|
||||||
defer output.deinit();
|
|
||||||
|
|
||||||
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
||||||
const start = std.time.microTimestamp();
|
const start = std.time.microTimestamp();
|
||||||
|
|
||||||
@ -114,9 +118,9 @@ pub fn generateText(
|
|||||||
// collect and print generated sequence
|
// collect and print generated sequence
|
||||||
num_tokens_generated += 1;
|
num_tokens_generated += 1;
|
||||||
const generated_token = generated_token_buffer[0];
|
const generated_token = generated_token_buffer[0];
|
||||||
const chunk = try tokenizer_decoder.next(generated_token) orelse unreachable;
|
if (try tokenizer_decoder.next(generated_token)) |chunk| {
|
||||||
try output.appendSlice(chunk);
|
try writer.writeAll(chunk);
|
||||||
std.debug.print("{s}", .{chunk});
|
}
|
||||||
|
|
||||||
// check for eos
|
// check for eos
|
||||||
if (i == output_tokens_len) break :generation;
|
if (i == output_tokens_len) break :generation;
|
||||||
@ -145,22 +149,6 @@ pub fn generateText(
|
|||||||
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||||||
std.debug.print("\n", .{});
|
std.debug.print("\n", .{});
|
||||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||||||
return output.toOwnedSlice();
|
|
||||||
}
|
|
||||||
|
|
||||||
const params = clap.parseParamsComptime(
|
|
||||||
\\--help print this help
|
|
||||||
\\--prompt <STRING> the prompt
|
|
||||||
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
|
|
||||||
\\--seed <UINT> random seed (optional)
|
|
||||||
\\--seq-len <UINT> sequence length
|
|
||||||
\\--create-options <STRING> platform creation options JSON, defaults to {}
|
|
||||||
\\--no-llama3 <BOOL> skip prompt template
|
|
||||||
\\--sharding <BOOL> default: true: sharding on or off
|
|
||||||
);
|
|
||||||
|
|
||||||
pub fn bool_parser(in: []const u8) error{}!bool {
|
|
||||||
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() !void {
|
pub fn main() !void {
|
||||||
@ -174,31 +162,33 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
const parsers = comptime .{
|
const parsers = comptime .{
|
||||||
.BOOL = bool_parser,
|
.BOOL = bool_parser,
|
||||||
.UINT = clap.parsers.int(usize, 0),
|
.UINT = clap.parsers.int(u32, 0),
|
||||||
.STRING = clap.parsers.string,
|
.STRING = clap.parsers.string,
|
||||||
.PATH = clap.parsers.string,
|
.PATH = clap.parsers.string,
|
||||||
};
|
};
|
||||||
var diag: clap.Diagnostic = .{};
|
var diag: clap.Diagnostic = .{};
|
||||||
var stderr_buffer: [1024]u8 = undefined;
|
var stderr_buffer: [1024]u8 = undefined;
|
||||||
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
||||||
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
defer stderr.interface.flush() catch {};
|
||||||
|
|
||||||
|
var cli = clap.parse(clap.Help, ¶ms, parsers, .{
|
||||||
.diagnostic = &diag,
|
.diagnostic = &diag,
|
||||||
.allocator = allocator,
|
.allocator = allocator,
|
||||||
}) catch |err| {
|
}) catch |err| {
|
||||||
diag.report(&stderr.interface, err) catch {};
|
diag.report(&stderr.interface, err) catch {};
|
||||||
stderr.interface.print("usage: ", .{}) catch {};
|
stderr.interface.writeAll("usage: ") catch {};
|
||||||
clap.usage(&stderr.interface, clap.Help, ¶ms) catch {};
|
clap.usage(&stderr.interface, clap.Help, ¶ms) catch {};
|
||||||
stderr.interface.print("\n", .{}) catch {};
|
stderr.interface.writeAll("\n") catch {};
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
defer res.deinit();
|
defer cli.deinit();
|
||||||
|
|
||||||
if (res.args.help != 0) {
|
if (cli.args.help != 0) {
|
||||||
clap.help(&stderr.interface, clap.Help, ¶ms, .{}) catch {};
|
clap.help(&stderr.interface, clap.Help, ¶ms, .{}) catch {};
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const hf_model_path = res.args.@"hf-model-path" orelse {
|
const hf_model_path = cli.args.@"hf-model-path" orelse {
|
||||||
log.err("Missing --hf-model-path", .{});
|
log.err("Missing --hf-model-path", .{});
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
@ -238,80 +228,87 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
const compilation_options = zml.CompilationOptions{
|
const compilation_options = zml.CompilationOptions{
|
||||||
.xla_dump_to = "/tmp/zml/llama",
|
.xla_dump_to = "/tmp/zml/llama",
|
||||||
.sharding_enabled = res.args.sharding orelse true,
|
.sharding_enabled = cli.args.sharding orelse true,
|
||||||
};
|
};
|
||||||
|
|
||||||
// initialize ZML platform with optional create options
|
// initialize ZML platform with optional create options
|
||||||
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
||||||
const create_opts_json = res.args.@"create-options" orelse "{}";
|
const create_opts_json = cli.args.@"create-options" orelse "{}";
|
||||||
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
|
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
|
||||||
const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
|
const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
|
||||||
create_opts.deinit();
|
create_opts.deinit();
|
||||||
context.printAvailablePlatforms(platform);
|
context.printAvailablePlatforms(platform);
|
||||||
|
|
||||||
var ts = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
|
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
|
||||||
defer ts.deinit();
|
defer store.deinit();
|
||||||
|
|
||||||
var model_arena = std.heap.ArenaAllocator.init(allocator);
|
|
||||||
var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts);
|
|
||||||
|
|
||||||
|
// Write metadata from the config file into the LlamaLm struct.
|
||||||
|
const seq_len: u32 = cli.args.@"seq-len" orelse 256;
|
||||||
const llama_options: llama.LlamaLM.Options = .{
|
const llama_options: llama.LlamaLM.Options = .{
|
||||||
.max_seq_len = @intCast(res.args.@"seq-len" orelse 256),
|
.max_seq_len = seq_len,
|
||||||
.sampling_strategy = .{
|
.sampling_strategy = .{
|
||||||
.topk = 1,
|
.topk = 1,
|
||||||
.temperature = 1.0,
|
.temperature = 1.0,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
model_instance.init(config, llama_options);
|
|
||||||
|
|
||||||
const dims = model_instance.model.shape();
|
// Contains memory for llama_tensors and llama_buffers.
|
||||||
const dtype = model_instance.model.embed_tokens.weight.dtype();
|
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
defer compiler_arena.deinit();
|
||||||
|
|
||||||
const tokens_shape_prefill = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32);
|
// Initialize the Llama struct and map the content of the .safetensors to the model tensors.
|
||||||
const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
|
const llama_tensors: llama.LlamaLM = try .init(compiler_arena.allocator(), config, llama_options, store);
|
||||||
|
|
||||||
|
// Specify shapes of input arguments
|
||||||
|
const prefill_tokens_shape = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32);
|
||||||
|
const gen_tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
|
||||||
const token_idx_shape = zml.Shape.init(.{}, .u32);
|
const token_idx_shape = zml.Shape.init(.{}, .u32);
|
||||||
|
|
||||||
const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h});
|
const dtype = llama_tensors.model.embed_tokens.weight.dtype();
|
||||||
|
const kv_shape = zml.Shape.init(.{
|
||||||
|
.layer = llama_tensors.model.layers.len,
|
||||||
|
.k = seq_len,
|
||||||
|
.h = config.num_key_value_heads,
|
||||||
|
.hd = config.head_dim orelse @divExact(config.hidden_size, config.num_attention_heads),
|
||||||
|
}, dtype).withSharding(.{.h});
|
||||||
const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
|
const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
|
||||||
const rng_shape = zml.Tensor.Rng.shape();
|
const rng_shape = zml.Tensor.Rng.shape();
|
||||||
|
|
||||||
|
// Compile the model twice, one for prefill, one for generation.
|
||||||
var start = try std.time.Timer.start();
|
var start = try std.time.Timer.start();
|
||||||
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{
|
var fut_mod_prefill = try asynk.asyncc(zml.compileModel, .{
|
||||||
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
allocator, llama.LlamaLM.forward, llama_tensors,
|
||||||
.{
|
.{
|
||||||
tokens_shape_prefill,
|
prefill_tokens_shape,
|
||||||
token_idx_shape,
|
token_idx_shape,
|
||||||
kv_cache_shape,
|
kv_cache_shape,
|
||||||
rng_shape,
|
rng_shape,
|
||||||
},
|
},
|
||||||
ts,
|
|
||||||
platform,
|
platform,
|
||||||
});
|
});
|
||||||
|
|
||||||
var fut_mod = try asynk.asyncc(zml.compile, .{
|
var fut_mod = try asynk.asyncc(zml.compileModel, .{
|
||||||
allocator, llama.LlamaLM.forward, .{ config, llama_options },
|
allocator, llama.LlamaLM.forward, llama_tensors,
|
||||||
.{
|
.{
|
||||||
tokens_shape,
|
gen_tokens_shape,
|
||||||
token_idx_shape,
|
token_idx_shape,
|
||||||
kv_cache_shape,
|
kv_cache_shape,
|
||||||
rng_shape,
|
rng_shape,
|
||||||
},
|
},
|
||||||
ts,
|
|
||||||
platform,
|
platform,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// While we are still compiling load the weights to the device.
|
||||||
log.info("\tLoading Llama weights from {s}...", .{model_weights_path});
|
log.info("\tLoading Llama weights from {s}...", .{model_weights_path});
|
||||||
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
|
var llama_buffers = try store.loadModelById(llama.LlamaLM, compiler_arena.allocator(), llama_tensors, platform);
|
||||||
defer zml.aio.unloadBuffers(&llama_weights);
|
defer zml.aio.unloadBuffers(&llama_buffers);
|
||||||
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||||||
|
|
||||||
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_buffers);
|
||||||
defer llama_module_prefill.deinit();
|
defer llama_module_prefill.deinit();
|
||||||
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
var llama_module = (try fut_mod.awaitt()).prepare(llama_buffers);
|
||||||
defer llama_module.deinit();
|
defer llama_module.deinit();
|
||||||
log.info("✅\tCompiled model in {D}", .{start.read()});
|
log.info("✅\tCompiled model in {D}", .{start.read()});
|
||||||
|
|
||||||
log.info("Creating KvCache", .{});
|
log.info("Creating KvCache", .{});
|
||||||
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
||||||
|
|
||||||
@ -320,16 +317,34 @@ pub fn asyncMain() !void {
|
|||||||
var timer = try stdx.time.Timer.start();
|
var timer = try stdx.time.Timer.start();
|
||||||
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
|
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
|
||||||
|
|
||||||
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path);
|
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
|
||||||
};
|
};
|
||||||
errdefer tokenizer.deinit();
|
errdefer tokenizer.deinit();
|
||||||
|
|
||||||
const prompt = res.args.prompt orelse "What is the capital of France?";
|
const prompt = cli.args.prompt orelse "What is the capital of France?";
|
||||||
log.info("✅\tPrompt: {s}", .{prompt});
|
log.info("✅\tPrompt: {s}", .{prompt});
|
||||||
|
|
||||||
const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
// Unbuffered writing of the tokens to stdout.
|
||||||
const skip_llama3_encoding = res.args.@"no-llama3" orelse false;
|
var stdout = std.fs.File.stdout().writer(&.{});
|
||||||
const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding);
|
|
||||||
// generated text will be printed token by token.
|
const seed: u128 = cli.args.seed orelse @bitCast(std.time.nanoTimestamp());
|
||||||
defer allocator.free(generated_text);
|
const skip_llama3_encoding = cli.args.@"no-llama3" orelse false;
|
||||||
|
|
||||||
|
try generateText(
|
||||||
|
config,
|
||||||
|
llama_tensors,
|
||||||
|
llama_module_prefill,
|
||||||
|
llama_module,
|
||||||
|
kv_cache,
|
||||||
|
tokenizer,
|
||||||
|
allocator,
|
||||||
|
seed,
|
||||||
|
prompt[0..],
|
||||||
|
skip_llama3_encoding,
|
||||||
|
&stdout.interface,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_parser(in: []const u8) error{}!bool {
|
||||||
|
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,19 +45,14 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
// We manually produce a BufferStore. You would not normally do that.
|
// We manually produce a BufferStore. You would not normally do that.
|
||||||
// A BufferStore is usually created by loading model data from a file.
|
// A BufferStore is usually created by loading model data from a file.
|
||||||
var buffers: zml.aio.BufferStore.Buffers = .{};
|
var store: zml.aio.BufferStore = .init(allocator);
|
||||||
try buffers.put(arena, "weight", zml.HostBuffer.fromArray(&weights));
|
defer store.deinit();
|
||||||
try buffers.put(arena, "bias", zml.HostBuffer.fromArray(&bias));
|
try store.buffers.put(store.arena.allocator(), "weight", zml.HostBuffer.fromArray(&weights));
|
||||||
|
try store.buffers.put(store.arena.allocator(), "bias", zml.HostBuffer.fromArray(&bias));
|
||||||
// the actual BufferStore
|
|
||||||
const buffer_store: zml.aio.BufferStore = .{
|
|
||||||
.arena = arena_state,
|
|
||||||
.buffers = buffers,
|
|
||||||
};
|
|
||||||
|
|
||||||
// A clone of our model, consisting of shapes. We only need shapes for compiling.
|
// A clone of our model, consisting of shapes. We only need shapes for compiling.
|
||||||
// We use the BufferStore to infer the shapes.
|
// We use the BufferStore to infer the shapes.
|
||||||
var model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store);
|
var model_shapes = try zml.aio.populateModel(Layer, allocator, store);
|
||||||
model_shapes.weight = model_shapes.weight.withSharding(.{-1});
|
model_shapes.weight = model_shapes.weight.withSharding(.{-1});
|
||||||
model_shapes.bias = model_shapes.bias.?.withSharding(.{-1});
|
model_shapes.bias = model_shapes.bias.?.withSharding(.{-1});
|
||||||
|
|
||||||
@ -68,7 +63,7 @@ pub fn asyncMain() !void {
|
|||||||
// Produce a bufferized weights struct from the fake BufferStore.
|
// Produce a bufferized weights struct from the fake BufferStore.
|
||||||
// This is like the inferred shapes, but with actual values.
|
// This is like the inferred shapes, but with actual values.
|
||||||
// We will need to send those to the computation device later.
|
// We will need to send those to the computation device later.
|
||||||
var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, buffer_store, arena, platform);
|
var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, store, arena, platform);
|
||||||
defer zml.aio.unloadBuffers(&model_weights); // for good practice
|
defer zml.aio.unloadBuffers(&model_weights); // for good practice
|
||||||
|
|
||||||
// Wait for compilation to finish
|
// Wait for compilation to finish
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user