const std = @import("std"); const asynk = @import("async"); const stdx = @import("stdx"); const zml = @import("zml"); const Tensor = zml.Tensor; const log = std.log.scoped(.modernbert); pub const ModernBertOptions = struct { num_attention_heads: i64, pad_token: u32, local_attention: u32, tie_word_embeddings: bool = false, }; pub const ModernBertForMaskedLM = struct { model: ModernBertModel, head: ModernBertPredictionHead, decoder: struct { weight: ?zml.Tensor, bias: zml.Tensor }, pub fn init(self: *ModernBertForMaskedLM, options: ModernBertOptions) void { self.model.init(options); self.head.norm.eps = 1e-5; self.head.dense.weight = self.head.dense.weight.withSharding(.{0}); if (options.tie_word_embeddings == true) { self.decoder.weight = null; } else if (self.decoder.weight) |decoder_weight| { self.decoder.weight = decoder_weight.withSharding(.{1}); } } pub fn forward(self: ModernBertForMaskedLM, input_ids: Tensor) zml.Tensor.ArgMaxRes { const outputs: Tensor = zml.call(self.model, .forward, .{input_ids}); const head_outputs: Tensor = zml.call(self.head, .forward, .{outputs}); // either use decoder or tied weights const decoder_weights = self.decoder.weight orelse self.model.embeddings.tok_embeddings.weight; const logits = head_outputs.withTags(.{ .b, .s, .d }).dot(decoder_weights.withTags(.{ .voc, .d }), .{.d}); const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape())); const probabilities = biased_logits.softmax(.voc); return probabilities.topK(5, .voc, .{ .descending = true }); } }; pub const ModernBertModel = struct { options: ModernBertOptions, embeddings: ModernBertEmbeddings, layers: []ModernBertEncoderLayer, final_norm: zml.nn.LayerNorm, pub fn init(self: *ModernBertModel, options: ModernBertOptions) void { self.options = options; self.final_norm.eps = 1e-5; for (self.layers, 0..) |*encoder_layer, layer_idx| { encoder_layer.attn.Wqkv.weight = encoder_layer.attn.Wqkv.weight.withSharding(.{0}); encoder_layer.attn.Wo.weight = encoder_layer.attn.Wo.weight.withSharding(.{1}); encoder_layer.mlp.Wi.weight = encoder_layer.mlp.Wi.weight.withSharding(.{0}); encoder_layer.mlp.Wo.weight = encoder_layer.mlp.Wo.weight.withSharding(.{1}); if (encoder_layer.attn_norm) |*norm| norm.eps = 1e-5; encoder_layer.mlp_norm.eps = 1e-5; encoder_layer.attn.is_global_attention = (layer_idx % 3 == 0); encoder_layer.attn.num_heads = options.num_attention_heads; } } pub fn forward(self: ModernBertModel, input_ids: Tensor) Tensor { var hidden_states: Tensor = zml.call(self.embeddings, .forward, .{input_ids}).withTags(.{ .b, .src, .d }); const global_mask = globalAttnMask(input_ids, hidden_states.dtype(), self.options.pad_token); const local_mask = localAttnMask(global_mask, self.options.local_attention); // Process through all encoder layers for (self.layers) |encoder_layer| { hidden_states = zml.call(encoder_layer, .forward, .{ hidden_states, global_mask, local_mask, }); } // Final layer normalization hidden_states = zml.call(self.final_norm, .forward, .{hidden_states}); return hidden_states; } /// Find [PAD] tokens in inputs, and assign them a -inf attention mask. /// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k } pub fn globalAttnMask(input_ids: Tensor, dt: zml.DataType, pad_token: u32) Tensor { const ids = input_ids.withTags(.{ .b, .k }); // Mask keys where corresponding token is [PAD] const padding = ids.cmp(.EQ, Tensor.scalar(pad_token, ids.dtype())); const pad_mask = padding.select(Tensor.constant(.{}, dt.minValue()), Tensor.constant(.{}, dt.zero())); // Broadcast to the desired output shape. const seq_len = ids.dim(.k); const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt); return pad_mask.broad(pad_mask_shape).print(); } /// Restrict global attn mask to a sliding window. /// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k } pub fn localAttnMask(global_mask: Tensor, window_size: u32) Tensor { const mask_shape = global_mask.shape(); // Calculate distance between positions const rows = Tensor.iota(mask_shape, .q); const cols = Tensor.iota(mask_shape, .k); const distance = rows.sub(cols).abs(); // Note: we divide by two because the BERT local attention is symetric around the query token. // Create sliding window mask (1 for positions within window, 0 outside) const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32)); const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue()); return window_mask.select(global_mask, minus_inf).print(); } }; pub const ModernBertPredictionHead = struct { dense: zml.nn.Linear, norm: zml.nn.LayerNorm, pub fn forward(self: ModernBertPredictionHead, hidden_states: Tensor) Tensor { const dense_output: Tensor = zml.call(self.dense, .forward, .{hidden_states}); const activated_output = dense_output.gelu(); return zml.call(self.norm, .forward, .{activated_output}); } }; pub const ModernBertEmbeddings = struct { tok_embeddings: zml.nn.TokenEmbedding, norm: zml.nn.LayerNorm, pub fn forward(self: ModernBertEmbeddings, input_ids: Tensor) Tensor { // Perform tok_embeddings const hidden_states = zml.call(self.tok_embeddings, .forward, .{input_ids}); // Perform norm return zml.call(self.norm, .forward, .{hidden_states}); } }; pub const ModernBertEncoderLayer = struct { attn_norm: ?zml.nn.LayerNorm = null, attn: ModernBertAttention, mlp_norm: zml.nn.LayerNorm, mlp: ModernBertMLP, pub fn forward( self: ModernBertEncoderLayer, hidden_states: Tensor, global_mask: Tensor, local_mask: Tensor, ) Tensor { const attn_norm_output = if (self.attn_norm) |attn_norm| zml.call(attn_norm, .forward, .{hidden_states}) else hidden_states; const attn_output: Tensor = zml.call(self.attn, .forward, .{ attn_norm_output, global_mask, local_mask, }); var output = hidden_states.add(attn_output); const mlp_norm_output: Tensor = zml.call(self.mlp_norm, .forward, .{output}); const mlp_output = zml.call(self.mlp, .forward, .{mlp_norm_output}); output = output.add(mlp_output); return output; } }; /// Performs multi-headed self attention on a batch of unpadded sequences. /// /// If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. /// If Flash Attention 2 is not installed, the implementation will use SDPA, pub const ModernBertAttention = struct { Wqkv: zml.nn.Linear, Wo: zml.nn.Linear, is_global_attention: bool = false, num_heads: i64 = undefined, /// sdpa_attention_forward pub fn forward( self: ModernBertAttention, hidden_states: Tensor, global_mask: Tensor, local_mask: Tensor, ) Tensor { const batch_size = hidden_states.shape().dim(0); const seq_length = hidden_states.shape().dim(1); const hidden_size = hidden_states.shape().dim(2); const num_heads = self.num_heads; const head_dim = @divExact(hidden_size, num_heads); // Project to query, key, value - { batch_size, seq_len, 3 * num_heads * head_dim } var qkv: Tensor = zml.call(self.Wqkv, .forward, .{hidden_states}); // Reshape to { batch_size, seq_len, 3, num_heads, head_dim } qkv = qkv.reshape(.{ batch_size, seq_length, 3, num_heads, head_dim }).withTags(.{ .b, .s, .chunk, .h, .hd }); // Split into query, key, value tensors - each { batch_size, seq_length, num_heads, head_dim } var q, var k, var v = qkv.chunkExact(.chunk, 3); q = q.squeeze(.chunk); k = k.squeeze(.chunk); v = v.squeeze(.chunk); // Apply rotary position embeddings (RoPE) // Layer 0, 3, 6, 9, 12 ... use global RoPE // Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE const rope_opts = zml.nn.RopeOpts{ .layout = .sequential, .freq_base = if (self.is_global_attention) 160_000 else 10_000, }; q = zml.nn.rope(q, null, rope_opts); k = zml.nn.rope(k, null, rope_opts); // rename dimensions for sdpa q = q.rename(.{ .s = .q }); k = k.rename(.{ .s = .k }); v = v.rename(.{ .s = .k }); // Scaled dot product attention const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = if (self.is_global_attention) global_mask else local_mask }); const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); // Final projection return zml.call(self.Wo, .forward, .{attn}); } }; /// Switch out the old MLP layers for GeGLU layers, improving on the original BERT’s GeLU activation function. /// /// The GeGLU activation function is a combination of the Gated Linear Unit (GLU) and the Gaussian Error Linear Unit (GeLU). /// /// see: https://paperswithcode.com/method/geglu pub const ModernBertMLP = struct { Wi: zml.nn.Linear, Wo: zml.nn.Linear, pub fn forward(self: ModernBertMLP, hidden_states: Tensor) Tensor { // Perform Wi const wi_output: Tensor = zml.call(self.Wi, .forward, .{hidden_states}); // Split into input and gate tensors along the last dimension const input, const gate = wi_output.chunkExact(-1, 2); // Apply activation const activated_input = input.gelu().mul(gate); // Perform Wo return zml.call(self.Wo, .forward, .{activated_input}); } };