diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 0f57d87..6f6594f 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -86,7 +86,6 @@ pub const LlamaLM = struct { rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor.Rng } { stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index }); - var tokens = tokens_.withPartialTags(.{.s}); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache }); tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts); @@ -115,7 +114,7 @@ pub const LlamaLM = struct { logits = logits.rename(.{ .d = .voc }); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng); - return .{ next_tokens.reuseBuffer(tokens_), new_rng }; + return .{ next_tokens.convert(tokens_.dtype()).reuseBuffer(tokens_), new_rng }; } pub fn increment(_: u8, token_index: Tensor) Tensor { @@ -163,7 +162,6 @@ pub const Llama = struct { /// Returns result and updated KV cache. pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } { const embeds = embed(self.embed_tokens, tokens); - var hidden = embeds; var updated_kv_cache = kv_cache; @@ -280,12 +278,12 @@ pub const SelfAttn = struct { // Note: in Pytorch it would be very inefficient to generate the full attn_mask, // then slice into it, but XLA is able to optimize this correctly. - attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .b = token_index.shape().dim(0), .coord = 1 }), .{}); + attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{}); // In self-attention, .s axis is used both for keys and queries. const pos_index = b: { - const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .b = token_index.shape().dim(0), .s = x.dim(.s) }, token_index.dtype())); - break :b temp.add(token_index.withTags(.{.b}).broad(temp.shape())); + const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype())); + break :b temp.add(token_index.broad(temp.shape())); }; q = zml.nn.rope(q, pos_index, self.rope_opts); diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 6c2a9df..3cf2670 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -85,7 +85,8 @@ pub fn generateText( // prepare device buffers for the prefill tokens and the index var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer); defer prefill_tokens.deinit(); - var prefill_token_index = try zml.Buffer.fromSlice(platform, .{}, &[_]u32{0}); + var prefill_token_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0); + defer prefill_token_index.deinit(); // init RNG and prefill @@ -97,7 +98,7 @@ pub fn generateText( // Prepare for token-by-token generation var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token - var current_token = try zml.Buffer.fromSlice(platform, .{}, &first_token_hostbuffer); + var current_token = try zml.Buffer.fromSlice(platform, .{1}, &first_token_hostbuffer); defer current_token.deinit(); // Here we will copy the generated token from device @@ -121,6 +122,7 @@ pub fn generateText( // current token index needs to go into a zml.Buffer const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)}; const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer); + defer token_index.deinit(); // call to generate the next token @@ -256,13 +258,11 @@ pub fn asyncMain() !void { const dims = model_instance.model.shape(); const dtype = model_instance.model.embed_tokens.weight.dtype(); - const batch_size = 1; + const tokens_shape_prefill = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32); + const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32); + const token_idx_shape = zml.Shape.init(.{}, .u32); - const tokens_shape_prefill = zml.Shape.init(.{ .b = batch_size, .s = llama_options.max_seq_len }, .u32); - const tokens_shape = zml.Shape.init(.{ .b = batch_size, .s = 1 }, .u32); - const token_idx_shape = zml.Shape.init(.{ .b = batch_size }, .u32); - - const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .b = batch_size, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h}); + const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h}); const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape); const rng_shape = zml.Tensor.Rng.shape();