const std = @import("std"); const builtin = @import("builtin"); const async = @import("async"); const clap = @import("clap"); const stdx = @import("stdx"); const zml = @import("zml"); const Buffer = zml.Buffer; const Tensor = zml.Tensor; const ShapeOf = zml.ShapeOf; const GptOss = @import("GptOss.zig"); const log = std.log.scoped(.GptOss); pub const std_options: std.Options = .{ .log_level = .info, .logFn = async.logFn(std.log.defaultLog), }; const cli_params = clap.parseParamsComptime( \\--help print this help \\--prompt the prompt \\--hf-model-path path to the directory containing model weights, config and tokenizer \\--seed random seed (optional) \\--seq-len max sequence length \\--prompt-len max prompt length \\--temperature temperature (default 1.0) \\--topk topk (default 10) \\--expert-budget token budget per expert \\--platform-options platform options, using Zon syntax, eg '.{.cuda=.{.allocator=.{.async=.{.memory_fraction=0.95}}}}' \\--nochat skip prompt template \\--sharding default: true: sharding on or off ); pub fn tokenizePrompt(tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, no_chat: bool, out: []u32) ![]u32 { var encoder = try tokenizer.encoder(); defer encoder.deinit(); if (no_chat) { const tokens = try encoder.encode(prompt); if (tokens.len > out.len) return error.PromptTooLong; @memcpy(out[0..tokens.len], tokens); return out[0..tokens.len]; } const start_header = tokenizer.tokenToId("<|start|>") orelse return error.NoSuchToken; const end_header_start_message = tokenizer.tokenToId("<|message|>") orelse return error.NoSuchToken; const end_message = tokenizer.tokenToId("<|end|>") orelse return error.NoSuchToken; var tokens: std.ArrayList(u32) = .initBuffer(out); const system_prompt = try encoder.encode("You are ChatGPT, a large language model trained by OpenAI.\n"); if (system_prompt.len + 4 > tokens.unusedCapacitySlice().len) return error.PromptTooLong; tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("system").?, end_header_start_message }); tokens.appendSliceAssumeCapacity(system_prompt); tokens.appendAssumeCapacity(end_message); const user_prompt = try encoder.encode(prompt); if (user_prompt.len + 9 > tokens.unusedCapacitySlice().len) return error.PromptTooLong; tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("user").?, end_header_start_message }); tokens.appendSliceAssumeCapacity(user_prompt); tokens.appendSliceAssumeCapacity(&.{ end_message, start_header, tokenizer.tokenToId("assistant").?, tokenizer.tokenToId("<|channel|>") orelse return error.NoSuchToken, tokenizer.tokenToId("analysis") orelse return error.NoSuchToken, end_header_start_message, }); return tokens.items; } pub fn generateText( config: GptOss.Config, options: GptOss.Options, mod_prefill: zml.ModuleExe(GptOss.forward), mod_generate: zml.ModuleExe(GptOss.forward), kv_cache_: zml.Bufferized(GptOss.KvCache), tokenizer: zml.tokenizer.Tokenizer, allocator: std.mem.Allocator, seed: u128, prompt_tok: []const u32, output: *std.Io.Writer, ) !void { var tokenizer_decoder = try tokenizer.decoder(); defer tokenizer_decoder.deinit(); const platform = mod_generate.platform(); // init RNG and buffers var rng = try zml.Tensor.Rng.init(platform, seed); var generated_token_buffer = [_]u32{undefined}; var current_token, var kv_cache = prefill: { // prepare device buffers for the prefill tokens and their positions const prefill_buffer = try allocator.alloc(u32, options.max_prompt_len); @memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok); var prefill_tokens = try zml.Buffer.fromSlice(platform, .{options.max_prompt_len}, prefill_buffer); defer prefill_tokens.deinit(); var prefill_token_pos = try zml.Buffer.scalar(platform, prompt_tok.len, .u32); defer prefill_token_pos.deinit(); const first_token, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, .{ .prefill = prefill_token_pos }, kv_cache_, rng }); // extract the first generated token _ = try first_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer)); log.warn("first_token: {d}", .{generated_token_buffer[0]}); break :prefill .{ first_token, kv_cache }; }; defer zml.aio.unloadBuffers(&kv_cache); defer current_token.deinit(); const output_tokens_len = options.max_seq_len - prompt_tok.len - 1; const start = std.time.microTimestamp(); // One token has already been generated by the prefill. var num_tokens_generated: usize = 1; generation: for (0..output_tokens_len + 1) |i| { // collect and print generated sequence num_tokens_generated += 1; const generated_token = generated_token_buffer[0]; if (try tokenizer_decoder.next(generated_token)) |chunk| { try output.writeAll(chunk); } // check for eos if (i == output_tokens_len) break :generation; switch (config.eos_token_id.value) { .int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation, .ints => |eos_list| { for (eos_list) |eos| { if (generated_token == @as(u32, @intCast(eos))) break :generation; } }, } // current token pos needs to go into a zml.Buffer const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)}; const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer); defer token_pos.deinit(); // call to generate the next token current_token, kv_cache, rng = mod_generate.call(.{ current_token, .{ .gen = token_pos }, kv_cache, rng }); // extract the generated token from the buffer _ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer)); } const end = std.time.microTimestamp(); const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s); const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration; log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed }); } pub fn main() !void { try async.AsyncThread.main(std.heap.smp_allocator, asyncMain); } pub fn asyncMain() !void { log.info(" GptOss was compiled with {}", .{@import("builtin").mode}); var allocator: std.mem.Allocator = alloc: { if (builtin.mode == .Debug) { var dbg_alloc: std.heap.DebugAllocator(.{ .never_unmap = true, .retain_metadata = true, }) = .init; break :alloc dbg_alloc.allocator(); } break :alloc std.heap.smp_allocator; }; const cli = ClapBoilerplate.parseCli(allocator); defer cli.deinit(); const hf_model_path = cli.args.@"hf-model-path" orelse { log.err("Missing --hf-model-path", .{}); return; }; const config = config: { var arena: std.heap.ArenaAllocator = .init(allocator); defer arena.deinit(); const model_config_path = try std.fs.path.join(arena.allocator(), &.{ hf_model_path, "config.json" }); var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only }); defer config_json_file.close() catch unreachable; var config_reader = config_json_file.reader(try arena.allocator().alloc(u8, 256)); var reader = std.json.Reader.init(allocator, &config_reader.interface); defer reader.deinit(); var config = try std.json.parseFromTokenSourceLeaky(GptOss.Config, arena.allocator(), &reader, .{ .ignore_unknown_fields = true }); // From generation_config.json config.eos_token_id = .{ .value = .{ .ints = &.{ 200002, 199999, 200012 } } }; break :config config; }; var context = try zml.Context.init(); defer context.deinit(); // initialize ZML platform const platform: zml.Platform = platform: { const arena: std.heap.ArenaAllocator = .init(allocator); defer arena.deinit(); // eg: --platform-options='.{.cuda=.{.allocator=.{.bfc=.{.memory_fraction=0.99}}}}' // eg: --platform-options='.{.cpu=.{.device_count=8}}' const platform_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(cli.args.@"platform-options" orelse ".{}"), null, .{ .free_on_error = false }) catch |err| { log.err("Failed to parse --platform-options as json ({}): {s}", .{ err, cli.args.@"platform-options".? }); return err; }; const compilation_options = zml.CompilationOptions{ .xla_dump_to = "/tmp/zml/gpt_oss", .sharding_enabled = cli.args.sharding orelse true, }; const platform = context .autoPlatform(platform_opts) .withCompilationOptions(compilation_options); context.printAvailablePlatforms(platform); break :platform platform; }; const options: GptOss.Options = .{ .max_seq_len = cli.args.@"seq-len" orelse 8192, .max_prompt_len = cli.args.@"prompt-len" orelse 256, .tokens_per_expert_ratio = cli.args.@"expert-budget" orelse 4.0, .sampling_strategy = .{ .topk = cli.args.topk orelse 10, .temperature = 1.0, }, }; var compiler_arena = std.heap.ArenaAllocator.init(allocator); defer compiler_arena.deinit(); const model_weights_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" }); defer allocator.free(model_weights_path); var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path); defer store.deinit(); const model: GptOss = try GptOss.init(compiler_arena.allocator(), store, config, options); const tokens_shape_prefill = zml.Shape.init(.{ .s = options.max_prompt_len }, .u32); const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32); const dtype = model.model.embed_tokens.weight.dtype(); const kv_shape = zml.Shape.init(.{ .layer = model.model.layers.len, .k = options.max_seq_len, .h = config.num_key_value_heads, .hd = config.head_dim, }, dtype).withSharding(.{.h}); const kv_cache_shape: zml.ShapeOf(GptOss.KvCache) = GptOss.KvCache.initShape(kv_shape); const rng_shape = zml.Tensor.Rng.shape(); var start = try std.time.Timer.start(); var fut_mod_prefill = try async.async(zml.compileModel, .{ allocator, GptOss.forward, model, .{ tokens_shape_prefill, zml.ShapeOf(GptOss.Mode){ .prefill = .scalar(.u32) }, kv_cache_shape, rng_shape, }, platform, }); var fut_mod = try async.async(zml.compileModel, .{ allocator, GptOss.forward, model, .{ tokens_shape, zml.ShapeOf(GptOss.Mode){ .gen = .scalar(.u32) }, kv_cache_shape, rng_shape, }, platform, }); log.info("\tLoading GptOss weights from {s}...", .{model_weights_path}); var gpt_oss_weights = try model.loadBuffers(compiler_arena.allocator(), store, platform); defer zml.aio.unloadBuffers(&gpt_oss_weights); log.info("✅\tLoaded weights in {D}", .{start.read()}); var module_prefill = (try fut_mod_prefill.await()).prepare(gpt_oss_weights); defer module_prefill.deinit(); var module_gen = (try fut_mod.await()).prepare(gpt_oss_weights); defer module_gen.deinit(); log.info("✅\tCompiled model in {D}", .{start.read()}); log.info("Creating KvCache", .{}); const kv_cache = try GptOss.KvCache.initBuffer(kv_shape, platform); var tokenizer = blk: { const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" }); defer allocator.free(model_tokenizer_path); log.info("Loading tokenizer from {s}", .{model_tokenizer_path}); var timer = try stdx.time.Timer.start(); defer log.info("Loaded tokenizer from {s} [{f}]", .{ model_tokenizer_path, timer.read() }); break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path); }; errdefer tokenizer.deinit(); const prompt = cli.args.prompt orelse "What are some fun facts about animals?"; log.info("✅\tPrompt: {s}", .{prompt}); const no_chat = cli.args.nochat orelse false; const prompt_tok_buf = try allocator.alloc(u32, options.max_prompt_len); defer allocator.free(prompt_tok_buf); const prompt_tok = tokenizePrompt(tokenizer, prompt, no_chat, prompt_tok_buf) catch |err| switch (err) { error.PromptTooLong => std.debug.panic("Prompt too long, expected at most {d} tokens. Consider increasing --max-prompt-len", .{prompt_tok_buf.len}), else => |e| return e, }; log.info("\t Tokenized prompt: {any} ({d} tokens)", .{ prompt_tok, prompt_tok.len }); const seed = cli.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp())); // Unbuffered writing of the tokens to stdout. // generated text will be printed token by token. var output = std.fs.File.stdout().writer(&.{}); try generateText(config, options, module_prefill, module_gen, kv_cache, tokenizer, allocator, seed, prompt_tok, &output.interface); } const ClapBoilerplate = struct { pub const Cli = clap.Result(clap.Help, &cli_params, parsers); fn bool_parser(in: []const u8) error{}!bool { return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; } const parsers = .{ .BOOL = bool_parser, .UINT = clap.parsers.int(u32, 0), .FLOAT = clap.parsers.float(f32), .STRING = clap.parsers.string, .PATH = clap.parsers.string, }; pub fn parseCli(allocator: std.mem.Allocator) Cli { var diag: clap.Diagnostic = .{}; var stderr_buffer: [1024]u8 = undefined; var stderr = std.fs.File.stderr().writer(&stderr_buffer); const cli = clap.parse(clap.Help, &cli_params, parsers, .{ .diagnostic = &diag, .allocator = allocator, }) catch |err| { diag.report(&stderr.interface, err) catch {}; stderr.interface.print("usage: ", .{}) catch {}; clap.usage(&stderr.interface, clap.Help, &cli_params) catch {}; stderr.interface.print("\n", .{}) catch {}; stderr.interface.flush() catch {}; std.process.exit(1); }; if (cli.args.help != 0) { clap.help(&stderr.interface, clap.Help, &cli_params, .{}) catch {}; stderr.interface.flush() catch {}; std.process.exit(0); } return cli; } };