Add Bazel build rule and test for Llama3 tokenizer’s byte fallback and unknown token handling.

This commit is contained in:
Foke Singh 2024-02-02 10:25:48 +00:00
parent 5120fe00dc
commit b643f7bc53
2 changed files with 69 additions and 0 deletions

View File

@ -1,8 +1,10 @@
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
zig_cc_binary(
name = "llama",
srcs = [
@ -164,6 +166,22 @@ zig_cc_binary(
],
)
zig_cc_binary(
name = "test_tokenizer",
main = "test_tokenizer.zig",
deps = [
"//third_party/tigerbeetle:flags",
"@zml//stdx",
"@zml//zml",
],
# Note: all Llama-3.x tokenizers are the same,
# but using the 3.2-1B version because downloading the tokenizer triggers downloading the model.
args = [
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
],
data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"],
)
mtree_spec(
name = "mtree",
srcs = [":llama"],

View File

@ -0,0 +1,51 @@
const std = @import("std");
const log = std.log.scoped(.@"//llama:test_tokenizer");
const flags = @import("tigerbeetle/flags");
const zml = @import("zml");
const Flags = struct {
tokenizer: []const u8,
prompt: []const u8 =
\\Examples of titles:
\\📉 Stock Market Trends
\\🍪 Perfect Chocolate Chip Recipe
\\Evolution of Music Streaming
\\Remote Work Productivity Tips
\\Artificial Intelligence in Healthcare
\\🎮 Video Game Development Insights
\\
,
expected: []const u8 = "128000,41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
verbose: bool = false,
};
pub fn main() !void {
var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{};
const allocator = gpa.allocator();
var raw_args = std.process.args();
const args = flags.parse(&raw_args, Flags);
log.info("\tLoading tokenizer from {s}", .{args.tokenizer});
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, args.tokenizer);
log.info("\tLoaded tokenizer from {s}", .{args.tokenizer});
defer tokenizer.deinit();
const prompt_tok = try tokenizer.encode(allocator, args.prompt, .{ .debug = args.verbose });
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
if (args.expected.len > 0) {
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
var it = std.mem.splitSequence(u8, args.expected, ",");
while (it.next()) |int_token| {
const tok = try std.fmt.parseInt(u32, int_token, 10);
try expected.append(tok);
}
if (std.mem.eql(u32, expected.items, prompt_tok)) {
log.info("All good !", .{});
} else {
log.err("Doesn't match expected: {d}", .{expected.items});
}
}
}