zml: Add support for Llama 3.2 text-only models. Implement transpose over embed_tokens as a replacement for missing lm_head and make lm_head optional for compatibility. Add repositories and executions to Bazel and update README.
This commit is contained in:
parent
1c9749c25e
commit
237a877a29
@ -86,6 +86,56 @@ http_file(
|
||||
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
||||
)
|
||||
|
||||
# Llama 3.2
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.2-1B-Instruct",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
)
|
||||
""",
|
||||
commit = "9213176726f574b556790deb65791e0c5aa438b6",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.json",
|
||||
],
|
||||
model = "meta-llama/Llama-3.2-1B-Instruct",
|
||||
)
|
||||
use_repo(huggingface, "Meta-Llama-3.2-1B-Instruct")
|
||||
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.2-3B-Instruct",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
)
|
||||
""",
|
||||
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
||||
includes = [
|
||||
"*.safetensors",
|
||||
"model.safetensors.index.json",
|
||||
"tokenizer.json",
|
||||
],
|
||||
model = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
)
|
||||
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
|
||||
|
||||
|
||||
# Llama 3.1
|
||||
huggingface.model(
|
||||
name = "Meta-Llama-3.1-8B-Instruct",
|
||||
build_file_content = """\
|
||||
@ -155,6 +205,8 @@ filegroup(
|
||||
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
)
|
||||
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
#OpenLLaMa
|
||||
huggingface.model(
|
||||
name = "OpenLM-Research-OpenLLaMA-3B",
|
||||
build_file_content = """\
|
||||
|
||||
@ -51,6 +51,39 @@ cc_binary(
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "Llama-3.2-1B-Instruct",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
||||
"--num-heads=32",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
|
||||
"@Meta-Llama-3.2-1B-Instruct//:tokenizer",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "Llama-3.2-3B-Instruct",
|
||||
args = [
|
||||
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)",
|
||||
"--num-heads=24",
|
||||
"--num-kv-heads=8",
|
||||
"--rope-freq-base=500000",
|
||||
],
|
||||
data = [
|
||||
"@Meta-Llama-3.2-3B-Instruct//:model",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
|
||||
"@Meta-Llama-3.2-3B-Instruct//:tokenizer",
|
||||
],
|
||||
deps = [":llama_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "OpenLLaMA-3B",
|
||||
args = [
|
||||
|
||||
@ -24,7 +24,7 @@ pub const LlamaOptions = struct {
|
||||
/// Llama architecture, using huggingface transformers naming.
|
||||
/// Dimensions of activations: {.b, .s, .d}
|
||||
pub const LlamaLM = struct {
|
||||
lm_head: zml.nn.Linear,
|
||||
lm_head: ?zml.nn.Linear = null,
|
||||
model: Llama,
|
||||
|
||||
// Options controlling generation
|
||||
@ -55,7 +55,9 @@ pub const LlamaLM = struct {
|
||||
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
|
||||
// It currently crashes/compilation fails
|
||||
if (options.gen_opts.topk == 1) {
|
||||
self.lm_head.weight = self.lm_head.weight.withSharding(.{0});
|
||||
if (self.lm_head) |lm_head| {
|
||||
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,12 +78,12 @@ pub const LlamaLM = struct {
|
||||
|
||||
var tokens = tokens_.withPartialTags(.{.s});
|
||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
||||
tokens, const new_rng = updateTokens(self.lm_head, tokens, token_index, out, rng, self.gen_opts);
|
||||
tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts);
|
||||
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
|
||||
}
|
||||
|
||||
pub fn updateTokens(
|
||||
lm_head: zml.nn.Linear,
|
||||
self: LlamaLM,
|
||||
tokens_: Tensor,
|
||||
token_index: Tensor,
|
||||
out_: Tensor,
|
||||
@ -92,7 +94,11 @@ pub const LlamaLM = struct {
|
||||
const out = out_.withPartialTags(.{ .s, .d });
|
||||
|
||||
const next_token_pred = out.gatherValues(.s, token_index, .{});
|
||||
var logits = zml.call(lm_head, .forward, .{next_token_pred});
|
||||
var logits = if (self.lm_head) |lm_head|
|
||||
zml.call(lm_head, .forward, .{next_token_pred})
|
||||
else
|
||||
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d});
|
||||
|
||||
if (logits.shape().hasTag(.voc) == null)
|
||||
logits = logits.rename(.{ .d = .voc });
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@ const ShapeOf = zml.ShapeOf;
|
||||
|
||||
const log = std.log.scoped(.llama);
|
||||
|
||||
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
|
||||
|
||||
// set this to false to disable the verbose logging
|
||||
const show_mlir = true;
|
||||
|
||||
@ -71,6 +73,7 @@ pub fn generateText(
|
||||
|
||||
const start = std.time.microTimestamp();
|
||||
const output_freq: u8 = 1;
|
||||
var eos_index: ?usize = null;
|
||||
for (0..output_tokens_len) |i| {
|
||||
//_ = i;
|
||||
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
@ -84,26 +87,34 @@ pub fn generateText(
|
||||
decode_progress += output_freq;
|
||||
std.debug.print("{s}", .{output.items[n..]});
|
||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
|
||||
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
|
||||
// Handle strange scenarios when eos id isn't the very next token after decode_progress
|
||||
eos_index = decode_progress - output_freq + index;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
|
||||
}
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
|
||||
var total_token_count: usize = max_seq_len;
|
||||
const n = output.items.len;
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..]), .{});
|
||||
if (eos_index) |end_idx| {
|
||||
// count = eos index + 1
|
||||
total_token_count = end_idx + 1;
|
||||
}
|
||||
const generated_token_count = total_token_count - prompt_tok.len;
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
|
||||
std.debug.print("{s}\n", .{output.items[n..]});
|
||||
const end = std.time.microTimestamp();
|
||||
|
||||
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||||
const speed = @as(f64, @floatFromInt(max_seq_len)) / duration;
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed });
|
||||
const speed = @as(f64, @floatFromInt(generated_token_count)) / duration;
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed });
|
||||
|
||||
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
|
||||
const end_index = std.mem.indexOfScalar(i32, token_buffer, 128001) orelse max_seq_len;
|
||||
output.clearRetainingCapacity();
|
||||
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..end_index]), .{});
|
||||
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
|
||||
return output.toOwnedSlice();
|
||||
}
|
||||
|
||||
@ -199,7 +210,7 @@ pub fn asyncMain() !void {
|
||||
defer tokenizer.deinit();
|
||||
|
||||
const dims = llama.model.shape();
|
||||
const dtype = llama.lm_head.weight.dtype();
|
||||
const dtype = llama.model.embed_tokens.weight.dtype();
|
||||
|
||||
// Note: we compile the model without a batching dimension.
|
||||
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user