Add Bazel build rule and test for Llama3 tokenizer’s byte fallback and unknown token handling.
This commit is contained in:
parent
5120fe00dc
commit
b643f7bc53
@ -1,8 +1,10 @@
|
|||||||
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
|
||||||
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
|
||||||
|
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
|
||||||
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
|
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
|
||||||
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||||
|
|
||||||
|
|
||||||
zig_cc_binary(
|
zig_cc_binary(
|
||||||
name = "llama",
|
name = "llama",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -164,6 +166,22 @@ zig_cc_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
zig_cc_binary(
|
||||||
|
name = "test_tokenizer",
|
||||||
|
main = "test_tokenizer.zig",
|
||||||
|
deps = [
|
||||||
|
"//third_party/tigerbeetle:flags",
|
||||||
|
"@zml//stdx",
|
||||||
|
"@zml//zml",
|
||||||
|
],
|
||||||
|
# Note: all Llama-3.x tokenizers are the same,
|
||||||
|
# but using the 3.2-1B version because downloading the tokenizer triggers downloading the model.
|
||||||
|
args = [
|
||||||
|
"--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)",
|
||||||
|
],
|
||||||
|
data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"],
|
||||||
|
)
|
||||||
|
|
||||||
mtree_spec(
|
mtree_spec(
|
||||||
name = "mtree",
|
name = "mtree",
|
||||||
srcs = [":llama"],
|
srcs = [":llama"],
|
||||||
|
|||||||
51
examples/llama/test_tokenizer.zig
Normal file
51
examples/llama/test_tokenizer.zig
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const log = std.log.scoped(.@"//llama:test_tokenizer");
|
||||||
|
|
||||||
|
const flags = @import("tigerbeetle/flags");
|
||||||
|
const zml = @import("zml");
|
||||||
|
|
||||||
|
const Flags = struct {
|
||||||
|
tokenizer: []const u8,
|
||||||
|
prompt: []const u8 =
|
||||||
|
\\Examples of titles:
|
||||||
|
\\📉 Stock Market Trends
|
||||||
|
\\🍪 Perfect Chocolate Chip Recipe
|
||||||
|
\\Evolution of Music Streaming
|
||||||
|
\\Remote Work Productivity Tips
|
||||||
|
\\Artificial Intelligence in Healthcare
|
||||||
|
\\🎮 Video Game Development Insights
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
expected: []const u8 = "128000,41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
|
||||||
|
verbose: bool = false,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn main() !void {
|
||||||
|
var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{};
|
||||||
|
const allocator = gpa.allocator();
|
||||||
|
|
||||||
|
var raw_args = std.process.args();
|
||||||
|
const args = flags.parse(&raw_args, Flags);
|
||||||
|
|
||||||
|
log.info("\tLoading tokenizer from {s}", .{args.tokenizer});
|
||||||
|
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, args.tokenizer);
|
||||||
|
log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer});
|
||||||
|
defer tokenizer.deinit();
|
||||||
|
|
||||||
|
const prompt_tok = try tokenizer.encode(allocator, args.prompt, .{ .debug = args.verbose });
|
||||||
|
|
||||||
|
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
|
||||||
|
if (args.expected.len > 0) {
|
||||||
|
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
|
||||||
|
var it = std.mem.splitSequence(u8, args.expected, ",");
|
||||||
|
while (it.next()) |int_token| {
|
||||||
|
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
||||||
|
try expected.append(tok);
|
||||||
|
}
|
||||||
|
if (std.mem.eql(u32, expected.items, prompt_tok)) {
|
||||||
|
log.info("All good !", .{});
|
||||||
|
} else {
|
||||||
|
log.err("Doesn't match expected: {d}", .{expected.items});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user