zml: Add support for Llama 3.2 text-only models. Implement transpose over embed_tokens as a replacement for missing lm_head and make lm_head optional for compatibility. Add repositories and executions to Bazel and update README.
This commit is contained in:
parent
1c9749c25e
commit
237a877a29
@ -86,6 +86,56 @@ http_file(
|
|||||||
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Llama 3.2
|
||||||
|
huggingface.model(
|
||||||
|
name = "Meta-Llama-3.2-1B-Instruct",
|
||||||
|
build_file_content = """\
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
filegroup(
|
||||||
|
name = "model",
|
||||||
|
srcs = ["model.safetensors"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "tokenizer",
|
||||||
|
srcs = ["tokenizer.json"],
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
commit = "9213176726f574b556790deb65791e0c5aa438b6",
|
||||||
|
includes = [
|
||||||
|
"model.safetensors",
|
||||||
|
"tokenizer.json",
|
||||||
|
],
|
||||||
|
model = "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
)
|
||||||
|
use_repo(huggingface, "Meta-Llama-3.2-1B-Instruct")
|
||||||
|
|
||||||
|
huggingface.model(
|
||||||
|
name = "Meta-Llama-3.2-3B-Instruct",
|
||||||
|
build_file_content = """\
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
filegroup(
|
||||||
|
name = "model",
|
||||||
|
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "tokenizer",
|
||||||
|
srcs = ["tokenizer.json"],
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
||||||
|
includes = [
|
||||||
|
"*.safetensors",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
],
|
||||||
|
model = "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
)
|
||||||
|
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
|
||||||
|
|
||||||
|
|
||||||
|
# Llama 3.1
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "Meta-Llama-3.1-8B-Instruct",
|
name = "Meta-Llama-3.1-8B-Instruct",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
@ -155,6 +205,8 @@ filegroup(
|
|||||||
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
)
|
)
|
||||||
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
|
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
|
||||||
|
|
||||||
|
#OpenLLaMa
|
||||||
huggingface.model(
|
huggingface.model(
|
||||||
name = "OpenLM-Research-OpenLLaMA-3B",
|
name = "OpenLM-Research-OpenLLaMA-3B",
|
||||||
build_file_content = """\
|
build_file_content = """\
|
||||||
|
|||||||
@ -51,6 +51,39 @@ cc_binary(
|
|||||||
deps = [":llama_lib"],
|
deps = [":llama_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "Llama-3.2-1B-Instruct",
|
||||||
|
args = [
|
||||||
|
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||||
|
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
||||||
|
"--num-heads=32",
|
||||||
|
"--num-kv-heads=8",
|
||||||
|
"--rope-freq-base=500000",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||||
|
"@Meta-Llama-3.2-1B-Instruct//:tokenizer",
|
||||||
|
],
|
||||||
|
deps = [":llama_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "Llama-3.2-3B-Instruct",
|
||||||
|
args = [
|
||||||
|
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
||||||
|
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)",
|
||||||
|
"--num-heads=24",
|
||||||
|
"--num-kv-heads=8",
|
||||||
|
"--rope-freq-base=500000",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"@Meta-Llama-3.2-3B-Instruct//:model",
|
||||||
|
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
||||||
|
"@Meta-Llama-3.2-3B-Instruct//:tokenizer",
|
||||||
|
],
|
||||||
|
deps = [":llama_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "OpenLLaMA-3B",
|
name = "OpenLLaMA-3B",
|
||||||
args = [
|
args = [
|
||||||
|
|||||||
@ -24,7 +24,7 @@ pub const LlamaOptions = struct {
|
|||||||
/// Llama architecture, using huggingface transformers naming.
|
/// Llama architecture, using huggingface transformers naming.
|
||||||
/// Dimensions of activations: {.b, .s, .d}
|
/// Dimensions of activations: {.b, .s, .d}
|
||||||
pub const LlamaLM = struct {
|
pub const LlamaLM = struct {
|
||||||
lm_head: zml.nn.Linear,
|
lm_head: ?zml.nn.Linear = null,
|
||||||
model: Llama,
|
model: Llama,
|
||||||
|
|
||||||
// Options controlling generation
|
// Options controlling generation
|
||||||
@ -55,7 +55,9 @@ pub const LlamaLM = struct {
|
|||||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||||
// It currently crashes/compilation fails
|
// It currently crashes/compilation fails
|
||||||
if (options.gen_opts.topk == 1) {
|
if (options.gen_opts.topk == 1) {
|
||||||
self.lm_head.weight = self.lm_head.weight.withSharding(.{0});
|
if (self.lm_head) |lm_head| {
|
||||||
|
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,12 +78,12 @@ pub const LlamaLM = struct {
|
|||||||
|
|
||||||
var tokens = tokens_.withPartialTags(.{.s});
|
var tokens = tokens_.withPartialTags(.{.s});
|
||||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
||||||
tokens, const new_rng = updateTokens(self.lm_head, tokens, token_index, out, rng, self.gen_opts);
|
tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts);
|
||||||
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
|
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn updateTokens(
|
pub fn updateTokens(
|
||||||
lm_head: zml.nn.Linear,
|
self: LlamaLM,
|
||||||
tokens_: Tensor,
|
tokens_: Tensor,
|
||||||
token_index: Tensor,
|
token_index: Tensor,
|
||||||
out_: Tensor,
|
out_: Tensor,
|
||||||
@ -92,7 +94,11 @@ pub const LlamaLM = struct {
|
|||||||
const out = out_.withPartialTags(.{ .s, .d });
|
const out = out_.withPartialTags(.{ .s, .d });
|
||||||
|
|
||||||
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
||||||
var logits = zml.call(lm_head, .forward, .{next_token_pred});
|
var logits = if (self.lm_head) |lm_head|
|
||||||
|
zml.call(lm_head, .forward, .{next_token_pred})
|
||||||
|
else
|
||||||
|
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d});
|
||||||
|
|
||||||
if (logits.shape().hasTag(.voc) == null)
|
if (logits.shape().hasTag(.voc) == null)
|
||||||
logits = logits.rename(.{ .d = .voc });
|
logits = logits.rename(.{ .d = .voc });
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,8 @@ const ShapeOf = zml.ShapeOf;
|
|||||||
|
|
||||||
const log = std.log.scoped(.llama);
|
const log = std.log.scoped(.llama);
|
||||||
|
|
||||||
|
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
|
||||||
|
|
||||||
// set this to false to disable the verbose logging
|
// set this to false to disable the verbose logging
|
||||||
const show_mlir = true;
|
const show_mlir = true;
|
||||||
|
|
||||||
@ -71,6 +73,7 @@ pub fn generateText(
|
|||||||
|
|
||||||
const start = std.time.microTimestamp();
|
const start = std.time.microTimestamp();
|
||||||
const output_freq: u8 = 1;
|
const output_freq: u8 = 1;
|
||||||
|
var eos_index: ?usize = null;
|
||||||
for (0..output_tokens_len) |i| {
|
for (0..output_tokens_len) |i| {
|
||||||
//_ = i;
|
//_ = i;
|
||||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
||||||
@ -84,26 +87,34 @@ pub fn generateText(
|
|||||||
decode_progress += output_freq;
|
decode_progress += output_freq;
|
||||||
std.debug.print("{s}", .{output.items[n..]});
|
std.debug.print("{s}", .{output.items[n..]});
|
||||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
|
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
|
||||||
|
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
|
||||||
|
// Handle strange scenarios when eos id isn't the very next token after decode_progress
|
||||||
|
eos_index = decode_progress - output_freq + index;
|
||||||
|
break;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std.debug.print("\n", .{});
|
var total_token_count: usize = max_seq_len;
|
||||||
|
|
||||||
const n = output.items.len;
|
const n = output.items.len;
|
||||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..]), .{});
|
if (eos_index) |end_idx| {
|
||||||
|
// count = eos index + 1
|
||||||
|
total_token_count = end_idx + 1;
|
||||||
|
}
|
||||||
|
const generated_token_count = total_token_count - prompt_tok.len;
|
||||||
|
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
|
||||||
std.debug.print("{s}\n", .{output.items[n..]});
|
std.debug.print("{s}\n", .{output.items[n..]});
|
||||||
const end = std.time.microTimestamp();
|
const end = std.time.microTimestamp();
|
||||||
|
|
||||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||||
const speed = @as(f64, @floatFromInt(max_seq_len)) / duration;
|
const speed = @as(f64, @floatFromInt(generated_token_count)) / duration;
|
||||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed });
|
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed });
|
||||||
|
|
||||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
||||||
const end_index = std.mem.indexOfScalar(i32, token_buffer, 128001) orelse max_seq_len;
|
|
||||||
output.clearRetainingCapacity();
|
output.clearRetainingCapacity();
|
||||||
|
|
||||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..end_index]), .{});
|
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
|
||||||
return output.toOwnedSlice();
|
return output.toOwnedSlice();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +210,7 @@ pub fn asyncMain() !void {
|
|||||||
defer tokenizer.deinit();
|
defer tokenizer.deinit();
|
||||||
|
|
||||||
const dims = llama.model.shape();
|
const dims = llama.model.shape();
|
||||||
const dtype = llama.lm_head.weight.dtype();
|
const dtype = llama.model.embed_tokens.weight.dtype();
|
||||||
|
|
||||||
// Note: we compile the model without a batching dimension.
|
// Note: we compile the model without a batching dimension.
|
||||||
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user