Radix/examples/modernbert/modernbert.zig

270 lines
10 KiB
Zig
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

const std = @import("std");
const async = @import("async");
const stdx = @import("stdx");
const zml = @import("zml");
const Tensor = zml.Tensor;
const log = std.log.scoped(.modernbert);
pub const ModernBertOptions = struct {
num_attention_heads: i64,
pad_token: u32,
local_attention: u32,
tie_word_embeddings: bool = false,
};
pub const ModernBertForMaskedLM = struct {
model: ModernBertModel,
head: ModernBertPredictionHead,
decoder: struct { weight: ?zml.Tensor, bias: zml.Tensor },
pub fn init(self: *ModernBertForMaskedLM, options: ModernBertOptions) void {
self.model.init(options);
self.head.norm.eps = 1e-5;
self.head.dense.weight = self.head.dense.weight.withSharding(.{0});
if (options.tie_word_embeddings == true) {
self.decoder.weight = null;
} else if (self.decoder.weight) |decoder_weight| {
self.decoder.weight = decoder_weight.withSharding(.{1});
}
}
pub fn forward(self: ModernBertForMaskedLM, input_ids: Tensor) zml.Tensor.ArgMaxRes {
const outputs: Tensor = zml.call(self.model, .forward, .{input_ids});
const head_outputs: Tensor = zml.call(self.head, .forward, .{outputs});
// either use decoder or tied weights
const decoder_weights = self.decoder.weight orelse self.model.embeddings.tok_embeddings.weight;
const logits = head_outputs.withTags(.{ .b, .s, .d }).dot(decoder_weights.withTags(.{ .voc, .d }), .{.d});
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
const probabilities = biased_logits.softmax(.voc);
return probabilities.topK(5, .voc, .{ .descending = true });
}
};
pub const ModernBertModel = struct {
options: ModernBertOptions,
embeddings: ModernBertEmbeddings,
layers: []ModernBertEncoderLayer,
final_norm: zml.nn.LayerNorm,
pub fn init(self: *ModernBertModel, options: ModernBertOptions) void {
self.options = options;
self.final_norm.eps = 1e-5;
self.embeddings.norm.eps = 1e-5;
for (self.layers, 0..) |*encoder_layer, layer_idx| {
encoder_layer.attn.Wqkv.weight = encoder_layer.attn.Wqkv.weight.withSharding(.{0});
encoder_layer.attn.Wo.weight = encoder_layer.attn.Wo.weight.withSharding(.{1});
encoder_layer.mlp.Wi.weight = encoder_layer.mlp.Wi.weight.withSharding(.{0});
encoder_layer.mlp.Wo.weight = encoder_layer.mlp.Wo.weight.withSharding(.{1});
if (encoder_layer.attn_norm) |*norm| norm.eps = 1e-5;
encoder_layer.mlp_norm.eps = 1e-5;
encoder_layer.attn.is_global_attention = (layer_idx % 3 == 0);
encoder_layer.attn.num_heads = options.num_attention_heads;
}
}
pub fn forward(self: ModernBertModel, input_ids: Tensor) Tensor {
var hidden_states: Tensor = zml.call(self.embeddings, .forward, .{input_ids}).withTags(.{ .b, .src, .d });
const global_mask = globalAttnMask(input_ids, hidden_states.dtype(), self.options.pad_token);
const local_mask = localAttnMask(global_mask, self.options.local_attention);
// Process through all encoder layers
for (self.layers) |encoder_layer| {
hidden_states = zml.call(encoder_layer, .forward, .{
hidden_states,
global_mask,
local_mask,
});
}
// Final layer normalization
hidden_states = zml.call(self.final_norm, .forward, .{hidden_states});
return hidden_states;
}
/// Find [PAD] tokens in inputs, and assign them a -inf attention mask.
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
pub fn globalAttnMask(input_ids: Tensor, dt: zml.DataType, pad_token: u32) Tensor {
const ids = input_ids.withTags(.{ .b, .k });
// Mask keys where corresponding token is [PAD]
const padding = ids.cmp(.EQ, Tensor.scalar(pad_token, ids.dtype()));
const pad_mask = padding.select(Tensor.constant(.{}, dt.minValue()), Tensor.constant(.{}, dt.zero()));
// Broadcast to the desired output shape.
const seq_len = ids.dim(.k);
const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt);
return pad_mask.broad(pad_mask_shape);
}
/// Restrict global attn mask to a sliding window.
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
pub fn localAttnMask(global_mask: Tensor, window_size: u32) Tensor {
const mask_shape = global_mask.shape();
// Calculate distance between positions
const rows = Tensor.iota(mask_shape, .q);
const cols = Tensor.iota(mask_shape, .k);
const distance = rows.sub(cols).abs();
// Note: we divide by two because the BERT local attention is symetric around the query token.
// Create sliding window mask (1 for positions within window, 0 outside)
const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32));
const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue());
return window_mask.select(global_mask, minus_inf);
}
};
pub const ModernBertPredictionHead = struct {
dense: zml.nn.Linear,
norm: zml.nn.LayerNorm,
pub fn forward(self: ModernBertPredictionHead, hidden_states: Tensor) Tensor {
const dense_output: Tensor = zml.call(self.dense, .forward, .{hidden_states});
const activated_output = dense_output.gelu();
return zml.call(self.norm, .forward, .{activated_output});
}
};
pub const ModernBertEmbeddings = struct {
tok_embeddings: zml.nn.TokenEmbedding,
norm: zml.nn.LayerNorm,
pub fn forward(self: ModernBertEmbeddings, input_ids: Tensor) Tensor {
// Perform tok_embeddings
const hidden_states = zml.call(self.tok_embeddings, .forward, .{input_ids});
// Perform norm
return zml.call(self.norm, .forward, .{hidden_states});
}
};
pub const ModernBertEncoderLayer = struct {
attn_norm: ?zml.nn.LayerNorm = null,
attn: ModernBertAttention,
mlp_norm: zml.nn.LayerNorm,
mlp: ModernBertMLP,
pub fn forward(
self: ModernBertEncoderLayer,
hidden_states: Tensor,
global_mask: Tensor,
local_mask: Tensor,
) Tensor {
const attn_norm_output = if (self.attn_norm) |attn_norm|
zml.call(attn_norm, .forward, .{hidden_states})
else
hidden_states;
const attn_output: Tensor = zml.call(self.attn, .forward, .{
attn_norm_output,
global_mask,
local_mask,
});
var output = hidden_states.add(attn_output);
const mlp_norm_output: Tensor = zml.call(self.mlp_norm, .forward, .{output});
const mlp_output = zml.call(self.mlp, .forward, .{mlp_norm_output});
output = output.add(mlp_output);
return output;
}
};
/// Performs multi-headed self attention on a batch of unpadded sequences.
///
/// If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
/// If Flash Attention 2 is not installed, the implementation will use SDPA,
pub const ModernBertAttention = struct {
Wqkv: zml.nn.Linear,
Wo: zml.nn.Linear,
is_global_attention: bool = false,
num_heads: i64 = undefined,
/// sdpa_attention_forward
pub fn forward(
self: ModernBertAttention,
hidden_states: Tensor,
global_mask: Tensor,
local_mask: Tensor,
) Tensor {
const batch_size = hidden_states.shape().dim(0);
const seq_length = hidden_states.shape().dim(1);
const hidden_size = hidden_states.shape().dim(2);
const num_heads = self.num_heads;
const head_dim = @divExact(hidden_size, num_heads);
// Project to query, key, value - { batch_size, seq_len, 3 * num_heads * head_dim }
var qkv: Tensor = zml.call(self.Wqkv, .forward, .{hidden_states});
// Reshape to { batch_size, seq_len, 3, num_heads, head_dim }
qkv = qkv.reshape(.{ batch_size, seq_length, 3, num_heads, head_dim }).withTags(.{ .b, .s, .chunk, .h, .hd });
// Split into query, key, value tensors - each { batch_size, seq_length, num_heads, head_dim }
var q, var k, var v = qkv.chunkExact(.chunk, 3);
q = q.squeeze(.chunk);
k = k.squeeze(.chunk);
v = v.squeeze(.chunk);
// Apply rotary position embeddings (RoPE)
// Layer 0, 3, 6, 9, 12 ... use global RoPE
// Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE
const rope_opts = zml.nn.RopeOpts{
.layout = .sequential,
.freq_base = if (self.is_global_attention) 160_000 else 10_000,
};
q = zml.nn.rope(q, null, rope_opts);
k = zml.nn.rope(k, null, rope_opts);
// rename dimensions for sdpa
q = q.rename(.{ .s = .q });
k = k.rename(.{ .s = .k });
v = v.rename(.{ .s = .k });
// Scaled dot product attention
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = if (self.is_global_attention) global_mask else local_mask });
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
// Final projection
return zml.call(self.Wo, .forward, .{attn});
}
};
/// Switch out the old MLP layers for GeGLU layers, improving on the original BERTs GeLU activation function.
///
/// The GeGLU activation function is a combination of the Gated Linear Unit (GLU) and the Gaussian Error Linear Unit (GeLU).
///
/// see: https://paperswithcode.com/method/geglu
pub const ModernBertMLP = struct {
Wi: zml.nn.Linear,
Wo: zml.nn.Linear,
pub fn forward(self: ModernBertMLP, hidden_states: Tensor) Tensor {
// Perform Wi
const wi_output: Tensor = zml.call(self.Wi, .forward, .{hidden_states});
// Split into input and gate tensors along the last dimension
const input, const gate = wi_output.chunkExact(-1, 2);
// Apply activation
const activated_input = input.gelu().mul(gate);
// Perform Wo
return zml.call(self.Wo, .forward, .{activated_input});
}
};