const std = @import("std"); const zml = @import("zml"); const meta = zml.meta; const asynk = @import("async"); const flags = @import("tigerbeetle/flags"); const llama_mod = @import("llama.zig"); const LlamaLM = llama_mod.LlamaLM; const Llama = llama_mod.Llama; const KvCache = llama_mod.KvCache; const TransformerLayer = llama_mod.TransformerLayer; const SelfAttn = llama_mod.SelfAttn; const Buffer = zml.Buffer; const Tensor = zml.Tensor; const ShapeOf = zml.ShapeOf; const log = std.log.scoped(.llama); // set this to false to disable the verbose logging const show_mlir = true; pub const std_options = .{ .log_level = .err, .log_scope_levels = &[_]std.log.ScopeLevel{ .{ .scope = .pjrt, .level = if (show_mlir) .debug else .err }, .{ .scope = .zml_module, .level = if (show_mlir) .debug else .err }, .{ .scope = .zml, .level = if (show_mlir) .debug else .err }, .{ .scope = .llama, .level = if (show_mlir) .debug else .info }, }, }; pub fn generateText( llama: LlamaLM, mod_prefill: zml.module.ExeWithWeights(LlamaLM.forward), mod: zml.module.ExeWithWeights(LlamaLM.forward), tokenizer: zml.tokenizer.Tokenizer, allocator: std.mem.Allocator, seed: u128, prompt: []const u8, ) ![]const u8 { const prompt_tok = tokenizer.encode(allocator, prompt, .{}) catch unreachable; log.debug("Tokenized Prompt {d}", .{prompt_tok}); const dims = llama.model.shape(); const max_seq_len = dims.s; const token_buffer = try allocator.alloc(i32, @intCast(max_seq_len)); @memset(token_buffer, 0); for (0..prompt_tok.len) |i| { token_buffer[i] = @intCast(prompt_tok[i]); } const tracer_buffer = try allocator.alloc(u8, @intCast(max_seq_len)); defer allocator.free(token_buffer); defer allocator.free(tracer_buffer); defer allocator.free(prompt_tok); var output = std.ArrayList(u8).init(allocator); defer output.deinit(); var tokens = try zml.Buffer.fromSlice(mod.platform(), .{max_seq_len}, token_buffer); var token_index = try zml.Buffer.fromSlice(mod.platform(), .{}, &[_]i32{@intCast(prompt_tok.len - 1)}); var rng = try zml.Tensor.Rng.init(mod.platform(), seed); tokens, token_index, var kv_cache, rng = mod_prefill.call(.{ tokens, token_index, null, rng }); defer kv_cache.k.deinit(); defer kv_cache.v.deinit(); defer kv_cache.layer_index.deinit(); const tracer = zml.tools.Tracer.init("ai.zml.models.llama"); var decode_progress = prompt_tok.len; const output_tokens_len = max_seq_len - prompt_tok.len - 1; const start = std.time.microTimestamp(); const output_freq: u8 = 1; for (0..output_tokens_len) |i| { //_ = i; const frame_id = tracer.frameStart(try std.fmt.bufPrintZ(tracer_buffer, "Generate token {}/{}", .{ i + 1, output_tokens_len })); tokens, token_index, kv_cache, rng = mod.call(.{ tokens, token_index, kv_cache, rng }); if ((i + 1) % output_freq == 0) { const n = output.items.len; _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer)); try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..][0..output_freq]), .{}); decode_progress += output_freq; std.debug.print("{s}", .{output.items[n..]}); tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Decoded token {}/{} : {s}", .{ i + 1, output_tokens_len, output.items[n..] })); } else { tracer.frameEnd(frame_id, try std.fmt.bufPrintZ(tracer_buffer, "Generated token {}/{}", .{ i + 1, output_tokens_len })); } } std.debug.print("\n", .{}); const n = output.items.len; try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[decode_progress..]), .{}); std.debug.print("{s}\n", .{output.items[n..]}); const end = std.time.microTimestamp(); const duration = zml.meta.divFloat(f64, end - start, std.time.us_per_s); const speed = @as(f64, @floatFromInt(max_seq_len)) / duration; log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed }); _ = try tokens.toHost(std.mem.sliceAsBytes(token_buffer)); const end_index = std.mem.indexOfScalar(i32, token_buffer, 128001) orelse max_seq_len; output.clearRetainingCapacity(); try tokenizer.decodeWithOpts(&output, @ptrCast(token_buffer[0..end_index]), .{}); return output.toOwnedSlice(); } pub fn main() !void { try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{}); } pub fn asyncMain() !void { const CliArgs = struct { pub const help = \\ llama --model=llama3.7B.safetensors --tokenizer=vocab.json --num_layers=2 ; model: []const u8, tokenizer: ?[]const u8 = null, layer_start: u8 = 0, num_layers: ?u8 = null, seq_len: u32 = 256, topk: u32 = 2, temperature: u32 = 1, num_heads: ?i64 = null, num_kv_heads: ?i64 = null, rope_freq_base: ?i64 = null, prompt: ?[]const u8 = null, test_activations: ?[]const u8 = null, seed: ?u128 = null, }; log.info(" LLama was compiled with {}", .{@import("builtin").mode}); var gpa = std.heap.GeneralPurposeAllocator(.{ .thread_safe = true }){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); const tmp = try std.fs.openDirAbsolute("/tmp", .{}); try tmp.makePath("zml/llama/cache"); var context = try zml.Context.init(); defer context.deinit(); const compilation_options = zml.CompilationOptions{ .cache_location = "/tmp/zml/llama/cache", .xla_dump_to = "/tmp/zml/llama", .sharding_enabled = true, }; const platform = context.autoPlatform().withCompilationOptions(compilation_options); { // List available targets std.debug.print("\nSupported Platforms:\n", .{}); const selected_prefix = "✅"; const not_selected_prefix = "• "; const selected_postfix = "(AUTO-SELECTED)\n"; const not_selected_postfix = "\n"; for (zml.platform.available_targets) |target| { std.debug.print(" {s} {s} {s}", .{ if (target == platform.target) selected_prefix else not_selected_prefix, @tagName(target), if (target == platform.target) selected_postfix else not_selected_postfix, }); // now the platform's devices if (context.platforms.get(target)) |pfm| { for (pfm.getDevices(), 0..) |device, index| { const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api); std.debug.print(" ◦ #{d}: {s}\n", .{ index, deviceKind, }); // we only list 1 CPU device if (target == .cpu) break; } } } std.debug.print("\n", .{}); } var args = std.process.args(); const cli_args = flags.parse(&args, CliArgs); const model_file = cli_args.model; var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const model_arena = arena_state.allocator(); log.info("Model file: {s}", .{model_file}); var ts = try zml.aio.detectFormatAndOpen(allocator, model_file); defer ts.deinit(); var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts); const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model"); const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads; const rope_impl = if (ts.metadata("rope_impl", .string)) |val| std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).? else .sequential; const llama_options: llama_mod.LlamaOptions = .{ .max_seq_len = cli_args.seq_len, .num_kv_heads = num_kv_heads, .num_heads = num_heads, .gen_opts = .{ .topk = cli_args.topk, .temperature = @floatFromInt(cli_args.temperature), }, .rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5), .rope_opts = .{ .impl = rope_impl, .freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))), }, }; log.info("✅ Parsed llama config: {}", .{llama_options}); llama.init(llama_options); if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) { log.err("Model doesn't have an embbedded tokenizer, please provide a path to a tokenizer.", .{}); @panic("No tokenizer provided"); } const tokenizer_path = cli_args.tokenizer orelse cli_args.model; log.info(" Loading tokenizer from {s}", .{tokenizer_path}); var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path); log.info("✅ Loaded tokenizer from {s}", .{tokenizer_path}); defer tokenizer.deinit(); const dims = llama.model.shape(); const dtype = llama.lm_head.weight.dtype(); // Note: we compile the model without a batching dimension. // To do so, we would just need to add `.b = batch_size` to `token_shape` and `kv_shape`. const tokens_shape = zml.Shape.init(.{ .s = dims.s }, .i32); const token_idx_shape = zml.Shape.init(.{}, .i32); const kv_shape = zml.Shape.init(.{ .layer = llama.model.layers.len, .h = dims.nkvh, .k = dims.s, .hd = dims.hd }, dtype).withSharding(.{.h}); // needs to be optional const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape); const rng_shape = Tensor.Rng.shape(); const compile_start = std.time.milliTimestamp(); var fut_mod_prefill = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform }); var fut_mod = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform }); log.info("Starting loading weights", .{}); var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform); defer zml.aio.unloadBuffers(&llama_weights); log.info("✅ Done loading weights", .{}); log.info("✅ Llama model loaded from {s}", .{cli_args.model}); var llama_module_prefill = try (try fut_mod_prefill.await_()).prepare(allocator, llama_weights); defer llama_module_prefill.deinit(); var llama_module = try (try fut_mod.await_()).prepare(allocator, llama_weights); defer llama_module.deinit(); const compile_end = std.time.milliTimestamp(); log.info("✅ Compiled model in {d} milliseconds! \n", .{compile_end - compile_start}); const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily."; log.info("✅ Prompt: {s}\n", .{prompt}); const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp())); const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt); defer allocator.free(story); }