Introduce sharding attributes to Llama weights to enable Tensor Parallelism.

This commit is contained in:
Foke Singh 2023-04-13 12:35:27 +00:00
parent 833ff5f28d
commit fdb7da5c9b
2 changed files with 22 additions and 7 deletions

View File

@ -42,6 +42,20 @@ pub const LlamaLM = struct {
layer.self_attn.rope_opts = options.rope_opts; layer.self_attn.rope_opts = options.rope_opts;
layer.input_layernorm.eps = options.rms_norm_eps; layer.input_layernorm.eps = options.rms_norm_eps;
layer.post_attention_layernorm.eps = options.rms_norm_eps; layer.post_attention_layernorm.eps = options.rms_norm_eps;
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0});
layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0});
layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0});
layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1});
}
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails
if (options.gen_opts.topk == 1) {
self.lm_head.weight = self.lm_head.weight.withSharding(.{0});
} }
} }
@ -274,9 +288,9 @@ pub const SelfAttn = struct {
) struct { Tensor, KvCache } { ) struct { Tensor, KvCache } {
// log.debug("x.shape: {}", .{x.shape()}); // log.debug("x.shape: {}", .{x.shape()});
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads; const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }); var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }); var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }); var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
// Generate the attention mask. // Generate the attention mask.
const kv_cache = kv_cache_ orelse initKvCache(k.shape()); const kv_cache = kv_cache_ orelse initKvCache(k.shape());
const seq_len = kv_cache.k.dim(.k); const seq_len = kv_cache.k.dim(.k);
@ -299,7 +313,7 @@ pub const SelfAttn = struct {
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32)); const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
if (token_index) |_| { if (token_index) |_| {
std.debug.assert(q.dim(.q) == 1); meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
k = new_kv_cache.keys(); k = new_kv_cache.keys();
v = new_kv_cache.values(); v = new_kv_cache.values();
} }
@ -328,8 +342,8 @@ pub const KvCache = struct {
pub fn init(kv_shape: zml.Shape) KvCache { pub fn init(kv_shape: zml.Shape) KvCache {
// The KV-cache is initialized with ones to detect reads of uninitialized memory. // The KV-cache is initialized with ones to detect reads of uninitialized memory.
return .{ return .{
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()), .k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()), .v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.layer_index = Tensor.scalar(-1, .i32), .layer_index = Tensor.scalar(-1, .i32),
}; };
} }

View File

@ -145,6 +145,7 @@ pub fn asyncMain() !void {
const compilation_options = zml.CompilationOptions{ const compilation_options = zml.CompilationOptions{
.cache_location = "/tmp/zml/llama/cache", .cache_location = "/tmp/zml/llama/cache",
.xla_dump_to = "/tmp/zml/llama", .xla_dump_to = "/tmp/zml/llama",
.sharding_enabled = true,
}; };
const platform = context.autoPlatform().withCompilationOptions(compilation_options); const platform = context.autoPlatform().withCompilationOptions(compilation_options);
@ -234,7 +235,7 @@ pub fn asyncMain() !void {
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`. // To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32); const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
const token_idx_shape = zml.Shape.init(.{}, .i32); const token_idx_shape = zml.Shape.init(.{}, .i32);
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype); const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h});
// needs to be optional // needs to be optional
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape); const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
const rng_shape = Tensor.Rng.shape(); const rng_shape = Tensor.Rng.shape();