377 lines
15 KiB
Zig
377 lines
15 KiB
Zig
|
|
const std = @import("std");
|
||
|
|
const builtin = @import("builtin");
|
||
|
|
|
||
|
|
const async = @import("async");
|
||
|
|
const clap = @import("clap");
|
||
|
|
const stdx = @import("stdx");
|
||
|
|
const zml = @import("zml");
|
||
|
|
const Buffer = zml.Buffer;
|
||
|
|
const Tensor = zml.Tensor;
|
||
|
|
const ShapeOf = zml.ShapeOf;
|
||
|
|
|
||
|
|
const GptOss = @import("GptOss.zig");
|
||
|
|
|
||
|
|
const log = std.log.scoped(.GptOss);
|
||
|
|
|
||
|
|
pub const std_options: std.Options = .{
|
||
|
|
.log_level = .info,
|
||
|
|
.logFn = async.logFn(std.log.defaultLog),
|
||
|
|
};
|
||
|
|
|
||
|
|
const cli_params = clap.parseParamsComptime(
|
||
|
|
\\--help print this help
|
||
|
|
\\--prompt <STRING> the prompt
|
||
|
|
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
|
||
|
|
\\--seed <UINT> random seed (optional)
|
||
|
|
\\--seq-len <UINT> max sequence length
|
||
|
|
\\--prompt-len <UINT> max prompt length
|
||
|
|
\\--temperature <FLOAT> temperature (default 1.0)
|
||
|
|
\\--topk <UINT> topk (default 10)
|
||
|
|
\\--expert-budget <FLOAT> token budget per expert
|
||
|
|
\\--platform-options <STRING> platform options, using Zon syntax, eg '.{.cuda=.{.allocator=.{.async=.{.memory_fraction=0.95}}}}'
|
||
|
|
\\--nochat <BOOL> skip prompt template
|
||
|
|
\\--sharding <BOOL> default: true: sharding on or off
|
||
|
|
);
|
||
|
|
|
||
|
|
pub fn tokenizePrompt(tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, no_chat: bool, out: []u32) ![]u32 {
|
||
|
|
var encoder = try tokenizer.encoder();
|
||
|
|
defer encoder.deinit();
|
||
|
|
|
||
|
|
if (no_chat) {
|
||
|
|
const tokens = try encoder.encode(prompt);
|
||
|
|
if (tokens.len > out.len) return error.PromptTooLong;
|
||
|
|
@memcpy(out[0..tokens.len], tokens);
|
||
|
|
return out[0..tokens.len];
|
||
|
|
}
|
||
|
|
|
||
|
|
const start_header = tokenizer.tokenToId("<|start|>") orelse return error.NoSuchToken;
|
||
|
|
const end_header_start_message = tokenizer.tokenToId("<|message|>") orelse return error.NoSuchToken;
|
||
|
|
const end_message = tokenizer.tokenToId("<|end|>") orelse return error.NoSuchToken;
|
||
|
|
|
||
|
|
var tokens: std.ArrayList(u32) = .initBuffer(out);
|
||
|
|
|
||
|
|
const system_prompt = try encoder.encode("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||
|
|
if (system_prompt.len + 4 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
|
||
|
|
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("system").?, end_header_start_message });
|
||
|
|
tokens.appendSliceAssumeCapacity(system_prompt);
|
||
|
|
tokens.appendAssumeCapacity(end_message);
|
||
|
|
|
||
|
|
const user_prompt = try encoder.encode(prompt);
|
||
|
|
if (user_prompt.len + 9 > tokens.unusedCapacitySlice().len) return error.PromptTooLong;
|
||
|
|
tokens.appendSliceAssumeCapacity(&.{ start_header, tokenizer.tokenToId("user").?, end_header_start_message });
|
||
|
|
tokens.appendSliceAssumeCapacity(user_prompt);
|
||
|
|
tokens.appendSliceAssumeCapacity(&.{
|
||
|
|
end_message,
|
||
|
|
start_header,
|
||
|
|
tokenizer.tokenToId("assistant").?,
|
||
|
|
tokenizer.tokenToId("<|channel|>") orelse return error.NoSuchToken,
|
||
|
|
tokenizer.tokenToId("analysis") orelse return error.NoSuchToken,
|
||
|
|
end_header_start_message,
|
||
|
|
});
|
||
|
|
|
||
|
|
return tokens.items;
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn generateText(
|
||
|
|
config: GptOss.Config,
|
||
|
|
options: GptOss.Options,
|
||
|
|
mod_prefill: zml.ModuleExe(GptOss.forward),
|
||
|
|
mod_generate: zml.ModuleExe(GptOss.forward),
|
||
|
|
kv_cache_: zml.Bufferized(GptOss.KvCache),
|
||
|
|
tokenizer: zml.tokenizer.Tokenizer,
|
||
|
|
allocator: std.mem.Allocator,
|
||
|
|
seed: u128,
|
||
|
|
prompt_tok: []const u32,
|
||
|
|
output: *std.Io.Writer,
|
||
|
|
) !void {
|
||
|
|
var tokenizer_decoder = try tokenizer.decoder();
|
||
|
|
defer tokenizer_decoder.deinit();
|
||
|
|
|
||
|
|
const platform = mod_generate.platform();
|
||
|
|
|
||
|
|
// init RNG and buffers
|
||
|
|
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||
|
|
var generated_token_buffer = [_]u32{undefined};
|
||
|
|
|
||
|
|
var current_token, var kv_cache = prefill: {
|
||
|
|
// prepare device buffers for the prefill tokens and their positions
|
||
|
|
const prefill_buffer = try allocator.alloc(u32, options.max_prompt_len);
|
||
|
|
@memcpy(prefill_buffer[0..prompt_tok.len], prompt_tok);
|
||
|
|
|
||
|
|
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{options.max_prompt_len}, prefill_buffer);
|
||
|
|
defer prefill_tokens.deinit();
|
||
|
|
var prefill_token_pos = try zml.Buffer.scalar(platform, prompt_tok.len, .u32);
|
||
|
|
defer prefill_token_pos.deinit();
|
||
|
|
|
||
|
|
const first_token, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, .{ .prefill = prefill_token_pos }, kv_cache_, rng });
|
||
|
|
|
||
|
|
// extract the first generated token
|
||
|
|
_ = try first_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||
|
|
log.warn("first_token: {d}", .{generated_token_buffer[0]});
|
||
|
|
break :prefill .{ first_token, kv_cache };
|
||
|
|
};
|
||
|
|
defer zml.aio.unloadBuffers(&kv_cache);
|
||
|
|
defer current_token.deinit();
|
||
|
|
|
||
|
|
const output_tokens_len = options.max_seq_len - prompt_tok.len - 1;
|
||
|
|
const start = std.time.microTimestamp();
|
||
|
|
|
||
|
|
// One token has already been generated by the prefill.
|
||
|
|
var num_tokens_generated: usize = 1;
|
||
|
|
|
||
|
|
generation: for (0..output_tokens_len + 1) |i| {
|
||
|
|
// collect and print generated sequence
|
||
|
|
num_tokens_generated += 1;
|
||
|
|
const generated_token = generated_token_buffer[0];
|
||
|
|
if (try tokenizer_decoder.next(generated_token)) |chunk| {
|
||
|
|
try output.writeAll(chunk);
|
||
|
|
}
|
||
|
|
|
||
|
|
// check for eos
|
||
|
|
if (i == output_tokens_len) break :generation;
|
||
|
|
switch (config.eos_token_id.value) {
|
||
|
|
.int => |eos| if (generated_token == @as(u32, @intCast(eos))) break :generation,
|
||
|
|
.ints => |eos_list| {
|
||
|
|
for (eos_list) |eos| {
|
||
|
|
if (generated_token == @as(u32, @intCast(eos))) break :generation;
|
||
|
|
}
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
// current token pos needs to go into a zml.Buffer
|
||
|
|
const token_pos_buffer = &[_]u32{@intCast(prompt_tok.len + i)};
|
||
|
|
const token_pos = try zml.Buffer.fromSlice(platform, .{}, token_pos_buffer);
|
||
|
|
defer token_pos.deinit();
|
||
|
|
|
||
|
|
// call to generate the next token
|
||
|
|
current_token, kv_cache, rng = mod_generate.call(.{ current_token, .{ .gen = token_pos }, kv_cache, rng });
|
||
|
|
|
||
|
|
// extract the generated token from the buffer
|
||
|
|
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||
|
|
}
|
||
|
|
const end = std.time.microTimestamp();
|
||
|
|
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||
|
|
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||
|
|
|
||
|
|
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn main() !void {
|
||
|
|
try async.AsyncThread.main(std.heap.smp_allocator, asyncMain);
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn asyncMain() !void {
|
||
|
|
log.info(" GptOss was compiled with {}", .{@import("builtin").mode});
|
||
|
|
|
||
|
|
var allocator: std.mem.Allocator = alloc: {
|
||
|
|
if (builtin.mode == .Debug) {
|
||
|
|
var dbg_alloc: std.heap.DebugAllocator(.{
|
||
|
|
.never_unmap = true,
|
||
|
|
.retain_metadata = true,
|
||
|
|
}) = .init;
|
||
|
|
break :alloc dbg_alloc.allocator();
|
||
|
|
}
|
||
|
|
break :alloc std.heap.smp_allocator;
|
||
|
|
};
|
||
|
|
|
||
|
|
const cli = ClapBoilerplate.parseCli(allocator);
|
||
|
|
defer cli.deinit();
|
||
|
|
|
||
|
|
const hf_model_path = cli.args.@"hf-model-path" orelse {
|
||
|
|
log.err("Missing --hf-model-path", .{});
|
||
|
|
return;
|
||
|
|
};
|
||
|
|
|
||
|
|
const config = config: {
|
||
|
|
var arena: std.heap.ArenaAllocator = .init(allocator);
|
||
|
|
defer arena.deinit();
|
||
|
|
|
||
|
|
const model_config_path = try std.fs.path.join(arena.allocator(), &.{ hf_model_path, "config.json" });
|
||
|
|
|
||
|
|
var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only });
|
||
|
|
defer config_json_file.close() catch unreachable;
|
||
|
|
|
||
|
|
var config_reader = config_json_file.reader(try arena.allocator().alloc(u8, 256));
|
||
|
|
var reader = std.json.Reader.init(allocator, &config_reader.interface);
|
||
|
|
defer reader.deinit();
|
||
|
|
var config = try std.json.parseFromTokenSourceLeaky(GptOss.Config, arena.allocator(), &reader, .{ .ignore_unknown_fields = true });
|
||
|
|
|
||
|
|
// From generation_config.json
|
||
|
|
config.eos_token_id = .{ .value = .{ .ints = &.{ 200002, 199999, 200012 } } };
|
||
|
|
break :config config;
|
||
|
|
};
|
||
|
|
|
||
|
|
var context = try zml.Context.init();
|
||
|
|
defer context.deinit();
|
||
|
|
|
||
|
|
// initialize ZML platform
|
||
|
|
const platform: zml.Platform = platform: {
|
||
|
|
const arena: std.heap.ArenaAllocator = .init(allocator);
|
||
|
|
defer arena.deinit();
|
||
|
|
|
||
|
|
// eg: --platform-options='.{.cuda=.{.allocator=.{.bfc=.{.memory_fraction=0.99}}}}'
|
||
|
|
// eg: --platform-options='.{.cpu=.{.device_count=8}}'
|
||
|
|
const platform_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(cli.args.@"platform-options" orelse ".{}"), null, .{ .free_on_error = false }) catch |err| {
|
||
|
|
log.err("Failed to parse --platform-options as json ({}): {s}", .{ err, cli.args.@"platform-options".? });
|
||
|
|
return err;
|
||
|
|
};
|
||
|
|
|
||
|
|
const compilation_options = zml.CompilationOptions{
|
||
|
|
.xla_dump_to = "/tmp/zml/gpt_oss",
|
||
|
|
.sharding_enabled = cli.args.sharding orelse true,
|
||
|
|
};
|
||
|
|
|
||
|
|
const platform = context
|
||
|
|
.autoPlatform(platform_opts)
|
||
|
|
.withCompilationOptions(compilation_options);
|
||
|
|
context.printAvailablePlatforms(platform);
|
||
|
|
|
||
|
|
break :platform platform;
|
||
|
|
};
|
||
|
|
|
||
|
|
const options: GptOss.Options = .{
|
||
|
|
.max_seq_len = cli.args.@"seq-len" orelse 8192,
|
||
|
|
.max_prompt_len = cli.args.@"prompt-len" orelse 256,
|
||
|
|
.tokens_per_expert_ratio = cli.args.@"expert-budget" orelse 4.0,
|
||
|
|
.sampling_strategy = .{
|
||
|
|
.topk = cli.args.topk orelse 10,
|
||
|
|
.temperature = 1.0,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
|
||
|
|
defer compiler_arena.deinit();
|
||
|
|
|
||
|
|
const model_weights_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" });
|
||
|
|
defer allocator.free(model_weights_path);
|
||
|
|
|
||
|
|
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
|
||
|
|
defer store.deinit();
|
||
|
|
|
||
|
|
const model: GptOss = try GptOss.init(compiler_arena.allocator(), store, config, options);
|
||
|
|
|
||
|
|
const tokens_shape_prefill = zml.Shape.init(.{ .s = options.max_prompt_len }, .u32);
|
||
|
|
const tokens_shape = zml.Shape.init(.{ .s = 1 }, .u32);
|
||
|
|
|
||
|
|
const dtype = model.model.embed_tokens.weight.dtype();
|
||
|
|
|
||
|
|
const kv_shape = zml.Shape.init(.{
|
||
|
|
.layer = model.model.layers.len,
|
||
|
|
.k = options.max_seq_len,
|
||
|
|
.h = config.num_key_value_heads,
|
||
|
|
.hd = config.head_dim,
|
||
|
|
}, dtype).withSharding(.{.h});
|
||
|
|
|
||
|
|
const kv_cache_shape: zml.ShapeOf(GptOss.KvCache) = GptOss.KvCache.initShape(kv_shape);
|
||
|
|
const rng_shape = zml.Tensor.Rng.shape();
|
||
|
|
|
||
|
|
var start = try std.time.Timer.start();
|
||
|
|
var fut_mod_prefill = try async.async(zml.compileModel, .{
|
||
|
|
allocator, GptOss.forward, model,
|
||
|
|
.{
|
||
|
|
tokens_shape_prefill,
|
||
|
|
zml.ShapeOf(GptOss.Mode){ .prefill = .scalar(.u32) },
|
||
|
|
kv_cache_shape,
|
||
|
|
rng_shape,
|
||
|
|
},
|
||
|
|
platform,
|
||
|
|
});
|
||
|
|
|
||
|
|
var fut_mod = try async.async(zml.compileModel, .{
|
||
|
|
allocator, GptOss.forward, model,
|
||
|
|
.{
|
||
|
|
tokens_shape,
|
||
|
|
zml.ShapeOf(GptOss.Mode){ .gen = .scalar(.u32) },
|
||
|
|
kv_cache_shape,
|
||
|
|
rng_shape,
|
||
|
|
},
|
||
|
|
platform,
|
||
|
|
});
|
||
|
|
|
||
|
|
log.info("\tLoading GptOss weights from {s}...", .{model_weights_path});
|
||
|
|
var gpt_oss_weights = try model.loadBuffers(compiler_arena.allocator(), store, platform);
|
||
|
|
defer zml.aio.unloadBuffers(&gpt_oss_weights);
|
||
|
|
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||
|
|
|
||
|
|
var module_prefill = (try fut_mod_prefill.await()).prepare(gpt_oss_weights);
|
||
|
|
defer module_prefill.deinit();
|
||
|
|
var module_gen = (try fut_mod.await()).prepare(gpt_oss_weights);
|
||
|
|
defer module_gen.deinit();
|
||
|
|
log.info("✅\tCompiled model in {D}", .{start.read()});
|
||
|
|
|
||
|
|
log.info("Creating KvCache", .{});
|
||
|
|
const kv_cache = try GptOss.KvCache.initBuffer(kv_shape, platform);
|
||
|
|
|
||
|
|
var tokenizer = blk: {
|
||
|
|
const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" });
|
||
|
|
defer allocator.free(model_tokenizer_path);
|
||
|
|
|
||
|
|
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
|
||
|
|
var timer = try stdx.time.Timer.start();
|
||
|
|
defer log.info("Loaded tokenizer from {s} [{f}]", .{ model_tokenizer_path, timer.read() });
|
||
|
|
|
||
|
|
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
|
||
|
|
};
|
||
|
|
errdefer tokenizer.deinit();
|
||
|
|
|
||
|
|
const prompt = cli.args.prompt orelse "What are some fun facts about animals?";
|
||
|
|
log.info("✅\tPrompt: {s}", .{prompt});
|
||
|
|
|
||
|
|
const no_chat = cli.args.nochat orelse false;
|
||
|
|
const prompt_tok_buf = try allocator.alloc(u32, options.max_prompt_len);
|
||
|
|
defer allocator.free(prompt_tok_buf);
|
||
|
|
|
||
|
|
const prompt_tok = tokenizePrompt(tokenizer, prompt, no_chat, prompt_tok_buf) catch |err| switch (err) {
|
||
|
|
error.PromptTooLong => std.debug.panic("Prompt too long, expected at most {d} tokens. Consider increasing --max-prompt-len", .{prompt_tok_buf.len}),
|
||
|
|
else => |e| return e,
|
||
|
|
};
|
||
|
|
log.info("\t Tokenized prompt: {any} ({d} tokens)", .{ prompt_tok, prompt_tok.len });
|
||
|
|
|
||
|
|
const seed = cli.args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||
|
|
|
||
|
|
// Unbuffered writing of the tokens to stdout.
|
||
|
|
// generated text will be printed token by token.
|
||
|
|
var output = std.fs.File.stdout().writer(&.{});
|
||
|
|
|
||
|
|
try generateText(config, options, module_prefill, module_gen, kv_cache, tokenizer, allocator, seed, prompt_tok, &output.interface);
|
||
|
|
}
|
||
|
|
|
||
|
|
const ClapBoilerplate = struct {
|
||
|
|
pub const Cli = clap.Result(clap.Help, &cli_params, parsers);
|
||
|
|
|
||
|
|
fn bool_parser(in: []const u8) error{}!bool {
|
||
|
|
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||
|
|
}
|
||
|
|
|
||
|
|
const parsers = .{
|
||
|
|
.BOOL = bool_parser,
|
||
|
|
.UINT = clap.parsers.int(u32, 0),
|
||
|
|
.FLOAT = clap.parsers.float(f32),
|
||
|
|
.STRING = clap.parsers.string,
|
||
|
|
.PATH = clap.parsers.string,
|
||
|
|
};
|
||
|
|
|
||
|
|
pub fn parseCli(allocator: std.mem.Allocator) Cli {
|
||
|
|
var diag: clap.Diagnostic = .{};
|
||
|
|
var stderr_buffer: [1024]u8 = undefined;
|
||
|
|
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
||
|
|
const cli = clap.parse(clap.Help, &cli_params, parsers, .{
|
||
|
|
.diagnostic = &diag,
|
||
|
|
.allocator = allocator,
|
||
|
|
}) catch |err| {
|
||
|
|
diag.report(&stderr.interface, err) catch {};
|
||
|
|
stderr.interface.print("usage: ", .{}) catch {};
|
||
|
|
clap.usage(&stderr.interface, clap.Help, &cli_params) catch {};
|
||
|
|
stderr.interface.print("\n", .{}) catch {};
|
||
|
|
stderr.interface.flush() catch {};
|
||
|
|
std.process.exit(1);
|
||
|
|
};
|
||
|
|
if (cli.args.help != 0) {
|
||
|
|
clap.help(&stderr.interface, clap.Help, &cli_params, .{}) catch {};
|
||
|
|
stderr.interface.flush() catch {};
|
||
|
|
std.process.exit(0);
|
||
|
|
}
|
||
|
|
return cli;
|
||
|
|
}
|
||
|
|
};
|