From 17d02621e7a2c59f031af9670291227c4f716f32 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Tue, 11 Jun 2024 17:33:22 +0000 Subject: [PATCH] Add tests for the ModernBERT example, covering activation utilities, build setup, and example Zig code. --- examples/MODULE.bazel | 47 +++ examples/modernbert/BUILD.bazel | 63 ++++ .../modernbert/activations/activations.py | 55 ++++ .../modernbert/activations/requirements.in | 4 + examples/modernbert/main.zig | 275 ++++++++++++++++++ examples/modernbert/modernbert.zig | 268 +++++++++++++++++ examples/modernbert/test.zig | 238 +++++++++++++++ 7 files changed, 950 insertions(+) create mode 100644 examples/modernbert/BUILD.bazel create mode 100644 examples/modernbert/activations/activations.py create mode 100644 examples/modernbert/activations/requirements.in create mode 100644 examples/modernbert/main.zig create mode 100644 examples/modernbert/modernbert.zig create mode 100644 examples/modernbert/test.zig diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index 9a20ad9..0c4b4c8 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -139,6 +139,53 @@ http_file( url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin", ) +# ModernBERT +huggingface.model( + name = "ModernBERT-base", + build_file_content = """\ +package(default_visibility = ["//visibility:public"]) +filegroup( + name = "model", + srcs = ["model.safetensors"], +) + +filegroup( + name = "tokenizer", + srcs = ["tokenizer.json"], +) +""", + commit = "94032bb66234a691cf6248265170006a7ced4970", + includes = [ + "model.safetensors", + "tokenizer.json", + ], + model = "answerdotai/ModernBERT-base", +) +use_repo(huggingface, "ModernBERT-base") + +huggingface.model( + name = "ModernBERT-large", + build_file_content = """\ +package(default_visibility = ["//visibility:public"]) +filegroup( + name = "model", + srcs = ["model.safetensors"], +) + +filegroup( + name = "tokenizer", + srcs = ["tokenizer.json"], +) +""", + commit = "4bbcbf40bed02ce487125bcb3c897ea9bdc88340", + includes = [ + "model.safetensors", + "tokenizer.json", + ], + model = "answerdotai/ModernBERT-large", +) +use_repo(huggingface, "ModernBERT-large") + bazel_dep(name = "rules_rust", version = "0.57.1") rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") rust.toolchain( diff --git a/examples/modernbert/BUILD.bazel b/examples/modernbert/BUILD.bazel new file mode 100644 index 0000000..eee6e51 --- /dev/null +++ b/examples/modernbert/BUILD.bazel @@ -0,0 +1,63 @@ +load("@zml//bazel:zig.bzl", "zig_cc_binary") + +zig_cc_binary( + name = "modernbert", + srcs = ["modernbert.zig"], + main = "main.zig", + deps = [ + "@com_github_hejsil_clap//:clap", + "@zml//async", + "@zml//stdx", + "@zml//zml", + ], +) + +cc_binary( + name = "ModernBERT-base", + args = [ + "--model=$(location @ModernBERT-base//:model.safetensors)", + "--tokenizer=$(location @ModernBERT-base//:tokenizer)", + "--num-attention-heads=12", + "--tie-word-embeddings=true", + ], + data = [ + "@ModernBERT-base//:model.safetensors", + "@ModernBERT-base//:tokenizer", + ], + deps = [":modernbert_lib"], +) + +cc_binary( + name = "ModernBERT-large", + args = [ + "--model=$(location @ModernBERT-large//:model.safetensors)", + "--tokenizer=$(location @ModernBERT-large//:tokenizer)", + "--num-attention-heads=16", + "--tie-word-embeddings=true", + ], + data = [ + "@ModernBERT-large//:model.safetensors", + "@ModernBERT-large//:tokenizer", + ], + deps = [":modernbert_lib"], +) + +zig_cc_binary( + name = "test-implementation", + srcs = ["modernbert.zig"], + args = [ + "--model=$(location @ModernBERT-base//:model.safetensors)", + ], + data = [ + "@ModernBERT-base//:model.safetensors", + ], + main = "test.zig", + tags = [ + "no_ci", + ], + deps = [ + "@com_github_hejsil_clap//:clap", + "@zml//async", + "@zml//zml", + ], +) diff --git a/examples/modernbert/activations/activations.py b/examples/modernbert/activations/activations.py new file mode 100644 index 0000000..b82416d --- /dev/null +++ b/examples/modernbert/activations/activations.py @@ -0,0 +1,55 @@ +import logging +import torch +from transformers import pipeline +from tools.zml_utils import ActivationCollector + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" +) +log = logging.getLogger(__name__) + +MODEL_NAME: str = "answerdotai/ModernBERT-base" + + +def main() -> None: + try: + log.info("Start running main()") + + log.info(f"CPU capability : `{torch.backends.cpu.get_cpu_capability()}`") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + log.info(f"Loading model : `{MODEL_NAME}`") + + fill_mask_pipeline = pipeline( + "fill-mask", + model=MODEL_NAME, + device_map=device, + ) + model, tokenizer = fill_mask_pipeline.model, fill_mask_pipeline.tokenizer + log.info( + f"Model loaded successfully {model.config.architectures} - `{model.config.torch_dtype}` - {tokenizer.model_max_length} max tokens" # noqa: E501 + ) + + # Wrap the pipeline, and extract activations. + # Activations files can be huge for big models, + # so let's stop collecting after 1000 layers. + zml_pipeline = ActivationCollector( + fill_mask_pipeline, max_layers=1000, stop_after_first_step=True + ) + + input_text = "Paris is the [MASK] of France." + outputs, activations = zml_pipeline(input_text) + log.info(f"ouputs : {outputs}") + + filename = MODEL_NAME.split("/")[-1] + ".activations.pt" + torch.save(activations, filename) + log.info(f"Saved {len(activations)} activations to {filename}") + + log.info("End running main()") + except Exception as exception: + log.error(exception) + raise + + +if __name__ == "__main__": + main() diff --git a/examples/modernbert/activations/requirements.in b/examples/modernbert/activations/requirements.in new file mode 100644 index 0000000..415d610 --- /dev/null +++ b/examples/modernbert/activations/requirements.in @@ -0,0 +1,4 @@ +torch +transformers==4.48.1 +accelerate +numpy==1.26.4 \ No newline at end of file diff --git a/examples/modernbert/main.zig b/examples/modernbert/main.zig new file mode 100644 index 0000000..c0156d0 --- /dev/null +++ b/examples/modernbert/main.zig @@ -0,0 +1,275 @@ +const std = @import("std"); +const log = std.log.scoped(.modernbert); + +const modernbert = @import("modernbert.zig"); + +const asynk = @import("async"); +const clap = @import("clap"); +const stdx = @import("stdx"); +const zml = @import("zml"); +const Tensor = zml.Tensor; + +pub const std_options = .{ + .log_level = .info, + .log_scope_levels = &[_]std.log.ScopeLevel{ + .{ .scope = .modernbert, .level = .info }, + }, + .logFn = asynk.logFn(std.log.defaultLog), +}; + +const params = clap.parseParamsComptime( + \\--help print this help + \\--text the prompt + \\--model model path + \\--tokenizer tokenizer path + \\--seq-len sequence length + \\--num-attention-heads number of attention heads + \\--tie-word-embeddings default: false: tied weights + \\--create-options platform creation options JSON, defaults to {} + \\--sharding default: true: sharding on or off +); + +const clap_parsers = .{ + .BOOL = bool_parser, + .UINT = clap.parsers.int(usize, 0), + .STRING = clap.parsers.string, + .PATH = clap.parsers.string, +}; + +pub fn main() !void { + try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); +} + +pub fn asyncMain() !void { + const allocator = std.heap.c_allocator; + const stderr = std.io.getStdErr().writer(); + + var diag: clap.Diagnostic = .{}; + var cli = clap.parse(clap.Help, ¶ms, clap_parsers, .{ + .diagnostic = &diag, + .allocator = allocator, + }) catch |err| { + try diag.report(stderr, err); + try printUsageAndExit(stderr); + }; + defer cli.deinit(); + + if (cli.args.help != 0) { + try clap.help(stderr, clap.Help, ¶ms, .{}); + return; + } + + const tmp = try std.fs.openDirAbsolute("/tmp", .{}); + try tmp.makePath("zml/modernbert/cache"); + + // Create ZML context + var context = try zml.Context.init(); + defer context.deinit(); + + // Platform and compilation options + const create_opts_json = cli.args.@"create-options" orelse "{}"; + const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, allocator, create_opts_json, .{}); + const compilation_options = zml.CompilationOptions{ + .xla_dump_to = "/tmp/zml/modernbert", + .sharding_enabled = cli.args.sharding orelse true, + }; + + // Auto-select platform + const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options); + context.printAvailablePlatforms(platform); + + // Detects the format of the model file (base on filename) and open it. + const model_file = cli.args.model orelse { + stderr.print("Error: missing --model=...\n\n", .{}) catch {}; + printUsageAndExit(stderr); + unreachable; + }; + var tensor_store = try zml.aio.detectFormatAndOpen(allocator, model_file); + defer tensor_store.deinit(); + + // Memory arena dedicated to model shapes and weights + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const model_arena = arena_state.allocator(); + + var tokenizer = blk: { + if (cli.args.tokenizer) |tok| { + log.info("\tLoading tokenizer from {s}", .{tok}); + var timer = try stdx.time.Timer.start(); + defer log.info("✅\tLoaded tokenizer from {s} [{}]", .{ tok, timer.read() }); + + break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok); + } else { + log.err("Error: missing --tokenizer", .{}); + return; + } + }; + defer tokenizer.deinit(); + + // Create the model struct, with tensor shapes extracted from the tensor_store + // TODO: read from config.json + const modernbert_options = modernbert.ModernBertOptions{ + .pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken, + .num_attention_heads = @intCast(cli.args.@"num-attention-heads" orelse 12), + .tie_word_embeddings = cli.args.@"tie-word-embeddings" orelse false, + .local_attention = 128, + }; + var modern_bert_for_masked_lm = try zml.aio.populateModel(modernbert.ModernBertForMaskedLM, model_arena, tensor_store); + modern_bert_for_masked_lm.init(modernbert_options); + + log.info("\tModernBERT options: {}", .{modernbert_options}); + + // Prepare shapes for compilation + const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256)); + const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32); + + var start = try std.time.Timer.start(); + + // Load weights + log.info("\tLoading ModernBERT weights from {?s}...", .{model_file}); + var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform); + defer zml.aio.unloadBuffers(&bert_weights); + log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms}); + + // Compile the model + log.info("\tCompiling ModernBERT model...", .{}); + var fut_mod = try asynk.asyncc(zml.compile, .{ + allocator, + modernbert.ModernBertForMaskedLM.forward, + .{modernbert_options}, + .{input_shape}, + tensor_store, + platform, + }); + var bert_module = (try fut_mod.awaitt()).prepare(bert_weights); + defer bert_module.deinit(); + log.info("✅\tLoaded weights and compiled model in {d}ms", .{start.read() / std.time.ns_per_ms}); + + const text = cli.args.text orelse "Paris is the [MASK] of France."; + log.info("\tInput text: {s}", .{text}); + + try unmask(allocator, bert_module, tokenizer, seq_len, text); +} + +/// fill-mask pipeline +/// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/fill_mask.py +pub fn unmask( + allocator: std.mem.Allocator, + mod: zml.ModuleExe(modernbert.ModernBertForMaskedLM.forward), + tokenizer: zml.tokenizer.Tokenizer, + seq_len: i64, + text: []const u8, +) !void { + var tokenizer_decoder = try tokenizer.decoder(); + defer tokenizer_decoder.deinit(); + + const pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken; + const mask_token = tokenizer.tokenToId("[MASK]") orelse return error.NoSuchToken; + + // Tokenize input text + const tokens: []const u32 = try tokenize(allocator, tokenizer, text); + defer allocator.free(tokens); + + // Find "[MASK]" positions + const mask_positions = try findMaskPositions(allocator, tokens, mask_token); + defer allocator.free(mask_positions); + + // Prepare input tensors + const inputs = try prepareTensorInputs(allocator, tokens, seq_len, pad_token); + defer allocator.free(inputs); + + // Create input tensors (on the accelerator) + const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .i64); + const input_ids_tensor = try zml.Buffer.fromSlice(mod.platform(), input_shape.dims(), inputs); + defer input_ids_tensor.deinit(); + + // Model inference (retrieve indices) + var inference_timer = try std.time.Timer.start(); + var topk = mod.call(.{input_ids_tensor}); + defer zml.aio.unloadBuffers(&topk); + const inference_time = inference_timer.read(); + + // Transfer the result to host memory (CPU) + var indices_host_buffer = try topk.indices.toHostAlloc(allocator); + defer indices_host_buffer.deinit(allocator); + var values_host_buffer = try topk.values.toHostAlloc(allocator); + defer values_host_buffer.deinit(allocator); + + // We consider only the first occurrence of [MASK], which has five predictions + const pred_offset = mask_positions[0] * 5; + const predictions = indices_host_buffer.items(i32)[pred_offset..][0..5]; + const scores = values_host_buffer.items(f32)[pred_offset..][0..5]; + + // Log timing information + log.info("⏱️\tModel inference in {d}ms", .{inference_time / std.time.ns_per_ms}); + + log.info("✅\tTop 5 predictions:", .{}); + for (predictions, scores) |token_id, score| { + const token_text = try tokenizer_decoder.next(@intCast(token_id)); + if (token_text) |word| { + log.info("\t • score: {d:.4} word: '{s}' token: {}", .{ score, word, token_id }); + } + } +} + +pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 { + var tokens = std.ArrayList(u32).init(allocator); + var encoder = try tokenizer.encoder(); + defer encoder.deinit(); + + const bos = tokenizer.tokenToId("[CLS]") orelse return error.NoSuchToken; + const eos = tokenizer.tokenToId("[SEP]") orelse return error.NoSuchToken; + + try tokens.append(bos); + try tokens.appendSlice(try encoder.encode(prompt)); + try tokens.append(eos); + + return tokens.toOwnedSlice(); +} + +fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize { + var mask_positions = std.ArrayList(usize).init(allocator); + defer mask_positions.deinit(); + + for (tokens, 0..) |token, i| { + if (token == mask_token) { + try mask_positions.append(i); + } + } + + if (mask_positions.items.len == 0) { + log.err("Input text must contains `[MASK]`", .{}); + return error.InvalidInput; + } + + if (mask_positions.items.len > 1) log.warn("Currently only supporting one [MASK] per input", .{}); + + return mask_positions.toOwnedSlice(); +} + +fn prepareTensorInputs( + allocator: std.mem.Allocator, + tokens: []const u32, + seq_len: i64, + pad_token: u32, +) ![]u32 { + const input_ids = try allocator.alloc(u32, @intCast(seq_len)); + + @memset(input_ids, pad_token); + for (tokens, 0..) |token, i| { + input_ids[i] = @intCast(token); + } + + return input_ids; +} + +fn bool_parser(in: []const u8) error{}!bool { + return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; +} + +fn printUsageAndExit(stderr: anytype) noreturn { + stderr.print("usage: ", .{}) catch {}; + clap.usage(stderr, clap.Help, ¶ms) catch {}; + stderr.print("\n", .{}) catch {}; + std.process.exit(0); +} diff --git a/examples/modernbert/modernbert.zig b/examples/modernbert/modernbert.zig new file mode 100644 index 0000000..4a1182b --- /dev/null +++ b/examples/modernbert/modernbert.zig @@ -0,0 +1,268 @@ +const std = @import("std"); +const log = std.log.scoped(.modernbert); + +const asynk = @import("async"); +const stdx = @import("stdx"); +const zml = @import("zml"); + +const Tensor = zml.Tensor; + +pub const ModernBertOptions = struct { + num_attention_heads: i64, + pad_token: u32, + local_attention: u32, + tie_word_embeddings: bool = false, +}; + +pub const ModernBertForMaskedLM = struct { + model: ModernBertModel, + head: ModernBertPredictionHead, + decoder: struct { weight: ?zml.Tensor, bias: zml.Tensor }, + + pub fn init(self: *ModernBertForMaskedLM, options: ModernBertOptions) void { + self.model.init(options); + self.head.norm.eps = 1e-5; + + self.head.dense.weight = self.head.dense.weight.withSharding(.{0}); + + if (options.tie_word_embeddings == true) { + self.decoder.weight = null; + } else if (self.decoder.weight) |decoder_weight| { + self.decoder.weight = decoder_weight.withSharding(.{1}); + } + } + + pub fn forward(self: ModernBertForMaskedLM, input_ids: Tensor) zml.Tensor.ArgMaxRes { + const outputs: Tensor = zml.call(self.model, .forward, .{input_ids}); + const head_outputs: Tensor = zml.call(self.head, .forward, .{outputs}); + + // either use decoder or tied weights + const decoder_weights = self.decoder.weight orelse self.model.embeddings.tok_embeddings.weight; + + const logits = head_outputs.withTags(.{ .b, .s, .d }).dot(decoder_weights.withTags(.{ .voc, .d }), .{.d}); + const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape())); + + const probabilities = biased_logits.softmax(.voc); + return probabilities.topK(5, .voc, .{ .descending = true }); + } +}; + +pub const ModernBertModel = struct { + options: ModernBertOptions, + embeddings: ModernBertEmbeddings, + layers: []ModernBertEncoderLayer, + final_norm: zml.nn.LayerNorm, + + pub fn init(self: *ModernBertModel, options: ModernBertOptions) void { + self.options = options; + self.final_norm.eps = 1e-5; + for (self.layers, 0..) |*encoder_layer, layer_idx| { + encoder_layer.attn.Wqkv.weight = encoder_layer.attn.Wqkv.weight.withSharding(.{0}); + encoder_layer.attn.Wo.weight = encoder_layer.attn.Wo.weight.withSharding(.{1}); + + encoder_layer.mlp.Wi.weight = encoder_layer.mlp.Wi.weight.withSharding(.{0}); + encoder_layer.mlp.Wo.weight = encoder_layer.mlp.Wo.weight.withSharding(.{1}); + + if (encoder_layer.attn_norm) |*norm| norm.eps = 1e-5; + encoder_layer.mlp_norm.eps = 1e-5; + encoder_layer.attn.is_global_attention = (layer_idx % 3 == 0); + encoder_layer.attn.num_heads = options.num_attention_heads; + } + } + + pub fn forward(self: ModernBertModel, input_ids: Tensor) Tensor { + var hidden_states: Tensor = zml.call(self.embeddings, .forward, .{input_ids}).withTags(.{ .b, .src, .d }); + + const global_mask = globalAttnMask(input_ids, hidden_states.dtype(), self.options.pad_token); + const local_mask = localAttnMask(global_mask, self.options.local_attention); + + // Process through all encoder layers + for (self.layers) |encoder_layer| { + hidden_states = zml.call(encoder_layer, .forward, .{ + hidden_states, + global_mask, + local_mask, + }); + } + + // Final layer normalization + hidden_states = zml.call(self.final_norm, .forward, .{hidden_states}); + + return hidden_states; + } + + /// Find [PAD] tokens in inputs, and assign them a -inf attention mask. + /// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k } + pub fn globalAttnMask(input_ids: Tensor, dt: zml.DataType, pad_token: u32) Tensor { + const ids = input_ids.withTags(.{ .b, .k }); + + // Mask keys where corresponding token is [PAD] + const padding = ids.cmp(.EQ, Tensor.scalar(pad_token, ids.dtype())); + const pad_mask = padding.select(Tensor.constant(.{}, dt.minValue()), Tensor.constant(.{}, dt.zero())); + + // Broadcast to the desired output shape. + const seq_len = ids.dim(.k); + const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt); + return pad_mask.broad(pad_mask_shape).print(); + } + + /// Restrict global attn mask to a sliding window. + /// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k } + pub fn localAttnMask(global_mask: Tensor, window_size: u32) Tensor { + const mask_shape = global_mask.shape(); + + // Calculate distance between positions + const rows = Tensor.iota(mask_shape, .q); + const cols = Tensor.iota(mask_shape, .k); + const distance = rows.sub(cols).abs(); + + // Note: we divide by two because the BERT local attention is symetric around the query token. + // Create sliding window mask (1 for positions within window, 0 outside) + const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32)); + const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue()); + return window_mask.select(global_mask, minus_inf).print(); + } +}; + +pub const ModernBertPredictionHead = struct { + dense: zml.nn.Linear, + norm: zml.nn.LayerNorm, + + pub fn forward(self: ModernBertPredictionHead, hidden_states: Tensor) Tensor { + const dense_output: Tensor = zml.call(self.dense, .forward, .{hidden_states}); + + const activated_output = dense_output.gelu(); + + return zml.call(self.norm, .forward, .{activated_output}); + } +}; + +pub const ModernBertEmbeddings = struct { + tok_embeddings: zml.nn.TokenEmbedding, + norm: zml.nn.LayerNorm, + + pub fn forward(self: ModernBertEmbeddings, input_ids: Tensor) Tensor { + // Perform tok_embeddings + const hidden_states = zml.call(self.tok_embeddings, .forward, .{input_ids}); + + // Perform norm + return zml.call(self.norm, .forward, .{hidden_states}); + } +}; + +pub const ModernBertEncoderLayer = struct { + attn_norm: ?zml.nn.LayerNorm = null, + attn: ModernBertAttention, + mlp_norm: zml.nn.LayerNorm, + mlp: ModernBertMLP, + + pub fn forward( + self: ModernBertEncoderLayer, + hidden_states: Tensor, + global_mask: Tensor, + local_mask: Tensor, + ) Tensor { + const attn_norm_output = if (self.attn_norm) |attn_norm| + zml.call(attn_norm, .forward, .{hidden_states}) + else + hidden_states; + + const attn_output: Tensor = zml.call(self.attn, .forward, .{ + attn_norm_output, + global_mask, + local_mask, + }); + + var output = hidden_states.add(attn_output); + + const mlp_norm_output: Tensor = zml.call(self.mlp_norm, .forward, .{output}); + const mlp_output = zml.call(self.mlp, .forward, .{mlp_norm_output}); + output = output.add(mlp_output); + + return output; + } +}; + +/// Performs multi-headed self attention on a batch of unpadded sequences. +/// +/// If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. +/// If Flash Attention 2 is not installed, the implementation will use SDPA, +pub const ModernBertAttention = struct { + Wqkv: zml.nn.Linear, + Wo: zml.nn.Linear, + is_global_attention: bool = false, + num_heads: i64 = undefined, + + /// sdpa_attention_forward + pub fn forward( + self: ModernBertAttention, + hidden_states: Tensor, + global_mask: Tensor, + local_mask: Tensor, + ) Tensor { + const batch_size = hidden_states.shape().dim(0); + const seq_length = hidden_states.shape().dim(1); + const hidden_size = hidden_states.shape().dim(2); + const num_heads = self.num_heads; + const head_dim = @divExact(hidden_size, num_heads); + + // Project to query, key, value - { batch_size, seq_len, 3 * num_heads * head_dim } + var qkv: Tensor = zml.call(self.Wqkv, .forward, .{hidden_states}); + + // Reshape to { batch_size, seq_len, 3, num_heads, head_dim } + qkv = qkv.reshape(.{ batch_size, seq_length, 3, num_heads, head_dim }).withTags(.{ .b, .s, .chunk, .h, .hd }); + + // Split into query, key, value tensors - each { batch_size, seq_length, num_heads, head_dim } + var q, var k, var v = qkv.chunkExact(.chunk, 3); + q = q.squeeze(.chunk); + k = k.squeeze(.chunk); + v = v.squeeze(.chunk); + + // Apply rotary position embeddings (RoPE) + // Layer 0, 3, 6, 9, 12 ... use global RoPE + // Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE + const rope_opts = zml.nn.RopeOpts{ + .impl = .sequential, + .freq_base = if (self.is_global_attention) 160_000 else 10_000, + }; + + q = zml.nn.rope(q, null, rope_opts); + k = zml.nn.rope(k, null, rope_opts); + + // rename dimensions for sdpa + q = q.rename(.{ .s = .q }); + k = k.rename(.{ .s = .k }); + v = v.rename(.{ .s = .k }); + + // Scaled dot product attention + const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = if (self.is_global_attention) global_mask else local_mask }); + const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); + + // Final projection + return zml.call(self.Wo, .forward, .{attn}); + } +}; + +/// Switch out the old MLP layers for GeGLU layers, improving on the original BERT’s GeLU activation function. +/// +/// The GeGLU activation function is a combination of the Gated Linear Unit (GLU) and the Gaussian Error Linear Unit (GeLU). +/// +/// see: https://paperswithcode.com/method/geglu +pub const ModernBertMLP = struct { + Wi: zml.nn.Linear, + Wo: zml.nn.Linear, + + pub fn forward(self: ModernBertMLP, hidden_states: Tensor) Tensor { + // Perform Wi + const wi_output: Tensor = zml.call(self.Wi, .forward, .{hidden_states}); + + // Split into input and gate tensors along the last dimension + const input, const gate = wi_output.chunkExact(-1, 2); + + // Apply activation + const activated_input = input.gelu().mul(gate); + + // Perform Wo + return zml.call(self.Wo, .forward, .{activated_input}); + } +}; diff --git a/examples/modernbert/test.zig b/examples/modernbert/test.zig new file mode 100644 index 0000000..8dd8bbb --- /dev/null +++ b/examples/modernbert/test.zig @@ -0,0 +1,238 @@ +const clap = @import("clap"); +const std = @import("std"); +const zml = @import("zml"); +const asynk = @import("async"); +const log = std.log; +const Tensor = zml.Tensor; +const modernbert_module = @import("modernbert.zig"); +const ModernBertOptions = modernbert_module.ModernBertOptions; + +const params = clap.parseParamsComptime( + \\--help print this help + \\--model model weights path + \\--activations model activations path +); + +fn printUsageAndExit(stderr: anytype) noreturn { + stderr.print("usage: ", .{}) catch {}; + clap.usage(stderr, clap.Help, ¶ms) catch {}; + stderr.print("\n", .{}) catch {}; + std.process.exit(0); +} +pub fn main() !void { + try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); +} + +pub fn asyncMain() !void { + // Short lived allocations + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + + const allocator = gpa.allocator(); + const stderr = std.io.getStdErr().writer(); + + // Read CLI arguments + const parsers = comptime .{ + .PATH = clap.parsers.string, + }; + var diag: clap.Diagnostic = .{}; + var res = clap.parse(clap.Help, ¶ms, parsers, .{ + .diagnostic = &diag, + .allocator = allocator, + }) catch |err| { + try diag.report(stderr, err); + try printUsageAndExit(stderr); + }; + defer res.deinit(); + + if (res.args.help != 0) { + try clap.help(stderr, clap.Help, ¶ms, .{}); + return; + } + + const model_file = res.args.model orelse { + stderr.print("Error: missing --model=...\n\n", .{}) catch {}; + printUsageAndExit(stderr); + unreachable; + }; + const activations_file = res.args.activations orelse { + stderr.print("Error: missing --activations=...\n\n", .{}) catch {}; + printUsageAndExit(stderr); + unreachable; + }; + + // Initialize the ZML context + var context = try zml.Context.init(); + defer context.deinit(); + + // Auto-select platform + const compute_platform = context.autoPlatform(.{}); + log.info("Selected platform: {s}", .{@tagName(compute_platform.target)}); + + // Create a dedicated memory arena for model-related allocations (dedicated to model shapes and weights) + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const model_arena = arena_state.allocator(); + + // Load the model weights file and parse its structure (shape) + var weights_file = try zml.aio.detectFormatAndOpen(allocator, model_file); + defer weights_file.deinit(); + log.info("Model contains {d} layers. Loaded from: {s}", .{ weights_file.buffers.count(), model_file }); + + // Load the activation data file + const activations = try zml.aio.torch.open(model_arena, activations_file); + defer activations.deinit(); + log.info("Found {} activations in {s}", .{ activations.buffers.count(), activations_file }); + + // Initialize model + var model = try zml.aio.populateModel( + modernbert_module.ModernBertForMaskedLM, + model_arena, + weights_file, + ); + + const modernbert_base_options: modernbert_module.ModernBertOptions = .{ + .num_attention_heads = 12, + .tie_word_embeddings = true, + .pad_token = 50283, + .local_attention = 128, + }; + model.init(modernbert_base_options); + + // Load model weights + const model_weights = try zml.aio.loadModelBuffers( + modernbert_module.ModernBertForMaskedLM, + model, + weights_file, + model_arena, + compute_platform, + ); + + // Test implementation + try testImplementation(compute_platform, model, model_weights, activations); +} + +fn testImplementation( + compute_platform: zml.Platform, + model: modernbert_module.ModernBertForMaskedLM, + model_weights: zml.Bufferized(modernbert_module.ModernBertForMaskedLM), + activations: zml.aio.BufferStore, +) !void { + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.embeddings.tok_embeddings", + model.model.embeddings.tok_embeddings, + model_weights.model.embeddings.tok_embeddings, + 1e-6, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.embeddings.norm", + model.model.embeddings.norm, + model_weights.model.embeddings.norm, + 1e-3, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.embeddings", + model.model.embeddings, + model_weights.model.embeddings, + 1e-3, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.final_norm", + model.model.final_norm, + model_weights.model.final_norm, + 1e-5, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.layers.2.mlp", + model.model.layers[2].mlp, + model_weights.model.layers[2].mlp, + 2e-3, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.layers.2.mlp_norm", + model.model.layers[2].mlp_norm, + model_weights.model.layers[2].mlp_norm, + 1e-4, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.layers.2.attn", + model.model.layers[2].attn, + model_weights.model.layers[2].attn, + 1e-6, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.layers.2", + model.model.layers[2], + model_weights.model.layers[2], + 2e-3, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model.layers.3.attn", + model.model.layers[3].attn, + model_weights.model.layers[3].attn, + 1e-5, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.model", + model.model, + model_weights.model, + 1e-2, + ); + + const TiedDecoder = struct { + weight: Tensor, + bias: Tensor, + + pub fn forward(self: @This(), head_outputs: Tensor) Tensor { + const results = head_outputs.withTags(.{ .b, .s, .d }).dot(self.weight.withTags(.{ .voc, .d }), .{.d}); + return results.add(self.bias.withTags(.{.voc}).broad(results.shape())); + } + }; + + try zml.testing.testLayer( + compute_platform, + activations, + "model.decoder", + TiedDecoder{ .weight = model.decoder.weight orelse model.model.embeddings.tok_embeddings.weight, .bias = model.decoder.bias }, + .{ .weight = model_weights.model.embeddings.tok_embeddings.weight, .bias = model_weights.decoder.bias }, + 1e-3, + ); + + try zml.testing.testLayer( + compute_platform, + activations, + "model.head", + model.head, + model_weights.head, + 0.1, // TODO: too high tolerance + ); +}