Add tests for the ModernBERT example, covering activation utilities, build setup, and example Zig code.

This commit is contained in:
Foke Singh 2024-06-11 17:33:22 +00:00
parent ab5ad874c3
commit 17d02621e7
7 changed files with 950 additions and 0 deletions

View File

@ -139,6 +139,53 @@ http_file(
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin", url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
) )
# ModernBERT
huggingface.model(
name = "ModernBERT-base",
build_file_content = """\
package(default_visibility = ["//visibility:public"])
filegroup(
name = "model",
srcs = ["model.safetensors"],
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
)
""",
commit = "94032bb66234a691cf6248265170006a7ced4970",
includes = [
"model.safetensors",
"tokenizer.json",
],
model = "answerdotai/ModernBERT-base",
)
use_repo(huggingface, "ModernBERT-base")
huggingface.model(
name = "ModernBERT-large",
build_file_content = """\
package(default_visibility = ["//visibility:public"])
filegroup(
name = "model",
srcs = ["model.safetensors"],
)
filegroup(
name = "tokenizer",
srcs = ["tokenizer.json"],
)
""",
commit = "4bbcbf40bed02ce487125bcb3c897ea9bdc88340",
includes = [
"model.safetensors",
"tokenizer.json",
],
model = "answerdotai/ModernBERT-large",
)
use_repo(huggingface, "ModernBERT-large")
bazel_dep(name = "rules_rust", version = "0.57.1") bazel_dep(name = "rules_rust", version = "0.57.1")
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
rust.toolchain( rust.toolchain(

View File

@ -0,0 +1,63 @@
load("@zml//bazel:zig.bzl", "zig_cc_binary")
zig_cc_binary(
name = "modernbert",
srcs = ["modernbert.zig"],
main = "main.zig",
deps = [
"@com_github_hejsil_clap//:clap",
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
)
cc_binary(
name = "ModernBERT-base",
args = [
"--model=$(location @ModernBERT-base//:model.safetensors)",
"--tokenizer=$(location @ModernBERT-base//:tokenizer)",
"--num-attention-heads=12",
"--tie-word-embeddings=true",
],
data = [
"@ModernBERT-base//:model.safetensors",
"@ModernBERT-base//:tokenizer",
],
deps = [":modernbert_lib"],
)
cc_binary(
name = "ModernBERT-large",
args = [
"--model=$(location @ModernBERT-large//:model.safetensors)",
"--tokenizer=$(location @ModernBERT-large//:tokenizer)",
"--num-attention-heads=16",
"--tie-word-embeddings=true",
],
data = [
"@ModernBERT-large//:model.safetensors",
"@ModernBERT-large//:tokenizer",
],
deps = [":modernbert_lib"],
)
zig_cc_binary(
name = "test-implementation",
srcs = ["modernbert.zig"],
args = [
"--model=$(location @ModernBERT-base//:model.safetensors)",
],
data = [
"@ModernBERT-base//:model.safetensors",
],
main = "test.zig",
tags = [
"no_ci",
],
deps = [
"@com_github_hejsil_clap//:clap",
"@zml//async",
"@zml//zml",
],
)

View File

@ -0,0 +1,55 @@
import logging
import torch
from transformers import pipeline
from tools.zml_utils import ActivationCollector
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
log = logging.getLogger(__name__)
MODEL_NAME: str = "answerdotai/ModernBERT-base"
def main() -> None:
try:
log.info("Start running main()")
log.info(f"CPU capability : `{torch.backends.cpu.get_cpu_capability()}`")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Loading model : `{MODEL_NAME}`")
fill_mask_pipeline = pipeline(
"fill-mask",
model=MODEL_NAME,
device_map=device,
)
model, tokenizer = fill_mask_pipeline.model, fill_mask_pipeline.tokenizer
log.info(
f"Model loaded successfully {model.config.architectures} - `{model.config.torch_dtype}` - {tokenizer.model_max_length} max tokens" # noqa: E501
)
# Wrap the pipeline, and extract activations.
# Activations files can be huge for big models,
# so let's stop collecting after 1000 layers.
zml_pipeline = ActivationCollector(
fill_mask_pipeline, max_layers=1000, stop_after_first_step=True
)
input_text = "Paris is the [MASK] of France."
outputs, activations = zml_pipeline(input_text)
log.info(f"ouputs : {outputs}")
filename = MODEL_NAME.split("/")[-1] + ".activations.pt"
torch.save(activations, filename)
log.info(f"Saved {len(activations)} activations to {filename}")
log.info("End running main()")
except Exception as exception:
log.error(exception)
raise
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
torch
transformers==4.48.1
accelerate
numpy==1.26.4

View File

@ -0,0 +1,275 @@
const std = @import("std");
const log = std.log.scoped(.modernbert);
const modernbert = @import("modernbert.zig");
const asynk = @import("async");
const clap = @import("clap");
const stdx = @import("stdx");
const zml = @import("zml");
const Tensor = zml.Tensor;
pub const std_options = .{
.log_level = .info,
.log_scope_levels = &[_]std.log.ScopeLevel{
.{ .scope = .modernbert, .level = .info },
},
.logFn = asynk.logFn(std.log.defaultLog),
};
const params = clap.parseParamsComptime(
\\--help print this help
\\--text <STRING> the prompt
\\--model <PATH> model path
\\--tokenizer <PATH> tokenizer path
\\--seq-len <UINT> sequence length
\\--num-attention-heads <UINT> number of attention heads
\\--tie-word-embeddings <BOOL> default: false: tied weights
\\--create-options <STRING> platform creation options JSON, defaults to {}
\\--sharding <BOOL> default: true: sharding on or off
);
const clap_parsers = .{
.BOOL = bool_parser,
.UINT = clap.parsers.int(usize, 0),
.STRING = clap.parsers.string,
.PATH = clap.parsers.string,
};
pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
const allocator = std.heap.c_allocator;
const stderr = std.io.getStdErr().writer();
var diag: clap.Diagnostic = .{};
var cli = clap.parse(clap.Help, &params, clap_parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
try diag.report(stderr, err);
try printUsageAndExit(stderr);
};
defer cli.deinit();
if (cli.args.help != 0) {
try clap.help(stderr, clap.Help, &params, .{});
return;
}
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
try tmp.makePath("zml/modernbert/cache");
// Create ZML context
var context = try zml.Context.init();
defer context.deinit();
// Platform and compilation options
const create_opts_json = cli.args.@"create-options" orelse "{}";
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
const compilation_options = zml.CompilationOptions{
.xla_dump_to = "/tmp/zml/modernbert",
.sharding_enabled = cli.args.sharding orelse true,
};
// Auto-select platform
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
// Detects the format of the model file (base on filename) and open it.
const model_file = cli.args.model orelse {
stderr.print("Error: missing --model=...\n\n", .{}) catch {};
printUsageAndExit(stderr);
unreachable;
};
var tensor_store = try zml.aio.detectFormatAndOpen(allocator, model_file);
defer tensor_store.deinit();
// Memory arena dedicated to model shapes and weights
var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit();
const model_arena = arena_state.allocator();
var tokenizer = blk: {
if (cli.args.tokenizer) |tok| {
log.info("\tLoading tokenizer from {s}", .{tok});
var timer = try stdx.time.Timer.start();
defer log.info("\tLoaded tokenizer from {s} [{}]", .{ tok, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok);
} else {
log.err("Error: missing --tokenizer", .{});
return;
}
};
defer tokenizer.deinit();
// Create the model struct, with tensor shapes extracted from the tensor_store
// TODO: read from config.json
const modernbert_options = modernbert.ModernBertOptions{
.pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken,
.num_attention_heads = @intCast(cli.args.@"num-attention-heads" orelse 12),
.tie_word_embeddings = cli.args.@"tie-word-embeddings" orelse false,
.local_attention = 128,
};
var modern_bert_for_masked_lm = try zml.aio.populateModel(modernbert.ModernBertForMaskedLM, model_arena, tensor_store);
modern_bert_for_masked_lm.init(modernbert_options);
log.info("\tModernBERT options: {}", .{modernbert_options});
// Prepare shapes for compilation
const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256));
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32);
var start = try std.time.Timer.start();
// Load weights
log.info("\tLoading ModernBERT weights from {?s}...", .{model_file});
var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform);
defer zml.aio.unloadBuffers(&bert_weights);
log.info("\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
// Compile the model
log.info("\tCompiling ModernBERT model...", .{});
var fut_mod = try asynk.asyncc(zml.compile, .{
allocator,
modernbert.ModernBertForMaskedLM.forward,
.{modernbert_options},
.{input_shape},
tensor_store,
platform,
});
var bert_module = (try fut_mod.awaitt()).prepare(bert_weights);
defer bert_module.deinit();
log.info("\tLoaded weights and compiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
const text = cli.args.text orelse "Paris is the [MASK] of France.";
log.info("\tInput text: {s}", .{text});
try unmask(allocator, bert_module, tokenizer, seq_len, text);
}
/// fill-mask pipeline
/// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/fill_mask.py
pub fn unmask(
allocator: std.mem.Allocator,
mod: zml.ModuleExe(modernbert.ModernBertForMaskedLM.forward),
tokenizer: zml.tokenizer.Tokenizer,
seq_len: i64,
text: []const u8,
) !void {
var tokenizer_decoder = try tokenizer.decoder();
defer tokenizer_decoder.deinit();
const pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken;
const mask_token = tokenizer.tokenToId("[MASK]") orelse return error.NoSuchToken;
// Tokenize input text
const tokens: []const u32 = try tokenize(allocator, tokenizer, text);
defer allocator.free(tokens);
// Find "[MASK]" positions
const mask_positions = try findMaskPositions(allocator, tokens, mask_token);
defer allocator.free(mask_positions);
// Prepare input tensors
const inputs = try prepareTensorInputs(allocator, tokens, seq_len, pad_token);
defer allocator.free(inputs);
// Create input tensors (on the accelerator)
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .i64);
const input_ids_tensor = try zml.Buffer.fromSlice(mod.platform(), input_shape.dims(), inputs);
defer input_ids_tensor.deinit();
// Model inference (retrieve indices)
var inference_timer = try std.time.Timer.start();
var topk = mod.call(.{input_ids_tensor});
defer zml.aio.unloadBuffers(&topk);
const inference_time = inference_timer.read();
// Transfer the result to host memory (CPU)
var indices_host_buffer = try topk.indices.toHostAlloc(allocator);
defer indices_host_buffer.deinit(allocator);
var values_host_buffer = try topk.values.toHostAlloc(allocator);
defer values_host_buffer.deinit(allocator);
// We consider only the first occurrence of [MASK], which has five predictions
const pred_offset = mask_positions[0] * 5;
const predictions = indices_host_buffer.items(i32)[pred_offset..][0..5];
const scores = values_host_buffer.items(f32)[pred_offset..][0..5];
// Log timing information
log.info("⏱️\tModel inference in {d}ms", .{inference_time / std.time.ns_per_ms});
log.info("\tTop 5 predictions:", .{});
for (predictions, scores) |token_id, score| {
const token_text = try tokenizer_decoder.next(@intCast(token_id));
if (token_text) |word| {
log.info("\t • score: {d:.4} word: '{s}' token: {}", .{ score, word, token_id });
}
}
}
pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 {
var tokens = std.ArrayList(u32).init(allocator);
var encoder = try tokenizer.encoder();
defer encoder.deinit();
const bos = tokenizer.tokenToId("[CLS]") orelse return error.NoSuchToken;
const eos = tokenizer.tokenToId("[SEP]") orelse return error.NoSuchToken;
try tokens.append(bos);
try tokens.appendSlice(try encoder.encode(prompt));
try tokens.append(eos);
return tokens.toOwnedSlice();
}
fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize {
var mask_positions = std.ArrayList(usize).init(allocator);
defer mask_positions.deinit();
for (tokens, 0..) |token, i| {
if (token == mask_token) {
try mask_positions.append(i);
}
}
if (mask_positions.items.len == 0) {
log.err("Input text must contains `[MASK]`", .{});
return error.InvalidInput;
}
if (mask_positions.items.len > 1) log.warn("Currently only supporting one [MASK] per input", .{});
return mask_positions.toOwnedSlice();
}
fn prepareTensorInputs(
allocator: std.mem.Allocator,
tokens: []const u32,
seq_len: i64,
pad_token: u32,
) ![]u32 {
const input_ids = try allocator.alloc(u32, @intCast(seq_len));
@memset(input_ids, pad_token);
for (tokens, 0..) |token, i| {
input_ids[i] = @intCast(token);
}
return input_ids;
}
fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}
fn printUsageAndExit(stderr: anytype) noreturn {
stderr.print("usage: ", .{}) catch {};
clap.usage(stderr, clap.Help, &params) catch {};
stderr.print("\n", .{}) catch {};
std.process.exit(0);
}

