diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 3bc8943..b170125 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -4,6 +4,7 @@ zig_library( name = "stdx", srcs = [ "debug.zig", + "flags.zig", "io.zig", "json.zig", "math.zig", diff --git a/stdx/flags.zig b/stdx/flags.zig new file mode 100644 index 0000000..c672398 --- /dev/null +++ b/stdx/flags.zig @@ -0,0 +1,582 @@ +//! From TigerBeetle, under Apache 2.0 attribution license. +//! https://github.com/tigerbeetle/tigerbeetle/blob/main/src/flags.zig TigerBeetle/ +//! +//! The purpose of `flags` is to define standard behavior for parsing CLI arguments and provide +//! a specific parsing library, implementing this behavior. +//! +//! These are TigerBeetle CLI guidelines: +//! +//! - The main principle is robustness --- make operator errors harder to make. +//! - For production usage, avoid defaults. +//! - Thoroughly validate options. +//! - In particular, check that no options are repeated. +//! - Use only long options (`--addresses`). +//! - Exception: `-h/--help` is allowed. +//! - Use `--key=value` syntax for an option with an argument. +//! Don't use `--key value`, as that can be ambiguous (e.g., `--key --verbose`). +//! - Use subcommand syntax when appropriate. +//! - Use positional arguments when appropriate. +//! +//! Design choices for this particular `flags` library: +//! +//! - Be a 80% solution. Parsing arguments is a surprisingly vast topic: auto-generated help, +//! bash completions, typo correction. Rather than providing a definitive solution, `flags` +//! is just one possible option. It is ok to re-implement arg parsing in a different way, as long +//! as the CLI guidelines are observed. +//! +//! - No auto-generated help. Zig doesn't expose doc comments through `@typeInfo`, so its hard to +//! implement auto-help nicely. Additionally, fully hand-crafted `--help` message can be of +//! higher quality. +//! +//! - Fatal errors. It might be "cleaner" to use `try` to propagate the error to the caller, but +//! during early CLI parsing, it is much simpler to terminate the process directly and save the +//! caller the hassle of propagating errors. The `fatal` function is public, to allow the caller +//! to run additional validation or parsing using the same error reporting mechanism. +//! +//! - Concise DSL. Most cli parsing is done for ad-hoc tools like benchmarking, where the ability to +//! quickly add a new argument is valuable. As this is a 80% solution, production code may use +//! more verbose approach if it gives better UX. +//! +//! - Caller manages ArgsIterator. ArgsIterator owns the backing memory of the args, so we let the +//! caller to manage the lifetime. The caller should be skipping program name. + +const std = @import("std"); +const builtin = @import("builtin"); +const assert = std.debug.assert; + +/// Format and print an error message to stderr, then exit with an exit code of 1. +pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn { + const stderr = std.io.getStdErr().writer(); + stderr.print("error: " ++ fmt_string ++ "\n", args) catch {}; + std.posix.exit(1); +} + +/// Parse CLI arguments for subcommands specified as Zig `struct` or `union(enum)`: +/// +/// ``` +/// const CliArgs = union(enum) { +/// start: struct { addresses: []const u8, replica: u32 }, +/// format: struct { +/// verbose: bool = false, +/// positional: struct { +/// path: []const u8, +/// } +/// }, +/// +/// pub const help = +/// \\ tigerbeetle start --addresses= --replica= +/// \\ tigerbeetle format [--verbose] +/// } +/// +/// const cli_args = parse_commands(&args, CliArgs); +/// ``` +/// +/// `positional` field is treated specially, it designates positional arguments. +/// +/// If `pub const help` declaration is present, it is used to implement `-h/--help` argument. +pub fn parse(args: *std.process.ArgIterator, comptime CliArgs: type) CliArgs { + assert(args.skip()); // Discard executable name. + + return switch (@typeInfo(CliArgs)) { + .Union => parse_commands(args, CliArgs), + .Struct => parse_flags(args, CliArgs), + else => unreachable, + }; +} + +/// Parse CLI arguments for current process. +/// See `stdx.flags.parse` documentation for more. +pub fn parseProcessArgs(comptime CliArgs: type) CliArgs { + var args = std.process.args(); + return parse(&args, CliArgs); +} + +fn parse_commands(args: *std.process.ArgIterator, comptime Commands: type) Commands { + comptime assert(@typeInfo(Commands) == .Union); + comptime assert(std.meta.fields(Commands).len >= 2); + + const first_arg = args.next() orelse fatal( + "subcommand required, expected {s}", + .{comptime fields_to_comma_list(Commands)}, + ); + + // NB: help must be declared as *pub* const to be visible here. + if (@hasDecl(Commands, "help")) { + if (std.mem.eql(u8, first_arg, "-h") or std.mem.eql(u8, first_arg, "--help")) { + std.io.getStdOut().writeAll(Commands.help) catch std.posix.exit(1); + std.posix.exit(0); + } + } + + inline for (comptime std.meta.fields(Commands)) |field| { + comptime assert(std.mem.indexOf(u8, field.name, "_") == null); + if (std.mem.eql(u8, first_arg, field.name)) { + return @unionInit(Commands, field.name, parse_flags(args, field.type)); + } + } + fatal("unknown subcommand: '{s}'", .{first_arg}); +} + +fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags { + @setEvalBranchQuota(5_000); + + if (Flags == void) { + if (args.next()) |arg| { + fatal("unexpected argument: '{s}'", .{arg}); + } + return {}; + } + + assert(@typeInfo(Flags) == .Struct); + + comptime var fields: [std.meta.fields(Flags).len]std.builtin.Type.StructField = undefined; + comptime var field_count = 0; + + comptime var positional_fields: []const std.builtin.Type.StructField = &.{}; + + comptime for (std.meta.fields(Flags)) |field| { + if (std.mem.eql(u8, field.name, "positional")) { + assert(@typeInfo(field.type) == .Struct); + positional_fields = std.meta.fields(field.type); + var optional_tail = false; + for (positional_fields) |positional_field| { + if (default_value(positional_field) == null) { + if (optional_tail) @panic("optional positional arguments must be last"); + } else { + optional_tail = true; + } + switch (@typeInfo(positional_field.type)) { + .Optional => |optional| { + // optional flags should have a default + assert(default_value(positional_field) != null); + assert(default_value(positional_field).? == null); + assert_valid_value_type(optional.child); + }, + else => { + assert_valid_value_type(positional_field.type); + }, + } + } + } else { + fields[field_count] = field; + field_count += 1; + + switch (@typeInfo(field.type)) { + .Bool => { + // boolean flags should have a default + assert(default_value(field) != null); + assert(default_value(field).? == false); + }, + .Optional => |optional| { + // optional flags should have a default + assert(default_value(field) != null); + assert(default_value(field).? == null); + + assert_valid_value_type(optional.child); + }, + else => { + assert_valid_value_type(field.type); + }, + } + } + }; + + var result: Flags = undefined; + // Would use std.enums.EnumFieldStruct(Flags, u32, 0) here but Flags is a Struct not an Enum. + var counts = comptime blk: { + var count_fields = std.meta.fields(Flags)[0..std.meta.fields(Flags).len].*; + for (&count_fields) |*field| { + field.type = u32; + field.alignment = @alignOf(u32); + field.default_value = @ptrCast(&@as(u32, 0)); + } + break :blk @Type(.{ .Struct = .{ + .layout = .auto, + .fields = &count_fields, + .decls = &.{}, + .is_tuple = false, + } }){}; + }; + + // When parsing arguments, we must consider longer arguments first, such that `--foo-bar=92` is + // not confused for a misspelled `--foo=92`. Using `std.sort` for comptime-only values does not + // work, so open-code insertion sort, and comptime assert order during the actual parsing. + comptime { + for (fields[0..field_count], 0..) |*field_right, i| { + for (fields[0..i]) |*field_left| { + if (field_left.name.len < field_right.name.len) { + std.mem.swap(std.builtin.Type.StructField, field_left, field_right); + } + } + } + } + + var parsed_positional = false; + next_arg: while (args.next()) |arg| { + comptime var field_len_prev = std.math.maxInt(usize); + inline for (fields[0..field_count]) |field| { + const flag = comptime flag_name(field); + + comptime assert(field_len_prev >= field.name.len); + field_len_prev = field.name.len; + if (std.mem.startsWith(u8, arg, flag)) { + if (parsed_positional) { + fatal("unexpected trailing option: '{s}'", .{arg}); + } + + @field(counts, field.name) += 1; + const flag_value = parse_flag(field.type, flag, arg); + @field(result, field.name) = flag_value; + continue :next_arg; + } + } + + if (@hasField(Flags, "positional")) { + counts.positional += 1; + switch (counts.positional - 1) { + inline 0...positional_fields.len - 1 => |positional_index| { + const positional_field = positional_fields[positional_index]; + const flag = comptime flag_name_positional(positional_field); + + if (arg.len == 0) fatal("{s}: empty argument", .{flag}); + // Prevent ambiguity between a flag and positional argument value. We could add + // support for bare ` -- ` as a disambiguation mechanism once we have a real + // use-case. + if (arg[0] == '-') fatal("unexpected argument: '{s}'", .{arg}); + parsed_positional = true; + + @field(result.positional, positional_field.name) = + parse_value(positional_field.type, flag, arg); + continue :next_arg; + }, + else => {}, // Fall-through to the unexpected argument error. + } + } + + fatal("unexpected argument: '{s}'", .{arg}); + } + + inline for (fields[0..field_count]) |field| { + const flag = flag_name(field); + switch (@field(counts, field.name)) { + 0 => if (default_value(field)) |default| { + @field(result, field.name) = default; + } else { + fatal("{s}: argument is required", .{flag}); + }, + 1 => {}, + else => fatal("{s}: duplicate argument", .{flag}), + } + } + + if (@hasField(Flags, "positional")) { + assert(counts.positional <= positional_fields.len); + inline for (positional_fields, 0..) |positional_field, positional_index| { + if (positional_index >= counts.positional) { + const flag = comptime flag_name_positional(positional_field); + if (default_value(positional_field)) |default| { + @field(result.positional, positional_field.name) = default; + } else { + fatal("{s}: argument is required", .{flag}); + } + } + } + } + + return result; +} + +fn assert_valid_value_type(comptime T: type) void { + comptime { + if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .Int) return; + + if (@typeInfo(T) == .Enum) { + const info = @typeInfo(T).Enum; + assert(info.is_exhaustive); + assert(info.fields.len >= 2); + return; + } + + @compileLog("unsupported type", T); + unreachable; + } +} + +/// Parse, e.g., `--cluster=123` into `123` integer +fn parse_flag(comptime T: type, flag: []const u8, arg: [:0]const u8) T { + assert(flag[0] == '-' and flag[1] == '-'); + + if (T == bool) { + if (!std.mem.eql(u8, arg, flag)) { + fatal("{s}: argument does not require a value in '{s}'", .{ flag, arg }); + } + return true; + } + + const value = parse_flag_split_value(flag, arg); + assert(value.len > 0); + return parse_value(T, flag, value); +} + +/// Splits the value part from a `--arg=value` syntax. +fn parse_flag_split_value(flag: []const u8, arg: [:0]const u8) [:0]const u8 { + assert(flag[0] == '-' and flag[1] == '-'); + assert(std.mem.startsWith(u8, arg, flag)); + + const value = arg[flag.len..]; + if (value.len == 0) { + fatal("{s}: expected value separator '='", .{flag}); + } + if (value[0] != '=') { + fatal( + "{s}: expected value separator '=', but found '{c}' in '{s}'", + .{ flag, value[0], arg }, + ); + } + if (value.len == 1) fatal("{s}: argument requires a value", .{flag}); + return value[1..]; +} + +fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T { + comptime assert(T != bool); + assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); + assert(value.len > 0); + + const V = switch (@typeInfo(T)) { + .Optional => |optional| optional.child, + else => T, + }; + + if (V == []const u8 or V == [:0]const u8) return value; + if (V == ByteSize) return parse_value_size(flag, value); + if (@typeInfo(V) == .Int) return parse_value_int(V, flag, value); + if (@typeInfo(V) == .Enum) return parse_value_enum(V, flag, value); + comptime unreachable; +} + +fn parse_value_size(flag: []const u8, value: []const u8) ByteSize { + assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); + + return ByteSize.parse(value) catch |err| { + switch (err) { + error.ParseOverflow => fatal( + "{s}: value exceeds 64-bit unsigned integer: '{s}'", + .{ flag, value }, + ), + error.InvalidSize => fatal( + "{s}: expected a size, but found '{s}'", + .{ flag, value }, + ), + error.InvalidUnit => fatal( + "{s}: invalid unit in size '{s}', (needed KiB, MiB, GiB or TiB)", + .{ flag, value }, + ), + error.BytesOverflow => fatal( + "{s}: size in bytes exceeds 64-bit unsigned integer: '{s}'", + .{ flag, value }, + ), + } + }; +} + +pub const ByteUnit = enum(u64) { + bytes = 1, + kib = 1024, + mib = 1024 * 1024, + gib = 1024 * 1024 * 1024, + tib = 1024 * 1024 * 1024 * 1024, +}; + +const ByteSizeParseError = error{ + ParseOverflow, + InvalidSize, + InvalidUnit, + BytesOverflow, +}; + +pub const ByteSize = struct { + value: u64, + unit: ByteUnit = .bytes, + + fn parse(value: []const u8) ByteSizeParseError!ByteSize { + assert(value.len != 0); + + const split: struct { + value_input: []const u8, + unit_input: []const u8, + } = split: for (0..value.len) |i| { + if (!std.ascii.isDigit(value[i])) { + break :split .{ + .value_input = value[0..i], + .unit_input = value[i..], + }; + } + } else { + break :split .{ + .value_input = value, + .unit_input = "", + }; + }; + + const amount = std.fmt.parseUnsigned(u64, split.value_input, 10) catch |err| { + switch (err) { + error.Overflow => { + return ByteSizeParseError.ParseOverflow; + }, + error.InvalidCharacter => { + // The only case this can happen is for the empty string + return ByteSizeParseError.InvalidSize; + }, + } + }; + + const unit = if (split.unit_input.len > 0) + unit: inline for (comptime std.enums.values(ByteUnit)) |tag| { + if (std.ascii.eqlIgnoreCase(split.unit_input, @tagName(tag))) { + break :unit tag; + } + } else { + return ByteSizeParseError.InvalidUnit; + } + else + ByteUnit.bytes; + + _ = std.math.mul(u64, amount, @intFromEnum(unit)) catch { + return ByteSizeParseError.BytesOverflow; + }; + + return ByteSize{ .value = amount, .unit = unit }; + } + + pub fn bytes(size: *const ByteSize) u64 { + return std.math.mul( + u64, + size.value, + @intFromEnum(size.unit), + ) catch unreachable; + } + + pub fn suffix(size: *const ByteSize) []const u8 { + return switch (size.unit) { + .bytes => "", + .kib => "KiB", + .mib => "MiB", + .gib => "GiB", + .tib => "TiB", + }; + } +}; + +test parse_value_size { + const kib = 1024; + const mib = kib * 1024; + const gib = mib * 1024; + const tib = gib * 1024; + + const cases = .{ + .{ 0, "0", 0, ByteUnit.bytes }, + .{ 1, "1", 1, ByteUnit.bytes }, + .{ 140737488355328, "140737488355328", 140737488355328, ByteUnit.bytes }, + .{ 140737488355328, "128TiB", 128, ByteUnit.tib }, + .{ 1 * tib, "1TiB", 1, ByteUnit.tib }, + .{ 10 * tib, "10tib", 10, ByteUnit.tib }, + .{ 1 * gib, "1GiB", 1, ByteUnit.gib }, + .{ 10 * gib, "10gib", 10, ByteUnit.gib }, + .{ 1 * mib, "1MiB", 1, ByteUnit.mib }, + .{ 10 * mib, "10mib", 10, ByteUnit.mib }, + .{ 1 * kib, "1KiB", 1, ByteUnit.kib }, + .{ 10 * kib, "10kib", 10, ByteUnit.kib }, + }; + + inline for (cases) |case| { + const bytes = case[0]; + const input = case[1]; + const unit_val = case[2]; + const unit = case[3]; + const got = parse_value_size("--size", input); + assert(bytes == got.bytes()); + assert(unit_val == got.value); + assert(unit == got.unit); + } +} + +/// Parse string value into an integer, providing a nice error message for the user. +fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T { + assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); + + return std.fmt.parseInt(T, value, 10) catch |err| { + switch (err) { + error.Overflow => fatal( + "{s}: value exceeds {d}-bit {s} integer: '{s}'", + .{ flag, @typeInfo(T).Int.bits, @tagName(@typeInfo(T).Int.signedness), value }, + ), + error.InvalidCharacter => fatal( + "{s}: expected an integer value, but found '{s}' (invalid digit)", + .{ flag, value }, + ), + } + }; +} + +fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E { + assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); + comptime assert(@typeInfo(E).Enum.is_exhaustive); + + return std.meta.stringToEnum(E, value) orelse fatal( + "{s}: expected one of {s}, but found '{s}'", + .{ flag, comptime fields_to_comma_list(E), value }, + ); +} + +fn fields_to_comma_list(comptime E: type) []const u8 { + comptime { + const field_count = std.meta.fields(E).len; + assert(field_count >= 2); + + var result: []const u8 = ""; + for (std.meta.fields(E), 0..) |field, field_index| { + const separator = switch (field_index) { + 0 => "", + else => ", ", + field_count - 1 => if (field_count == 2) " or " else ", or ", + }; + result = result ++ separator ++ "'" ++ field.name ++ "'"; + } + return result; + } +} + +pub fn flag_name(comptime field: std.builtin.Type.StructField) []const u8 { + // TODO(Zig): Cleanup when this is fixed after Zig 0.11. + // Without comptime blk, the compiler thinks the result is a runtime slice returning a UAF. + return comptime blk: { + assert(!std.mem.eql(u8, field.name, "positional")); + + var result: []const u8 = "--"; + var index = 0; + while (std.mem.indexOf(u8, field.name[index..], "_")) |i| { + result = result ++ field.name[index..][0..i] ++ "-"; + index = index + i + 1; + } + result = result ++ field.name[index..]; + break :blk result; + }; +} + +test flag_name { + const field = @typeInfo(struct { statsd: bool }).Struct.fields[0]; + try std.testing.expectEqualStrings(flag_name(field), "--statsd"); +} + +fn flag_name_positional(comptime field: std.builtin.Type.StructField) []const u8 { + comptime assert(std.mem.indexOf(u8, field.name, "_") == null); + return "<" ++ field.name ++ ">"; +} + +/// This is essentially `field.default_value`, but with a useful type instead of `?*anyopaque`. +pub fn default_value(comptime field: std.builtin.Type.StructField) ?field.type { + return if (field.default_value) |default_opaque| + @as(*const field.type, @ptrCast(@alignCast(default_opaque))).* + else + null; +} diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 7b03260..2cbc84a 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,4 +1,5 @@ pub const debug = @import("debug.zig"); +pub const flags = @import("flags.zig"); pub const io = @import("io.zig"); pub const json = @import("json.zig"); pub const math = @import("math.zig"); diff --git a/zml/aio.zig b/zml/aio.zig index 55155c4..175ffb3 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -10,6 +10,7 @@ const posix = @import("posix.zig"); pub const gguf = @import("aio/gguf.zig"); pub const nemo = @import("aio/nemo.zig"); pub const safetensors = @import("aio/safetensors.zig"); +pub const tinyllama = @import("aio/tinyllama.zig"); pub const torch = @import("aio/torch.zig"); pub const yaml = @import("aio/yaml.zig"); @@ -35,6 +36,8 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) try gguf.open(allocator, model_path) else if (std.mem.endsWith(u8, model_path, ".pt")) try torch.open(allocator, model_path) + else if (std.mem.endsWith(u8, model_path, ".tinyllama")) + try tinyllama.open(allocator, model_path) else { std.debug.panic("File extension not recognized: {s}", .{model_path}); }; diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index e3fe3f9..8114c90 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -129,35 +129,3 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype, } return off; } - -pub fn loadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !zml.tokenizer.Tokenizer { - const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{}); - defer tokenizer_file.close(); - var tok_reader = std.io.bufferedReader(tokenizer_file.reader()); - const r = tok_reader.reader(); - - const max_token_len = try r.readInt(u32, .little); - const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{ - .unk = 0, - .bos = 1, - .eos = 2, - }; - var tokenizer = try zml.tokenizer.Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true); - var i: u32 = 0; - while (readToken(&tokenizer, &r)) : (i += 1) { - // Pass - } else |_| { - if (i < vocab_size) { - zml.log.info("Read {d} words out of {?d}", .{ i, vocab_size }); - } - tokenizer.vocab_size = i; - } - try tokenizer.rewriteByteFallbackTokens(); - return tokenizer; -} - -fn readToken(tokenizer: *zml.tokenizer.Tokenizer, tok_reader: anytype) !void { - const score: f32 = @bitCast(try tok_reader.readInt(u32, .little)); - const len: usize = @intCast(try tok_reader.readInt(u32, .little)); - try tokenizer.readTokenInto(score, len, tok_reader); -} diff --git a/zml/tokenizer/BUILD.bazel b/zml/tokenizer/BUILD.bazel index 2e0a71f..3bd9da9 100644 --- a/zml/tokenizer/BUILD.bazel +++ b/zml/tokenizer/BUILD.bazel @@ -16,6 +16,7 @@ zig_library( name = "tokenizer", import_name = "zml/tokenizer", main = "tokenizer.zig", + srcs = ["homemade.zig"], visibility = ["//visibility:public"], deps = [ "//async", @@ -30,6 +31,8 @@ zig_cc_binary( main = "main.zig", visibility = ["//visibility:public"], deps = [ + "//stdx", + "//async", ":tokenizer", ], ) diff --git a/zml/tokenizer/hftokenizers/hftokenizers.zig b/zml/tokenizer/hftokenizers/hftokenizers.zig index d5c252f..538a25a 100644 --- a/zml/tokenizer/hftokenizers/hftokenizers.zig +++ b/zml/tokenizer/hftokenizers/hftokenizers.zig @@ -91,7 +91,7 @@ pub const Decoder = struct { }; pub const HFTokenizer = opaque { - pub fn from_file(model: []const u8) !*HFTokenizer { + pub fn fromFile(model: []const u8) !*HFTokenizer { return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model))); } @@ -107,7 +107,7 @@ pub const HFTokenizer = opaque { return Decoder.init(self); } - pub fn token_to_id(self: *HFTokenizer, token: []const u8) ?u32 { + pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 { return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token)); } }; diff --git a/zml/tokenizer/homemade.zig b/zml/tokenizer/homemade.zig new file mode 100644 index 0000000..448dd0e --- /dev/null +++ b/zml/tokenizer/homemade.zig @@ -0,0 +1,1215 @@ +//! Text tokenizer implementations +//! Disclaimer this is not a very robust implementation: +//! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization. +//! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer. +const builtin = @import("builtin"); +const std = @import("std"); + +const testing = std.testing; + +const log = std.log.scoped(.@"zml/tokenizer"); + +test { + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(Normalizer); + std.testing.refAllDecls(Tokenizer); +} + +/// Byte Pair Encoding tokenizer generally used for LLM. +pub const Tokenizer = struct { + tokens: [][]const u8, + token_lookup: std.StringHashMapUnmanaged(u32), + special_tokens: SpecialTokens, + + scores: []f32, + max_token_len: u32, + normalizer: ?Normalizer, + // Allows to split unknown unicode characters into bytes. + byte_fallback: bool = false, + + arena_state: std.heap.ArenaAllocator, + vocab_size: u32, + next_token_id: u32 = 0, + + pub const SpecialTokens = struct { + eos: u32, + bos: u32, + unk: u32, + pad: u32 = std.math.maxInt(u32), + hard_space: u32 = std.math.maxInt(u32), + }; + + pub fn init( + allocator: std.mem.Allocator, + vocab_size: u32, + max_token_len: u32, + normalizer: ?Normalizer, + special_tokens: SpecialTokens, + alloc_tokens: bool, + ) !Tokenizer { + var arena_state = std.heap.ArenaAllocator.init(allocator); + errdefer arena_state.deinit(); + const arena = arena_state.allocator(); + + var token_lookup: std.StringHashMapUnmanaged(u32) = .{}; + errdefer token_lookup.deinit(arena); + + try token_lookup.ensureTotalCapacity(arena, @intCast(vocab_size)); + + const tokens: [][]const u8 = if (alloc_tokens) try arena.alloc([]u8, vocab_size) else &.{}; + errdefer if (alloc_tokens) arena.free(tokens); + + const scores: []f32 = if (alloc_tokens) try arena.alloc(f32, vocab_size) else &.{}; + errdefer if (alloc_tokens) arena.free(scores); + + return .{ + .tokens = tokens, + .scores = scores, + .max_token_len = max_token_len, + .token_lookup = token_lookup, + .arena_state = arena_state, + .normalizer = normalizer, + .vocab_size = vocab_size, + .special_tokens = special_tokens, + }; + } + + pub fn deinit(self: Tokenizer) void { + self.arena_state.deinit(); + } + + pub fn encoder(self: *Tokenizer) !Encoder { + return Encoder.init(self); + } + + pub fn decoder(self: *Tokenizer) !Decoder { + return Decoder.init(self); + } + + /// Reads a new word directly into the tokenizer arena. + pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: anytype) !void { + const arena = self.arena_state.allocator(); + + const token = try arena.alloc(u8, len); + const n = try tok_reader.read(token); + std.debug.assert(n == len); + + return self.addOwnedToken(score, token); + } + + /// Adds a new token (and copy it) + pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void { + const arena = self.arena_state.allocator(); + return self.addOwnedToken(score, try arena.dupe(u8, token)); + } + + /// Adds a new token (without copying it) + pub fn addOwnedToken(self: *Tokenizer, score: f32, token: []const u8) void { + const i = self.next_token_id; + std.debug.assert(i < self.vocab_size); + self.next_token_id += 1; + + self.scores[i] = score; + self.tokens[i] = token; + const v = self.token_lookup.getOrPutAssumeCapacity(token); + if (!v.found_existing) { + v.value_ptr.* = i; + } + } + + pub fn addOwnedTokenByIndex(self: *Tokenizer, i: u32, score: f32, token: []const u8) void { + std.debug.assert(i < self.vocab_size); + self.next_token_id += 1; + self.scores[i] = score; + self.tokens[i] = token; + const v = self.token_lookup.getOrPutAssumeCapacity(token); + if (!v.found_existing) { + v.value_ptr.* = @intCast(i); + } + } + + pub fn lookup(self: *const Tokenizer, str: []const u8) ?u32 { + return self.token_lookup.get(str); + } + + pub fn tokenToId(self: *const Tokenizer, token: []const u8) ?u32 { + return self.token_lookup.get(token); + } + + pub const EncodeOptions = struct { + /// Should the beginning of sentence '' token be added. + add_bos: bool = true, + add_eos: bool = false, + pad_to: u32 = 0, + // Print tokenization intermediary steps. + debug: bool = false, + }; + + pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 { + if (options.debug) log.debug("Tokenizer.encode('{s}')", .{raw}); + const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw; + defer if (self.normalizer) |_| allocator.free(input); + if (options.debug) log.debug("Tokenizer.encode.normalize -> '{s}'", .{input}); + + // Allocate a buffer that can fit all indices as well as extra character if requested. + // We then slice it so that the token merging code doesn't see the bos token. + const tok_buff_alloc = try allocator.alloc(u32, @max(options.pad_to, input.len + 2)); + const tok_buff = if (options.add_bos) tok_buff_alloc[1..] else tok_buff_alloc; + + const MergeState = union(enum) { ready: u32, nope, hard_space, idk }; + const mergeable = try allocator.alloc(MergeState, tok_buff.len); + + var num_tokens: usize = 0; + var it: CharTokenIterator = .{ .input = input }; + while (try it.nextCodepointToken(self)) |token| : (num_tokens += 1) { + tok_buff[num_tokens] = token; + mergeable[num_tokens] = if (token == self.special_tokens.hard_space) + .hard_space + else + .idk; + } + + var stable_prefix: usize = 0; + var stable_off: usize = 0; + while (true) { + // This code is a bit overcomplicated cause I'm abstracting over two algorithms: + // BPE and sentencepiece unigram model. + // Normally BPE is pre-split on spaces then the regular merge algorithm is applied. + // With unigram model you work at sentence level and you handle spaces as you would any other bytes, + // hoping the final tokens mostly align with spaces. + // This seemed like a good idea, but is kinda bad because I had to add special code to speed up BPE + // by detecting when the first "word" is treated and can be safely removed from sequence. + // Also it doesn't work well with BPE vocab which have multi-space tokens (for indentation) + // and have custom splitting rules. + // This is fine for now cause we now have bindings to HF tokenizers for complexe use cases + // and are only using this for tinyllama/gguf models. + // If we come back to use this in production, the implementation would gain in speed/clarity + // by splitting in two. + // The merging token logic isn't that complicated anyway. + + // Step by step visualization of the progress. + if (options.debug) { + var _debug_buf: [256]u8 = undefined; + var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); + var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator()); + self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; + log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); + } + var best_score: f32 = -1e10; + var best_token: u32 = 0; + var best_idx: ?usize = null; + var input_off: usize = stable_off; + + // Find best tokens to merge in all available tokens + for (stable_prefix..num_tokens - 1) |i| { + if (tok_buff[i] == self.special_tokens.unk) { + input_off += 1; + continue; + } + const cur_tok = self.tokens[tok_buff[i]]; + defer input_off += cur_tok.len; + + // Lookup merge for current token, if not already done. + switch (mergeable[i]) { + .nope => continue, + .ready => {}, + .hard_space => { + // Since tokens are not allowed to merge through hard sep, + // we don't need to merge the sentence-wide best token. + // We can just merge the best token since beginning. + if (best_idx != null) break; + // OTOH if there was no merge possible since beginning, + // we can skip the beginning in future iterations. + stable_prefix = i + 1; + stable_off = input_off + cur_tok.len; + continue; + }, + .idk => { + + // Special tokens can't be concatenated. + if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) { + // Detects memory corruption of tokens. + if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !"); + + if (!std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len])) { + log.err("current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] }); + @panic("invalid tokenization"); + } + } + const next_tok = self.tokens[tok_buff[i + 1]]; + // if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token. + const next_tok_len = if (tok_buff[i + 1] == self.special_tokens.unk) 1 else next_tok.len; + const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok_len]; + // Save the result + mergeable[i] = if (self.lookup(concat_tokens)) |tok| + .{ .ready = tok } + else + .nope; + }, + } + + switch (mergeable[i]) { + .idk, .hard_space => unreachable, + .nope => continue, + .ready => |tok| { + if (self.scores[tok] > best_score) { + best_score = self.scores[tok]; + best_token = tok; + best_idx = i; + } + }, + } + } + + if (best_idx) |bidx| { + // Apply the merge. + tok_buff[bidx] = best_token; + std.mem.copyForwards(u32, tok_buff[bidx + 1 ..], tok_buff[bidx + 2 .. num_tokens]); + std.mem.copyForwards(MergeState, mergeable[bidx + 1 ..], mergeable[bidx + 2 .. num_tokens]); + num_tokens -= 1; + // We got two new merge lookups to do. + mergeable[bidx] = .idk; + if (bidx > 0 and mergeable[bidx - 1] != .hard_space) mergeable[bidx - 1] = .idk; + } else { + // No merge candidate => we are done ! + break; + } + } + + if (options.add_eos) { + tok_buff[num_tokens] = self.special_tokens.eos; + num_tokens += 1; + } + if (options.add_bos) { + tok_buff_alloc[0] = self.special_tokens.bos; + num_tokens += 1; + } + if (num_tokens < options.pad_to) { + for (num_tokens..options.pad_to) |i| { + tok_buff_alloc[i] = self.special_tokens.pad; + } + num_tokens = options.pad_to; + } + + // Release extra memory we don't need anymore. + allocator.free(mergeable); + _ = allocator.resize(tok_buff_alloc, num_tokens); + return tok_buff_alloc[0..num_tokens]; + } + + /// Returns a slice corresponding to the given id. Handles unknown ids and special ids. + pub fn lookupPiece(self: *const Tokenizer, id: usize) []const u8 { + return if (id == self.special_tokens.bos or id == self.special_tokens.eos or id == self.special_tokens.pad) + "" + else if (id == self.special_tokens.unk) + "" + else if (id > self.tokens.len) + "" // this means we received an invalid id, but we didn't want to panic. + else + self.tokens[id]; + } + + /// Converts the given slice of tokens back into bytes. + /// Note that if the tokenizer allows sub-unicode bytes, it's possible + /// the output is not valid utf8. + pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { + var output = std.ArrayList(u8).init(allocator); + errdefer output.deinit(); + + try self.decodeWithOpts(&output, input, .{}); + return output.toOwnedSlice(); + } + + pub fn decodeWithOpts( + self: *const Tokenizer, + output: *std.ArrayList(u8), + input: []const u32, + opts: struct { sep: []const u8 = "" }, + ) error{OutOfMemory}!void { + const escaped = if (self.normalizer) |n| n.escapedSpace() else null; + // Flag used to indicate if the first dummy whitespace has been consumed. + for (input) |id| { + // Retrieve the slice corresponding to the id. + var piece = self.lookupPiece(id); + + // Convert `▁` to a regular space. + if (escaped) |escspc| { + // we modify piece inside the loop, so we can use it in the condition + while (std.mem.startsWith(u8, piece, escspc)) { + piece = piece[escspc.len..]; + // don't output a space at beginning of text. + if (output.items.len > 0) try output.append(' '); + } + } + + try output.appendSlice(piece); + if (opts.sep.len > 0) try output.appendSlice(opts.sep); + } + } + + /// Some tokenizers have bytes encoded in hex like this: "<0x40>". + /// This break the tokenization algorithm because the input text + /// will contain "@" not "<0x40>", + /// and if the input contains "<0x40>" it needs to not be treated as a single byte. + /// So we replace byte fallbacks strings, by their corresponding character. + /// This enables the normal tokenization algorithm to work. + pub fn rewriteByteFallbackTokens(tokenizer: *Tokenizer) !void { + tokenizer.byte_fallback = true; + var single_bytes = try tokenizer.arena_state.allocator().alloc(u8, 256); + var byte_fallback_buf = "<0x00>".*; + + for (0..256) |i| { + const c: u8 = @truncate(i); + single_bytes[i] = c; + + // First lookup the byte fallback entry. + // Note: we assume upper case, but we could try both upper and lower case if needed. + _ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 }); + const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse { + log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf}); + return error.InvalidInput; + }; + + // Check if the character is already present in the vocab. + // In that case, nothing to do, + // but note that the fallback token will be "unreachable", + // ie there is no way the tokenizer can produce it. + if (tokenizer.token_lookup.get(&.{c})) |_| continue; + + const idx: u32 = entry.value_ptr.*; + tokenizer.token_lookup.removeByPtr(entry.key_ptr); + tokenizer.addOwnedTokenByIndex(idx, tokenizer.scores[idx], single_bytes[i .. i + 1]); + } + } +}; + +test Tokenizer { + const allocator = std.testing.allocator; + const special_tokens: Tokenizer.SpecialTokens = .{ + .unk = 0, + .bos = 1, + .eos = 2, + }; + + var tokenizer = try Tokenizer.init(allocator, 10, 5, null, special_tokens, true); + defer tokenizer.deinit(); + + try tokenizer.addToken(10, "hello"); + try tokenizer.addToken(3.5, "world"); + + try testing.expect(tokenizer.lookup("hello") == 0); + try testing.expect(tokenizer.lookup("world") == 1); + + // TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto +} + +pub const Encoder = struct { + inner: *Tokenizer, + arena: std.heap.ArenaAllocator, + current_ids: []const u32 = &.{}, + + fn init(inner: *Tokenizer) !Encoder { + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + // Warmup the arena. Page allocator is expensive, avoid calling it for small reallocs. + _ = try arena.allocator().alloc(u32, 4096); + std.debug.assert(arena.reset(.retain_capacity)); + return .{ .inner = inner, .arena = arena }; + } + + pub fn reset(self: *Encoder) void { + self.current_ids = &.{}; + std.debug.assert(self.arena.reset(.retain_capacity)); + } + + pub fn deinit(self: *Encoder) void { + self.arena.deinit(); + } + + pub fn encode(self: *Encoder, input: []const u8) ![]const u32 { + self.reset(); + const res = try self.inner.encode(self.arena.allocator(), input, .{ + .add_bos = true, + .add_eos = false, + .pad_to = 0, + // Print tokenization intermediary steps. + .debug = false, + }); + self.current_ids = res; + return res; + } + + pub fn ids(self: *const Encoder) []const u32 { + return self.current_ids; + } +}; + +pub const Decoder = struct { + const StringBuffer = std.BoundedArray(u8, 128); + const TokensIdsBuffer = std.BoundedArray(u32, 4); + + inner: *Tokenizer, + arena: std.heap.ArenaAllocator, + + current_string: ?[]const u8 = null, + last_string: StringBuffer = .{ .len = 0 }, + last_token_ids: TokensIdsBuffer = .{ .len = 0 }, + + fn init(inner: *Tokenizer) !Decoder { + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + // Warmup the arena. Page allocator is expensive, avoid calling it for small reallocs. + _ = try arena.allocator().alloc(u32, 4096); + std.debug.assert(arena.reset(.retain_capacity)); + return .{ .inner = inner, .arena = arena }; + } + + pub fn deinit(self: *Decoder) void { + self.arena.deinit(); + } + + pub fn reset(self: *Decoder) void { + std.debug.assert(self.arena.reset(.retain_capacity)); + self.current_string = null; + } + + pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 { + self.reset(); + const res = try self.inner.decode(self.arena.allocator(), ids); + self.current_string = res; + return res; + } + + pub fn string(self: *const Decoder) []const u8 { + return self.current_string; + } + + pub fn next(self: *Decoder, token_id: u32) !?[]const u8 { + if (self.last_token_ids.len >= self.last_token_ids.capacity()) { + _ = self.last_token_ids.orderedRemove(0); + } + self.last_token_ids.appendAssumeCapacity(token_id); + const new_string = try self.decode(self.last_token_ids.constSlice()); + if (self.last_string.len == 0) { + self.last_string = try StringBuffer.fromSlice(new_string); + return new_string; + } + var view = try std.unicode.Utf8View.init(self.last_string.constSlice()); + var it = view.iterator(); + while (it.nextCodepointSlice()) |cp| { + const start = it.i - cp.len; + if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) { + const chunk = new_string[self.last_string.len - start ..]; + self.last_string = try StringBuffer.fromSlice(new_string); + return chunk; + } + } + return null; + } +}; + +/// Given a slice, split it in the most simple tokens using the given tokenizer tokens. +/// The output of this can be used to initialize the tokenization algorithm. +/// Normally we split the input text into utf8 codepoint, +/// but if we find an unknown codepoint we either split it in bytes, or use the special "unknown" token, +/// depending on the tokenizer configuration. +const CharTokenIterator = struct { + state: union(enum) { by_codepoint, by_byte: u8 } = .by_codepoint, + input: []const u8, + + fn nextCodepointToken(self: *CharTokenIterator, tokenizer: *const Tokenizer) error{ TruncatedInput, Utf8InvalidStartByte }!?u32 { + if (self.input.len == 0) return null; + return switch (self.state) { + .by_byte => |*byte_left| { + const idx = tokenizer.lookup(self.input[0..1]) orelse { + // Normally this has been caught when calling `rewriteByteFallbackTokens`. + std.debug.panic("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback for token '<0x{X:02}>'", .{self.input[0]}); + }; + + self.input = self.input[1..]; + byte_left.* -|= 1; + if (byte_left.* == 0) self.state = .by_codepoint; + return idx; + }, + .by_codepoint => { + // Try to lookup valid utf8 codepoint first. + const utf8_len = try std.unicode.utf8ByteSequenceLength(self.input[0]); + if (self.input.len < utf8_len) return error.TruncatedInput; + if (tokenizer.lookup(self.input[0..utf8_len])) |idx| { + self.input = self.input[utf8_len..]; + return idx; + } + + // Otherwise split in bytes if it's allowed. + if (tokenizer.byte_fallback) { + // TODO: replace this by a continue statement next time we bump Zig. + self.state = .{ .by_byte = utf8_len }; + return self.nextCodepointToken(tokenizer); + } + + // Or mark the full utf8 codepoint as unknown. + log.debug("Token not found for char '{s}'", .{self.input[0..utf8_len]}); + self.input = self.input[utf8_len..]; + return tokenizer.special_tokens.unk; + }, + }; + } +}; + +test CharTokenIterator { + const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, .eos = 2 }; + var tokenizer = try Tokenizer.init(std.testing.allocator, 16, 4, null, special_tokens, true); + defer tokenizer.deinit(); + + tokenizer.addOwnedToken(1.0, ""); // 0 + tokenizer.addOwnedToken(1.0, ""); // 1 + tokenizer.addOwnedToken(1.0, ""); // 2 + tokenizer.addOwnedToken(1.0, "ζ"); // 3 + tokenizer.addOwnedToken(1.0, &.{0xE2}); // 4: ℳ, first byte + tokenizer.addOwnedToken(1.0, &.{0x84}); // 5: ℳ, second byte + tokenizer.addOwnedToken(1.0, &.{0xB3}); // 6: ℳ, third byte + tokenizer.addOwnedToken(1.0, "L"); // 7 + + // No byte fallback + { + tokenizer.byte_fallback = false; + var it: CharTokenIterator = .{ .input = "ζℳL" }; + var res: std.BoundedArray(u32, 8) = .{}; + while (try it.nextCodepointToken(&tokenizer)) |token| { + res.appendAssumeCapacity(token); + } + try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 0, 7 }, res.constSlice()); + } + + // with byte fallback + { + tokenizer.byte_fallback = true; + var it: CharTokenIterator = .{ .input = "ζℳL" }; + var res: std.BoundedArray(u32, 8) = .{}; + while (try it.nextCodepointToken(&tokenizer)) |token| { + res.appendAssumeCapacity(token); + } + try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 4, 5, 6, 7 }, res.constSlice()); + } +} + +/// Text normalizer. +/// Most tokenizer assumes the input text have been prepocessed with on of those. +pub const Normalizer = struct { + /// Space token used by sentencepiece derived tokenizer. + pub const sentencepiece_space = "▁"; // \xe2\x96\x81 + + _whitespace: std.BoundedArray(u8, 8) = .{}, + + flags: packed struct { + remove_extra_whitespaces: bool, + add_dummy_prefix: bool, + add_dummy_suffix: bool, + /// Cheap lower casing. + /// TODO: try to match Python "lower" + lower_case_ascii: bool, + /// cheap ascii punct splitting. + // doing this processing ahead of time simplifies the logic + split_on_punct_ascii: bool, + }, + + pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer { + var res: Normalizer = .{ .flags = flags }; + if (escaped_whitespace) |escaped| { + res._whitespace.appendSliceAssumeCapacity(escaped); + } + return res; + } + + pub inline fn escapedSpace(self: Normalizer) ?[]const u8 { + return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; + } + + fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { + try normalized.appendSlice(data); + for (data) |_| try normalized_to_origin.append(consumed); + } + + pub const Result = struct { + /// Normalized string + normalized: []const u8, + /// Mapping between chars in the original string and chars in the new string + normalized_to_origin: []const usize, + + pub fn deinit(self: Result, allocator: std.mem.Allocator) void { + allocator.free(self.normalized); + allocator.free(self.normalized_to_origin); + } + }; + + /// Simplifed version of Sentencepiece normalizer. + /// + /// Llama2 uses a normalizer called "identity" so this basically only handles trailing + /// whitespaces and replaces whitespace with the "▁" (U+2581) character. + pub fn normalize(self: Normalizer, allocator: std.mem.Allocator, input: []const u8) ![]const u8 { + const res = try self.normalizeWithMapping(allocator, input); + allocator.free(res.normalized_to_origin); + return res.normalized; + } + + /// Returns both the normalized string and a mapping between the normalized string and the original. + pub fn normalizeWithMapping(self: Normalizer, allocator: std.mem.Allocator, input: []const u8) !Result { + // Number of bytes consumed from the input. + var consumed: usize = 0; + var trimmed_input = input; + + // Skip leading whitespaces. + if (self.flags.remove_extra_whitespaces) { + while (trimmed_input.len != 0) { + if (trimmed_input[0] != ' ') break; + trimmed_input = trimmed_input[1..]; + consumed += 1; + } + } + + // If the trimmed input is empty, we are done. + if (trimmed_input.len == 0) { + return .{ .normalized = &.{}, .normalized_to_origin = &.{} }; + } + + // Pre-allocate outputs + const space = self.escapedSpace() orelse " "; + const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; + var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); + errdefer normalized.deinit(); + var normalized_to_origin = try std.ArrayList(usize).initCapacity(allocator, normalized.capacity); + errdefer normalized_to_origin.deinit(); + + // If the spec asks for it, add a whitespace at the beginning. + if (self.flags.add_dummy_prefix) try addSlice(space, consumed, &normalized, &normalized_to_origin); + + var is_prev_space: bool = true; + var is_prev_word: bool = false; + + while (trimmed_input.len != 0) { + // NOTE(Corendos): This might feel weird but normally the slice we get comes from a normalizing process and can contain multiple codepoints. + // Since we have an "identity" normalizer, each slice is actually a unicode character. + const multibyte_length = try std.unicode.utf8ByteSequenceLength(trimmed_input[0]); + var slice = trimmed_input[0..multibyte_length]; + const origin = consumed; + consumed += multibyte_length; + trimmed_input = trimmed_input[multibyte_length..]; + + if (self.flags.remove_extra_whitespaces and is_prev_space) { + while (slice.len > 0 and slice[0] == ' ') { + slice = slice[1..]; + } + if (slice.len == 0) continue; + } + is_prev_space = slice[slice.len - 1] == ' '; + + if (slice.len == 1) ascii: { + // The more advanced logic only works with ascii atm + var byte = slice[0]; + if (self.escapedSpace() != null and byte == ' ') { + // replace the space token by the special token + try addSlice(space, origin, &normalized, &normalized_to_origin); + is_prev_word = false; + break :ascii; + } else if (self.flags.split_on_punct_ascii) { + if (is_prev_word and isPunct(slice)) { + // Insert a space, but continue handling the rest + try addSlice(space, origin, &normalized, &normalized_to_origin); + } + } + if (self.flags.lower_case_ascii) { + byte = std.ascii.toLower(byte); + } + try normalized.append(byte); + try normalized_to_origin.append(origin); + } else { + // we can safely copy to the output. + try addSlice(slice, origin, &normalized, &normalized_to_origin); + } + is_prev_word = !is_prev_space and !isPunct(slice); + } + + // Skip trailing whitespaces + if (self.flags.remove_extra_whitespaces) { + while (std.mem.endsWith(u8, normalized.items, space)) { + const length = normalized.items.len - space.len; + consumed = normalized_to_origin.items[length]; + try normalized.resize(length); + try normalized_to_origin.resize(length); + } + } + + try normalized_to_origin.append(consumed); + + std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1); + + if (self.flags.add_dummy_suffix) try addSlice(space, consumed, &normalized, &normalized_to_origin); + + return .{ + .normalized = try normalized.toOwnedSlice(), + .normalized_to_origin = try normalized_to_origin.toOwnedSlice(), + }; + } + + pub fn wellKnown(impl: KnownImplementation) Normalizer { + return switch (impl) { + .sentencepiece => init(.{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, sentencepiece_space), + .llama3 => init(.{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = false, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, null), + .gpt2 => init(.{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, null), + }; + } + + pub fn fromHfJson(config: std.json.ObjectMap) error{InvalidNormalizerJson}!Normalizer { + var normalizer: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = false, + .add_dummy_suffix = false, + .add_dummy_prefix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + } }; + + // Normalizer config can be a single normalizer, or a sequence of normalizers. + const maybe_steps = objectGet(config, .array, "normalizers"); + const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }}; + + for (steps) |step_val| { + if (step_val != .object) { + return error.InvalidNormalizerJson; + } + const step = step_val.object; + + const step_type = objectGet(step, .string, "type") orelse { + return error.InvalidNormalizerJson; + }; + if (std.mem.eql(u8, "Prepend", step_type)) { + normalizer.flags.add_dummy_prefix = true; + } else if (std.mem.eql(u8, "Append", step_type)) { + normalizer.flags.add_dummy_suffix = true; + } else if (std.mem.eql(u8, "Replace", step_type)) { + const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson; + const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson; + + if (std.mem.eql(u8, str_pattern, " ")) { + normalizer._whitespace.appendSliceAssumeCapacity( + objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson, + ); + } else { + log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" }); + } + } else { + log.warn("Unknown normalizer type: {s}", .{step_type}); + } + } + + return normalizer; + } + + test "Normalizer.fromHfJson" { + const config_json = + \\{ + \\ "type": "Sequence", + \\ "normalizers": [ + \\ { + \\ "type": "Prepend", + \\ "prepend": "▁" + \\ }, + \\ { + \\ "type": "Replace", + \\ "pattern": { + \\ "String": " " + \\ }, + \\ "content": "▁" + \\ } + \\ ] + \\} + ; + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const config = try std.json.parseFromSliceLeaky(std.json.Value, arena.allocator(), config_json, .{}); + const normalizer = try Normalizer.fromHfJson(config.object); + + const expected = Normalizer{ + ._whitespace = .{ .buffer = [_]u8{ 0xe2, 0x96, 0x81 } ++ [_]u8{0} ** 5, .len = 3 }, + .flags = .{ + .remove_extra_whitespaces = false, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, + }; + try std.testing.expectEqual(expected.flags, normalizer.flags); + try std.testing.expectEqualStrings(expected.escapedSpace().?, normalizer.escapedSpace().?); + } +}; +pub const KnownImplementation = enum(u8) { + sentencepiece, + gpt2, + llama3, +}; + +fn isPunct(unicode_char: []const u8) bool { + // TODO use unicode categories + if (unicode_char.len > 1) return false; + + return switch (unicode_char[0]) { + ' ', '\t' => false, + 0...8 => true, + 10...31 => true, + '!'...'/' => true, + ':'...'@' => true, + '['...'`' => true, + '{'...'~' => true, + else => false, + }; +} + +test Normalizer { + { + const n: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + } }; + const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!"); + defer res.deinit(testing.allocator); + + try testing.expectEqualSlices(u8, " Hellŏ world!", res.normalized); + try testing.expectEqualSlices( + usize, + // H e l l ŏ ␣ w o r l d ! + &.{ 0, 0, 1, 2, 3, 4, 4, 6, 8, 9, 10, 11, 12, 13, 14 }, + res.normalized_to_origin, + ); + } + + { + const n: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = true, + .add_dummy_suffix = true, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + } }; + const res = try n.normalize(testing.allocator, "Hello world!"); + defer testing.allocator.free(res); + + try testing.expectEqualSlices(u8, " Hello world! ", res); + } + + { + const n = Normalizer.init( + .{ + .remove_extra_whitespaces = false, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, + Normalizer.sentencepiece_space, + ); + const res = try n.normalize(testing.allocator, "Hello world!"); + defer testing.allocator.free(res); + + try testing.expectEqualSlices(u8, "▁Hello▁▁world!", res); + } + + { + const n: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = false, + .add_dummy_suffix = true, + .lower_case_ascii = true, + .split_on_punct_ascii = false, + } }; + const res = try n.normalize(testing.allocator, "Hello world!"); + defer testing.allocator.free(res); + + try testing.expectEqualSlices(u8, "hello world! ", res); + } + + { + const n: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = false, + .add_dummy_suffix = true, + .lower_case_ascii = false, + .split_on_punct_ascii = true, + } }; + const res = try n.normalize(testing.allocator, "Hello world!"); + defer testing.allocator.free(res); + + try testing.expectEqualSlices(u8, "Hello world ! ", res); + } +} + +/// gpt2 had their own way of storing text. +/// Unfortunately this has contaminated other models. +/// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm, +/// into utf8 bytes, and do lookups at runtime. +pub const Gpt2TextDecoder = struct { + const Code = std.BoundedArray(u8, 2); + + // TODO: benchmark this is more efficient than doing the conversion at runtime. + code_to_byte: std.AutoArrayHashMap(Code, u8), + + pub fn init(allocator: std.mem.Allocator) !Gpt2TextDecoder { + var self = Gpt2TextDecoder{ + .code_to_byte = std.AutoArrayHashMap(Code, u8).init(allocator), + }; + try self.code_to_byte.ensureTotalCapacity(256); + errdefer unreachable; + + var n: usize = 0; + for (0..256) |index| { + var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init + const i: u8 = @intCast(index); + if (isPrintableByte(i)) { + if (std.ascii.isASCII(i)) { + code.appendAssumeCapacity(i); + } else { + const codepoint: u21 = @as(u21, @intCast(i)); + code.len = @intCast(std.unicode.utf8Encode(codepoint, &code.buffer) catch unreachable); + } + } else { + const codepoint: u21 = 256 + @as(u21, @intCast(n)); + code.len = @intCast(std.unicode.utf8Encode(codepoint, &code.buffer) catch unreachable); + n += 1; + } + + self.code_to_byte.putAssumeCapacityNoClobber(code, i); + } + return self; + } + + pub fn deinit(self: *Gpt2TextDecoder) void { + self.code_to_byte.deinit(); + } + + /// Transform bytes representing text under the gpt2 encoding, + /// and write to the `unicode` buffer utf-8 bytes. + pub fn decode(self: Gpt2TextDecoder, unicode: *std.ArrayList(u8), bytes: []const u8) ![]const u8 { + const start = unicode.items.len; + var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes }; + while (it.nextCodepointSlice()) |codepoint| { + const code: Code = switch (codepoint.len) { + 1 => .{ .buffer = .{ codepoint[0], 0 }, .len = 1 }, // 0-init + 2 => .{ .buffer = .{ codepoint[0], codepoint[1] }, .len = 2 }, + else => return error.InvalidInput, + }; + const byte = self.code_to_byte.get(code) orelse return error.InvalidInput; + try unicode.append(byte); + } + return unicode.items[start..]; + } + + inline fn isPrintableByte(c: u8) bool { + return ('!' <= c and c <= '~') or (0xa1 <= c and c <= 0xac) or (0xae <= c and c <= 0xff); + } +}; + +test Gpt2TextDecoder { + var decoder = try Gpt2TextDecoder.init(testing.allocator); + defer decoder.deinit(); + + var out = std.ArrayList(u8).init(testing.allocator); + defer out.deinit(); + + // Ascii is not changed. + try testing.expectEqualStrings("getTitle", try decoder.decode(&out, "getTitle")); + // Leading space are represented with 'Ġ' + try testing.expectEqualStrings(" UINavigationController", try decoder.decode(&out, "ĠUINavigationController")); + // Russian is wild + try testing.expectEqualStrings(" работ", try decoder.decode(&out, "ĠÑĢабоÑĤ")); +} + +/// Open a json file in HF format and load the vocab from it. +pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tokenizer { + const file = try std.fs.cwd().openFile(tokenizer_path, .{}); + defer file.close(); + + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + const file_content = try file.readToEndAlloc(arena, 32 * 1024 * 1024); + + const info = try std.json.parseFromSliceLeaky(std.json.Value, arena, file_content, .{ + .duplicate_field_behavior = .use_last, + }); + const main_object = switch (info) { + .object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) { + return error.InvalidFormat; + } else obj, + else => return error.InvalidFormat, + }; + + const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat; + const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat; + const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{}; + const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len); + + const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config| + try Normalizer.fromHfJson(normalizer_config) + else + Normalizer.wellKnown(.llama3); + + // delay init of special tokens. + var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); + errdefer tokenizer.deinit(); + + // Buffer containing all concatenated tokens. + // Reserve a big chunk, to avoid grow event, but release over-allocated memory. + var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); + const original_alloc = all_tokens.items.ptr; + // A re-alloc event here means we have invalidated all slices inside the tokenizer. + // If this is too annoying we could switch to a custom type instead of slices. + defer { + std.debug.assert(all_tokens.items.ptr == original_alloc); + } + + // gpt2 based tokenizer got a special way of encoding unicode. + // we don't know in advance if this will be used by this tokenizer or not. + // so we assume it is the case, but if we find some unicode character, + // outside of the range used by gpt2 we know it was wrong, and start over. + var is_gpt2_vocab: bool = true; + var gpt2_decoder = try Gpt2TextDecoder.init(allocator); + defer gpt2_decoder.deinit(); + var it = vocab.iterator(); + while (it.next()) |kv| { + const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { + switch (err) { + error.InvalidInput => { + is_gpt2_vocab = false; + break; + }, + else => return err, + } + }; + const idx: u32 = @intCast(kv.value_ptr.*.integer); + tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); + } + + if (!is_gpt2_vocab) { + // We where wrong, this is not a gpt2 vocab, start over, + // and reset the tokenizer state. + tokenizer.next_token_id = 0; + tokenizer.token_lookup.clearRetainingCapacity(); + all_tokens.clearRetainingCapacity(); + it = vocab.iterator(); + while (it.next()) |kv| { + const idx: u32 = @intCast(kv.value_ptr.*.integer); + const token = try dup(&all_tokens, kv.key_ptr.*); + tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); + } + } + + // More tokens, typically added during fine tuning of the model. + for (added_tokens) |token_obj| { + if (token_obj != .object) return error.InvalidFormat; + const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat; + const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat); + const token = try if (is_gpt2_vocab) + gpt2_decoder.decode(&all_tokens, v) + else + dup(&all_tokens, v); + + tokenizer.addOwnedTokenByIndex(id, 0, token); + } + // We won't add more tokens here, let release. + all_tokens.shrinkAndFree(all_tokens.items.len); + + var unk = tokenizer.lookup(""); + if (objectGet(model, .integer, "unk_token")) |unk_tok| { + unk = @intCast(unk_tok); + } + + tokenizer.special_tokens = .{ + // TODO allow users to specify special tokens or read them from a tokenizer_config.json file + .bos = tokenizer.lookup("") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"), + .eos = tokenizer.lookup("") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"), + .unk = unk orelse std.math.maxInt(u32), + }; + + const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false; + if (!byte_fallback and unk == null) { + // GPT2 tokenizer have byte fallback already encoded in the model, + // but the json generally don't have the field set. + // We can detect it though because they don't specify an unknown token. + if (is_gpt2_vocab) { + tokenizer.byte_fallback = true; + } else { + log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{}); + } + } else if (byte_fallback) { + try tokenizer.rewriteByteFallbackTokens(); + } + return tokenizer; +} + +/// Returns a copy of the given string, stored inside the given ArrayList. +fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { + const n = buffer.items.len; + try buffer.appendSlice(str); + return buffer.items[n..]; +} + +/// Returns the given entry in a json object only if it has the right type. +fn objectGet( + object: std.json.ObjectMap, + comptime kind: std.meta.FieldEnum(std.json.Value), + key: []const u8, +) ?std.meta.FieldType(std.json.Value, kind) { + const val = object.get(key) orelse return null; + if (val != kind) return null; + return @field(val, @tagName(kind)); +} + +pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer { + const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{}); + defer tokenizer_file.close(); + var tok_reader = std.io.bufferedReader(tokenizer_file.reader()); + const r = tok_reader.reader(); + + const max_token_len = try r.readInt(u32, .little); + const special_tokens: Tokenizer.SpecialTokens = .{ + .unk = 0, + .bos = 1, + .eos = 2, + }; + var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true); + var i: u32 = 0; + while (readToken(&tokenizer, &r)) : (i += 1) { + // Pass + } else |_| { + if (i < vocab_size) { + log.info("Read {d} words out of {?d}", .{ i, vocab_size }); + } + tokenizer.vocab_size = i; + } + try tokenizer.rewriteByteFallbackTokens(); + return tokenizer; +} + +fn readToken(tokenizer: *Tokenizer, tok_reader: anytype) !void { + const score: f32 = @bitCast(try tok_reader.readInt(u32, .little)); + const len: usize = @intCast(try tok_reader.readInt(u32, .little)); + try tokenizer.readTokenInto(score, len, tok_reader); +} diff --git a/zml/tokenizer/main.zig b/zml/tokenizer/main.zig index b8efd01..d3554f1 100644 --- a/zml/tokenizer/main.zig +++ b/zml/tokenizer/main.zig @@ -1,22 +1,65 @@ const std = @import("std"); -const tokenizer = @import("zml/tokenizer"); +const log = std.log.scoped(.@"//zml/tokenizer"); + +const asynk = @import("async"); +const stdx = @import("stdx"); +const zml_tokenizer = @import("zml/tokenizer"); + +const Flags = struct { + tokenizer: []const u8, + prompt: []const u8, + expected: []const u8 = "", + verbose: bool = false, +}; pub fn main() !void { - const model2 = "/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json"; + try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); +} - var sp = try tokenizer.Tokenizer.from_file(std.heap.c_allocator, model2); - defer sp.deinit(); +pub fn asyncMain() !void { + var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}; + const allocator = gpa.allocator(); - std.debug.print("Loaded model\n", .{}); + const args = stdx.flags.parseProcessArgs(Flags); - var encoder = try sp.encoder(); + log.info("\tLoading tokenizer from {s}", .{args.tokenizer}); + var tokenizer = try zml_tokenizer.Tokenizer.fromFile(allocator, args.tokenizer); + log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer}); + defer tokenizer.deinit(); + + var encoder = try tokenizer.encoder(); defer encoder.deinit(); - var decoder = try sp.decoder(); + var decoder = try tokenizer.decoder(); defer decoder.deinit(); - const ids = try encoder.encode("Hello, world! plane pouet plane"); - const decoded = try decoder.decode(ids); + const prompt_tok = try encoder.encode(args.prompt); - std.debug.print("{d}\n{s}\n", .{ ids, decoded }); + log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok }); + + var errors: u8 = 0; + { + const reconstructed = try decoder.decode(prompt_tok); + if (!std.mem.eql(u8, args.prompt, reconstructed)) { + log.err("Reconstructed string from tokens doesn't match source: {s}", .{reconstructed}); + errors += 1; + } + } + + if (args.expected.len > 0) { + var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len); + var it = std.mem.splitSequence(u8, args.expected, ","); + while (it.next()) |int_token| { + const tok = try std.fmt.parseInt(u32, int_token, 10); + try expected.append(tok); + } + if (!std.mem.eql(u32, expected.items, prompt_tok)) { + log.err("Doesn't match expected: {d}", .{expected.items}); + errors += 1; + } + } + + if (errors == 0) log.info("All good !", .{}); + + std.process.exit(errors); } diff --git a/zml/tokenizer/sentencepiece/sentencepiece.zig b/zml/tokenizer/sentencepiece/sentencepiece.zig index a5ce8d4..7f4edc6 100644 --- a/zml/tokenizer/sentencepiece/sentencepiece.zig +++ b/zml/tokenizer/sentencepiece/sentencepiece.zig @@ -164,7 +164,7 @@ pub const Decoder = struct { }; pub const SentencePieceProcessor = opaque { - pub fn from_file(model: []const u8) !*SentencePieceProcessor { + pub fn fromFile(model: []const u8) !*SentencePieceProcessor { const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new()); errdefer sp.deinit(); try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model))); @@ -183,7 +183,7 @@ pub const SentencePieceProcessor = opaque { return try Decoder.init(self); } - pub fn token_to_id(self: *SentencePieceProcessor, token: []const u8) u32 { + pub fn tokenToId(self: *SentencePieceProcessor, token: []const u8) u32 { return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token))); } }; diff --git a/zml/tokenizer/tokenizer.zig b/zml/tokenizer/tokenizer.zig index 10307db..a28b3a0 100644 --- a/zml/tokenizer/tokenizer.zig +++ b/zml/tokenizer/tokenizer.zig @@ -3,15 +3,19 @@ const hftokenizers = @import("hftokenizers"); const sentencepiece = @import("sentencepiece"); const asynk = @import("async"); +const homemade = @import("homemade.zig"); + const Tokenizers = enum { hftokenizers, sentencepiece, + homemade, }; pub const Tokenizer = union(Tokenizers) { pub const Encoder = union(Tokenizers) { hftokenizers: hftokenizers.Encoder, sentencepiece: sentencepiece.Encoder, + homemade: homemade.Encoder, pub fn deinit(self: *Encoder) void { switch (self.*) { @@ -41,6 +45,7 @@ pub const Tokenizer = union(Tokenizers) { pub const Decoder = union(Tokenizers) { hftokenizers: hftokenizers.Decoder, sentencepiece: sentencepiece.Decoder, + homemade: homemade.Decoder, pub fn deinit(self: *Decoder) void { switch (self.*) { @@ -81,14 +86,22 @@ pub const Tokenizer = union(Tokenizers) { hftokenizers: *hftokenizers.HFTokenizer, sentencepiece: *sentencepiece.SentencePieceProcessor, + homemade: *homemade.Tokenizer, - pub fn from_file(_: std.mem.Allocator, model: []const u8) !Tokenizer { + pub fn fromFile(allocator: std.mem.Allocator, model: []const u8) !Tokenizer { if (std.mem.endsWith(u8, model, ".pb")) { - return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.from_file, .{model}) }; + return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.fromFile, .{model}) }; } if (std.mem.endsWith(u8, model, ".json")) { - return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.from_file, .{model}) }; + return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.fromFile, .{model}) }; } + + if (std.mem.endsWith(u8, model, ".tinyllama")) { + const tokenizer = try allocator.create(homemade.Tokenizer); + tokenizer.* = try asynk.callBlocking(homemade.fromTinyLlamaFile, .{ allocator, model, 32000 }); + return .{ .homemade = tokenizer }; + } + return error.InvalidArgument; } @@ -110,9 +123,9 @@ pub const Tokenizer = union(Tokenizers) { }; } - pub fn token_to_id(self: Tokenizer, token: []const u8) ?u32 { + pub fn tokenToId(self: Tokenizer, token: []const u8) ?u32 { return switch (self) { - inline else => |v| v.token_to_id(token), + inline else => |v| v.tokenToId(token), }; } };