Introduce sharding attributes to Llama weights to enable Tensor Parallelism.
This commit is contained in:
parent
833ff5f28d
commit
fdb7da5c9b
@ -42,6 +42,20 @@ pub const LlamaLM = struct {
|
|||||||
layer.self_attn.rope_opts = options.rope_opts;
|
layer.self_attn.rope_opts = options.rope_opts;
|
||||||
layer.input_layernorm.eps = options.rms_norm_eps;
|
layer.input_layernorm.eps = options.rms_norm_eps;
|
||||||
layer.post_attention_layernorm.eps = options.rms_norm_eps;
|
layer.post_attention_layernorm.eps = options.rms_norm_eps;
|
||||||
|
layer.mlp.up_proj.weight = layer.mlp.up_proj.weight.withSharding(.{0});
|
||||||
|
layer.mlp.gate_proj.weight = layer.mlp.gate_proj.weight.withSharding(.{0});
|
||||||
|
layer.mlp.down_proj.weight = layer.mlp.down_proj.weight.withSharding(.{1});
|
||||||
|
|
||||||
|
layer.self_attn.q_proj.weight = layer.self_attn.q_proj.weight.withSharding(.{0});
|
||||||
|
layer.self_attn.k_proj.weight = layer.self_attn.k_proj.weight.withSharding(.{0});
|
||||||
|
layer.self_attn.v_proj.weight = layer.self_attn.v_proj.weight.withSharding(.{0});
|
||||||
|
layer.self_attn.o_proj.weight = layer.self_attn.o_proj.weight.withSharding(.{1});
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||||
|
// It currently crashes/compilation fails
|
||||||
|
if (options.gen_opts.topk == 1) {
|
||||||
|
self.lm_head.weight = self.lm_head.weight.withSharding(.{0});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,9 +288,9 @@ pub const SelfAttn = struct {
|
|||||||
) struct { Tensor, KvCache } {
|
) struct { Tensor, KvCache } {
|
||||||
// log.debug("x.shape: {}", .{x.shape()});
|
// log.debug("x.shape: {}", .{x.shape()});
|
||||||
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
const num_kv_heads = if (self.num_kv_heads > 0) self.num_kv_heads else self.num_heads;
|
||||||
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto });
|
var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = self.num_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto });
|
var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto });
|
var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = num_kv_heads, .hd = .auto }).withSharding(.{.h});
|
||||||
// Generate the attention mask.
|
// Generate the attention mask.
|
||||||
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
const kv_cache = kv_cache_ orelse initKvCache(k.shape());
|
||||||
const seq_len = kv_cache.k.dim(.k);
|
const seq_len = kv_cache.k.dim(.k);
|
||||||
@ -299,7 +313,7 @@ pub const SelfAttn = struct {
|
|||||||
|
|
||||||
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
|
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
|
||||||
if (token_index) |_| {
|
if (token_index) |_| {
|
||||||
std.debug.assert(q.dim(.q) == 1);
|
meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
|
||||||
k = new_kv_cache.keys();
|
k = new_kv_cache.keys();
|
||||||
v = new_kv_cache.values();
|
v = new_kv_cache.values();
|
||||||
}
|
}
|
||||||
@ -328,8 +342,8 @@ pub const KvCache = struct {
|
|||||||
pub fn init(kv_shape: zml.Shape) KvCache {
|
pub fn init(kv_shape: zml.Shape) KvCache {
|
||||||
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
|
// The KV-cache is initialized with ones to detect reads of uninitialized memory.
|
||||||
return .{
|
return .{
|
||||||
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()),
|
.k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||||
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()),
|
.v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}),
|
||||||
.layer_index = Tensor.scalar(-1, .i32),
|
.layer_index = Tensor.scalar(-1, .i32),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -145,6 +145,7 @@ pub fn asyncMain() !void {
|
|||||||
const compilation_options = zml.CompilationOptions{
|
const compilation_options = zml.CompilationOptions{
|
||||||
.cache_location = "/tmp/zml/llama/cache",
|
.cache_location = "/tmp/zml/llama/cache",
|
||||||
.xla_dump_to = "/tmp/zml/llama",
|
.xla_dump_to = "/tmp/zml/llama",
|
||||||
|
.sharding_enabled = true,
|
||||||
};
|
};
|
||||||
|
|
||||||
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
|
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
|
||||||
@ -234,7 +235,7 @@ pub fn asyncMain() !void {
|
|||||||
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
||||||
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
|
const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32);
|
||||||
const token_idx_shape = zml.Shape.init(.{}, .i32);
|
const token_idx_shape = zml.Shape.init(.{}, .i32);
|
||||||
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype);
|
const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h});
|
||||||
// needs to be optional
|
// needs to be optional
|
||||||
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
|
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
|
||||||
const rng_shape = Tensor.Rng.shape();
|
const rng_shape = Tensor.Rng.shape();
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user