View File

@ -0,0 +1,268 @@
const std = @import("std");
const log = std.log.scoped(.modernbert);
const asynk = @import("async");
const stdx = @import("stdx");
const zml = @import("zml");
const Tensor = zml.Tensor;
pub const ModernBertOptions = struct {
num_attention_heads: i64,
pad_token: u32,
local_attention: u32,
tie_word_embeddings: bool = false,
};
pub const ModernBertForMaskedLM = struct {
model: ModernBertModel,
head: ModernBertPredictionHead,
decoder: struct { weight: ?zml.Tensor, bias: zml.Tensor },
pub fn init(self: *ModernBertForMaskedLM, options: ModernBertOptions) void {
self.model.init(options);
self.head.norm.eps = 1e-5;
self.head.dense.weight = self.head.dense.weight.withSharding(.{0});
if (options.tie_word_embeddings == true) {
self.decoder.weight = null;
} else if (self.decoder.weight) |decoder_weight| {
self.decoder.weight = decoder_weight.withSharding(.{1});
}
}
pub fn forward(self: ModernBertForMaskedLM, input_ids: Tensor) zml.Tensor.ArgMaxRes {
const outputs: Tensor = zml.call(self.model, .forward, .{input_ids});
const head_outputs: Tensor = zml.call(self.head, .forward, .{outputs});
// either use decoder or tied weights
const decoder_weights = self.decoder.weight orelse self.model.embeddings.tok_embeddings.weight;
const logits = head_outputs.withTags(.{ .b, .s, .d }).dot(decoder_weights.withTags(.{ .voc, .d }), .{.d});
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
const probabilities = biased_logits.softmax(.voc);
return probabilities.topK(5, .voc, .{ .descending = true });
}
};
pub const ModernBertModel = struct {
options: ModernBertOptions,
embeddings: ModernBertEmbeddings,
layers: []ModernBertEncoderLayer,
final_norm: zml.nn.LayerNorm,
pub fn init(self: *ModernBertModel, options: ModernBertOptions) void {
self.options = options;
self.final_norm.eps = 1e-5;
for (self.layers, 0..) |*encoder_layer, layer_idx| {
encoder_layer.attn.Wqkv.weight = encoder_layer.attn.Wqkv.weight.withSharding(.{0});
encoder_layer.attn.Wo.weight = encoder_layer.attn.Wo.weight.withSharding(.{1});
encoder_layer.mlp.Wi.weight = encoder_layer.mlp.Wi.weight.withSharding(.{0});
encoder_layer.mlp.Wo.weight = encoder_layer.mlp.Wo.weight.withSharding(.{1});
if (encoder_layer.attn_norm) |*norm| norm.eps = 1e-5;
encoder_layer.mlp_norm.eps = 1e-5;
encoder_layer.attn.is_global_attention = (layer_idx % 3 == 0);
encoder_layer.attn.num_heads = options.num_attention_heads;
}
}
pub fn forward(self: ModernBertModel, input_ids: Tensor) Tensor {
var hidden_states: Tensor = zml.call(self.embeddings, .forward, .{input_ids}).withTags(.{ .b, .src, .d });
const global_mask = globalAttnMask(input_ids, hidden_states.dtype(), self.options.pad_token);
const local_mask = localAttnMask(global_mask, self.options.local_attention);
// Process through all encoder layers
for (self.layers) |encoder_layer| {
hidden_states = zml.call(encoder_layer, .forward, .{
hidden_states,
global_mask,
local_mask,
});
}
// Final layer normalization
hidden_states = zml.call(self.final_norm, .forward, .{hidden_states});
return hidden_states;
}
/// Find [PAD] tokens in inputs, and assign them a -inf attention mask.
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
pub fn globalAttnMask(input_ids: Tensor, dt: zml.DataType, pad_token: u32) Tensor {
const ids = input_ids.withTags(.{ .b, .k });
// Mask keys where corresponding token is [PAD]
const padding = ids.cmp(.EQ, Tensor.scalar(pad_token, ids.dtype()));
const pad_mask = padding.select(Tensor.constant(.{}, dt.minValue()), Tensor.constant(.{}, dt.zero()));
// Broadcast to the desired output shape.
const seq_len = ids.dim(.k);
const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt);
return pad_mask.broad(pad_mask_shape).print();
}
/// Restrict global attn mask to a sliding window.
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
pub fn localAttnMask(global_mask: Tensor, window_size: u32) Tensor {
const mask_shape = global_mask.shape();
// Calculate distance between positions
const rows = Tensor.iota(mask_shape, .q);
const cols = Tensor.iota(mask_shape, .k);
const distance = rows.sub(cols).abs();
// Note: we divide by two because the BERT local attention is symetric around the query token.
// Create sliding window mask (1 for positions within window, 0 outside)
const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32));
const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue());
return window_mask.select(global_mask, minus_inf).print();
}
};
pub const ModernBertPredictionHead = struct {
dense: zml.nn.Linear,
norm: zml.nn.LayerNorm,
pub fn forward(self: ModernBertPredictionHead, hidden_states: Tensor) Tensor {
const dense_output: Tensor = zml.call(self.dense, .forward, .{hidden_states});
const activated_output = dense_output.gelu();
return zml.call(self.norm, .forward, .{activated_output});
}
};
pub const ModernBertEmbeddings = struct {
tok_embeddings: zml.nn.TokenEmbedding,
norm: zml.nn.LayerNorm,
pub fn forward(self: ModernBertEmbeddings, input_ids: Tensor) Tensor {
// Perform tok_embeddings
const hidden_states = zml.call(self.tok_embeddings, .forward, .{input_ids});
// Perform norm
return zml.call(self.norm, .forward, .{hidden_states});
}
};
pub const ModernBertEncoderLayer = struct {
attn_norm: ?zml.nn.LayerNorm = null,
attn: ModernBertAttention,
mlp_norm: zml.nn.LayerNorm,
mlp: ModernBertMLP,
pub fn forward(
self: ModernBertEncoderLayer,
hidden_states: Tensor,
global_mask: Tensor,
local_mask: Tensor,
) Tensor {
const attn_norm_output = if (self.attn_norm) |attn_norm|
zml.call(attn_norm, .forward, .{hidden_states})
else
hidden_states;
const attn_output: Tensor = zml.call(self.attn, .forward, .{
attn_norm_output,
global_mask,
local_mask,
});
var output = hidden_states.add(attn_output);
const mlp_norm_output: Tensor = zml.call(self.mlp_norm, .forward, .{output});
const mlp_output = zml.call(self.mlp, .forward, .{mlp_norm_output});
output = output.add(mlp_output);
return output;
}
};
/// Performs multi-headed self attention on a batch of unpadded sequences.
///
/// If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
/// If Flash Attention 2 is not installed, the implementation will use SDPA,
pub const ModernBertAttention = struct {
Wqkv: zml.nn.Linear,
Wo: zml.nn.Linear,
is_global_attention: bool = false,
num_heads: i64 = undefined,
/// sdpa_attention_forward
pub fn forward(
self: ModernBertAttention,
hidden_states: Tensor,
global_mask: Tensor,
local_mask: Tensor,
) Tensor {
const batch_size = hidden_states.shape().dim(0);
const seq_length = hidden_states.shape().dim(1);
const hidden_size = hidden_states.shape().dim(2);
const num_heads = self.num_heads;
const head_dim = @divExact(hidden_size, num_heads);
// Project to query, key, value - { batch_size, seq_len, 3 * num_heads * head_dim }
var qkv: Tensor = zml.call(self.Wqkv, .forward, .{hidden_states});
// Reshape to { batch_size, seq_len, 3, num_heads, head_dim }
qkv = qkv.reshape(.{ batch_size, seq_length, 3, num_heads, head_dim }).withTags(.{ .b, .s, .chunk, .h, .hd });
// Split into query, key, value tensors - each { batch_size, seq_length, num_heads, head_dim }
var q, var k, var v = qkv.chunkExact(.chunk, 3);
q = q.squeeze(.chunk);
k = k.squeeze(.chunk);
v = v.squeeze(.chunk);
// Apply rotary position embeddings (RoPE)
// Layer 0, 3, 6, 9, 12 ... use global RoPE
// Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE
const rope_opts = zml.nn.RopeOpts{
.impl = .sequential,
.freq_base = if (self.is_global_attention) 160_000 else 10_000,
};
q = zml.nn.rope(q, null, rope_opts);
k = zml.nn.rope(k, null, rope_opts);
// rename dimensions for sdpa
q = q.rename(.{ .s = .q });
k = k.rename(.{ .s = .k });
v = v.rename(.{ .s = .k });
// Scaled dot product attention
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = if (self.is_global_attention) global_mask else local_mask });
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
// Final projection
return zml.call(self.Wo, .forward, .{attn});
}
};
/// Switch out the old MLP layers for GeGLU layers, improving on the original BERTs GeLU activation function.
///
/// The GeGLU activation function is a combination of the Gated Linear Unit (GLU) and the Gaussian Error Linear Unit (GeLU).
///
/// see: https://paperswithcode.com/method/geglu
pub const ModernBertMLP = struct {
Wi: zml.nn.Linear,
Wo: zml.nn.Linear,
pub fn forward(self: ModernBertMLP, hidden_states: Tensor) Tensor {
// Perform Wi
const wi_output: Tensor = zml.call(self.Wi, .forward, .{hidden_states});
// Split into input and gate tensors along the last dimension
const input, const gate = wi_output.chunkExact(-1, 2);
// Apply activation
const activated_input = input.gelu().mul(gate);
// Perform Wo
return zml.call(self.Wo, .forward, .{activated_input});
}
};

