Add Bazel build rule and test for Llama3 tokenizer’s byte fallback and unknown token handling.
This commit is contained in:
parent
5120fe00dc
commit
b643f7bc53
@ -1,8 +1,10 @@
|
||||
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
||||
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
||||
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
||||
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
|
||||
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||
|
||||
|
||||
zig_cc_binary(
|
||||
name = "llama",
|
||||
srcs = [
|
||||
@ -164,6 +166,22 @@ zig_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
zig_cc_binary(
|
||||
name = "test_tokenizer",
|
||||
main = "test_tokenizer.zig",
|
||||
deps = [
|
||||
"//third_party/tigerbeetle:flags",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
],
|
||||
# Note: all Llama-3.x tokenizers are the same,
|
||||
# but using the 3.2-1B version because downloading the tokenizer triggers downloading the model.
|
||||
args = [
|
||||
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
||||
],
|
||||
data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"],
|
||||
)
|
||||
|
||||
mtree_spec(
|
||||
name = "mtree",
|
||||
srcs = [":llama"],
|
||||
|
||||
51
examples/llama/test_tokenizer.zig
Normal file
51
examples/llama/test_tokenizer.zig
Normal file
@ -0,0 +1,51 @@
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.@"//llama:test_tokenizer");
|
||||
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const zml = @import("zml");
|
||||
|
||||
const Flags = struct {
|
||||
tokenizer: []const u8,
|
||||
prompt: []const u8 =
|
||||
\\Examples of titles:
|
||||
\\📉 Stock Market Trends
|
||||
\\🍪 Perfect Chocolate Chip Recipe
|
||||
\\Evolution of Music Streaming
|
||||
\\Remote Work Productivity Tips
|
||||
\\Artificial Intelligence in Healthcare
|
||||
\\🎮 Video Game Development Insights
|
||||
\\
|
||||
,
|
||||
expected: []const u8 = "128000,41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
|
||||
verbose: bool = false,
|
||||
};
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{};
|
||||
const allocator = gpa.allocator();
|
||||
|
||||
var raw_args = std.process.args();
|
||||
const args = flags.parse(&raw_args, Flags);
|
||||
|
||||
log.info("\tLoading tokenizer from {s}", .{args.tokenizer});
|
||||
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, args.tokenizer);
|
||||
log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer});
|
||||
defer tokenizer.deinit();
|
||||
|
||||
const prompt_tok = try tokenizer.encode(allocator, args.prompt, .{ .debug = args.verbose });
|
||||
|
||||
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
|
||||
if (args.expected.len > 0) {
|
||||
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
|
||||
var it = std.mem.splitSequence(u8, args.expected, ",");
|
||||
while (it.next()) |int_token| {
|
||||
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
||||
try expected.append(tok);
|
||||
}
|
||||
if (std.mem.eql(u32, expected.items, prompt_tok)) {
|
||||
log.info("All good !", .{});
|
||||
} else {
|
||||
log.err("Doesn't match expected: {d}", .{expected.items});
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user