const clap = @import("clap"); const std = @import("std"); const zml = @import("zml"); const asynk = @import("async"); const log = std.log; const Tensor = zml.Tensor; const modernbert_module = @import("modernbert.zig"); const ModernBertOptions = modernbert_module.ModernBertOptions; const params = clap.parseParamsComptime( \\--help print this help \\--model model weights path \\--activations model activations path ); fn printUsageAndExit(stderr: anytype) noreturn { stderr.print("usage: ", .{}) catch {}; clap.usage(stderr, clap.Help, ¶ms) catch {}; stderr.print("\n", .{}) catch {}; std.process.exit(0); } pub fn main() !void { try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); } pub fn asyncMain() !void { // Short lived allocations var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); const stderr = std.io.getStdErr().writer(); // Read CLI arguments const parsers = comptime .{ .PATH = clap.parsers.string, }; var diag: clap.Diagnostic = .{}; var res = clap.parse(clap.Help, ¶ms, parsers, .{ .diagnostic = &diag, .allocator = allocator, }) catch |err| { try diag.report(stderr, err); try printUsageAndExit(stderr); }; defer res.deinit(); if (res.args.help != 0) { try clap.help(stderr, clap.Help, ¶ms, .{}); return; } const model_file = res.args.model orelse { stderr.print("Error: missing --model=...\n\n", .{}) catch {}; printUsageAndExit(stderr); unreachable; }; const activations_file = res.args.activations orelse { stderr.print("Error: missing --activations=...\n\n", .{}) catch {}; printUsageAndExit(stderr); unreachable; }; // Initialize the ZML context var context = try zml.Context.init(); defer context.deinit(); // Auto-select platform const compute_platform = context.autoPlatform(.{}); log.info("Selected platform: {s}", .{@tagName(compute_platform.target)}); // Create a dedicated memory arena for model-related allocations (dedicated to model shapes and weights) var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const model_arena = arena_state.allocator(); // Load the model weights file and parse its structure (shape) var weights_file = try zml.aio.detectFormatAndOpen(allocator, model_file); defer weights_file.deinit(); log.info("Model contains {d} layers. Loaded from: {s}", .{ weights_file.buffers.count(), model_file }); // Load the activation data file const activations = try zml.aio.torch.open(model_arena, activations_file); defer activations.deinit(); log.info("Found {} activations in {s}", .{ activations.buffers.count(), activations_file }); // Initialize model var model = try zml.aio.populateModel( modernbert_module.ModernBertForMaskedLM, model_arena, weights_file, ); const modernbert_base_options: modernbert_module.ModernBertOptions = .{ .num_attention_heads = 12, .tie_word_embeddings = true, .pad_token = 50283, .local_attention = 128, }; model.init(modernbert_base_options); // Load model weights const model_weights = try zml.aio.loadModelBuffers( modernbert_module.ModernBertForMaskedLM, model, weights_file, model_arena, compute_platform, ); // Test implementation try testImplementation(compute_platform, model, model_weights, activations); } fn testImplementation( compute_platform: zml.Platform, model: modernbert_module.ModernBertForMaskedLM, model_weights: zml.Bufferized(modernbert_module.ModernBertForMaskedLM), activations: zml.aio.BufferStore, ) !void { try zml.testing.testLayer( compute_platform, activations, "model.model.embeddings.tok_embeddings", model.model.embeddings.tok_embeddings, model_weights.model.embeddings.tok_embeddings, 1e-6, ); try zml.testing.testLayer( compute_platform, activations, "model.model.embeddings.norm", model.model.embeddings.norm, model_weights.model.embeddings.norm, 1e-3, ); try zml.testing.testLayer( compute_platform, activations, "model.model.embeddings", model.model.embeddings, model_weights.model.embeddings, 1e-3, ); try zml.testing.testLayer( compute_platform, activations, "model.model.final_norm", model.model.final_norm, model_weights.model.final_norm, 1e-5, ); try zml.testing.testLayer( compute_platform, activations, "model.model.layers.2.mlp", model.model.layers[2].mlp, model_weights.model.layers[2].mlp, 2e-3, ); try zml.testing.testLayer( compute_platform, activations, "model.model.layers.2.mlp_norm", model.model.layers[2].mlp_norm, model_weights.model.layers[2].mlp_norm, 1e-4, ); try zml.testing.testLayer( compute_platform, activations, "model.model.layers.2.attn", model.model.layers[2].attn, model_weights.model.layers[2].attn, 1e-6, ); try zml.testing.testLayer( compute_platform, activations, "model.model.layers.2", model.model.layers[2], model_weights.model.layers[2], 2e-3, ); try zml.testing.testLayer( compute_platform, activations, "model.model.layers.3.attn", model.model.layers[3].attn, model_weights.model.layers[3].attn, 1e-5, ); try zml.testing.testLayer( compute_platform, activations, "model.model", model.model, model_weights.model, 1e-2, ); const TiedDecoder = struct { weight: Tensor, bias: Tensor, pub fn forward(self: @This(), head_outputs: Tensor) Tensor { const results = head_outputs.withTags(.{ .b, .s, .d }).dot(self.weight.withTags(.{ .voc, .d }), .{.d}); return results.add(self.bias.withTags(.{.voc}).broad(results.shape())); } }; try zml.testing.testLayer( compute_platform, activations, "model.decoder", TiedDecoder{ .weight = model.decoder.weight orelse model.model.embeddings.tok_embeddings.weight, .bias = model.decoder.bias }, .{ .weight = model_weights.model.embeddings.tok_embeddings.weight, .bias = model_weights.decoder.bias }, 1e-3, ); try zml.testing.testLayer( compute_platform, activations, "model.head", model.head, model_weights.head, 0.1, // TODO: too high tolerance ); }