diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index 8444e78..13e5e3a 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -1,8 +1,10 @@ load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") +load("@bazel_skylib//rules:native_binary.bzl", "native_binary") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@zml//bazel:zig.bzl", "zig_cc_binary") + zig_cc_binary( name = "llama", srcs = [ @@ -164,6 +166,22 @@ zig_cc_binary( ], ) +zig_cc_binary( + name = "test_tokenizer", + main = "test_tokenizer.zig", + deps = [ + "//third_party/tigerbeetle:flags", + "@zml//stdx", + "@zml//zml", + ], + # Note: all Llama-3.x tokenizers are the same, + # but using the 3.2-1B version because downloading the tokenizer triggers downloading the model. + args = [ + "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)", + ], + data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"], +) + mtree_spec( name = "mtree", srcs = [":llama"], diff --git a/examples/llama/test_tokenizer.zig b/examples/llama/test_tokenizer.zig new file mode 100644 index 0000000..304a6cc --- /dev/null +++ b/examples/llama/test_tokenizer.zig @@ -0,0 +1,51 @@ +const std = @import("std"); +const log = std.log.scoped(.@"//llama:test_tokenizer"); + +const flags = @import("tigerbeetle/flags"); +const zml = @import("zml"); + +const Flags = struct { + tokenizer: []const u8, + prompt: []const u8 = + \\Examples of titles: + \\📉 Stock Market Trends + \\🍪 Perfect Chocolate Chip Recipe + \\Evolution of Music Streaming + \\Remote Work Productivity Tips + \\Artificial Intelligence in Healthcare + \\🎮 Video Game Development Insights + \\ + , + expected: []const u8 = "128000,41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", + verbose: bool = false, +}; + +pub fn main() !void { + var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}; + const allocator = gpa.allocator(); + + var raw_args = std.process.args(); + const args = flags.parse(&raw_args, Flags); + + log.info("\tLoading tokenizer from {s}", .{args.tokenizer}); + var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, args.tokenizer); + log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer}); + defer tokenizer.deinit(); + + const prompt_tok = try tokenizer.encode(allocator, args.prompt, .{ .debug = args.verbose }); + + log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok }); + if (args.expected.len > 0) { + var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len); + var it = std.mem.splitSequence(u8, args.expected, ","); + while (it.next()) |int_token| { + const tok = try std.fmt.parseInt(u32, int_token, 10); + try expected.append(tok); + } + if (std.mem.eql(u32, expected.items, prompt_tok)) { + log.info("All good !", .{}); + } else { + log.err("Doesn't match expected: {d}", .{expected.items}); + } + } +}