diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index da487bb..5d0e87d 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -18,19 +18,22 @@ pub const LlamaLM = struct { int: u32, ints: []u32, }), - num_hidden_layers: usize, - num_attention_heads: usize, - num_key_value_heads: usize, + head_dim: ?u32, + hidden_size: u32, + num_hidden_layers: u32, + num_attention_heads: u32, + num_key_value_heads: u32, rope_theta: f32, - max_position_embeddings: usize, + max_position_embeddings: u32, rms_norm_eps: f32, hf_rope_impl: bool = true, + tie_word_embeddings: bool = false, rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} }, }; pub const Options = struct { sampling_strategy: ?zml.nn.SamplingStrategy, - max_seq_len: usize, + max_seq_len: u32, }; lm_head: ?zml.nn.Linear, @@ -40,39 +43,79 @@ pub const LlamaLM = struct { gen_opts: zml.nn.SamplingStrategy = .{}, config: Config, - pub fn init(self: *LlamaLM, config: Config, options: Options) void { - self.config = config; - self.gen_opts = options.sampling_strategy orelse .{}; - self.model.max_seq_len = @intCast(options.max_seq_len); - self.model.num_heads = @intCast(config.num_attention_heads); - self.model.num_kv_heads = @intCast(config.num_key_value_heads); - self.model.rope_opts = .{ + pub fn init(allocator: std.mem.Allocator, config: Config, options: Options, store: zml.aio.BufferStore) !LlamaLM { + const rope_opts: zml.nn.RopeOpts = .{ .layout = if (config.hf_rope_impl) .sequential else .interleaved, .freq_base = config.rope_theta, .scaling = config.rope_scaling, }; - self.model.norm.eps = config.rms_norm_eps; - for (self.model.layers) |*layer| { - layer.self_attn.num_heads = self.model.num_heads; - layer.self_attn.num_kv_heads = self.model.num_kv_heads; - layer.self_attn.rope_opts = self.model.rope_opts; - layer.input_layernorm.eps = config.rms_norm_eps; - layer.post_attention_layernorm.eps = config.rms_norm_eps; - layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0}); - layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0}); - layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1}); - layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0}); - layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0}); - layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0}); - layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1}); + const layers = try allocator.alloc(TransformerLayer, config.num_hidden_layers); + var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024); + try prefix.push(stdx.noalloc, "model.layers"); + for (0.., layers) |i, *layer| { + try prefix.pushDigit(stdx.noalloc, i); + defer prefix.pop(); + var self_attn = try zml.aio.populateModelWithPrefix(SelfAttn, allocator, store, prefix.concat("self_attn")); + self_attn.num_heads = config.num_attention_heads; + self_attn.num_kv_heads = config.num_key_value_heads; + self_attn.rope_opts = rope_opts; + self_attn.q_proj.weight = self_attn.q_proj.weight.withSharding(.{0}); + self_attn.k_proj.weight = self_attn.k_proj.weight.withSharding(.{0}); + self_attn.v_proj.weight = self_attn.v_proj.weight.withSharding(.{0}); + self_attn.o_proj.weight = self_attn.o_proj.weight.withSharding(.{1}); + + var input_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("input_layernorm")); + input_layernorm.eps = config.rms_norm_eps; + + var post_attention_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("post_attention_layernorm")); + post_attention_layernorm.eps = config.rms_norm_eps; + + var mlp = try zml.aio.populateModelWithPrefix(Mlp, allocator, store, prefix.concat("mlp")); + mlp.up_proj.weight = mlp.up_proj.weight.withSharding(.{0}); + mlp.gate_proj.weight = mlp.gate_proj.weight.withSharding(.{0}); + mlp.down_proj.weight = mlp.down_proj.weight.withSharding(.{1}); + + layer.* = .{ + .self_attn = self_attn, + .input_layernorm = input_layernorm, + .post_attention_layernorm = post_attention_layernorm, + .mlp = mlp, + }; } - // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. - // It currently crashes/compilation fails - if (self.gen_opts.topk == 1 and self.lm_head != null) { - self.lm_head.?.weight = self.lm_head.?.weight.withSharding(.{0}); + var lm_head: ?zml.nn.Linear = null; + if (!config.tie_word_embeddings) { + lm_head = .{ .weight = store.getTensor("lm_head.weight") }; + if (options.sampling_strategy) |gen_opts| { + if (gen_opts.topk == 1) + lm_head.?.weight = lm_head.?.weight.withSharding(.{0}); + } } + + return .{ + .config = config, + .gen_opts = options.sampling_strategy orelse .{}, + .model = .{ + // Weights + .layers = layers, + .embed_tokens = .{ .weight = store.getTensor("model.embed_tokens.weight") }, + .norm = .{ + .weight = store.getTensor("model.norm.weight"), + .eps = config.rms_norm_eps, + }, + // Push down some configs + .max_seq_len = options.max_seq_len, + .num_heads = config.num_attention_heads, + .num_kv_heads = config.num_key_value_heads, + .rope_opts = .{ + .layout = if (config.hf_rope_impl) .sequential else .interleaved, + .freq_base = config.rope_theta, + .scaling = config.rope_scaling, + }, + }, + .lm_head = lm_head, + }; } /// Predicts the token at `token_index` position. @@ -129,36 +172,13 @@ pub const Llama = struct { layers: []TransformerLayer, max_seq_len: u32 = 0, - num_heads: i64 = 32, - num_kv_heads: i64 = 32, + num_heads: u32 = 32, + num_kv_heads: u32 = 32, rope_opts: zml.nn.RopeOpts = .{ .layout = .interleaved, .freq_base = 10_000, }, - const Shape = struct { - s: u32, - layer: u16, - hd: u16, - nh: u16, - nkvh: u16, - dtype: zml.DataType, - }; - - pub fn shape(self: Llama) Shape { - const key_dim = self.layers[0].self_attn.k_proj.weight.dim(0); - const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads; - - return .{ - .s = self.max_seq_len, - .layer = @intCast(self.layers.len), - .hd = @intCast(@divExact(key_dim, num_kv_heads)), - .nh = @intCast(self.num_heads), - .nkvh = @intCast(num_kv_heads), - .dtype = self.embed_tokens.weight.dtype(), - }; - } - /// Forward one token, using KV cache for previous tokens. /// Returns result and updated KV cache. pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } { @@ -177,14 +197,6 @@ pub const Llama = struct { pub fn embed(embed_tokens_: zml.nn.TokenEmbedding, tokens_: Tensor) Tensor { return zml.call(embed_tokens_, .forward, .{tokens_}).withPartialTags(.{.d}); } - - fn initKvCache(self: Llama, embed_shape: zml.Shape) KvCache { - const dims = self.shape(); - var kv_shape = embed_shape.insert(0, .{ .layer = dims.layer }).rename(.{ .s = .k }).splitAxes(.{ .d = .{ .h = dims.nkvh, .hd = dims.hd } }); - const perm = kv_shape.contiguousPerm(.{ .k, .h, .hd }); - kv_shape = kv_shape.transpose(perm.constSlice()); - return KvCache.init(kv_shape); - } }; pub const TransformerLayer = struct { @@ -245,6 +257,9 @@ pub const SelfAttn = struct { k_proj: zml.nn.Linear, v_proj: zml.nn.Linear, + q_norm: ?RmsNorm, + k_norm: ?RmsNorm, + o_proj: zml.nn.Linear, num_heads: i64 = undefined, num_kv_heads: i64 = 0, @@ -282,6 +297,8 @@ pub const SelfAttn = struct { break :b temp.add(token_index.broad(temp.shape())); }; + if (self.q_norm) |norm| q = norm.forward(q.rename(.{ .hd = .d })).rename(.{ .d = .hd }); + if (self.k_norm) |norm| k = norm.forward(k.rename(.{ .hd = .d })).rename(.{ .d = .hd }); q = zml.nn.rope(q, pos_index, self.rope_opts); k = zml.nn.rope(k, pos_index, self.rope_opts); q = q.rename(.{ .s = .q }); @@ -298,16 +315,6 @@ pub const SelfAttn = struct { const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); return .{ zml.call(self.o_proj, .forward, .{attn}), new_kv_cache }; } - - fn initKvCache(key_shape: zml.Shape) KvCache { - // When we call initKvCache, we haven't renamed .s to .k yet. - var kv_shape = key_shape.insert(0, .{ .layer = 1 }).rename(.{ .s = .k }); - const perm = kv_shape.contiguousPerm(.{ .h, .k, .hd }); - kv_shape = kv_shape.transpose(perm.constSlice()); - var res = KvCache.init(kv_shape); - res.layer_index = Tensor.scalar(0, .u32); - return res; - } }; pub const KvCache = struct { @@ -334,8 +341,8 @@ pub const KvCache = struct { pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) { return .{ - .k = try zml.Buffer.constant(platform, kv_shape, 1), - .v = try zml.Buffer.constant(platform, kv_shape, 1), + .k = try zml.Buffer.uninitialized(platform, kv_shape, .{}), + .v = try zml.Buffer.uninitialized(platform, kv_shape, .{}), .layer_index = try zml.Buffer.scalar(platform, 0, .u32), }; } diff --git a/examples/llama/main.zig b/examples/llama/main.zig index af94fd2..c1c8d80 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -22,35 +22,42 @@ pub const std_options: std.Options = .{ .logFn = asynk.logFn(std.log.defaultLog), }; +const params = clap.parseParamsComptime( + \\--help print this help + \\--prompt the prompt + \\--hf-model-path path to the directory containing model weights, config and tokenizer + \\--seed random seed (optional) + \\--seq-len sequence length + \\--create-options platform creation options JSON, defaults to {} + \\--no-llama3 skip prompt template + \\--sharding default: true: sharding on or off +); + pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 { - var tokens = std.array_list.Managed(u32).init(allocator); var encoder = try tokenizer.encoder(); defer encoder.deinit(); if (skip_llama3_encoding) { - // Copy to the arraylist so the ownership is the same in both branches. - try tokens.appendSlice(try encoder.encode(prompt)); - return tokens.toOwnedSlice(); + // Copy so the ownership is the same in both branches. + return try allocator.dupe(u32, try encoder.encode(prompt)); } - const start_header_id = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken; - const end_header_id = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken; - const eot_id = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken; - const newline_id = (try encoder.encode("\n"))[0]; + const start_header = tokenizer.tokenToId("<|start_header_id|>") orelse return error.NoSuchToken; + const end_header = tokenizer.tokenToId("<|end_header_id|>") orelse return error.NoSuchToken; + const user = tokenizer.tokenToId("user") orelse return error.NoSuchToken; + const assistant = tokenizer.tokenToId("assistant") orelse return error.NoSuchToken; + const eot = tokenizer.tokenToId("<|eot_id|>") orelse return error.NoSuchToken; + const newline = (try encoder.encode("\n"))[0]; - try tokens.append(config.bos_token_id); + var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len); + try tokens.appendSlice(allocator, &.{ config.bos_token_id, start_header, user, end_header, newline }); - try tokens.append(start_header_id); - try tokens.appendSlice(try encoder.encode("user")); - try tokens.appendSlice(&.{ end_header_id, newline_id }); + try tokens.appendSlice(allocator, try encoder.encode(prompt)); + try tokens.appendSlice(allocator, &.{ eot, newline }); - try tokens.appendSlice(try encoder.encode(prompt)); - try tokens.appendSlice(&.{ eot_id, newline_id }); - try tokens.append(start_header_id); - try tokens.appendSlice(try encoder.encode("assistant")); - try tokens.appendSlice(&.{ end_header_id, newline_id }); + try tokens.appendSlice(allocator, &.{ start_header, assistant, end_header, newline }); - return tokens.toOwnedSlice(); + return tokens.toOwnedSlice(allocator); } pub fn generateText( @@ -64,7 +71,8 @@ pub fn generateText( seed: u128, prompt: []const u8, skip_llama3_encoding: bool, -) ![]const u8 { + writer: *std.Io.Writer, +) !void { const prompt_tok: []const u32 = try tokenizePrompt(allocator, tokenizer, config, prompt, skip_llama3_encoding); defer allocator.free(prompt_tok); @@ -72,7 +80,7 @@ pub fn generateText( defer tokenizer_decoder.deinit(); const platform = mod_generate.platform(); - const max_seq_len = llama_.model.shape().s; + const max_seq_len = llama_.model.max_seq_len; // init RNG and buffers var rng = try zml.Tensor.Rng.init(platform, seed); @@ -100,10 +108,6 @@ pub fn generateText( var current_token = try zml.Buffer.fromSlice(platform, .{1}, &generated_token_buffer); defer current_token.deinit(); - // Here we collect the generated text - var output = std.array_list.Managed(u8).init(allocator); - defer output.deinit(); - const output_tokens_len = max_seq_len - prompt_tok.len - 1; const start = std.time.microTimestamp(); @@ -114,9 +118,9 @@ pub fn generateText( // collect and print generated sequence num_tokens_generated += 1; const generated_token = generated_token_buffer[0]; - const chunk = try tokenizer_decoder.next(generated_token) orelse unreachable; - try output.appendSlice(chunk); - std.debug.print("{s}", .{chunk}); + if (try tokenizer_decoder.next(generated_token)) |chunk| { + try writer.writeAll(chunk); + } // check for eos if (i == output_tokens_len) break :generation; @@ -145,22 +149,6 @@ pub fn generateText( const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration; std.debug.print("\n", .{}); log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed }); - return output.toOwnedSlice(); -} - -const params = clap.parseParamsComptime( - \\--help print this help - \\--prompt the prompt - \\--hf-model-path path to the directory containing model weights, config and tokenizer - \\--seed random seed (optional) - \\--seq-len sequence length - \\--create-options platform creation options JSON, defaults to {} - \\--no-llama3 skip prompt template - \\--sharding default: true: sharding on or off -); - -pub fn bool_parser(in: []const u8) error{}!bool { - return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; } pub fn main() !void { @@ -174,31 +162,33 @@ pub fn asyncMain() !void { const parsers = comptime .{ .BOOL = bool_parser, - .UINT = clap.parsers.int(usize, 0), + .UINT = clap.parsers.int(u32, 0), .STRING = clap.parsers.string, .PATH = clap.parsers.string, }; var diag: clap.Diagnostic = .{}; var stderr_buffer: [1024]u8 = undefined; var stderr = std.fs.File.stderr().writer(&stderr_buffer); - var res = clap.parse(clap.Help, ¶ms, parsers, .{ + defer stderr.interface.flush() catch {}; + + var cli = clap.parse(clap.Help, ¶ms, parsers, .{ .diagnostic = &diag, .allocator = allocator, }) catch |err| { diag.report(&stderr.interface, err) catch {}; - stderr.interface.print("usage: ", .{}) catch {}; + stderr.interface.writeAll("usage: ") catch {}; clap.usage(&stderr.interface, clap.Help, ¶ms) catch {}; - stderr.interface.print("\n", .{}) catch {}; + stderr.interface.writeAll("\n") catch {}; return; }; - defer res.deinit(); + defer cli.deinit(); - if (res.args.help != 0) { + if (cli.args.help != 0) { clap.help(&stderr.interface, clap.Help, ¶ms, .{}) catch {}; return; } - const hf_model_path = res.args.@"hf-model-path" orelse { + const hf_model_path = cli.args.@"hf-model-path" orelse { log.err("Missing --hf-model-path", .{}); return; }; @@ -238,80 +228,87 @@ pub fn asyncMain() !void { const compilation_options = zml.CompilationOptions{ .xla_dump_to = "/tmp/zml/llama", - .sharding_enabled = res.args.sharding orelse true, + .sharding_enabled = cli.args.sharding orelse true, }; // initialize ZML platform with optional create options // eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}' - const create_opts_json = res.args.@"create-options" orelse "{}"; + const create_opts_json = cli.args.@"create-options" orelse "{}"; const create_opts = try std.json.parseFromSlice(zml.Platform.CreateOptions, allocator, create_opts_json, .{}); const platform = context.autoPlatform(create_opts.value).withCompilationOptions(compilation_options); create_opts.deinit(); context.printAvailablePlatforms(platform); - var ts = try zml.aio.detectFormatAndOpen(allocator, model_weights_path); - defer ts.deinit(); - - var model_arena = std.heap.ArenaAllocator.init(allocator); - var model_instance = try zml.aio.populateModel(llama.LlamaLM, model_arena.allocator(), ts); + var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path); + defer store.deinit(); + // Write metadata from the config file into the LlamaLm struct. + const seq_len: u32 = cli.args.@"seq-len" orelse 256; const llama_options: llama.LlamaLM.Options = .{ - .max_seq_len = @intCast(res.args.@"seq-len" orelse 256), + .max_seq_len = seq_len, .sampling_strategy = .{ .topk = 1, .temperature = 1.0, }, }; - model_instance.init(config, llama_options); - const dims = model_instance.model.shape(); - const dtype = model_instance.model.embed_tokens.weight.dtype(); + // Contains memory for llama_tensors and llama_buffers. + var compiler_arena = std.heap.ArenaAllocator.init(allocator); + defer compiler_arena.deinit(); - const tokens_shape_prefill = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32); - const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32); + // Initialize the Llama struct and map the content of the .safetensors to the model tensors. + const llama_tensors: llama.LlamaLM = try .init(compiler_arena.allocator(), config, llama_options, store); + + // Specify shapes of input arguments + const prefill_tokens_shape = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32); + const gen_tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32); const token_idx_shape = zml.Shape.init(.{}, .u32); - const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h}); - + const dtype = llama_tensors.model.embed_tokens.weight.dtype(); + const kv_shape = zml.Shape.init(.{ + .layer = llama_tensors.model.layers.len, + .k = seq_len, + .h = config.num_key_value_heads, + .hd = config.head_dim orelse @divExact(config.hidden_size, config.num_attention_heads), + }, dtype).withSharding(.{.h}); const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape); const rng_shape = zml.Tensor.Rng.shape(); + // Compile the model twice, one for prefill, one for generation. var start = try std.time.Timer.start(); - var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ - allocator, llama.LlamaLM.forward, .{ config, llama_options }, + var fut_mod_prefill = try asynk.asyncc(zml.compileModel, .{ + allocator, llama.LlamaLM.forward, llama_tensors, .{ - tokens_shape_prefill, + prefill_tokens_shape, token_idx_shape, kv_cache_shape, rng_shape, }, - ts, platform, }); - var fut_mod = try asynk.asyncc(zml.compile, .{ - allocator, llama.LlamaLM.forward, .{ config, llama_options }, + var fut_mod = try asynk.asyncc(zml.compileModel, .{ + allocator, llama.LlamaLM.forward, llama_tensors, .{ - tokens_shape, + gen_tokens_shape, token_idx_shape, kv_cache_shape, rng_shape, }, - ts, platform, }); + // While we are still compiling load the weights to the device. log.info("\tLoading Llama weights from {s}...", .{model_weights_path}); - var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform); - defer zml.aio.unloadBuffers(&llama_weights); + var llama_buffers = try store.loadModelById(llama.LlamaLM, compiler_arena.allocator(), llama_tensors, platform); + defer zml.aio.unloadBuffers(&llama_buffers); log.info("✅\tLoaded weights in {D}", .{start.read()}); - var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights); + var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_buffers); defer llama_module_prefill.deinit(); - var llama_module = (try fut_mod.awaitt()).prepare(llama_weights); + var llama_module = (try fut_mod.awaitt()).prepare(llama_buffers); defer llama_module.deinit(); log.info("✅\tCompiled model in {D}", .{start.read()}); - log.info("Creating KvCache", .{}); const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform); @@ -320,16 +317,34 @@ pub fn asyncMain() !void { var timer = try stdx.time.Timer.start(); defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() }); - break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path); + break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path); }; errdefer tokenizer.deinit(); - const prompt = res.args.prompt orelse "What is the capital of France?"; + const prompt = cli.args.prompt orelse "What is the capital of France?"; log.info("✅\tPrompt: {s}", .{prompt}); - const seed = res.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp())); - const skip_llama3_encoding = res.args.@"no-llama3" orelse false; - const generated_text = try generateText(config, model_instance, llama_module_prefill, llama_module, kv_cache, tokenizer, allocator, seed, prompt[0..], skip_llama3_encoding); - // generated text will be printed token by token. - defer allocator.free(generated_text); + // Unbuffered writing of the tokens to stdout. + var stdout = std.fs.File.stdout().writer(&.{}); + + const seed: u128 = cli.args.seed orelse @bitCast(std.time.nanoTimestamp()); + const skip_llama3_encoding = cli.args.@"no-llama3" orelse false; + + try generateText( + config, + llama_tensors, + llama_module_prefill, + llama_module, + kv_cache, + tokenizer, + allocator, + seed, + prompt[0..], + skip_llama3_encoding, + &stdout.interface, + ); +} + +fn bool_parser(in: []const u8) error{}!bool { + return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; } diff --git a/examples/simple_layer/main.zig b/examples/simple_layer/main.zig index a84ab05..e0b89a9 100644 --- a/examples/simple_layer/main.zig +++ b/examples/simple_layer/main.zig @@ -45,19 +45,14 @@ pub fn asyncMain() !void { // We manually produce a BufferStore. You would not normally do that. // A BufferStore is usually created by loading model data from a file. - var buffers: zml.aio.BufferStore.Buffers = .{}; - try buffers.put(arena, "weight", zml.HostBuffer.fromArray(&weights)); - try buffers.put(arena, "bias", zml.HostBuffer.fromArray(&bias)); - - // the actual BufferStore - const buffer_store: zml.aio.BufferStore = .{ - .arena = arena_state, - .buffers = buffers, - }; + var store: zml.aio.BufferStore = .init(allocator); + defer store.deinit(); + try store.buffers.put(store.arena.allocator(), "weight", zml.HostBuffer.fromArray(&weights)); + try store.buffers.put(store.arena.allocator(), "bias", zml.HostBuffer.fromArray(&bias)); // A clone of our model, consisting of shapes. We only need shapes for compiling. // We use the BufferStore to infer the shapes. - var model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store); + var model_shapes = try zml.aio.populateModel(Layer, allocator, store); model_shapes.weight = model_shapes.weight.withSharding(.{-1}); model_shapes.bias = model_shapes.bias.?.withSharding(.{-1}); @@ -68,7 +63,7 @@ pub fn asyncMain() !void { // Produce a bufferized weights struct from the fake BufferStore. // This is like the inferred shapes, but with actual values. // We will need to send those to the computation device later. - var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, buffer_store, arena, platform); + var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, store, arena, platform); defer zml.aio.unloadBuffers(&model_weights); // for good practice // Wait for compilation to finish