From b643f7bc539841e68359c7b7c73b296037800e20 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Fri, 2 Feb 2024 10:25:48 +0000 Subject: [PATCH] =?UTF-8?q?Add=20Bazel=20build=20rule=20and=20test=20for?= =?UTF-8?q?=20Llama3=20tokenizer=E2=80=99s=20byte=20fallback=20and=20unkno?= =?UTF-8?q?wn=20token=20handling.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/llama/BUILD.bazel | 18 +++++++++++ examples/llama/test_tokenizer.zig | 51 +++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 examples/llama/test_tokenizer.zig diff --git a/examples/llama/BUILD.bazel b/examples/llama/BUILD.bazel index 8444e78..13e5e3a 100644 --- a/examples/llama/BUILD.bazel +++ b/examples/llama/BUILD.bazel @@ -1,8 +1,10 @@ load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup") +load("@bazel_skylib//rules:native_binary.bzl", "native_binary") load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load("@zml//bazel:zig.bzl", "zig_cc_binary") + zig_cc_binary( name = "llama", srcs = [ @@ -164,6 +166,22 @@ zig_cc_binary( ], ) +zig_cc_binary( + name = "test_tokenizer", + main = "test_tokenizer.zig", + deps = [ + "//third_party/tigerbeetle:flags", + "@zml//stdx", + "@zml//zml", + ], + # Note: all Llama-3.x tokenizers are the same, + # but using the 3.2-1B version because downloading the tokenizer triggers downloading the model. + args = [ + "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer)", + ], + data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer"], +) + mtree_spec( name = "mtree", srcs = [":llama"], diff --git a/examples/llama/test_tokenizer.zig b/examples/llama/test_tokenizer.zig new file mode 100644 index 0000000..304a6cc --- /dev/null +++ b/examples/llama/test_tokenizer.zig @@ -0,0 +1,51 @@ +const std = @import("std"); +const log = std.log.scoped(.@"//llama:test_tokenizer"); + +const flags = @import("tigerbeetle/flags"); +const zml = @import("zml"); + +const Flags = struct { + tokenizer: []const u8, + prompt: []const u8 = + \\Examples of titles: + \\📉 Stock Market Trends + \\🍪 Perfect Chocolate Chip Recipe + \\Evolution of Music Streaming + \\Remote Work Productivity Tips + \\Artificial Intelligence in Healthcare + \\🎮 Video Game Development Insights + \\ + , + expected: []const u8 = "128000,41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198", + verbose: bool = false, +}; + +pub fn main() !void { + var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}; + const allocator = gpa.allocator(); + + var raw_args = std.process.args(); + const args = flags.parse(&raw_args, Flags); + + log.info("\tLoading tokenizer from {s}", .{args.tokenizer}); + var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, args.tokenizer); + log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer}); + defer tokenizer.deinit(); + + const prompt_tok = try tokenizer.encode(allocator, args.prompt, .{ .debug = args.verbose }); + + log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok }); + if (args.expected.len > 0) { + var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len); + var it = std.mem.splitSequence(u8, args.expected, ","); + while (it.next()) |int_token| { + const tok = try std.fmt.parseInt(u32, int_token, 10); + try expected.append(tok); + } + if (std.mem.eql(u32, expected.items, prompt_tok)) { + log.info("All good !", .{}); + } else { + log.err("Doesn't match expected: {d}", .{expected.items}); + } + } +}