zml: Add support for Llama 3.2 text-only models. Implement transpose over embed_tokens as a replacement for missing lm_head and make lm_head optional for compatibility. Add repositories and executions to Bazel and update README.

This commit is contained in:
Foke Singh 2023-11-01 10:16:48 +00:00
parent 1c9749c25e
commit 237a877a29
4 changed files with 115 additions and 13 deletions

View File

@ -86,6 +86,56 @@ http_file(
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin", url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
) )
# Llama 3.2
huggingface.model(
name = "Meta-Llama-3.2-1B-Instruct",
build_file_content = """\
package(default_visibility = ["//visibility:public"])
filegroup(
name = "model",
srcs = ["model.safetensors"],
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
)
""",
commit = "9213176726f574b556790deb65791e0c5aa438b6",
includes = [
"model.safetensors",
"tokenizer.json",
],
model = "meta-llama/Llama-3.2-1B-Instruct",
)
use_repo(huggingface, "Meta-Llama-3.2-1B-Instruct")
huggingface.model(
name = "Meta-Llama-3.2-3B-Instruct",
build_file_content = """\
package(default_visibility = ["//visibility:public"])
filegroup(
name = "model",
srcs = glob(["*.safetensors"]) + ["model.safetensors.index.json"],
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
)
""",
commit = "0cb88a4f764b7a12671c53f0838cd831a0843b95",
includes = [
"*.safetensors",
"model.safetensors.index.json",
"tokenizer.json",
],
model = "meta-llama/Llama-3.2-3B-Instruct",
)
use_repo(huggingface, "Meta-Llama-3.2-3B-Instruct")
# Llama 3.1
huggingface.model( huggingface.model(
name = "Meta-Llama-3.1-8B-Instruct", name = "Meta-Llama-3.1-8B-Instruct",
build_file_content = """\ build_file_content = """\
@ -155,6 +205,8 @@ filegroup(
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
) )
use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0") use_repo(huggingface, "TinyLlama-1.1B-Chat-v1.0")
#OpenLLaMa
huggingface.model( huggingface.model(
name = "OpenLM-Research-OpenLLaMA-3B", name = "OpenLM-Research-OpenLLaMA-3B",
build_file_content = """\ build_file_content = """\

View File

@ -51,6 +51,39 @@ cc_binary(
deps = [":llama_lib"], deps = [":llama_lib"],
) )
cc_binary(
name = "Llama-3.2-1B-Instruct",
args = [
"--model=$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
"--num-heads=32",
"--num-kv-heads=8",
"--rope-freq-base=500000",
],
data = [
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
"@Meta-Llama-3.2-1B-Instruct//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary(
name = "Llama-3.2-3B-Instruct",
args = [
"--model=$(location @Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json)",
"--tokenizer=$(location @Meta-Llama-3.2-3B-Instruct//:tokenizer)",
"--num-heads=24",
"--num-kv-heads=8",
"--rope-freq-base=500000",
],
data = [
"@Meta-Llama-3.2-3B-Instruct//:model",
"@Meta-Llama-3.2-3B-Instruct//:model.safetensors.index.json",
"@Meta-Llama-3.2-3B-Instruct//:tokenizer",
],
deps = [":llama_lib"],
)
cc_binary( cc_binary(
name = "OpenLLaMA-3B", name = "OpenLLaMA-3B",
args = [ args = [

View File

@ -24,7 +24,7 @@ pub const LlamaOptions = struct {
/// Llama architecture, using huggingface transformers naming. /// Llama architecture, using huggingface transformers naming.
/// Dimensions of activations: {.b, .s, .d} /// Dimensions of activations: {.b, .s, .d}
pub const LlamaLM = struct { pub const LlamaLM = struct {
lm_head: zml.nn.Linear, lm_head: ?zml.nn.Linear = null,
model: Llama, model: Llama,
// Options controlling generation // Options controlling generation
@ -55,7 +55,9 @@ pub const LlamaLM = struct {
// TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled. // TODO(Corentin): Fix lm_head sharding when top-k sampling is enabled.
// It currently crashes/compilation fails // It currently crashes/compilation fails
if (options.gen_opts.topk == 1) { if (options.gen_opts.topk == 1) {
self.lm_head.weight = self.lm_head.weight.withSharding(.{0}); if (self.lm_head) |lm_head| {
self.lm_head.?.weight = lm_head.weight.withSharding(.{0});
}
} }
} }
@ -76,12 +78,12 @@ pub const LlamaLM = struct {
var tokens = tokens_.withPartialTags(.{.s}); var tokens = tokens_.withPartialTags(.{.s});
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache }); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
tokens, const new_rng = updateTokens(self.lm_head, tokens, token_index, out, rng, self.gen_opts); tokens, const new_rng = self.updateTokens(tokens, token_index, out, rng, self.gen_opts);
return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng }; return .{ tokens, increment(0, token_index), updated_kv_cache, new_rng };
} }
pub fn updateTokens( pub fn updateTokens(
lm_head: zml.nn.Linear, self: LlamaLM,
tokens_: Tensor, tokens_: Tensor,
token_index: Tensor, token_index: Tensor,
out_: Tensor, out_: Tensor,
@ -92,7 +94,11 @@ pub const LlamaLM = struct {
const out = out_.withPartialTags(.{ .s, .d }); const out = out_.withPartialTags(.{ .s, .d });
const next_token_pred = out.gatherValues(.s, token_index, .{}); const next_token_pred = out.gatherValues(.s, token_index, .{});
var logits = zml.call(lm_head, .forward, .{next_token_pred}); var logits = if (self.lm_head) |lm_head|
zml.call(lm_head, .forward, .{next_token_pred})
else
self.model.embed_tokens.weight.withTags(.{ .voc, .d }).dot(next_token_pred, .{.d});
if (logits.shape().hasTag(.voc) == null) if (logits.shape().hasTag(.voc) == null)
logits = logits.rename(.{ .d = .voc }); logits = logits.rename(.{ .d = .voc });

View File

@ -17,6 +17,8 @@ const ShapeOf = zml.ShapeOf;
const log = std.log.scoped(.llama); const log = std.log.scoped(.llama);
const eos_tokens: [3]i32 = .{ 128001, 128008, 128009 };
// set this to false to disable the verbose logging // set this to false to disable the verbose logging
const show_mlir = true; const show_mlir = true;
@ -71,6 +73,7 @@ pub fn generateText(
const start = std.time.microTimestamp(); const start = std.time.microTimestamp();
const output_freq: u8 = 1; const output_freq: u8 = 1;
var eos_index: ?usize = null;
for (0..output_tokens_len) |i| { for (0..output_tokens_len) |i| {
//_ = i; //_ = i;
const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len }));
@ -84,26 +87,34 @@ pub fn generateText(
decode_progress += output_freq; decode_progress += output_freq;
std.debug.print("{s}", .{output.items[n..]}); std.debug.print("{s}", .{output.items[n..]});
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] })); tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] }));
if (std.mem.indexOfAny(i32, token_buffer[decode_progress - output_freq ..], &eos_tokens)) |index| {
// Handle strange scenarios when eos id isn't the very next token after decode_progress
eos_index = decode_progress - output_freq + index;
break;
}
} else { } else {
tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len })); tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len }));
} }
} }
std.debug.print("\n", .{}); var total_token_count: usize = max_seq_len;
const n = output.items.len; const n = output.items.len;
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..]), .{}); if (eos_index) |end_idx| {
// count = eos index + 1
total_token_count = end_idx + 1;
}
const generated_token_count = total_token_count - prompt_tok.len;
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..total_token_count]), .{});
std.debug.print("{s}\n", .{output.items[n..]}); std.debug.print("{s}\n", .{output.items[n..]});
const end = std.time.microTimestamp(); const end = std.time.microTimestamp();
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s); const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
const speed = @as(f64, @floatFromInt(max_seq_len)) / duration; const speed = @as(f64, @floatFromInt(generated_token_count)) / duration;
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed }); log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ generated_token_count, duration, speed });
_ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer)); _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer));
const end_index = std.mem.indexOfScalar(i32, token_buffer, 128001) orelse max_seq_len;
output.clearRetainingCapacity(); output.clearRetainingCapacity();
try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..end_index]), .{}); try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..total_token_count]), .{});
return output.toOwnedSlice(); return output.toOwnedSlice();
} }
@ -199,7 +210,7 @@ pub fn asyncMain() !void {
defer tokenizer.deinit(); defer tokenizer.deinit();
const dims = llama.model.shape(); const dims = llama.model.shape();
const dtype = llama.lm_head.weight.dtype(); const dtype = llama.model.embed_tokens.weight.dtype();
// Note: we compile the model without a batching dimension. // Note: we compile the model without a batching dimension.
// To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`. // To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`.