From a811b2e1e393496142850157101ff6f5914e5c20 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Wed, 20 Mar 2024 13:37:19 +0000 Subject: [PATCH] llama: fix dimensions and data types Removed unnecessary batching dimension introduced by recent changes. Converted index outputs from i32 to u32 for token indices. Ensures Llama runs on CUDA and RoCM. Tested on CUDA. --- examples/llama/llama.zig | 10 ++++------ examples/llama/main.zig | 16 ++++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 0f57d87..6f6594f 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -86,7 +86,6 @@ pub const LlamaLM = struct { rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor.Rng } { stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index }); - var tokens = tokens_.withPartialTags(.{.s}); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache }); tokens, const new_rng = self.sampleTokens(self.lm_head, tokens, out, rng, self.gen_opts); @@ -115,7 +114,7 @@ pub const LlamaLM = struct { logits = logits.rename(.{ .d = .voc }); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, opts, rng); - return .{ next_tokens.reuseBuffer(tokens_), new_rng }; + return .{ next_tokens.convert(tokens_.dtype()).reuseBuffer(tokens_), new_rng }; } pub fn increment(_: u8, token_index: Tensor) Tensor { @@ -163,7 +162,6 @@ pub const Llama = struct { /// Returns result and updated KV cache. pub fn forward(self: Llama, tokens: Tensor, token_index: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } { const embeds = embed(self.embed_tokens, tokens); - var hidden = embeds; var updated_kv_cache = kv_cache; @@ -280,12 +278,12 @@ pub const SelfAttn = struct { // Note: in Pytorch it would be very inefficient to generate the full attn_mask, // then slice into it, but XLA is able to optimize this correctly. - attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .b = token_index.shape().dim(0), .coord = 1 }), .{}); + attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.s) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{}); // In self-attention, .s axis is used both for keys and queries. const pos_index = b: { - const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .b = token_index.shape().dim(0), .s = x.dim(.s) }, token_index.dtype())); - break :b temp.add(token_index.withTags(.{.b}).broad(temp.shape())); + const temp = Tensor.arange(.{ .end = x.dim(.s) }, token_index.dtype()).withTags(.{.s}).broad(zml.Shape.init(.{ .s = x.dim(.s) }, token_index.dtype())); + break :b temp.add(token_index.broad(temp.shape())); }; q = zml.nn.rope(q, pos_index, self.rope_opts); diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 6c2a9df..3cf2670 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -85,7 +85,8 @@ pub fn generateText( // prepare device buffers for the prefill tokens and the index var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer); defer prefill_tokens.deinit(); - var prefill_token_index = try zml.Buffer.fromSlice(platform, .{}, &[_]u32{0}); + var prefill_token_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0); + defer prefill_token_index.deinit(); // init RNG and prefill @@ -97,7 +98,7 @@ pub fn generateText( // Prepare for token-by-token generation var first_token_hostbuffer = [_]u32{prompt_tok[prompt_tok.len - 1]}; // start with the prompt's last token - var current_token = try zml.Buffer.fromSlice(platform, .{}, &first_token_hostbuffer); + var current_token = try zml.Buffer.fromSlice(platform, .{1}, &first_token_hostbuffer); defer current_token.deinit(); // Here we will copy the generated token from device @@ -121,6 +122,7 @@ pub fn generateText( // current token index needs to go into a zml.Buffer const token_index_buffer = &[_]u32{@intCast(prompt_tok.len + i)}; const token_index = try zml.Buffer.fromSlice(platform, .{}, token_index_buffer); + defer token_index.deinit(); // call to generate the next token @@ -256,13 +258,11 @@ pub fn asyncMain() !void { const dims = model_instance.model.shape(); const dtype = model_instance.model.embed_tokens.weight.dtype(); - const batch_size = 1; + const tokens_shape_prefill = zml.Shape.init(.{ .s = llama_options.max_seq_len }, .u32); + const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32); + const token_idx_shape = zml.Shape.init(.{}, .u32); - const tokens_shape_prefill = zml.Shape.init(.{ .b = batch_size, .s = llama_options.max_seq_len }, .u32); - const tokens_shape = zml.Shape.init(.{ .b = batch_size, .s = 1 }, .u32); - const token_idx_shape = zml.Shape.init(.{ .b = batch_size }, .u32); - - const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .b = batch_size, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h}); + const kv_shape = zml.Shape.init(.{ .layer = model_instance.model.layers.len, .k = dims.s, .h = dims.nkvh, .hd = dims.hd }, dtype).withSharding(.{.h}); const kv_cache_shape: zml.ShapeOf(llama.KvCache) = llama.KvCache.initShape(kv_shape); const rng_shape = zml.Tensor.Rng.shape();