Update llama and simple_layer examples to use BufferStore tensor IDs, new CPU device count API, and fix zml.call tag hashing.

This commit is contained in:
Foke Singh 2025-08-22 17:55:03 +00:00
parent cc969bd532
commit 6e7617918d
3 changed files with 188 additions and 171 deletions

View File

@ -18,19 +18,22 @@ pub const LlamaLM = struct {
int: u32,
ints: []u32,
}),
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: ?u32,
hidden_size: u32,
num_hidden_layers: u32,
num_attention_heads: u32,
num_key_value_heads: u32,
rope_theta: f32,
max_position_embeddings: usize,
max_position_embeddings: u32,
rms_norm_eps: f32,
hf_rope_impl: bool = true,
tie_word_embeddings: bool = false,
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
};
pub const Options = struct {
sampling_strategy: ?zml.nn.SamplingStrategy,
max_seq_len: usize,
max_seq_len: u32,
};
lm_head: ?zml.nn.Linear,
@ -40,41 +43,81 @@ pub const LlamaLM = struct {
gen_opts: zml.nn.SamplingStrategy = .{},
config: Config,
pub fn init(self: *LlamaLM, config: Config, options: Options) void {
self.config = config;
self.gen_opts = options.sampling_strategy orelse .{};
self.model.max_seq_len = @intCast(options.max_seq_len);
self.model.num_heads = @intCast(config.num_attention_heads);
self.model.num_kv_heads = @intCast(config.num_key_value_heads);
self.model.rope_opts = .{
pub fn init(allocator: std.mem.Allocator, config: Config, options: Options, store: zml.aio.BufferStore) !LlamaLM {
const rope_opts: zml.nn.RopeOpts = .{
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
.freq_base = config.rope_theta,
.scaling = config.rope_scaling,
};
self.model.norm.eps = config.rms_norm_eps;
for (self.model.layers) |*layer| {
layer.self_attn.num_heads = self.model.num_heads;
layer.self_attn.num_kv_heads = self.model.num_kv_heads;
layer.self_attn.rope_opts = self.model.rope_opts;
layer.input_layernorm.eps = config.rms_norm_eps;
layer.post_attention_layernorm.eps = config.rms_norm_eps;
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0});
layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0});
layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0});
layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1});
const layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers);
var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024);
try prefix.push(stdx.noalloc, "model.layers");
for (0.., layers) |i, *layer| {
try prefix.pushDigit(stdx.noalloc, i);
defer prefix.pop();
var self_attn = try zml.aio.populateModelWithPrefix(SelfAttn, allocator, store, prefix.concat("self_attn"));
self_attn.num_heads = config.num_attention_heads;
self_attn.num_kv_heads = config.num_key_value_heads;
self_attn.rope_opts = rope_opts;
self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0});
self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0});
self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0});
self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1});
var input_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("input_layernorm"));
input_layernorm.eps = config.rms_norm_eps;
var post_attention_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("post_attention_layernorm"));
post_attention_layernorm.eps = config.rms_norm_eps;
var mlp = try zml.aio.populateModelWithPrefix(Mlp, allocator, store, prefix.concat("mlp"));
mlp.up_proj.weight = mlp.up_proj.weight.withSharding(.{0});
mlp.gate_proj.weight = mlp.gate_proj.weight.withSharding(.{0});
mlp.down_proj.weight = mlp.down_proj.weight.withSharding(.{1});
layer.* = .{
.self_attn = self_attn,
.input_layernorm = input_layernorm,
.post_attention_layernorm = post_attention_layernorm,
.mlp = mlp,
};
}
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails
if (self.gen_opts.topk == 1 and self.lm_head != null) {
self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0});
var lm_head: ?zml.nn.Linear = null;
if (!config.tie_word_embeddings) {
lm_head = .{ .weight = store.getTensor("lm_head.weight") };
if (options.sampling_strategy) |gen_opts| {
if (gen_opts.topk == 1)
lm_head.?.weight = lm_head.?.weight.withSharding(.{0});
}
}
return .{
.config = config,
.gen_opts = options.sampling_strategy orelse .{},
.model = .{
// Weights
.layers = layers,
.embed_tokens = .{ .weight = store.getTensor("model.embed_tokens.weight") },
.norm = .{
.weight = store.getTensor("model.norm.weight"),
.eps = config.rms_norm_eps,
},
// Push down some configs
.max_seq_len = options.max_seq_len,
.num_heads = config.num_attention_heads,
.num_kv_heads = config.num_key_value_heads,
.rope_opts = .{
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
.freq_base = config.rope_theta,
.scaling = config.rope_scaling,
},
},
.lm_head = lm_head,
};
}
/// Predicts the token at `token_index` position.
/// Returns:
/// - updated `tokens`,
@ -129,36 +172,13 @@ pub const Llama = struct {
layers: []TransformerLayer,
max_seq_len: u32 = 0,
num_heads: i64 = 32,
num_kv_heads: i64 = 32,
num_heads: u32 = 32,
num_kv_heads: u32 = 32,
rope_opts: zml.nn.RopeOpts = .{
.layout = .interleaved,
.freq_base = 10_000,
},
const Shape = struct {
s: u32,
layer: u16,
hd: u16,
nh: u16,
nkvh: u16,
dtype: zml.DataType,
};
pub fn shape(self: Llama) Shape {
const key_dim = self.layers[0].self_attn.k_proj.weight.dim(0);
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
return .{
.s = self.max_seq_len,
.layer = @intCast(self.layers.len),
.hd = @intCast(@divExact(key_dim, num_kv_heads)),
.nh = @intCast(self.num_heads),
.nkvh = @intCast(num_kv_heads),
.dtype = self.embed_tokens.weight.dtype(),
};
}
/// Forward one token, using KV cache for previous tokens.
/// Returns result and updated KV cache.
pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } {
@ -177,14 +197,6 @@ pub const Llama = struct {
pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor {
return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d});
}
fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache {
const dims = self.shape();
var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } });
const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd });
kv_shape = kv_shape.transpose(perm.constSlice());
return KvCache.init(kv_shape);
}
};
pub const TransformerLayer = struct {
@ -245,6 +257,9 @@ pub const SelfAttn = struct {
k_proj: zml.nn.Linear,
v_proj: zml.nn.Linear,
q_norm: ?RmsNorm,
k_norm: ?RmsNorm,
o_proj: zml.nn.Linear,
num_heads: i64 = undefined,
num_kv_heads: i64 = 0,
@ -282,6 +297,8 @@ pub const SelfAttn = struct {
break :b temp.add(token_index.broad(temp.shape()));
};
if (self.q_norm) |norm| q = norm.forward(q.rename(.{ .hd = .d })).rename(.{ .d = .hd });
if (self.k_norm) |norm| k = norm.forward(k.rename(.{ .hd = .d })).rename(.{ .d = .hd });
q = zml.nn.rope(q, pos_index, self.rope_opts);
k = zml.nn.rope(k, pos_index, self.rope_opts);
q = q.rename(.{ .s = .q });
@ -298,16 +315,6 @@ pub const SelfAttn = struct {
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache };
}
fn initKvCache(key_shape: zml.Shape) KvCache {
// When we call initKvCache, we haven't renamed .s to .k yet.
var kv_shape = key_shape.insert(0, .{ .layer = 1 }).rename(.{ .s = .k });
const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd });
kv_shape = kv_shape.transpose(perm.constSlice());
var res = KvCache.init(kv_shape);
res.layer_index = Tensor.scalar(0, .u32);
return res;
}
};
pub const KvCache = struct {
@ -334,8 +341,8 @@ pub const KvCache = struct {
pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) {
return .{
.k = try zml.Buffer.constant(platform, kv_shape, 1),
.v = try zml.Buffer.constant(platform, kv_shape, 1),
.k = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
.v = try zml.Buffer.uninitialized(platform, kv_shape, .{}),
.layer_index = try zml.Buffer.scalar(platform, 0, .u32),
};
}

View File

@ -22,35 +22,42 @@ pub const std_options: std.Options = .{
.logFn = asynk.logFn(std.log.defaultLog),
};
const params = clap.parseParamsComptime(
\\--help print this help
\\--prompt <STRING> the prompt
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
\\--seed <UINT> random seed (optional)
\\--seq-len <UINT> sequence length
\\--create-options <STRING> platform creation options JSON, defaults to {}
\\--no-llama3 <BOOL> skip prompt template
\\--sharding <BOOL> default: true: sharding on or off
);
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
var tokens = std.array_list.Managed(u32).init(allocator);
var encoder = try tokenizer.encoder();
defer encoder.deinit();
if (skip_llama3_encoding) {
// Copy to the arraylist so the ownership is the same in both branches.
try tokens.appendSlice(try encoder.encode(prompt));
return tokens.toOwnedSlice();
// Copy so the ownership is the same in both branches.
return try allocator.dupe(u32, try encoder.encode(prompt));
}
const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
const newline_id = (try encoder.encode("\n"))[0];
const start_header = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken;
const end_header = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken;
const user = tokenizer.tokenToId("user") orelse return error.NoSuchToken;
const assistant = tokenizer.tokenToId("assistant") orelse return error.NoSuchToken;
const eot = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken;
const newline = (try encoder.encode("\n"))[0];
try tokens.append(config.bos_token_id);
var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len);
try tokens.appendSlice(allocator, &.{ config.bos_token_id, start_header, user, end_header, newline });
try tokens.append(start_header_id);
try tokens.appendSlice(try encoder.encode("user"));
try tokens.appendSlice(&.{ end_header_id, newline_id });
try tokens.appendSlice(allocator, try encoder.encode(prompt));
try tokens.appendSlice(allocator, &.{ eot, newline });
try tokens.appendSlice(try encoder.encode(prompt));
try tokens.appendSlice(&.{ eot_id, newline_id });
try tokens.append(start_header_id);
try tokens.appendSlice(try encoder.encode("assistant"));
try tokens.appendSlice(&.{ end_header_id, newline_id });
try tokens.appendSlice(allocator, &.{ start_header, assistant, end_header, newline });
return tokens.toOwnedSlice();
return tokens.toOwnedSlice(allocator);
}
pub fn generateText(
@ -64,7 +71,8 @@ pub fn generateText(
seed: u128,
prompt: []const u8,
skip_llama3_encoding: bool,
) ![]const u8 {
writer: *std.Io.Writer,
) !void {
const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding);
defer allocator.free(prompt_tok);
@ -72,7 +80,7 @@ pub fn generateText(
defer tokenizer_decoder.deinit();
const platform = mod_generate.platform();
const max_seq_len = llama_.model.shape().s;
const max_seq_len = llama_.model.max_seq_len;
// init RNG and buffers
var rng = try zml.Tensor.Rng.init(platform, seed);
@ -100,10 +108,6 @@ pub fn generateText(
var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer);
defer current_token.deinit();
// Here we collect the generated text
var output = std.array_list.Managed(u8).init(allocator);
defer output.deinit();
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
const start = std.time.microTimestamp();
@ -114,9 +118,9 @@ pub fn generateText(
// collect and print generated sequence
num_tokens_generated += 1;
const generated_token = generated_token_buffer[0];
const chunk = try tokenizer_decoder.next(generated_token) orelse unreachable;
try output.appendSlice(chunk);
std.debug.print("{s}", .{chunk});
if (try tokenizer_decoder.next(generated_token)) |chunk| {
try writer.writeAll(chunk);
}
// check for eos
if (i == output_tokens_len) break :generation;
@ -145,22 +149,6 @@ pub fn generateText(
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
std.debug.print("\n", .{});
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
return output.toOwnedSlice();
}
const params = clap.parseParamsComptime(
\\--help print this help
\\--prompt <STRING> the prompt
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
\\--seed <UINT> random seed (optional)
\\--seq-len <UINT> sequence length
\\--create-options <STRING> platform creation options JSON, defaults to {}
\\--no-llama3 <BOOL> skip prompt template
\\--sharding <BOOL> default: true: sharding on or off
);
pub fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}
pub fn main() !void {
@ -174,31 +162,33 @@ pub fn asyncMain() !void {
const parsers = comptime .{
.BOOL = bool_parser,
.UINT = clap.parsers.int(usize, 0),
.UINT = clap.parsers.int(u32, 0),
.STRING = clap.parsers.string,
.PATH = clap.parsers.string,
};
var diag: clap.Diagnostic = .{};
var stderr_buffer: [1024]u8 = undefined;
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
var res = clap.parse(clap.Help, &params, parsers, .{
defer stderr.interface.flush() catch {};
var cli = clap.parse(clap.Help, &params, parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
diag.report(&stderr.interface, err) catch {};
stderr.interface.print("usage: ", .{}) catch {};
stderr.interface.writeAll("usage: ") catch {};
clap.usage(&stderr.interface, clap.Help, &params) catch {};
stderr.interface.print("\n", .{}) catch {};
stderr.interface.writeAll("\n") catch {};
return;
};
defer res.deinit();
defer cli.deinit();
if (res.args.help != 0) {
if (cli.args.help != 0) {
clap.help(&stderr.interface, clap.Help, &params, .{}) catch {};
return;
}
const hf_model_path = res.args.@"hf-model-path" orelse {
const hf_model_path = cli.args.@"hf-model-path" orelse {
log.err("Missing --hf-model-path", .{});
return;
};
@ -238,80 +228,87 @@ pub fn asyncMain() !void {
const compilation_options = zml.CompilationOptions{
.xla_dump_to = "/tmp/zml/llama",
.sharding_enabled = res.args.sharding orelse true,
.sharding_enabled = cli.args.sharding orelse true,
};
// initialize ZML platform with optional create options
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
const create_opts_json = res.args.@"create-options" orelse "{}";
const create_opts_json = cli.args.@"create-options" orelse "{}";
const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options);
create_opts.deinit();
context.printAvailablePlatforms(platform);
var ts = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
defer ts.deinit();
var model_arena = std.heap.ArenaAllocator.init(allocator);
var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts);
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
defer store.deinit();
// Write metadata from the config file into the LlamaLm struct.
const seq_len: u32 = cli.args.@"seq-len" orelse 256;
const llama_options: llama.LlamaLM.Options = .{
.max_seq_len = @intCast(res.args.@"seq-len" orelse 256),
.max_seq_len = seq_len,
.sampling_strategy = .{
.topk = 1,
.temperature = 1.0,
},
};
model_instance.init(config, llama_options);
const dims = model_instance.model.shape();
const dtype = model_instance.model.embed_tokens.weight.dtype();
// Contains memory for llama_tensors and llama_buffers.
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
defer compiler_arena.deinit();
const tokens_shape_prefill = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32);
const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
// Initialize the Llama struct and map the content of the .safetensors to the model tensors.
const llama_tensors: llama.LlamaLM = try .init(compiler_arena.allocator(), config, llama_options, store);
// Specify shapes of input arguments
const prefill_tokens_shape = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32);
const gen_tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
const token_idx_shape = zml.Shape.init(.{}, .u32);
const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h});
const dtype = llama_tensors.model.embed_tokens.weight.dtype();
const kv_shape = zml.Shape.init(.{
.layer = llama_tensors.model.layers.len,
.k = seq_len,
.h = config.num_key_value_heads,
.hd = config.head_dim orelse @divExact(config.hidden_size, config.num_attention_heads),
}, dtype).withSharding(.{.h});
const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape);
const rng_shape = zml.Tensor.Rng.shape();
// Compile the model twice, one for prefill, one for generation.
var start = try std.time.Timer.start();
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{
allocator, llama.LlamaLM.forward, .{ config, llama_options },
var fut_mod_prefill = try asynk.asyncc(zml.compileModel, .{
allocator, llama.LlamaLM.forward, llama_tensors,
.{
tokens_shape_prefill,
prefill_tokens_shape,
token_idx_shape,
kv_cache_shape,
rng_shape,
},
ts,
platform,
});
var fut_mod = try asynk.asyncc(zml.compile, .{
allocator, llama.LlamaLM.forward, .{ config, llama_options },
var fut_mod = try asynk.asyncc(zml.compileModel, .{
allocator, llama.LlamaLM.forward, llama_tensors,
.{
tokens_shape,
gen_tokens_shape,
token_idx_shape,
kv_cache_shape,
rng_shape,
},
ts,
platform,
});
// While we are still compiling load the weights to the device.
log.info("\tLoading Llama weights from {s}...", .{model_weights_path});
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
defer zml.aio.unloadBuffers(&llama_weights);
var llama_buffers = try store.loadModelById(llama.LlamaLM, compiler_arena.allocator(), llama_tensors, platform);
defer zml.aio.unloadBuffers(&llama_buffers);
log.info("\tLoaded weights in {D}", .{start.read()});
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_buffers);
defer llama_module_prefill.deinit();
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
var llama_module = (try fut_mod.awaitt()).prepare(llama_buffers);
defer llama_module.deinit();
log.info("\tCompiled model in {D}", .{start.read()});
log.info("Creating KvCache", .{});
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
@ -320,16 +317,34 @@ pub fn asyncMain() !void {
var timer = try stdx.time.Timer.start();
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path);
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
};
errdefer tokenizer.deinit();
const prompt = res.args.prompt orelse "What is the capital of France?";
const prompt = cli.args.prompt orelse "What is the capital of France?";
log.info("\tPrompt: {s}", .{prompt});
const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
const skip_llama3_encoding = res.args.@"no-llama3" orelse false;
const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding);
// generated text will be printed token by token.
defer allocator.free(generated_text);
// Unbuffered writing of the tokens to stdout.
var stdout = std.fs.File.stdout().writer(&.{});
const seed: u128 = cli.args.seed orelse @bitCast(std.time.nanoTimestamp());
const skip_llama3_encoding = cli.args.@"no-llama3" orelse false;
try generateText(
config,
llama_tensors,
llama_module_prefill,
llama_module,
kv_cache,
tokenizer,
allocator,
seed,
prompt[0..],
skip_llama3_encoding,
&stdout.interface,
);
}
fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}

