Add Zig example programs for LLaMA, ModernBERT, and SimpleLayer, including a Bazel BUILD file for the LLaMA example.

This commit is contained in:
Foke Singh 2025-07-29 16:07:11 +00:00
parent 488a844a0f
commit 0ed7f5c907
5 changed files with 54 additions and 48 deletions

View File

@ -10,6 +10,7 @@ zig_binary(
"llama.zig", "llama.zig",
], ],
main = "main.zig", main = "main.zig",
visibility = ["//visibility:public"],
deps = [ deps = [
"//async", "//async",
"//stdx", "//stdx",

View File

@ -87,7 +87,7 @@ pub const LlamaLM = struct {
kv_cache: KvCache, kv_cache: KvCache,
rng: Tensor.Rng, rng: Tensor.Rng,
) struct { Tensor, KvCache, Tensor.Rng } { ) struct { Tensor, KvCache, Tensor.Rng } {
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index }); stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {f} and {f}", .{ tokens_, token_index });
const tokens = tokens_.withPartialTags(.{.s}); const tokens = tokens_.withPartialTags(.{.s});
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache }); const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts); const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts);
@ -200,8 +200,8 @@ pub const TransformerLayer = struct {
kv_cache: KvCache, kv_cache: KvCache,
) struct { Tensor, KvCache } { ) struct { Tensor, KvCache } {
// Self Attention // Self Attention
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) }); //log.debug("TransformerLayer({f}) -> {f}", .{ x0, self.input_layernorm.forward(x0) });
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0}); stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0});
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0}); const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache }); const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
@ -336,7 +336,7 @@ pub const KvCache = struct {
return .{ return .{
.k = try zml.Buffer.constant(platform, kv_shape, 1), .k = try zml.Buffer.constant(platform, kv_shape, 1),
.v = try zml.Buffer.constant(platform, kv_shape, 1), .v = try zml.Buffer.constant(platform, kv_shape, 1),
.layer_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0), .layer_index = try zml.Buffer.scalar(platform, 0, .u32),
}; };
} }

View File

@ -1,19 +1,19 @@
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const clap = @import("clap"); const clap = @import("clap");
const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("zml"); const zml = @import("zml");
const Buffer = zml.Buffer;
const Tensor = zml.Tensor;
const ShapeOf = zml.ShapeOf;
const llama = @import("llama.zig"); const llama = @import("llama.zig");
const LlamaLM = llama.LlamaLM; const LlamaLM = llama.LlamaLM;
const Llama = llama.Llama; const Llama = llama.Llama;
const KvCache = llama.KvCache; const KvCache = llama.KvCache;
const TransformerLayer = llama.TransformerLayer; const TransformerLayer = llama.TransformerLayer;
const SelfAttn = llama.SelfAttn; const SelfAttn = llama.SelfAttn;
const Buffer = zml.Buffer;
const Tensor = zml.Tensor;
const ShapeOf = zml.ShapeOf;
const log = std.log.scoped(.llama); const log = std.log.scoped(.llama);
@ -23,7 +23,7 @@ pub const std_options: std.Options = .{
}; };
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 { pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
var tokens = std.ArrayList(u32).init(allocator); var tokens = std.array_list.Managed(u32).init(allocator);
var encoder = try tokenizer.encoder(); var encoder = try tokenizer.encoder();
defer encoder.deinit(); defer encoder.deinit();
@ -101,7 +101,7 @@ pub fn generateText(
defer current_token.deinit(); defer current_token.deinit();
// Here we collect the generated text // Here we collect the generated text
var output = std.ArrayList(u8).init(allocator); var output = std.array_list.Managed(u8).init(allocator);
defer output.deinit(); defer output.deinit();
const output_tokens_len = max_seq_len - prompt_tok.len - 1; const output_tokens_len = max_seq_len - prompt_tok.len - 1;
@ -179,21 +179,22 @@ pub fn asyncMain() !void {
.PATH = clap.parsers.string, .PATH = clap.parsers.string,
}; };
var diag: clap.Diagnostic = .{}; var diag: clap.Diagnostic = .{};
const stderr = std.io.getStdErr().writer(); var stderr_buffer: [1024]u8 = undefined;
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
var res = clap.parse(clap.Help, &params, parsers, .{ var res = clap.parse(clap.Help, &params, parsers, .{
.diagnostic = &diag, .diagnostic = &diag,
.allocator = allocator, .allocator = allocator,
}) catch |err| { }) catch |err| {
diag.report(stderr, err) catch {}; diag.report(&stderr.interface, err) catch {};
stderr.print("usage: ", .{}) catch {}; stderr.interface.print("usage: ", .{}) catch {};
clap.usage(stderr, clap.Help, &params) catch {}; clap.usage(&stderr.interface, clap.Help, &params) catch {};
stderr.print("\n", .{}) catch {}; stderr.interface.print("\n", .{}) catch {};
return; return;
}; };
defer res.deinit(); defer res.deinit();
if (res.args.help != 0) { if (res.args.help != 0) {
clap.help(std.io.getStdErr().writer(), clap.Help, &params, .{}) catch {}; clap.help(&stderr.interface, clap.Help, &params, .{}) catch {};
return; return;
} }
@ -224,7 +225,9 @@ pub fn asyncMain() !void {
const config = blk: { const config = blk: {
var config_json_file = try asynk.File.open(model_config_path, .{ .mode = .read_only }); var config_json_file = try asynk.File.open(model_config_path, .{ .mode = .read_only });
defer config_json_file.close() catch unreachable; defer config_json_file.close() catch unreachable;
var reader = std.json.reader(allocator, config_json_file.reader()); var config_json_buffer: [256]u8 = undefined;
var config_reader = config_json_file.reader(&config_json_buffer);
var reader = std.json.Reader.init(allocator, &config_reader.interface);
defer reader.deinit(); defer reader.deinit();
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true }); const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
break :blk config_obj; break :blk config_obj;
@ -298,16 +301,16 @@ pub fn asyncMain() !void {
platform, platform,
}); });
log.info("\tLoading Llama weights from {?s}...", .{model_weights_path}); log.info("\tLoading Llama weights from {s}...", .{model_weights_path});
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform); var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
defer zml.aio.unloadBuffers(&llama_weights); defer zml.aio.unloadBuffers(&llama_weights);
log.info("\tLoaded weights in {}", .{std.fmt.fmtDuration(start.read())}); log.info("\tLoaded weights in {D}", .{start.read()});
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights); var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
defer llama_module_prefill.deinit(); defer llama_module_prefill.deinit();
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights); var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
defer llama_module.deinit(); defer llama_module.deinit();
log.info("\tCompiled model in {}", .{std.fmt.fmtDuration(start.read())}); log.info("\tCompiled model in {D}", .{start.read()});
log.info("Creating KvCache", .{}); log.info("Creating KvCache", .{});
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform); const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
@ -315,7 +318,7 @@ pub fn asyncMain() !void {
var tokenizer = blk: { var tokenizer = blk: {
log.info("Loading tokenizer from {s}", .{model_tokenizer_path}); log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
var timer = try stdx.time.Timer.start(); var timer = try stdx.time.Timer.start();
defer log.info("Loaded tokenizer from {s} [{}]", .{ model_tokenizer_path, timer.read() }); defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path); break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path);
}; };

