diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index f01ab91..09a11c6 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -42,6 +42,20 @@ pub const LlamaLM = struct { layer.self_attn.rope_opts = options.rope_opts; layer.input_layernorm.eps = options.rms_norm_eps; layer.post_attention_layernorm.eps = options.rms_norm_eps; + layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0}); + layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0}); + layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1}); + + layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0}); + layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0}); + layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0}); + layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1}); + } + + // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. + // It currently crashes/compilation fails + if (options.gen_opts.topk == 1) { + self.lm_head.weight = self.lm_head.weight.withSharding(.{0}); } } @@ -274,9 +288,9 @@ pub const SelfAttn = struct { ) struct { Tensor, KvCache } { // log.debug("x.shape: {}", .{x.shape()}); const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads; - var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }); - var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }); - var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }); + var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h}); + var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); + var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h}); // Generate the attention mask. const kv_cache = kv_cache_ orelse initKvCache(k.shape()); const seq_len = kv_cache.k.dim(.k); @@ -299,7 +313,7 @@ pub const SelfAttn = struct { const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32)); if (token_index) |_| { - std.debug.assert(q.dim(.q) == 1); + meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)}); k = new_kv_cache.keys(); v = new_kv_cache.values(); } @@ -328,8 +342,8 @@ pub const KvCache = struct { pub fn init(kv_shape: zml.Shape) KvCache { // The KV-cache is initialized with ones to detect reads of uninitialized memory. return .{ - .k = Tensor.constant(kv_shape, kv_shape.dtype().one()), - .v = Tensor.constant(kv_shape, kv_shape.dtype().one()), + .k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), + .v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .layer_index = Tensor.scalar(-1, .i32), }; } diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 43a3d8f..7be0776 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -145,6 +145,7 @@ pub fn asyncMain() !void { const compilation_options = zml.CompilationOptions{ .cache_location = "/tmp/zml/llama/cache", .xla_dump_to = "/tmp/zml/llama", + .sharding_enabled = true, }; const platform = context.autoPlatform().withCompilationOptions(compilation_options); @@ -234,7 +235,7 @@ pub fn asyncMain() !void { // To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`. const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32); const token_idx_shape = zml.Shape.init(.{}, .i32); - const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype); + const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h}); // needs to be optional const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape); const rng_shape = Tensor.Rng.shape();