View File

@ -0,0 +1,238 @@
const clap = @import("clap");
const std = @import("std");
const zml = @import("zml");
const asynk = @import("async");
const log = std.log;
const Tensor = zml.Tensor;
const modernbert_module = @import("modernbert.zig");
const ModernBertOptions = modernbert_module.ModernBertOptions;
const params = clap.parseParamsComptime(
\\--help print this help
\\--model <PATH> model weights path
\\--activations <PATH> model activations path
);
fn printUsageAndExit(stderr: anytype) noreturn {
stderr.print("usage: ", .{}) catch {};
clap.usage(stderr, clap.Help, &params) catch {};
stderr.print("\n", .{}) catch {};
std.process.exit(0);
}
pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
// Short lived allocations
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const stderr = std.io.getStdErr().writer();
// Read CLI arguments
const parsers = comptime .{
.PATH = clap.parsers.string,
};
var diag: clap.Diagnostic = .{};
var res = clap.parse(clap.Help, &params, parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
try diag.report(stderr, err);
try printUsageAndExit(stderr);
};
defer res.deinit();
if (res.args.help != 0) {
try clap.help(stderr, clap.Help, &params, .{});
return;
}
const model_file = res.args.model orelse {
stderr.print("Error: missing --model=...\n\n", .{}) catch {};
printUsageAndExit(stderr);
unreachable;
};
const activations_file = res.args.activations orelse {
stderr.print("Error: missing --activations=...\n\n", .{}) catch {};
printUsageAndExit(stderr);
unreachable;
};
// Initialize the ZML context
var context = try zml.Context.init();
defer context.deinit();
// Auto-select platform
const compute_platform = context.autoPlatform(.{});
log.info("Selected platform: {s}", .{@tagName(compute_platform.target)});
// Create a dedicated memory arena for model-related allocations (dedicated to model shapes and weights)
var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit();
const model_arena = arena_state.allocator();
// Load the model weights file and parse its structure (shape)
var weights_file = try zml.aio.detectFormatAndOpen(allocator, model_file);
defer weights_file.deinit();
log.info("Model contains {d} layers. Loaded from: {s}", .{ weights_file.buffers.count(), model_file });
// Load the activation data file
const activations = try zml.aio.torch.open(model_arena, activations_file);
defer activations.deinit();
log.info("Found {} activations in {s}", .{ activations.buffers.count(), activations_file });
// Initialize model
var model = try zml.aio.populateModel(
modernbert_module.ModernBertForMaskedLM,
model_arena,
weights_file,
);
const modernbert_base_options: modernbert_module.ModernBertOptions = .{
.num_attention_heads = 12,
.tie_word_embeddings = true,
.pad_token = 50283,
.local_attention = 128,
};
model.init(modernbert_base_options);
// Load model weights
const model_weights = try zml.aio.loadModelBuffers(
modernbert_module.ModernBertForMaskedLM,
model,
weights_file,
model_arena,
compute_platform,
);
// Test implementation
try testImplementation(compute_platform, model, model_weights, activations);
}
fn testImplementation(
compute_platform: zml.Platform,
model: modernbert_module.ModernBertForMaskedLM,
model_weights: zml.Bufferized(modernbert_module.ModernBertForMaskedLM),
activations: zml.aio.BufferStore,
) !void {
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.embeddings.tok_embeddings",
model.model.embeddings.tok_embeddings,
model_weights.model.embeddings.tok_embeddings,
1e-6,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.embeddings.norm",
model.model.embeddings.norm,
model_weights.model.embeddings.norm,
1e-3,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.embeddings",
model.model.embeddings,
model_weights.model.embeddings,
1e-3,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.final_norm",
model.model.final_norm,
model_weights.model.final_norm,
1e-5,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.layers.2.mlp",
model.model.layers[2].mlp,
model_weights.model.layers[2].mlp,
2e-3,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.layers.2.mlp_norm",
model.model.layers[2].mlp_norm,
model_weights.model.layers[2].mlp_norm,
1e-4,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.layers.2.attn",
model.model.layers[2].attn,
model_weights.model.layers[2].attn,
1e-6,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.layers.2",
model.model.layers[2],
model_weights.model.layers[2],
2e-3,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model.layers.3.attn",
model.model.layers[3].attn,
model_weights.model.layers[3].attn,
1e-5,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.model",
model.model,
model_weights.model,
1e-2,
);
const TiedDecoder = struct {
weight: Tensor,
bias: Tensor,
pub fn forward(self: @This(), head_outputs: Tensor) Tensor {
const results = head_outputs.withTags(.{ .b, .s, .d }).dot(self.weight.withTags(.{ .voc, .d }), .{.d});
return results.add(self.bias.withTags(.{.voc}).broad(results.shape()));
}
};
try zml.testing.testLayer(
compute_platform,
activations,
"model.decoder",
TiedDecoder{ .weight = model.decoder.weight orelse model.model.embeddings.tok_embeddings.weight, .bias = model.decoder.bias },
.{ .weight = model_weights.model.embeddings.tok_embeddings.weight, .bias = model_weights.decoder.bias },
1e-3,
);
try zml.testing.testLayer(
compute_platform,
activations,
"model.head",
model.head,
model_weights.head,
0.1, // TODO: too high tolerance
);
}