View File

@ -45,19 +45,14 @@ pub fn asyncMain() !void {
// We manually produce a BufferStore. You would not normally do that.
// A BufferStore is usually created by loading model data from a file.
var buffers: zml.aio.BufferStore.Buffers = .{};
try buffers.put(arena, "weight", zml.HostBuffer.fromArray(&weights));
try buffers.put(arena, "bias", zml.HostBuffer.fromArray(&bias));
// the actual BufferStore
const buffer_store: zml.aio.BufferStore = .{
.arena = arena_state,
.buffers = buffers,
};
var store: zml.aio.BufferStore = .init(allocator);
defer store.deinit();
try store.buffers.put(store.arena.allocator(), "weight", zml.HostBuffer.fromArray(&weights));
try store.buffers.put(store.arena.allocator(), "bias", zml.HostBuffer.fromArray(&bias));
// A clone of our model, consisting of shapes. We only need shapes for compiling.
// We use the BufferStore to infer the shapes.
var model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store);
var model_shapes = try zml.aio.populateModel(Layer, allocator, store);
model_shapes.weight = model_shapes.weight.withSharding(.{-1});
model_shapes.bias = model_shapes.bias.?.withSharding(.{-1});
@ -68,7 +63,7 @@ pub fn asyncMain() !void {
// Produce a bufferized weights struct from the fake BufferStore.
// This is like the inferred shapes, but with actual values.
// We will need to send those to the computation device later.
var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, buffer_store, arena, platform);
var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, store, arena, platform);
defer zml.aio.unloadBuffers(&model_weights); // for good practice
// Wait for compilation to finish