const std = @import("std"); const asynk = @import("async"); const clap = @import("clap"); const stdx = @import("stdx"); const zml = @import("zml"); const Tensor = zml.Tensor; const modernbert = @import("modernbert.zig"); const log = std.log.scoped(.modernbert); pub const std_options: std.Options = .{ .log_level = .info, .log_scope_levels = &[_]std.log.ScopeLevel{ .{ .scope = .modernbert, .level = .info }, }, .logFn = asynk.logFn(std.log.defaultLog), }; const params = clap.parseParamsComptime( \\--help print this help \\--text the prompt \\--model model path \\--tokenizer tokenizer path \\--seq-len sequence length \\--num-attention-heads number of attention heads \\--tie-word-embeddings default: false: tied weights \\--create-options platform creation options JSON, defaults to {} \\--sharding default: true: sharding on or off ); const clap_parsers = .{ .BOOL = bool_parser, .UINT = clap.parsers.int(usize, 0), .STRING = clap.parsers.string, .PATH = clap.parsers.string, }; pub fn main() !void { try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); } pub fn asyncMain() !void { const allocator = std.heap.c_allocator; const stderr = std.fs.File.stderr(); var diag: clap.Diagnostic = .{}; var cli = clap.parse(clap.Help, ¶ms, clap_parsers, .{ .diagnostic = &diag, .allocator = allocator, }) catch |err| { try diag.reportToFile(stderr, err); try printUsageAndExit(stderr); }; defer cli.deinit(); if (cli.args.help != 0) { try clap.helpToFile(stderr, clap.Help, ¶ms, .{}); return; } const tmp = try std.fs.openDirAbsolute("/tmp", .{}); try tmp.makePath("zml/modernbert/cache"); // Create ZML context var context = try zml.Context.init(); defer context.deinit(); // Platform and compilation options const create_opts_json = cli.args.@"create-options" orelse "{}"; const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, allocator, create_opts_json, .{}); const compilation_options = zml.CompilationOptions{ .xla_dump_to = "/tmp/zml/modernbert", .sharding_enabled = cli.args.sharding orelse true, }; // Auto-select platform const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options); context.printAvailablePlatforms(platform); // Detects the format of the model file (base on filename) and open it. const model_file = cli.args.model orelse { var buf: [256]u8 = undefined; var writer = stderr.writer(&buf); writer.interface.print("Error: missing --model=...\n\n", .{}) catch {}; printUsageAndExit(stderr); unreachable; }; var tensor_store = try zml.aio.detectFormatAndOpen(allocator, model_file); defer tensor_store.deinit(); // Memory arena dedicated to model shapes and weights var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const model_arena = arena_state.allocator(); var tokenizer = blk: { if (cli.args.tokenizer) |tok| { log.info("\tLoading tokenizer from {s}", .{tok}); var timer = try stdx.time.Timer.start(); defer log.info("✅\tLoaded tokenizer from {s} [{D}]", .{ tok, timer.read() }); break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok); } else { log.err("Error: missing --tokenizer", .{}); return; } }; defer tokenizer.deinit(); // Create the model struct, with tensor shapes extracted from the tensor_store // TODO: read from config.json const modernbert_options = modernbert.ModernBertOptions{ .pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken, .num_attention_heads = @intCast(cli.args.@"num-attention-heads" orelse 12), .tie_word_embeddings = cli.args.@"tie-word-embeddings" orelse false, .local_attention = 128, }; var modern_bert_for_masked_lm = try zml.aio.populateModel(modernbert.ModernBertForMaskedLM, model_arena, tensor_store); modern_bert_for_masked_lm.init(modernbert_options); log.info("\tModernBERT options: {}", .{modernbert_options}); // Prepare shapes for compilation const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256)); const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32); var start = try stdx.time.Timer.start(); // Load weights log.info("\tLoading ModernBERT weights from {s}...", .{model_file}); var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform); defer zml.aio.unloadBuffers(&bert_weights); log.info("✅\tLoaded weights in {D}", .{start.read()}); // Compile the model log.info("\tCompiling ModernBERT model...", .{}); var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, modernbert.ModernBertForMaskedLM.forward, .{modernbert_options}, .{input_shape}, tensor_store, platform, }); var bert_module = (try fut_mod.awaitt()).prepare(bert_weights); defer bert_module.deinit(); log.info("✅\tLoaded weights and compiled model in {D}", .{start.read()}); const text = cli.args.text orelse "Paris is the [MASK] of France."; log.info("\tInput text: {s}", .{text}); try unmask(allocator, bert_module, tokenizer, seq_len, text); } /// fill-mask pipeline /// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/fill_mask.py pub fn unmask( allocator: std.mem.Allocator, mod: zml.ModuleExe(modernbert.ModernBertForMaskedLM.forward), tokenizer: zml.tokenizer.Tokenizer, seq_len: i64, text: []const u8, ) !void { var tokenizer_decoder = try tokenizer.decoder(); defer tokenizer_decoder.deinit(); const pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken; const mask_token = tokenizer.tokenToId("[MASK]") orelse return error.NoSuchToken; // Tokenize input text const tokens: []const u32 = try tokenize(allocator, tokenizer, text); defer allocator.free(tokens); // Find "[MASK]" positions const mask_positions = try findMaskPositions(allocator, tokens, mask_token); defer allocator.free(mask_positions); // Prepare input tensors const inputs = try prepareTensorInputs(allocator, tokens, seq_len, pad_token); defer allocator.free(inputs); // Create input tensors (on the accelerator) const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .i64); const input_ids_tensor = try zml.Buffer.fromSlice(mod.platform(), input_shape.dims(), inputs); defer input_ids_tensor.deinit(); // Model inference (retrieve indices) var inference_timer = try std.time.Timer.start(); var topk = mod.call(.{input_ids_tensor}); defer zml.aio.unloadBuffers(&topk); const inference_time = inference_timer.read(); // Transfer the result to host memory (CPU) var indices_host_buffer = try topk.indices.toHostAlloc(allocator); defer indices_host_buffer.deinit(allocator); var values_host_buffer = try topk.values.toHostAlloc(allocator); defer values_host_buffer.deinit(allocator); // We consider only the first occurrence of [MASK], which has five predictions const pred_offset = mask_positions[0] * 5; const predictions = indices_host_buffer.items(i32)[pred_offset..][0..5]; const scores = values_host_buffer.items(f32)[pred_offset..][0..5]; // Log timing information log.info("⏱️\tModel inference in {d}ms", .{inference_time / std.time.ns_per_ms}); log.info("✅\tTop 5 predictions:", .{}); for (predictions, scores) |token_id, score| { const token_text = try tokenizer_decoder.next(@intCast(token_id)); if (token_text) |word| { log.info("\t • score: {d:.4} word: '{s}' token: {}", .{ score, word, token_id }); } } } pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 { var tokens = std.array_list.Managed(u32).init(allocator); var encoder = try tokenizer.encoder(); defer encoder.deinit(); const bos = tokenizer.tokenToId("[CLS]") orelse return error.NoSuchToken; const eos = tokenizer.tokenToId("[SEP]") orelse return error.NoSuchToken; try tokens.append(bos); try tokens.appendSlice(try encoder.encode(prompt)); try tokens.append(eos); return tokens.toOwnedSlice(); } fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize { var mask_positions = std.array_list.Managed(usize).init(allocator); defer mask_positions.deinit(); for (tokens, 0..) |token, i| { if (token == mask_token) { try mask_positions.append(i); } } if (mask_positions.items.len == 0) { log.err("Input text must contains `[MASK]`", .{}); return error.InvalidInput; } if (mask_positions.items.len > 1) log.warn("Currently only supporting one [MASK] per input", .{}); return mask_positions.toOwnedSlice(); } fn prepareTensorInputs( allocator: std.mem.Allocator, tokens: []const u32, seq_len: i64, pad_token: u32, ) ![]u32 { const input_ids = try allocator.alloc(u32, @intCast(seq_len)); @memset(input_ids, pad_token); for (tokens, 0..) |token, i| { input_ids[i] = @intCast(token); } return input_ids; } fn bool_parser(in: []const u8) error{}!bool { return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; } fn printUsageAndExit(stderr: std.fs.File) noreturn { clap.usageToFile(stderr, clap.Help, ¶ms) catch {}; std.process.exit(0); }