Introduce sharding attributes to Llama weights to enable Tensor Parallelism.

This commit is contained in:
Foke Singh 2023-04-13 12:35:27 +00:00
parent 833ff5f28d
commit fdb7da5c9b
2 changed files with 22 additions and 7 deletions

View File

@ -42,6 +42,20 @@ pub const LlamaLM = struct {
layer.self_attn.rope_opts = options.rope_opts;
layer.input_layernorm.eps = options.rms_norm_eps;
layer.post_attention_layernorm.eps = options.rms_norm_eps;
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0});
layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0});
layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0});
layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1});
}
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails
if (options.gen_opts.topk == 1) {
self.lm_head.weight = self.lm_head.weight.withSharding(.{0});
}
}
@ -274,9 +288,9 @@ pub const SelfAttn = struct {
) struct { Tensor, KvCache } {
// log.debug("x.shape: {}", .{x.shape()});
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto });
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto });
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto });
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
// Generate the attention mask.
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
const seq_len = kv_cache.k.dim(.k);
@ -299,7 +313,7 @@ pub const SelfAttn = struct {
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
if (token_index) |_| {
std.debug.assert(q.dim(.q) == 1);
meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
k = new_kv_cache.keys();
v = new_kv_cache.values();
}
@ -328,8 +342,8 @@ pub const KvCache = struct {
pub fn init(kv_shape: zml.Shape) KvCache {
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
return .{
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()),
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()),
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
.layer_index = Tensor.scalar(-1, .i32),
};
}

View File

@ -145,6 +145,7 @@ pub fn asyncMain() !void {
const compilation_options = zml.CompilationOptions{
.cache_location = "/tmp/zml/llama/cache",
.xla_dump_to = "/tmp/zml/llama",
.sharding_enabled = true,
};
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
@ -234,7 +235,7 @@ pub fn asyncMain() !void {
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
const token_idx_shape = zml.Shape.init(.{}, .i32);
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype);
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h});
// needs to be optional
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
const rng_shape = Tensor.Rng.shape();