Add tests for the ModernBERT example, covering activation utilities, build setup, and example Zig code.
This commit is contained in:
parent
ab5ad874c3
commit
17d02621e7
@ -139,6 +139,53 @@ http_file(
|
||||
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
|
||||
)
|
||||
|
||||
# ModernBERT
|
||||
huggingface.model(
|
||||
name = "ModernBERT-base",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
)
|
||||
""",
|
||||
commit = "94032bb66234a691cf6248265170006a7ced4970",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.json",
|
||||
],
|
||||
model = "answerdotai/ModernBERT-base",
|
||||
)
|
||||
use_repo(huggingface, "ModernBERT-base")
|
||||
|
||||
huggingface.model(
|
||||
name = "ModernBERT-large",
|
||||
build_file_content = """\
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
filegroup(
|
||||
name = "model",
|
||||
srcs = ["model.safetensors"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tokenizer",
|
||||
srcs = ["tokenizer.json"],
|
||||
)
|
||||
""",
|
||||
commit = "4bbcbf40bed02ce487125bcb3c897ea9bdc88340",
|
||||
includes = [
|
||||
"model.safetensors",
|
||||
"tokenizer.json",
|
||||
],
|
||||
model = "answerdotai/ModernBERT-large",
|
||||
)
|
||||
use_repo(huggingface, "ModernBERT-large")
|
||||
|
||||
bazel_dep(name = "rules_rust", version = "0.57.1")
|
||||
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
||||
rust.toolchain(
|
||||
|
||||
63
examples/modernbert/BUILD.bazel
Normal file
63
examples/modernbert/BUILD.bazel
Normal file
@ -0,0 +1,63 @@
|
||||
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||
|
||||
zig_cc_binary(
|
||||
name = "modernbert",
|
||||
srcs = ["modernbert.zig"],
|
||||
main = "main.zig",
|
||||
deps = [
|
||||
"@com_github_hejsil_clap//:clap",
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "ModernBERT-base",
|
||||
args = [
|
||||
"--model=$(location @ModernBERT-base//:model.safetensors)",
|
||||
"--tokenizer=$(location @ModernBERT-base//:tokenizer)",
|
||||
"--num-attention-heads=12",
|
||||
"--tie-word-embeddings=true",
|
||||
],
|
||||
data = [
|
||||
"@ModernBERT-base//:model.safetensors",
|
||||
"@ModernBERT-base//:tokenizer",
|
||||
],
|
||||
deps = [":modernbert_lib"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "ModernBERT-large",
|
||||
args = [
|
||||
"--model=$(location @ModernBERT-large//:model.safetensors)",
|
||||
"--tokenizer=$(location @ModernBERT-large//:tokenizer)",
|
||||
"--num-attention-heads=16",
|
||||
"--tie-word-embeddings=true",
|
||||
],
|
||||
data = [
|
||||
"@ModernBERT-large//:model.safetensors",
|
||||
"@ModernBERT-large//:tokenizer",
|
||||
],
|
||||
deps = [":modernbert_lib"],
|
||||
)
|
||||
|
||||
zig_cc_binary(
|
||||
name = "test-implementation",
|
||||
srcs = ["modernbert.zig"],
|
||||
args = [
|
||||
"--model=$(location @ModernBERT-base//:model.safetensors)",
|
||||
],
|
||||
data = [
|
||||
"@ModernBERT-base//:model.safetensors",
|
||||
],
|
||||
main = "test.zig",
|
||||
tags = [
|
||||
"no_ci",
|
||||
],
|
||||
deps = [
|
||||
"@com_github_hejsil_clap//:clap",
|
||||
"@zml//async",
|
||||
"@zml//zml",
|
||||
],
|
||||
)
|
||||
55
examples/modernbert/activations/activations.py
Normal file
55
examples/modernbert/activations/activations.py
Normal file
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
from tools.zml_utils import ActivationCollector
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MODEL_NAME: str = "answerdotai/ModernBERT-base"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
log.info("Start running main()")
|
||||
|
||||
log.info(f"CPU capability : `{torch.backends.cpu.get_cpu_capability()}`")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
log.info(f"Loading model : `{MODEL_NAME}`")
|
||||
|
||||
fill_mask_pipeline = pipeline(
|
||||
"fill-mask",
|
||||
model=MODEL_NAME,
|
||||
device_map=device,
|
||||
)
|
||||
model, tokenizer = fill_mask_pipeline.model, fill_mask_pipeline.tokenizer
|
||||
log.info(
|
||||
f"Model loaded successfully {model.config.architectures} - `{model.config.torch_dtype}` - {tokenizer.model_max_length} max tokens" # noqa: E501
|
||||
)
|
||||
|
||||
# Wrap the pipeline, and extract activations.
|
||||
# Activations files can be huge for big models,
|
||||
# so let's stop collecting after 1000 layers.
|
||||
zml_pipeline = ActivationCollector(
|
||||
fill_mask_pipeline, max_layers=1000, stop_after_first_step=True
|
||||
)
|
||||
|
||||
input_text = "Paris is the [MASK] of France."
|
||||
outputs, activations = zml_pipeline(input_text)
|
||||
log.info(f"ouputs : {outputs}")
|
||||
|
||||
filename = MODEL_NAME.split("/")[-1] + ".activations.pt"
|
||||
torch.save(activations, filename)
|
||||
log.info(f"Saved {len(activations)} activations to {filename}")
|
||||
|
||||
log.info("End running main()")
|
||||
except Exception as exception:
|
||||
log.error(exception)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
examples/modernbert/activations/requirements.in
Normal file
4
examples/modernbert/activations/requirements.in
Normal file
@ -0,0 +1,4 @@
|
||||
torch
|
||||
transformers==4.48.1
|
||||
accelerate
|
||||
numpy==1.26.4
|
||||
275
examples/modernbert/main.zig
Normal file
275
examples/modernbert/main.zig
Normal file
@ -0,0 +1,275 @@
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.modernbert);
|
||||
|
||||
const modernbert = @import("modernbert.zig");
|
||||
|
||||
const asynk = @import("async");
|
||||
const clap = @import("clap");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
const Tensor = zml.Tensor;
|
||||
|
||||
pub const std_options = .{
|
||||
.log_level = .info,
|
||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||
.{ .scope = .modernbert, .level = .info },
|
||||
},
|
||||
.logFn = asynk.logFn(std.log.defaultLog),
|
||||
};
|
||||
|
||||
const params = clap.parseParamsComptime(
|
||||
\\--help print this help
|
||||
\\--text <STRING> the prompt
|
||||
\\--model <PATH> model path
|
||||
\\--tokenizer <PATH> tokenizer path
|
||||
\\--seq-len <UINT> sequence length
|
||||
\\--num-attention-heads <UINT> number of attention heads
|
||||
\\--tie-word-embeddings <BOOL> default: false: tied weights
|
||||
\\--create-options <STRING> platform creation options JSON, defaults to {}
|
||||
\\--sharding <BOOL> default: true: sharding on or off
|
||||
);
|
||||
|
||||
const clap_parsers = .{
|
||||
.BOOL = bool_parser,
|
||||
.UINT = clap.parsers.int(usize, 0),
|
||||
.STRING = clap.parsers.string,
|
||||
.PATH = clap.parsers.string,
|
||||
};
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
const allocator = std.heap.c_allocator;
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
|
||||
var diag: clap.Diagnostic = .{};
|
||||
var cli = clap.parse(clap.Help, ¶ms, clap_parsers, .{
|
||||
.diagnostic = &diag,
|
||||
.allocator = allocator,
|
||||
}) catch |err| {
|
||||
try diag.report(stderr, err);
|
||||
try printUsageAndExit(stderr);
|
||||
};
|
||||
defer cli.deinit();
|
||||
|
||||
if (cli.args.help != 0) {
|
||||
try clap.help(stderr, clap.Help, ¶ms, .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
|
||||
try tmp.makePath("zml/modernbert/cache");
|
||||
|
||||
// Create ZML context
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
// Platform and compilation options
|
||||
const create_opts_json = cli.args.@"create-options" orelse "{}";
|
||||
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, allocator, create_opts_json, .{});
|
||||
const compilation_options = zml.CompilationOptions{
|
||||
.xla_dump_to = "/tmp/zml/modernbert",
|
||||
.sharding_enabled = cli.args.sharding orelse true,
|
||||
};
|
||||
|
||||
// Auto-select platform
|
||||
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Detects the format of the model file (base on filename) and open it.
|
||||
const model_file = cli.args.model orelse {
|
||||
stderr.print("Error: missing --model=...\n\n", .{}) catch {};
|
||||
printUsageAndExit(stderr);
|
||||
unreachable;
|
||||
};
|
||||
var tensor_store = try zml.aio.detectFormatAndOpen(allocator, model_file);
|
||||
defer tensor_store.deinit();
|
||||
|
||||
// Memory arena dedicated to model shapes and weights
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const model_arena = arena_state.allocator();
|
||||
|
||||
var tokenizer = blk: {
|
||||
if (cli.args.tokenizer) |tok| {
|
||||
log.info("\tLoading tokenizer from {s}", .{tok});
|
||||
var timer = try stdx.time.Timer.start();
|
||||
defer log.info("✅\tLoaded tokenizer from {s} [{}]", .{ tok, timer.read() });
|
||||
|
||||
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok);
|
||||
} else {
|
||||
log.err("Error: missing --tokenizer", .{});
|
||||
return;
|
||||
}
|
||||
};
|
||||
defer tokenizer.deinit();
|
||||
|
||||
// Create the model struct, with tensor shapes extracted from the tensor_store
|
||||
// TODO: read from config.json
|
||||
const modernbert_options = modernbert.ModernBertOptions{
|
||||
.pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken,
|
||||
.num_attention_heads = @intCast(cli.args.@"num-attention-heads" orelse 12),
|
||||
.tie_word_embeddings = cli.args.@"tie-word-embeddings" orelse false,
|
||||
.local_attention = 128,
|
||||
};
|
||||
var modern_bert_for_masked_lm = try zml.aio.populateModel(modernbert.ModernBertForMaskedLM, model_arena, tensor_store);
|
||||
modern_bert_for_masked_lm.init(modernbert_options);
|
||||
|
||||
log.info("\tModernBERT options: {}", .{modernbert_options});
|
||||
|
||||
// Prepare shapes for compilation
|
||||
const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256));
|
||||
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32);
|
||||
|
||||
var start = try std.time.Timer.start();
|
||||
|
||||
// Load weights
|
||||
log.info("\tLoading ModernBERT weights from {?s}...", .{model_file});
|
||||
var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform);
|
||||
defer zml.aio.unloadBuffers(&bert_weights);
|
||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
// Compile the model
|
||||
log.info("\tCompiling ModernBERT model...", .{});
|
||||
var fut_mod = try asynk.asyncc(zml.compile, .{
|
||||
allocator,
|
||||
modernbert.ModernBertForMaskedLM.forward,
|
||||
.{modernbert_options},
|
||||
.{input_shape},
|
||||
tensor_store,
|
||||
platform,
|
||||
});
|
||||
var bert_module = (try fut_mod.awaitt()).prepare(bert_weights);
|
||||
defer bert_module.deinit();
|
||||
log.info("✅\tLoaded weights and compiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
const text = cli.args.text orelse "Paris is the [MASK] of France.";
|
||||
log.info("\tInput text: {s}", .{text});
|
||||
|
||||
try unmask(allocator, bert_module, tokenizer, seq_len, text);
|
||||
}
|
||||
|
||||
/// fill-mask pipeline
|
||||
/// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/fill_mask.py
|
||||
pub fn unmask(
|
||||
allocator: std.mem.Allocator,
|
||||
mod: zml.ModuleExe(modernbert.ModernBertForMaskedLM.forward),
|
||||
tokenizer: zml.tokenizer.Tokenizer,
|
||||
seq_len: i64,
|
||||
text: []const u8,
|
||||
) !void {
|
||||
var tokenizer_decoder = try tokenizer.decoder();
|
||||
defer tokenizer_decoder.deinit();
|
||||
|
||||
const pad_token = tokenizer.tokenToId("[PAD]") orelse return error.NoSuchToken;
|
||||
const mask_token = tokenizer.tokenToId("[MASK]") orelse return error.NoSuchToken;
|
||||
|
||||
// Tokenize input text
|
||||
const tokens: []const u32 = try tokenize(allocator, tokenizer, text);
|
||||
defer allocator.free(tokens);
|
||||
|
||||
// Find "[MASK]" positions
|
||||
const mask_positions = try findMaskPositions(allocator, tokens, mask_token);
|
||||
defer allocator.free(mask_positions);
|
||||
|
||||
// Prepare input tensors
|
||||
const inputs = try prepareTensorInputs(allocator, tokens, seq_len, pad_token);
|
||||
defer allocator.free(inputs);
|
||||
|
||||
// Create input tensors (on the accelerator)
|
||||
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .i64);
|
||||
const input_ids_tensor = try zml.Buffer.fromSlice(mod.platform(), input_shape.dims(), inputs);
|
||||
defer input_ids_tensor.deinit();
|
||||
|
||||
// Model inference (retrieve indices)
|
||||
var inference_timer = try std.time.Timer.start();
|
||||
var topk = mod.call(.{input_ids_tensor});
|
||||
defer zml.aio.unloadBuffers(&topk);
|
||||
const inference_time = inference_timer.read();
|
||||
|
||||
// Transfer the result to host memory (CPU)
|
||||
var indices_host_buffer = try topk.indices.toHostAlloc(allocator);
|
||||
defer indices_host_buffer.deinit(allocator);
|
||||
var values_host_buffer = try topk.values.toHostAlloc(allocator);
|
||||
defer values_host_buffer.deinit(allocator);
|
||||
|
||||
// We consider only the first occurrence of [MASK], which has five predictions
|
||||
const pred_offset = mask_positions[0] * 5;
|
||||
const predictions = indices_host_buffer.items(i32)[pred_offset..][0..5];
|
||||
const scores = values_host_buffer.items(f32)[pred_offset..][0..5];
|
||||
|
||||
// Log timing information
|
||||
log.info("⏱️\tModel inference in {d}ms", .{inference_time / std.time.ns_per_ms});
|
||||
|
||||
log.info("✅\tTop 5 predictions:", .{});
|
||||
for (predictions, scores) |token_id, score| {
|
||||
const token_text = try tokenizer_decoder.next(@intCast(token_id));
|
||||
if (token_text) |word| {
|
||||
log.info("\t • score: {d:.4} word: '{s}' token: {}", .{ score, word, token_id });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 {
|
||||
var tokens = std.ArrayList(u32).init(allocator);
|
||||
var encoder = try tokenizer.encoder();
|
||||
defer encoder.deinit();
|
||||
|
||||
const bos = tokenizer.tokenToId("[CLS]") orelse return error.NoSuchToken;
|
||||
const eos = tokenizer.tokenToId("[SEP]") orelse return error.NoSuchToken;
|
||||
|
||||
try tokens.append(bos);
|
||||
try tokens.appendSlice(try encoder.encode(prompt));
|
||||
try tokens.append(eos);
|
||||
|
||||
return tokens.toOwnedSlice();
|
||||
}
|
||||
|
||||
fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize {
|
||||
var mask_positions = std.ArrayList(usize).init(allocator);
|
||||
defer mask_positions.deinit();
|
||||
|
||||
for (tokens, 0..) |token, i| {
|
||||
if (token == mask_token) {
|
||||
try mask_positions.append(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.items.len == 0) {
|
||||
log.err("Input text must contains `[MASK]`", .{});
|
||||
return error.InvalidInput;
|
||||
}
|
||||
|
||||
if (mask_positions.items.len > 1) log.warn("Currently only supporting one [MASK] per input", .{});
|
||||
|
||||
return mask_positions.toOwnedSlice();
|
||||
}
|
||||
|
||||
fn prepareTensorInputs(
|
||||
allocator: std.mem.Allocator,
|
||||
tokens: []const u32,
|
||||
seq_len: i64,
|
||||
pad_token: u32,
|
||||
) ![]u32 {
|
||||
const input_ids = try allocator.alloc(u32, @intCast(seq_len));
|
||||
|
||||
@memset(input_ids, pad_token);
|
||||
for (tokens, 0..) |token, i| {
|
||||
input_ids[i] = @intCast(token);
|
||||
}
|
||||
|
||||
return input_ids;
|
||||
}
|
||||
|
||||
fn bool_parser(in: []const u8) error{}!bool {
|
||||
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||
}
|
||||
|
||||
fn printUsageAndExit(stderr: anytype) noreturn {
|
||||
stderr.print("usage: ", .{}) catch {};
|
||||
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
||||
stderr.print("\n", .{}) catch {};
|
||||
std.process.exit(0);
|
||||
}
|
||||
268
examples/modernbert/modernbert.zig
Normal file
268
examples/modernbert/modernbert.zig
Normal file
@ -0,0 +1,268 @@
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.modernbert);
|
||||
|
||||
const asynk = @import("async");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const Tensor = zml.Tensor;
|
||||
|
||||
pub const ModernBertOptions = struct {
|
||||
num_attention_heads: i64,
|
||||
pad_token: u32,
|
||||
local_attention: u32,
|
||||
tie_word_embeddings: bool = false,
|
||||
};
|
||||
|
||||
pub const ModernBertForMaskedLM = struct {
|
||||
model: ModernBertModel,
|
||||
head: ModernBertPredictionHead,
|
||||
decoder: struct { weight: ?zml.Tensor, bias: zml.Tensor },
|
||||
|
||||
pub fn init(self: *ModernBertForMaskedLM, options: ModernBertOptions) void {
|
||||
self.model.init(options);
|
||||
self.head.norm.eps = 1e-5;
|
||||
|
||||
self.head.dense.weight = self.head.dense.weight.withSharding(.{0});
|
||||
|
||||
if (options.tie_word_embeddings == true) {
|
||||
self.decoder.weight = null;
|
||||
} else if (self.decoder.weight) |decoder_weight| {
|
||||
self.decoder.weight = decoder_weight.withSharding(.{1});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(self: ModernBertForMaskedLM, input_ids: Tensor) zml.Tensor.ArgMaxRes {
|
||||
const outputs: Tensor = zml.call(self.model, .forward, .{input_ids});
|
||||
const head_outputs: Tensor = zml.call(self.head, .forward, .{outputs});
|
||||
|
||||
// either use decoder or tied weights
|
||||
const decoder_weights = self.decoder.weight orelse self.model.embeddings.tok_embeddings.weight;
|
||||
|
||||
const logits = head_outputs.withTags(.{ .b, .s, .d }).dot(decoder_weights.withTags(.{ .voc, .d }), .{.d});
|
||||
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
|
||||
|
||||
const probabilities = biased_logits.softmax(.voc);
|
||||
return probabilities.topK(5, .voc, .{ .descending = true });
|
||||
}
|
||||
};
|
||||
|
||||
pub const ModernBertModel = struct {
|
||||
options: ModernBertOptions,
|
||||
embeddings: ModernBertEmbeddings,
|
||||
layers: []ModernBertEncoderLayer,
|
||||
final_norm: zml.nn.LayerNorm,
|
||||
|
||||
pub fn init(self: *ModernBertModel, options: ModernBertOptions) void {
|
||||
self.options = options;
|
||||
self.final_norm.eps = 1e-5;
|
||||
for (self.layers, 0..) |*encoder_layer, layer_idx| {
|
||||
encoder_layer.attn.Wqkv.weight = encoder_layer.attn.Wqkv.weight.withSharding(.{0});
|
||||
encoder_layer.attn.Wo.weight = encoder_layer.attn.Wo.weight.withSharding(.{1});
|
||||
|
||||
encoder_layer.mlp.Wi.weight = encoder_layer.mlp.Wi.weight.withSharding(.{0});
|
||||
encoder_layer.mlp.Wo.weight = encoder_layer.mlp.Wo.weight.withSharding(.{1});
|
||||
|
||||
if (encoder_layer.attn_norm) |*norm| norm.eps = 1e-5;
|
||||
encoder_layer.mlp_norm.eps = 1e-5;
|
||||
encoder_layer.attn.is_global_attention = (layer_idx % 3 == 0);
|
||||
encoder_layer.attn.num_heads = options.num_attention_heads;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(self: ModernBertModel, input_ids: Tensor) Tensor {
|
||||
var hidden_states: Tensor = zml.call(self.embeddings, .forward, .{input_ids}).withTags(.{ .b, .src, .d });
|
||||
|
||||
const global_mask = globalAttnMask(input_ids, hidden_states.dtype(), self.options.pad_token);
|
||||
const local_mask = localAttnMask(global_mask, self.options.local_attention);
|
||||
|
||||
// Process through all encoder layers
|
||||
for (self.layers) |encoder_layer| {
|
||||
hidden_states = zml.call(encoder_layer, .forward, .{
|
||||
hidden_states,
|
||||
global_mask,
|
||||
local_mask,
|
||||
});
|
||||
}
|
||||
|
||||
// Final layer normalization
|
||||
hidden_states = zml.call(self.final_norm, .forward, .{hidden_states});
|
||||
|
||||
return hidden_states;
|
||||
}
|
||||
|
||||
/// Find [PAD] tokens in inputs, and assign them a -inf attention mask.
|
||||
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
|
||||
pub fn globalAttnMask(input_ids: Tensor, dt: zml.DataType, pad_token: u32) Tensor {
|
||||
const ids = input_ids.withTags(.{ .b, .k });
|
||||
|
||||
// Mask keys where corresponding token is [PAD]
|
||||
const padding = ids.cmp(.EQ, Tensor.scalar(pad_token, ids.dtype()));
|
||||
const pad_mask = padding.select(Tensor.constant(.{}, dt.minValue()), Tensor.constant(.{}, dt.zero()));
|
||||
|
||||
// Broadcast to the desired output shape.
|
||||
const seq_len = ids.dim(.k);
|
||||
const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt);
|
||||
return pad_mask.broad(pad_mask_shape).print();
|
||||
}
|
||||
|
||||
/// Restrict global attn mask to a sliding window.
|
||||
/// Output shapes follows zml.nn.sdpa convention: .{ .b, .q, .k }
|
||||
pub fn localAttnMask(global_mask: Tensor, window_size: u32) Tensor {
|
||||
const mask_shape = global_mask.shape();
|
||||
|
||||
// Calculate distance between positions
|
||||
const rows = Tensor.iota(mask_shape, .q);
|
||||
const cols = Tensor.iota(mask_shape, .k);
|
||||
const distance = rows.sub(cols).abs();
|
||||
|
||||
// Note: we divide by two because the BERT local attention is symetric around the query token.
|
||||
// Create sliding window mask (1 for positions within window, 0 outside)
|
||||
const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32));
|
||||
const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue());
|
||||
return window_mask.select(global_mask, minus_inf).print();
|
||||
}
|
||||
};
|
||||
|
||||
pub const ModernBertPredictionHead = struct {
|
||||
dense: zml.nn.Linear,
|
||||
norm: zml.nn.LayerNorm,
|
||||
|
||||
pub fn forward(self: ModernBertPredictionHead, hidden_states: Tensor) Tensor {
|
||||
const dense_output: Tensor = zml.call(self.dense, .forward, .{hidden_states});
|
||||
|
||||
const activated_output = dense_output.gelu();
|
||||
|
||||
return zml.call(self.norm, .forward, .{activated_output});
|
||||
}
|
||||
};
|
||||
|
||||
pub const ModernBertEmbeddings = struct {
|
||||
tok_embeddings: zml.nn.TokenEmbedding,
|
||||
norm: zml.nn.LayerNorm,
|
||||
|
||||
pub fn forward(self: ModernBertEmbeddings, input_ids: Tensor) Tensor {
|
||||
// Perform tok_embeddings
|
||||
const hidden_states = zml.call(self.tok_embeddings, .forward, .{input_ids});
|
||||
|
||||
// Perform norm
|
||||
return zml.call(self.norm, .forward, .{hidden_states});
|
||||
}
|
||||
};
|
||||
|
||||
pub const ModernBertEncoderLayer = struct {
|
||||
attn_norm: ?zml.nn.LayerNorm = null,
|
||||
attn: ModernBertAttention,
|
||||
mlp_norm: zml.nn.LayerNorm,
|
||||
mlp: ModernBertMLP,
|
||||
|
||||
pub fn forward(
|
||||
self: ModernBertEncoderLayer,
|
||||
hidden_states: Tensor,
|
||||
global_mask: Tensor,
|
||||
local_mask: Tensor,
|
||||
) Tensor {
|
||||
const attn_norm_output = if (self.attn_norm) |attn_norm|
|
||||
zml.call(attn_norm, .forward, .{hidden_states})
|
||||
else
|
||||
hidden_states;
|
||||
|
||||
const attn_output: Tensor = zml.call(self.attn, .forward, .{
|
||||
attn_norm_output,
|
||||
global_mask,
|
||||
local_mask,
|
||||
});
|
||||
|
||||
var output = hidden_states.add(attn_output);
|
||||
|
||||
const mlp_norm_output: Tensor = zml.call(self.mlp_norm, .forward, .{output});
|
||||
const mlp_output = zml.call(self.mlp, .forward, .{mlp_norm_output});
|
||||
output = output.add(mlp_output);
|
||||
|
||||
return output;
|
||||
}
|
||||
};
|
||||
|
||||
/// Performs multi-headed self attention on a batch of unpadded sequences.
|
||||
///
|
||||
/// If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
|
||||
/// If Flash Attention 2 is not installed, the implementation will use SDPA,
|
||||
pub const ModernBertAttention = struct {
|
||||
Wqkv: zml.nn.Linear,
|
||||
Wo: zml.nn.Linear,
|
||||
is_global_attention: bool = false,
|
||||
num_heads: i64 = undefined,
|
||||
|
||||
/// sdpa_attention_forward
|
||||
pub fn forward(
|
||||
self: ModernBertAttention,
|
||||
hidden_states: Tensor,
|
||||
global_mask: Tensor,
|
||||
local_mask: Tensor,
|
||||
) Tensor {
|
||||
const batch_size = hidden_states.shape().dim(0);
|
||||
const seq_length = hidden_states.shape().dim(1);
|
||||
const hidden_size = hidden_states.shape().dim(2);
|
||||
const num_heads = self.num_heads;
|
||||
const head_dim = @divExact(hidden_size, num_heads);
|
||||
|
||||
// Project to query, key, value - { batch_size, seq_len, 3 * num_heads * head_dim }
|
||||
var qkv: Tensor = zml.call(self.Wqkv, .forward, .{hidden_states});
|
||||
|
||||
// Reshape to { batch_size, seq_len, 3, num_heads, head_dim }
|
||||
qkv = qkv.reshape(.{ batch_size, seq_length, 3, num_heads, head_dim }).withTags(.{ .b, .s, .chunk, .h, .hd });
|
||||
|
||||
// Split into query, key, value tensors - each { batch_size, seq_length, num_heads, head_dim }
|
||||
var q, var k, var v = qkv.chunkExact(.chunk, 3);
|
||||
q = q.squeeze(.chunk);
|
||||
k = k.squeeze(.chunk);
|
||||
v = v.squeeze(.chunk);
|
||||
|
||||
// Apply rotary position embeddings (RoPE)
|
||||
// Layer 0, 3, 6, 9, 12 ... use global RoPE
|
||||
// Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE
|
||||
const rope_opts = zml.nn.RopeOpts{
|
||||
.impl = .sequential,
|
||||
.freq_base = if (self.is_global_attention) 160_000 else 10_000,
|
||||
};
|
||||
|
||||
q = zml.nn.rope(q, null, rope_opts);
|
||||
k = zml.nn.rope(k, null, rope_opts);
|
||||
|
||||
// rename dimensions for sdpa
|
||||
q = q.rename(.{ .s = .q });
|
||||
k = k.rename(.{ .s = .k });
|
||||
v = v.rename(.{ .s = .k });
|
||||
|
||||
// Scaled dot product attention
|
||||
const attn_output = zml.nn.sdpa(q, k, v, .{ .attn_mask = if (self.is_global_attention) global_mask else local_mask });
|
||||
const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s });
|
||||
|
||||
// Final projection
|
||||
return zml.call(self.Wo, .forward, .{attn});
|
||||
}
|
||||
};
|
||||
|
||||
/// Switch out the old MLP layers for GeGLU layers, improving on the original BERT’s GeLU activation function.
|
||||
///
|
||||
/// The GeGLU activation function is a combination of the Gated Linear Unit (GLU) and the Gaussian Error Linear Unit (GeLU).
|
||||
///
|
||||
/// see: https://paperswithcode.com/method/geglu
|
||||
pub const ModernBertMLP = struct {
|
||||
Wi: zml.nn.Linear,
|
||||
Wo: zml.nn.Linear,
|
||||
|
||||
pub fn forward(self: ModernBertMLP, hidden_states: Tensor) Tensor {
|
||||
// Perform Wi
|
||||
const wi_output: Tensor = zml.call(self.Wi, .forward, .{hidden_states});
|
||||
|
||||
// Split into input and gate tensors along the last dimension
|
||||
const input, const gate = wi_output.chunkExact(-1, 2);
|
||||
|
||||
// Apply activation
|
||||
const activated_input = input.gelu().mul(gate);
|
||||
|
||||
// Perform Wo
|
||||
return zml.call(self.Wo, .forward, .{activated_input});
|
||||
}
|
||||
};
|
||||
238
examples/modernbert/test.zig
Normal file
238
examples/modernbert/test.zig
Normal file
@ -0,0 +1,238 @@
|
||||
const clap = @import("clap");
|
||||
const std = @import("std");
|
||||
const zml = @import("zml");
|
||||
const asynk = @import("async");
|
||||
const log = std.log;
|
||||
const Tensor = zml.Tensor;
|
||||
const modernbert_module = @import("modernbert.zig");
|
||||
const ModernBertOptions = modernbert_module.ModernBertOptions;
|
||||
|
||||
const params = clap.parseParamsComptime(
|
||||
\\--help print this help
|
||||
\\--model <PATH> model weights path
|
||||
\\--activations <PATH> model activations path
|
||||
);
|
||||
|
||||
fn printUsageAndExit(stderr: anytype) noreturn {
|
||||
stderr.print("usage: ", .{}) catch {};
|
||||
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
||||
stderr.print("\n", .{}) catch {};
|
||||
std.process.exit(0);
|
||||
}
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
// Short lived allocations
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
|
||||
const allocator = gpa.allocator();
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
|
||||
// Read CLI arguments
|
||||
const parsers = comptime .{
|
||||
.PATH = clap.parsers.string,
|
||||
};
|
||||
var diag: clap.Diagnostic = .{};
|
||||
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
||||
.diagnostic = &diag,
|
||||
.allocator = allocator,
|
||||
}) catch |err| {
|
||||
try diag.report(stderr, err);
|
||||
try printUsageAndExit(stderr);
|
||||
};
|
||||
defer res.deinit();
|
||||
|
||||
if (res.args.help != 0) {
|
||||
try clap.help(stderr, clap.Help, ¶ms, .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const model_file = res.args.model orelse {
|
||||
stderr.print("Error: missing --model=...\n\n", .{}) catch {};
|
||||
printUsageAndExit(stderr);
|
||||
unreachable;
|
||||
};
|
||||
const activations_file = res.args.activations orelse {
|
||||
stderr.print("Error: missing --activations=...\n\n", .{}) catch {};
|
||||
printUsageAndExit(stderr);
|
||||
unreachable;
|
||||
};
|
||||
|
||||
// Initialize the ZML context
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
// Auto-select platform
|
||||
const compute_platform = context.autoPlatform(.{});
|
||||
log.info("Selected platform: {s}", .{@tagName(compute_platform.target)});
|
||||
|
||||
// Create a dedicated memory arena for model-related allocations (dedicated to model shapes and weights)
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const model_arena = arena_state.allocator();
|
||||
|
||||
// Load the model weights file and parse its structure (shape)
|
||||
var weights_file = try zml.aio.detectFormatAndOpen(allocator, model_file);
|
||||
defer weights_file.deinit();
|
||||
log.info("Model contains {d} layers. Loaded from: {s}", .{ weights_file.buffers.count(), model_file });
|
||||
|
||||
// Load the activation data file
|
||||
const activations = try zml.aio.torch.open(model_arena, activations_file);
|
||||
defer activations.deinit();
|
||||
log.info("Found {} activations in {s}", .{ activations.buffers.count(), activations_file });
|
||||
|
||||
// Initialize model
|
||||
var model = try zml.aio.populateModel(
|
||||
modernbert_module.ModernBertForMaskedLM,
|
||||
model_arena,
|
||||
weights_file,
|
||||
);
|
||||
|
||||
const modernbert_base_options: modernbert_module.ModernBertOptions = .{
|
||||
.num_attention_heads = 12,
|
||||
.tie_word_embeddings = true,
|
||||
.pad_token = 50283,
|
||||
.local_attention = 128,
|
||||
};
|
||||
model.init(modernbert_base_options);
|
||||
|
||||
// Load model weights
|
||||
const model_weights = try zml.aio.loadModelBuffers(
|
||||
modernbert_module.ModernBertForMaskedLM,
|
||||
model,
|
||||
weights_file,
|
||||
model_arena,
|
||||
compute_platform,
|
||||
);
|
||||
|
||||
// Test implementation
|
||||
try testImplementation(compute_platform, model, model_weights, activations);
|
||||
}
|
||||
|
||||
fn testImplementation(
|
||||
compute_platform: zml.Platform,
|
||||
model: modernbert_module.ModernBertForMaskedLM,
|
||||
model_weights: zml.Bufferized(modernbert_module.ModernBertForMaskedLM),
|
||||
activations: zml.aio.BufferStore,
|
||||
) !void {
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.embeddings.tok_embeddings",
|
||||
model.model.embeddings.tok_embeddings,
|
||||
model_weights.model.embeddings.tok_embeddings,
|
||||
1e-6,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.embeddings.norm",
|
||||
model.model.embeddings.norm,
|
||||
model_weights.model.embeddings.norm,
|
||||
1e-3,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.embeddings",
|
||||
model.model.embeddings,
|
||||
model_weights.model.embeddings,
|
||||
1e-3,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.final_norm",
|
||||
model.model.final_norm,
|
||||
model_weights.model.final_norm,
|
||||
1e-5,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.layers.2.mlp",
|
||||
model.model.layers[2].mlp,
|
||||
model_weights.model.layers[2].mlp,
|
||||
2e-3,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.layers.2.mlp_norm",
|
||||
model.model.layers[2].mlp_norm,
|
||||
model_weights.model.layers[2].mlp_norm,
|
||||
1e-4,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.layers.2.attn",
|
||||
model.model.layers[2].attn,
|
||||
model_weights.model.layers[2].attn,
|
||||
1e-6,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.layers.2",
|
||||
model.model.layers[2],
|
||||
model_weights.model.layers[2],
|
||||
2e-3,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model.layers.3.attn",
|
||||
model.model.layers[3].attn,
|
||||
model_weights.model.layers[3].attn,
|
||||
1e-5,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.model",
|
||||
model.model,
|
||||
model_weights.model,
|
||||
1e-2,
|
||||
);
|
||||
|
||||
const TiedDecoder = struct {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
|
||||
pub fn forward(self: @This(), head_outputs: Tensor) Tensor {
|
||||
const results = head_outputs.withTags(.{ .b, .s, .d }).dot(self.weight.withTags(.{ .voc, .d }), .{.d});
|
||||
return results.add(self.bias.withTags(.{.voc}).broad(results.shape()));
|
||||
}
|
||||
};
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.decoder",
|
||||
TiedDecoder{ .weight = model.decoder.weight orelse model.model.embeddings.tok_embeddings.weight, .bias = model.decoder.bias },
|
||||
.{ .weight = model_weights.model.embeddings.tok_embeddings.weight, .bias = model_weights.decoder.bias },
|
||||
1e-3,
|
||||
);
|
||||
|
||||
try zml.testing.testLayer(
|
||||
compute_platform,
|
||||
activations,
|
||||
"model.head",
|
||||
model.head,
|
||||
model_weights.head,
|
||||
0.1, // TODO: too high tolerance
|
||||
);
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user