View File

@ -1,7 +1,4 @@
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.modernbert);
const modernbert = @import("modernbert.zig");
const asynk = @import("async"); const asynk = @import("async");
const clap = @import("clap"); const clap = @import("clap");
@ -9,6 +6,10 @@ const stdx = @import("stdx");
const zml = @import("zml"); const zml = @import("zml");
const Tensor = zml.Tensor; const Tensor = zml.Tensor;
const modernbert = @import("modernbert.zig");
const log = std.log.scoped(.modernbert);
pub const std_options: std.Options = .{ pub const std_options: std.Options = .{
.log_level = .info, .log_level = .info,
.log_scope_levels = &[_]std.log.ScopeLevel{ .log_scope_levels = &[_]std.log.ScopeLevel{
@ -42,20 +43,20 @@ pub fn main() !void {
pub fn asyncMain() !void { pub fn asyncMain() !void {
const allocator = std.heap.c_allocator; const allocator = std.heap.c_allocator;
const stderr = std.io.getStdErr().writer(); const stderr = std.fs.File.stderr();
var diag: clap.Diagnostic = .{}; var diag: clap.Diagnostic = .{};
var cli = clap.parse(clap.Help, &params, clap_parsers, .{ var cli = clap.parse(clap.Help, &params, clap_parsers, .{
.diagnostic = &diag, .diagnostic = &diag,
.allocator = allocator, .allocator = allocator,
}) catch |err| { }) catch |err| {
try diag.report(stderr, err); try diag.reportToFile(stderr, err);
try printUsageAndExit(stderr); try printUsageAndExit(stderr);
}; };
defer cli.deinit(); defer cli.deinit();
if (cli.args.help != 0) { if (cli.args.help != 0) {
try clap.help(stderr, clap.Help, &params, .{}); try clap.helpToFile(stderr, clap.Help, &params, .{});
return; return;
} }
@ -80,7 +81,9 @@ pub fn asyncMain() !void {
// Detects the format of the model file (base on filename) and open it. // Detects the format of the model file (base on filename) and open it.
const model_file = cli.args.model orelse { const model_file = cli.args.model orelse {
stderr.print("Error: missing --model=...\n\n", .{}) catch {}; var buf: [256]u8 = undefined;
var writer = stderr.writer(&buf);
writer.interface.print("Error: missing --model=...\n\n", .{}) catch {};
printUsageAndExit(stderr); printUsageAndExit(stderr);
unreachable; unreachable;
}; };
@ -96,7 +99,7 @@ pub fn asyncMain() !void {
if (cli.args.tokenizer) |tok| { if (cli.args.tokenizer) |tok| {
log.info("\tLoading tokenizer from {s}", .{tok}); log.info("\tLoading tokenizer from {s}", .{tok});
var timer = try stdx.time.Timer.start(); var timer = try stdx.time.Timer.start();
defer log.info("\tLoaded tokenizer from {s} [{}]", .{ tok, timer.read() }); defer log.info("\tLoaded tokenizer from {s} [{D}]", .{ tok, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok); break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok);
} else { } else {
@ -123,13 +126,13 @@ pub fn asyncMain() !void {
const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256)); const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256));
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32); const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32);
var start = try std.time.Timer.start(); var start = try stdx.time.Timer.start();
// Load weights // Load weights
log.info("\tLoading ModernBERT weights from {?s}...", .{model_file}); log.info("\tLoading ModernBERT weights from {s}...", .{model_file});
var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform); var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform);
defer zml.aio.unloadBuffers(&bert_weights); defer zml.aio.unloadBuffers(&bert_weights);
log.info("\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms}); log.info("\tLoaded weights in {D}", .{start.read()});
// Compile the model // Compile the model
log.info("\tCompiling ModernBERT model...", .{}); log.info("\tCompiling ModernBERT model...", .{});
@ -143,7 +146,7 @@ pub fn asyncMain() !void {
}); });
var bert_module = (try fut_mod.awaitt()).prepare(bert_weights); var bert_module = (try fut_mod.awaitt()).prepare(bert_weights);
defer bert_module.deinit(); defer bert_module.deinit();
log.info("\tLoaded weights and compiled model in {d}ms", .{start.read() / std.time.ns_per_ms}); log.info("\tLoaded weights and compiled model in {D}", .{start.read()});
const text = cli.args.text orelse "Paris is the [MASK] of France."; const text = cli.args.text orelse "Paris is the [MASK] of France.";
log.info("\tInput text: {s}", .{text}); log.info("\tInput text: {s}", .{text});
@ -213,7 +216,7 @@ pub fn unmask(
} }
pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 { pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 {
var tokens = std.ArrayList(u32).init(allocator); var tokens = std.array_list.Managed(u32).init(allocator);
var encoder = try tokenizer.encoder(); var encoder = try tokenizer.encoder();
defer encoder.deinit(); defer encoder.deinit();
@ -228,7 +231,7 @@ pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer
} }
fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize { fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize {
var mask_positions = std.ArrayList(usize).init(allocator); var mask_positions = std.array_list.Managed(usize).init(allocator);
defer mask_positions.deinit(); defer mask_positions.deinit();
for (tokens, 0..) |token, i| { for (tokens, 0..) |token, i| {
@ -267,9 +270,7 @@ fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null; return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
} }
fn printUsageAndExit(stderr: anytype) noreturn { fn printUsageAndExit(stderr: std.fs.File) noreturn {
stderr.print("usage: ", .{}) catch {}; clap.usageToFile(stderr, clap.Help, &params) catch {};
clap.usage(stderr, clap.Help, &params) catch {};
stderr.print("\n", .{}) catch {};
std.process.exit(0); std.process.exit(0);
} }

View File

@ -1,6 +1,7 @@
const std = @import("std"); const std = @import("std");
const zml = @import("zml");
const asynk = @import("async"); const asynk = @import("async");
const zml = @import("zml");
/// Model definition /// Model definition
const Layer = struct { const Layer = struct {
@ -38,9 +39,9 @@ pub fn asyncMain() !void {
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);
// Our weights and bias to use // Our weights and bias to use
var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 }; var weights = [4]f32{ 2.0, 2.0, 2.0, 2.0 };
var bias = [4]f16{ 1.0, 2.0, 3.0, 4.0 }; var bias = [4]f32{ 1.0, 2.0, 3.0, 4.0 };
const input_shape = zml.Shape.init(.{4}, .f16); const input_shape = zml.Shape.init(.{4}, .f32);
// We manually produce a BufferStore. You would not normally do that. // We manually produce a BufferStore. You would not normally do that.
// A BufferStore is usually created by loading model data from a file. // A BufferStore is usually created by loading model data from a file.
@ -80,8 +81,8 @@ pub fn asyncMain() !void {
// prepare an input buffer // prepare an input buffer
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer // Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
// with a specific shape from an array. // with a specific shape from an array.
// For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape. // For situations where e.g. you have an [4]f32 array but need a .{2, 2} input shape.
var input = [4]f16{ 5.0, 5.0, 5.0, 5.0 }; var input = [4]f32{ 5.0, 5.0, 5.0, 5.0 };
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input), .{}); var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input), .{});
defer input_buffer.deinit(); defer input_buffer.deinit();
@ -92,7 +93,7 @@ pub fn asyncMain() !void {
// fetch the result to CPU memory // fetch the result to CPU memory
const cpu_result = try result.toHostAlloc(arena); const cpu_result = try result.toHostAlloc(arena);
std.debug.print( std.debug.print(
"\nThe result of {d} * {d} + {d} = {d}\n", "\nThe result of {any} * {any} + {any} = {any}\n",
.{ &weights, &input, &bias, cpu_result.items(f16) }, .{ &weights, &input, &bias, cpu_result.items(f32) },
); );
} }