1216 lines
48 KiB
Zig
1216 lines
48 KiB
Zig
|
|
//! Text tokenizer implementations
|
|||
|
|
//! Disclaimer this is not a very robust implementation:
|
|||
|
|
//! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization.
|
|||
|
|
//! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer.
|
|||
|
|
const builtin = @import("builtin");
|
|||
|
|
const std = @import("std");
|
|||
|
|
|
|||
|
|
const testing = std.testing;
|
|||
|
|
|
|||
|
|
const log = std.log.scoped(.@"zml/tokenizer");
|
|||
|
|
|
|||
|
|
test {
|
|||
|
|
std.testing.refAllDecls(@This());
|
|||
|
|
std.testing.refAllDecls(Normalizer);
|
|||
|
|
std.testing.refAllDecls(Tokenizer);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Byte Pair Encoding tokenizer generally used for LLM.
|
|||
|
|
pub const Tokenizer = struct {
|
|||
|
|
tokens: [][]const u8,
|
|||
|
|
token_lookup: std.StringHashMapUnmanaged(u32),
|
|||
|
|
special_tokens: SpecialTokens,
|
|||
|
|
|
|||
|
|
scores: []f32,
|
|||
|
|
max_token_len: u32,
|
|||
|
|
normalizer: ?Normalizer,
|
|||
|
|
// Allows to split unknown unicode characters into bytes.
|
|||
|
|
byte_fallback: bool = false,
|
|||
|
|
|
|||
|
|
arena_state: std.heap.ArenaAllocator,
|
|||
|
|
vocab_size: u32,
|
|||
|
|
next_token_id: u32 = 0,
|
|||
|
|
|
|||
|
|
pub const SpecialTokens = struct {
|
|||
|
|
eos: u32,
|
|||
|
|
bos: u32,
|
|||
|
|
unk: u32,
|
|||
|
|
pad: u32 = std.math.maxInt(u32),
|
|||
|
|
hard_space: u32 = std.math.maxInt(u32),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
pub fn init(
|
|||
|
|
allocator: std.mem.Allocator,
|
|||
|
|
vocab_size: u32,
|
|||
|
|
max_token_len: u32,
|
|||
|
|
normalizer: ?Normalizer,
|
|||
|
|
special_tokens: SpecialTokens,
|
|||
|
|
alloc_tokens: bool,
|
|||
|
|
) !Tokenizer {
|
|||
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|||
|
|
errdefer arena_state.deinit();
|
|||
|
|
const arena = arena_state.allocator();
|
|||
|
|
|
|||
|
|
var token_lookup: std.StringHashMapUnmanaged(u32) = .{};
|
|||
|
|
errdefer token_lookup.deinit(arena);
|
|||
|
|
|
|||
|
|
try token_lookup.ensureTotalCapacity(arena, @intCast(vocab_size));
|
|||
|
|
|
|||
|
|
const tokens: [][]const u8 = if (alloc_tokens) try arena.alloc([]u8, vocab_size) else &.{};
|
|||
|
|
errdefer if (alloc_tokens) arena.free(tokens);
|
|||
|
|
|
|||
|
|
const scores: []f32 = if (alloc_tokens) try arena.alloc(f32, vocab_size) else &.{};
|
|||
|
|
errdefer if (alloc_tokens) arena.free(scores);
|
|||
|
|
|
|||
|
|
return .{
|
|||
|
|
.tokens = tokens,
|
|||
|
|
.scores = scores,
|
|||
|
|
.max_token_len = max_token_len,
|
|||
|
|
.token_lookup = token_lookup,
|
|||
|
|
.arena_state = arena_state,
|
|||
|
|
.normalizer = normalizer,
|
|||
|
|
.vocab_size = vocab_size,
|
|||
|
|
.special_tokens = special_tokens,
|
|||
|
|
};
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn deinit(self: Tokenizer) void {
|
|||
|
|
self.arena_state.deinit();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn encoder(self: *Tokenizer) !Encoder {
|
|||
|
|
return Encoder.init(self);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn decoder(self: *Tokenizer) !Decoder {
|
|||
|
|
return Decoder.init(self);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Reads a new word directly into the tokenizer arena.
|
|||
|
|
pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: anytype) !void {
|
|||
|
|
const arena = self.arena_state.allocator();
|
|||
|
|
|
|||
|
|
const token = try arena.alloc(u8, len);
|
|||
|
|
const n = try tok_reader.read(token);
|
|||
|
|
std.debug.assert(n == len);
|
|||
|
|
|
|||
|
|
return self.addOwnedToken(score, token);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Adds a new token (and copy it)
|
|||
|
|
pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void {
|
|||
|
|
const arena = self.arena_state.allocator();
|
|||
|
|
return self.addOwnedToken(score, try arena.dupe(u8, token));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Adds a new token (without copying it)
|
|||
|
|
pub fn addOwnedToken(self: *Tokenizer, score: f32, token: []const u8) void {
|
|||
|
|
const i = self.next_token_id;
|
|||
|
|
std.debug.assert(i < self.vocab_size);
|
|||
|
|
self.next_token_id += 1;
|
|||
|
|
|
|||
|
|
self.scores[i] = score;
|
|||
|
|
self.tokens[i] = token;
|
|||
|
|
const v = self.token_lookup.getOrPutAssumeCapacity(token);
|
|||
|
|
if (!v.found_existing) {
|
|||
|
|
v.value_ptr.* = i;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn addOwnedTokenByIndex(self: *Tokenizer, i: u32, score: f32, token: []const u8) void {
|
|||
|
|
std.debug.assert(i < self.vocab_size);
|
|||
|
|
self.next_token_id += 1;
|
|||
|
|
self.scores[i] = score;
|
|||
|
|
self.tokens[i] = token;
|
|||
|
|
const v = self.token_lookup.getOrPutAssumeCapacity(token);
|
|||
|
|
if (!v.found_existing) {
|
|||
|
|
v.value_ptr.* = @intCast(i);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn lookup(self: *const Tokenizer, str: []const u8) ?u32 {
|
|||
|
|
return self.token_lookup.get(str);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn tokenToId(self: *const Tokenizer, token: []const u8) ?u32 {
|
|||
|
|
return self.token_lookup.get(token);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub const EncodeOptions = struct {
|
|||
|
|
/// Should the beginning of sentence '<s>' token be added.
|
|||
|
|
add_bos: bool = true,
|
|||
|
|
add_eos: bool = false,
|
|||
|
|
pad_to: u32 = 0,
|
|||
|
|
// Print tokenization intermediary steps.
|
|||
|
|
debug: bool = false,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 {
|
|||
|
|
if (options.debug) log.debug("Tokenizer.encode('{s}')", .{raw});
|
|||
|
|
const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw;
|
|||
|
|
defer if (self.normalizer) |_| allocator.free(input);
|
|||
|
|
if (options.debug) log.debug("Tokenizer.encode.normalize -> '{s}'", .{input});
|
|||
|
|
|
|||
|
|
// Allocate a buffer that can fit all indices as well as extra character if requested.
|
|||
|
|
// We then slice it so that the token merging code doesn't see the bos token.
|
|||
|
|
const tok_buff_alloc = try allocator.alloc(u32, @max(options.pad_to, input.len + 2));
|
|||
|
|
const tok_buff = if (options.add_bos) tok_buff_alloc[1..] else tok_buff_alloc;
|
|||
|
|
|
|||
|
|
const MergeState = union(enum) { ready: u32, nope, hard_space, idk };
|
|||
|
|
const mergeable = try allocator.alloc(MergeState, tok_buff.len);
|
|||
|
|
|
|||
|
|
var num_tokens: usize = 0;
|
|||
|
|
var it: CharTokenIterator = .{ .input = input };
|
|||
|
|
while (try it.nextCodepointToken(self)) |token| : (num_tokens += 1) {
|
|||
|
|
tok_buff[num_tokens] = token;
|
|||
|
|
mergeable[num_tokens] = if (token == self.special_tokens.hard_space)
|
|||
|
|
.hard_space
|
|||
|
|
else
|
|||
|
|
.idk;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var stable_prefix: usize = 0;
|
|||
|
|
var stable_off: usize = 0;
|
|||
|
|
while (true) {
|
|||
|
|
// This code is a bit overcomplicated cause I'm abstracting over two algorithms:
|
|||
|
|
// BPE and sentencepiece unigram model.
|
|||
|
|
// Normally BPE is pre-split on spaces then the regular merge algorithm is applied.
|
|||
|
|
// With unigram model you work at sentence level and you handle spaces as you would any other bytes,
|
|||
|
|
// hoping the final tokens mostly align with spaces.
|
|||
|
|
// This seemed like a good idea, but is kinda bad because I had to add special code to speed up BPE
|
|||
|
|
// by detecting when the first "word" is treated and can be safely removed from sequence.
|
|||
|
|
// Also it doesn't work well with BPE vocab which have multi-space tokens (for indentation)
|
|||
|
|
// and have custom splitting rules.
|
|||
|
|
// This is fine for now cause we now have bindings to HF tokenizers for complexe use cases
|
|||
|
|
// and are only using this for tinyllama/gguf models.
|
|||
|
|
// If we come back to use this in production, the implementation would gain in speed/clarity
|
|||
|
|
// by splitting in two.
|
|||
|
|
// The merging token logic isn't that complicated anyway.
|
|||
|
|
|
|||
|
|
// Step by step visualization of the progress.
|
|||
|
|
if (options.debug) {
|
|||
|
|
var _debug_buf: [256]u8 = undefined;
|
|||
|
|
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
|||
|
|
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
|
|||
|
|
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
|||
|
|
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
|||
|
|
}
|
|||
|
|
var best_score: f32 = -1e10;
|
|||
|
|
var best_token: u32 = 0;
|
|||
|
|
var best_idx: ?usize = null;
|
|||
|
|
var input_off: usize = stable_off;
|
|||
|
|
|
|||
|
|
// Find best tokens to merge in all available tokens
|
|||
|
|
for (stable_prefix..num_tokens - 1) |i| {
|
|||
|
|
if (tok_buff[i] == self.special_tokens.unk) {
|
|||
|
|
input_off += 1;
|
|||
|
|
continue;
|
|||
|
|
}
|
|||
|
|
const cur_tok = self.tokens[tok_buff[i]];
|
|||
|
|
defer input_off += cur_tok.len;
|
|||
|
|
|
|||
|
|
// Lookup merge for current token, if not already done.
|
|||
|
|
switch (mergeable[i]) {
|
|||
|
|
.nope => continue,
|
|||
|
|
.ready => {},
|
|||
|
|
.hard_space => {
|
|||
|
|
// Since tokens are not allowed to merge through hard sep,
|
|||
|
|
// we don't need to merge the sentence-wide best token.
|
|||
|
|
// We can just merge the best token since beginning.
|
|||
|
|
if (best_idx != null) break;
|
|||
|
|
// OTOH if there was no merge possible since beginning,
|
|||
|
|
// we can skip the beginning in future iterations.
|
|||
|
|
stable_prefix = i + 1;
|
|||
|
|
stable_off = input_off + cur_tok.len;
|
|||
|
|
continue;
|
|||
|
|
},
|
|||
|
|
.idk => {
|
|||
|
|
|
|||
|
|
// Special tokens can't be concatenated.
|
|||
|
|
if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) {
|
|||
|
|
// Detects memory corruption of tokens.
|
|||
|
|
if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !");
|
|||
|
|
|
|||
|
|
if (!std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len])) {
|
|||
|
|
log.err("current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
|
|||
|
|
@panic("invalid tokenization");
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
const next_tok = self.tokens[tok_buff[i + 1]];
|
|||
|
|
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.
|
|||
|
|
const next_tok_len = if (tok_buff[i + 1] == self.special_tokens.unk) 1 else next_tok.len;
|
|||
|
|
const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok_len];
|
|||
|
|
// Save the result
|
|||
|
|
mergeable[i] = if (self.lookup(concat_tokens)) |tok|
|
|||
|
|
.{ .ready = tok }
|
|||
|
|
else
|
|||
|
|
.nope;
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch (mergeable[i]) {
|
|||
|
|
.idk, .hard_space => unreachable,
|
|||
|
|
.nope => continue,
|
|||
|
|
.ready => |tok| {
|
|||
|
|
if (self.scores[tok] > best_score) {
|
|||
|
|
best_score = self.scores[tok];
|
|||
|
|
best_token = tok;
|
|||
|
|
best_idx = i;
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if (best_idx) |bidx| {
|
|||
|
|
// Apply the merge.
|
|||
|
|
tok_buff[bidx] = best_token;
|
|||
|
|
std.mem.copyForwards(u32, tok_buff[bidx + 1 ..], tok_buff[bidx + 2 .. num_tokens]);
|
|||
|
|
std.mem.copyForwards(MergeState, mergeable[bidx + 1 ..], mergeable[bidx + 2 .. num_tokens]);
|
|||
|
|
num_tokens -= 1;
|
|||
|
|
// We got two new merge lookups to do.
|
|||
|
|
mergeable[bidx] = .idk;
|
|||
|
|
if (bidx > 0 and mergeable[bidx - 1] != .hard_space) mergeable[bidx - 1] = .idk;
|
|||
|
|
} else {
|
|||
|
|
// No merge candidate => we are done !
|
|||
|
|
break;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if (options.add_eos) {
|
|||
|
|
tok_buff[num_tokens] = self.special_tokens.eos;
|
|||
|
|
num_tokens += 1;
|
|||
|
|
}
|
|||
|
|
if (options.add_bos) {
|
|||
|
|
tok_buff_alloc[0] = self.special_tokens.bos;
|
|||
|
|
num_tokens += 1;
|
|||
|
|
}
|
|||
|
|
if (num_tokens < options.pad_to) {
|
|||
|
|
for (num_tokens..options.pad_to) |i| {
|
|||
|
|
tok_buff_alloc[i] = self.special_tokens.pad;
|
|||
|
|
}
|
|||
|
|
num_tokens = options.pad_to;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Release extra memory we don't need anymore.
|
|||
|
|
allocator.free(mergeable);
|
|||
|
|
_ = allocator.resize(tok_buff_alloc, num_tokens);
|
|||
|
|
return tok_buff_alloc[0..num_tokens];
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Returns a slice corresponding to the given id. Handles unknown ids and special ids.
|
|||
|
|
pub fn lookupPiece(self: *const Tokenizer, id: usize) []const u8 {
|
|||
|
|
return if (id == self.special_tokens.bos or id == self.special_tokens.eos or id == self.special_tokens.pad)
|
|||
|
|
""
|
|||
|
|
else if (id == self.special_tokens.unk)
|
|||
|
|
"<unk>"
|
|||
|
|
else if (id > self.tokens.len)
|
|||
|
|
"<oob>" // this means we received an invalid id, but we didn't want to panic.
|
|||
|
|
else
|
|||
|
|
self.tokens[id];
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Converts the given slice of tokens back into bytes.
|
|||
|
|
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
|||
|
|
/// the output is not valid utf8.
|
|||
|
|
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
|
|||
|
|
var output = std.ArrayList(u8).init(allocator);
|
|||
|
|
errdefer output.deinit();
|
|||
|
|
|
|||
|
|
try self.decodeWithOpts(&output, input, .{});
|
|||
|
|
return output.toOwnedSlice();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn decodeWithOpts(
|
|||
|
|
self: *const Tokenizer,
|
|||
|
|
output: *std.ArrayList(u8),
|
|||
|
|
input: []const u32,
|
|||
|
|
opts: struct { sep: []const u8 = "" },
|
|||
|
|
) error{OutOfMemory}!void {
|
|||
|
|
const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
|
|||
|
|
// Flag used to indicate if the first dummy whitespace has been consumed.
|
|||
|
|
for (input) |id| {
|
|||
|
|
// Retrieve the slice corresponding to the id.
|
|||
|
|
var piece = self.lookupPiece(id);
|
|||
|
|
|
|||
|
|
// Convert `▁` to a regular space.
|
|||
|
|
if (escaped) |escspc| {
|
|||
|
|
// we modify piece inside the loop, so we can use it in the condition
|
|||
|
|
while (std.mem.startsWith(u8, piece, escspc)) {
|
|||
|
|
piece = piece[escspc.len..];
|
|||
|
|
// don't output a space at beginning of text.
|
|||
|
|
if (output.items.len > 0) try output.append(' ');
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try output.appendSlice(piece);
|
|||
|
|
if (opts.sep.len > 0) try output.appendSlice(opts.sep);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Some tokenizers have bytes encoded in hex like this: "<0x40>".
|
|||
|
|
/// This break the tokenization algorithm because the input text
|
|||
|
|
/// will contain "@" not "<0x40>",
|
|||
|
|
/// and if the input contains "<0x40>" it needs to not be treated as a single byte.
|
|||
|
|
/// So we replace byte fallbacks strings, by their corresponding character.
|
|||
|
|
/// This enables the normal tokenization algorithm to work.
|
|||
|
|
pub fn rewriteByteFallbackTokens(tokenizer: *Tokenizer) !void {
|
|||
|
|
tokenizer.byte_fallback = true;
|
|||
|
|
var single_bytes = try tokenizer.arena_state.allocator().alloc(u8, 256);
|
|||
|
|
var byte_fallback_buf = "<0x00>".*;
|
|||
|
|
|
|||
|
|
for (0..256) |i| {
|
|||
|
|
const c: u8 = @truncate(i);
|
|||
|
|
single_bytes[i] = c;
|
|||
|
|
|
|||
|
|
// First lookup the byte fallback entry.
|
|||
|
|
// Note: we assume upper case, but we could try both upper and lower case if needed.
|
|||
|
|
_ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 });
|
|||
|
|
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
|
|||
|
|
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
|
|||
|
|
return error.InvalidInput;
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// Check if the character is already present in the vocab.
|
|||
|
|
// In that case, nothing to do,
|
|||
|
|
// but note that the fallback token will be "unreachable",
|
|||
|
|
// ie there is no way the tokenizer can produce it.
|
|||
|
|
if (tokenizer.token_lookup.get(&.{c})) |_| continue;
|
|||
|
|
|
|||
|
|
const idx: u32 = entry.value_ptr.*;
|
|||
|
|
tokenizer.token_lookup.removeByPtr(entry.key_ptr);
|
|||
|
|
tokenizer.addOwnedTokenByIndex(idx, tokenizer.scores[idx], single_bytes[i .. i + 1]);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
test Tokenizer {
|
|||
|
|
const allocator = std.testing.allocator;
|
|||
|
|
const special_tokens: Tokenizer.SpecialTokens = .{
|
|||
|
|
.unk = 0,
|
|||
|
|
.bos = 1,
|
|||
|
|
.eos = 2,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
var tokenizer = try Tokenizer.init(allocator, 10, 5, null, special_tokens, true);
|
|||
|
|
defer tokenizer.deinit();
|
|||
|
|
|
|||
|
|
try tokenizer.addToken(10, "hello");
|
|||
|
|
try tokenizer.addToken(3.5, "world");
|
|||
|
|
|
|||
|
|
try testing.expect(tokenizer.lookup("hello") == 0);
|
|||
|
|
try testing.expect(tokenizer.lookup("world") == 1);
|
|||
|
|
|
|||
|
|
// TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub const Encoder = struct {
|
|||
|
|
inner: *Tokenizer,
|
|||
|
|
arena: std.heap.ArenaAllocator,
|
|||
|
|
current_ids: []const u32 = &.{},
|
|||
|
|
|
|||
|
|
fn init(inner: *Tokenizer) !Encoder {
|
|||
|
|
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
|
|||
|
|
// Warmup the arena. Page allocator is expensive, avoid calling it for small reallocs.
|
|||
|
|
_ = try arena.allocator().alloc(u32, 4096);
|
|||
|
|
std.debug.assert(arena.reset(.retain_capacity));
|
|||
|
|
return .{ .inner = inner, .arena = arena };
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn reset(self: *Encoder) void {
|
|||
|
|
self.current_ids = &.{};
|
|||
|
|
std.debug.assert(self.arena.reset(.retain_capacity));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn deinit(self: *Encoder) void {
|
|||
|
|
self.arena.deinit();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
|||
|
|
self.reset();
|
|||
|
|
const res = try self.inner.encode(self.arena.allocator(), input, .{
|
|||
|
|
.add_bos = true,
|
|||
|
|
.add_eos = false,
|
|||
|
|
.pad_to = 0,
|
|||
|
|
// Print tokenization intermediary steps.
|
|||
|
|
.debug = false,
|
|||
|
|
});
|
|||
|
|
self.current_ids = res;
|
|||
|
|
return res;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn ids(self: *const Encoder) []const u32 {
|
|||
|
|
return self.current_ids;
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
pub const Decoder = struct {
|
|||
|
|
const StringBuffer = std.BoundedArray(u8, 128);
|
|||
|
|
const TokensIdsBuffer = std.BoundedArray(u32, 4);
|
|||
|
|
|
|||
|
|
inner: *Tokenizer,
|
|||
|
|
arena: std.heap.ArenaAllocator,
|
|||
|
|
|
|||
|
|
current_string: ?[]const u8 = null,
|
|||
|
|
last_string: StringBuffer = .{ .len = 0 },
|
|||
|
|
last_token_ids: TokensIdsBuffer = .{ .len = 0 },
|
|||
|
|
|
|||
|
|
fn init(inner: *Tokenizer) !Decoder {
|
|||
|
|
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
|
|||
|
|
// Warmup the arena. Page allocator is expensive, avoid calling it for small reallocs.
|
|||
|
|
_ = try arena.allocator().alloc(u32, 4096);
|
|||
|
|
std.debug.assert(arena.reset(.retain_capacity));
|
|||
|
|
return .{ .inner = inner, .arena = arena };
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn deinit(self: *Decoder) void {
|
|||
|
|
self.arena.deinit();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn reset(self: *Decoder) void {
|
|||
|
|
std.debug.assert(self.arena.reset(.retain_capacity));
|
|||
|
|
self.current_string = null;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
|
|||
|
|
self.reset();
|
|||
|
|
const res = try self.inner.decode(self.arena.allocator(), ids);
|
|||
|
|
self.current_string = res;
|
|||
|
|
return res;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn string(self: *const Decoder) []const u8 {
|
|||
|
|
return self.current_string;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
|
|||
|
|
if (self.last_token_ids.len >= self.last_token_ids.capacity()) {
|
|||
|
|
_ = self.last_token_ids.orderedRemove(0);
|
|||
|
|
}
|
|||
|
|
self.last_token_ids.appendAssumeCapacity(token_id);
|
|||
|
|
const new_string = try self.decode(self.last_token_ids.constSlice());
|
|||
|
|
if (self.last_string.len == 0) {
|
|||
|
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
|||
|
|
return new_string;
|
|||
|
|
}
|
|||
|
|
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
|
|||
|
|
var it = view.iterator();
|
|||
|
|
while (it.nextCodepointSlice()) |cp| {
|
|||
|
|
const start = it.i - cp.len;
|
|||
|
|
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
|
|||
|
|
const chunk = new_string[self.last_string.len - start ..];
|
|||
|
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
|||
|
|
return chunk;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return null;
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
/// Given a slice, split it in the most simple tokens using the given tokenizer tokens.
|
|||
|
|
/// The output of this can be used to initialize the tokenization algorithm.
|
|||
|
|
/// Normally we split the input text into utf8 codepoint,
|
|||
|
|
/// but if we find an unknown codepoint we either split it in bytes, or use the special "unknown" token,
|
|||
|
|
/// depending on the tokenizer configuration.
|
|||
|
|
const CharTokenIterator = struct {
|
|||
|
|
state: union(enum) { by_codepoint, by_byte: u8 } = .by_codepoint,
|
|||
|
|
input: []const u8,
|
|||
|
|
|
|||
|
|
fn nextCodepointToken(self: *CharTokenIterator, tokenizer: *const Tokenizer) error{ TruncatedInput, Utf8InvalidStartByte }!?u32 {
|
|||
|
|
if (self.input.len == 0) return null;
|
|||
|
|
return switch (self.state) {
|
|||
|
|
.by_byte => |*byte_left| {
|
|||
|
|
const idx = tokenizer.lookup(self.input[0..1]) orelse {
|
|||
|
|
// Normally this has been caught when calling `rewriteByteFallbackTokens`.
|
|||
|
|
std.debug.panic("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback for token '<0x{X:02}>'", .{self.input[0]});
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.input = self.input[1..];
|
|||
|
|
byte_left.* -|= 1;
|
|||
|
|
if (byte_left.* == 0) self.state = .by_codepoint;
|
|||
|
|
return idx;
|
|||
|
|
},
|
|||
|
|
.by_codepoint => {
|
|||
|
|
// Try to lookup valid utf8 codepoint first.
|
|||
|
|
const utf8_len = try std.unicode.utf8ByteSequenceLength(self.input[0]);
|
|||
|
|
if (self.input.len < utf8_len) return error.TruncatedInput;
|
|||
|
|
if (tokenizer.lookup(self.input[0..utf8_len])) |idx| {
|
|||
|
|
self.input = self.input[utf8_len..];
|
|||
|
|
return idx;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Otherwise split in bytes if it's allowed.
|
|||
|
|
if (tokenizer.byte_fallback) {
|
|||
|
|
// TODO: replace this by a continue statement next time we bump Zig.
|
|||
|
|
self.state = .{ .by_byte = utf8_len };
|
|||
|
|
return self.nextCodepointToken(tokenizer);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Or mark the full utf8 codepoint as unknown.
|
|||
|
|
log.debug("Token not found for char '{s}'", .{self.input[0..utf8_len]});
|
|||
|
|
self.input = self.input[utf8_len..];
|
|||
|
|
return tokenizer.special_tokens.unk;
|
|||
|
|
},
|
|||
|
|
};
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
test CharTokenIterator {
|
|||
|
|
const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, .eos = 2 };
|
|||
|
|
var tokenizer = try Tokenizer.init(std.testing.allocator, 16, 4, null, special_tokens, true);
|
|||
|
|
defer tokenizer.deinit();
|
|||
|
|
|
|||
|
|
tokenizer.addOwnedToken(1.0, "<unk>"); // 0
|
|||
|
|
tokenizer.addOwnedToken(1.0, "<s>"); // 1
|
|||
|
|
tokenizer.addOwnedToken(1.0, "</s>"); // 2
|
|||
|
|
tokenizer.addOwnedToken(1.0, "ζ"); // 3
|
|||
|
|
tokenizer.addOwnedToken(1.0, &.{0xE2}); // 4: ℳ, first byte
|
|||
|
|
tokenizer.addOwnedToken(1.0, &.{0x84}); // 5: ℳ, second byte
|
|||
|
|
tokenizer.addOwnedToken(1.0, &.{0xB3}); // 6: ℳ, third byte
|
|||
|
|
tokenizer.addOwnedToken(1.0, "L"); // 7
|
|||
|
|
|
|||
|
|
// No byte fallback
|
|||
|
|
{
|
|||
|
|
tokenizer.byte_fallback = false;
|
|||
|
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
|||
|
|
var res: std.BoundedArray(u32, 8) = .{};
|
|||
|
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
|||
|
|
res.appendAssumeCapacity(token);
|
|||
|
|
}
|
|||
|
|
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 0, 7 }, res.constSlice());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// with byte fallback
|
|||
|
|
{
|
|||
|
|
tokenizer.byte_fallback = true;
|
|||
|
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
|||
|
|
var res: std.BoundedArray(u32, 8) = .{};
|
|||
|
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
|||
|
|
res.appendAssumeCapacity(token);
|
|||
|
|
}
|
|||
|
|
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 4, 5, 6, 7 }, res.constSlice());
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Text normalizer.
|
|||
|
|
/// Most tokenizer assumes the input text have been prepocessed with on of those.
|
|||
|
|
pub const Normalizer = struct {
|
|||
|
|
/// Space token used by sentencepiece derived tokenizer.
|
|||
|
|
pub const sentencepiece_space = "▁"; // \xe2\x96\x81
|
|||
|
|
|
|||
|
|
_whitespace: std.BoundedArray(u8, 8) = .{},
|
|||
|
|
|
|||
|
|
flags: packed struct {
|
|||
|
|
remove_extra_whitespaces: bool,
|
|||
|
|
add_dummy_prefix: bool,
|
|||
|
|
add_dummy_suffix: bool,
|
|||
|
|
/// Cheap lower casing.
|
|||
|
|
/// TODO: try to match Python "lower"
|
|||
|
|
lower_case_ascii: bool,
|
|||
|
|
/// cheap ascii punct splitting.
|
|||
|
|
// doing this processing ahead of time simplifies the logic
|
|||
|
|
split_on_punct_ascii: bool,
|
|||
|
|
},
|
|||
|
|
|
|||
|
|
pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer {
|
|||
|
|
var res: Normalizer = .{ .flags = flags };
|
|||
|
|
if (escaped_whitespace) |escaped| {
|
|||
|
|
res._whitespace.appendSliceAssumeCapacity(escaped);
|
|||
|
|
}
|
|||
|
|
return res;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub inline fn escapedSpace(self: Normalizer) ?[]const u8 {
|
|||
|
|
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
|
|||
|
|
try normalized.appendSlice(data);
|
|||
|
|
for (data) |_| try normalized_to_origin.append(consumed);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub const Result = struct {
|
|||
|
|
/// Normalized string
|
|||
|
|
normalized: []const u8,
|
|||
|
|
/// Mapping between chars in the original string and chars in the new string
|
|||
|
|
normalized_to_origin: []const usize,
|
|||
|
|
|
|||
|
|
pub fn deinit(self: Result, allocator: std.mem.Allocator) void {
|
|||
|
|
allocator.free(self.normalized);
|
|||
|
|
allocator.free(self.normalized_to_origin);
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
/// Simplifed version of Sentencepiece normalizer.
|
|||
|
|
///
|
|||
|
|
/// Llama2 uses a normalizer called "identity" so this basically only handles trailing
|
|||
|
|
/// whitespaces and replaces whitespace with the "▁" (U+2581) character.
|
|||
|
|
pub fn normalize(self: Normalizer, allocator: std.mem.Allocator, input: []const u8) ![]const u8 {
|
|||
|
|
const res = try self.normalizeWithMapping(allocator, input);
|
|||
|
|
allocator.free(res.normalized_to_origin);
|
|||
|
|
return res.normalized;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Returns both the normalized string and a mapping between the normalized string and the original.
|
|||
|
|
pub fn normalizeWithMapping(self: Normalizer, allocator: std.mem.Allocator, input: []const u8) !Result {
|
|||
|
|
// Number of bytes consumed from the input.
|
|||
|
|
var consumed: usize = 0;
|
|||
|
|
var trimmed_input = input;
|
|||
|
|
|
|||
|
|
// Skip leading whitespaces.
|
|||
|
|
if (self.flags.remove_extra_whitespaces) {
|
|||
|
|
while (trimmed_input.len != 0) {
|
|||
|
|
if (trimmed_input[0] != ' ') break;
|
|||
|
|
trimmed_input = trimmed_input[1..];
|
|||
|
|
consumed += 1;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// If the trimmed input is empty, we are done.
|
|||
|
|
if (trimmed_input.len == 0) {
|
|||
|
|
return .{ .normalized = &.{}, .normalized_to_origin = &.{} };
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Pre-allocate outputs
|
|||
|
|
const space = self.escapedSpace() orelse " ";
|
|||
|
|
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
|||
|
|
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
|||
|
|
errdefer normalized.deinit();
|
|||
|
|
var normalized_to_origin = try std.ArrayList(usize).initCapacity(allocator, normalized.capacity);
|
|||
|
|
errdefer normalized_to_origin.deinit();
|
|||
|
|
|
|||
|
|
// If the spec asks for it, add a whitespace at the beginning.
|
|||
|
|
if (self.flags.add_dummy_prefix) try addSlice(space, consumed, &normalized, &normalized_to_origin);
|
|||
|
|
|
|||
|
|
var is_prev_space: bool = true;
|
|||
|
|
var is_prev_word: bool = false;
|
|||
|
|
|
|||
|
|
while (trimmed_input.len != 0) {
|
|||
|
|
// NOTE(Corendos): This might feel weird but normally the slice we get comes from a normalizing process and can contain multiple codepoints.
|
|||
|
|
// Since we have an "identity" normalizer, each slice is actually a unicode character.
|
|||
|
|
const multibyte_length = try std.unicode.utf8ByteSequenceLength(trimmed_input[0]);
|
|||
|
|
var slice = trimmed_input[0..multibyte_length];
|
|||
|
|
const origin = consumed;
|
|||
|
|
consumed += multibyte_length;
|
|||
|
|
trimmed_input = trimmed_input[multibyte_length..];
|
|||
|
|
|
|||
|
|
if (self.flags.remove_extra_whitespaces and is_prev_space) {
|
|||
|
|
while (slice.len > 0 and slice[0] == ' ') {
|
|||
|
|
slice = slice[1..];
|
|||
|
|
}
|
|||
|
|
if (slice.len == 0) continue;
|
|||
|
|
}
|
|||
|
|
is_prev_space = slice[slice.len - 1] == ' ';
|
|||
|
|
|
|||
|
|
if (slice.len == 1) ascii: {
|
|||
|
|
// The more advanced logic only works with ascii atm
|
|||
|
|
var byte = slice[0];
|
|||
|
|
if (self.escapedSpace() != null and byte == ' ') {
|
|||
|
|
// replace the space token by the special token
|
|||
|
|
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
|||
|
|
is_prev_word = false;
|
|||
|
|
break :ascii;
|
|||
|
|
} else if (self.flags.split_on_punct_ascii) {
|
|||
|
|
if (is_prev_word and isPunct(slice)) {
|
|||
|
|
// Insert a space, but continue handling the rest
|
|||
|
|
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if (self.flags.lower_case_ascii) {
|
|||
|
|
byte = std.ascii.toLower(byte);
|
|||
|
|
}
|
|||
|
|
try normalized.append(byte);
|
|||
|
|
try normalized_to_origin.append(origin);
|
|||
|
|
} else {
|
|||
|
|
// we can safely copy to the output.
|
|||
|
|
try addSlice(slice, origin, &normalized, &normalized_to_origin);
|
|||
|
|
}
|
|||
|
|
is_prev_word = !is_prev_space and !isPunct(slice);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Skip trailing whitespaces
|
|||
|
|
if (self.flags.remove_extra_whitespaces) {
|
|||
|
|
while (std.mem.endsWith(u8, normalized.items, space)) {
|
|||
|
|
const length = normalized.items.len - space.len;
|
|||
|
|
consumed = normalized_to_origin.items[length];
|
|||
|
|
try normalized.resize(length);
|
|||
|
|
try normalized_to_origin.resize(length);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try normalized_to_origin.append(consumed);
|
|||
|
|
|
|||
|
|
std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1);
|
|||
|
|
|
|||
|
|
if (self.flags.add_dummy_suffix) try addSlice(space, consumed, &normalized, &normalized_to_origin);
|
|||
|
|
|
|||
|
|
return .{
|
|||
|
|
.normalized = try normalized.toOwnedSlice(),
|
|||
|
|
.normalized_to_origin = try normalized_to_origin.toOwnedSlice(),
|
|||
|
|
};
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn wellKnown(impl: KnownImplementation) Normalizer {
|
|||
|
|
return switch (impl) {
|
|||
|
|
.sentencepiece => init(.{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
}, sentencepiece_space),
|
|||
|
|
.llama3 => init(.{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = false,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
}, null),
|
|||
|
|
.gpt2 => init(.{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
}, null),
|
|||
|
|
};
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn fromHfJson(config: std.json.ObjectMap) error{InvalidNormalizerJson}!Normalizer {
|
|||
|
|
var normalizer: Normalizer = .{ .flags = .{
|
|||
|
|
.remove_extra_whitespaces = false,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.add_dummy_prefix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
} };
|
|||
|
|
|
|||
|
|
// Normalizer config can be a single normalizer, or a sequence of normalizers.
|
|||
|
|
const maybe_steps = objectGet(config, .array, "normalizers");
|
|||
|
|
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
|
|||
|
|
|
|||
|
|
for (steps) |step_val| {
|
|||
|
|
if (step_val != .object) {
|
|||
|
|
return error.InvalidNormalizerJson;
|
|||
|
|
}
|
|||
|
|
const step = step_val.object;
|
|||
|
|
|
|||
|
|
const step_type = objectGet(step, .string, "type") orelse {
|
|||
|
|
return error.InvalidNormalizerJson;
|
|||
|
|
};
|
|||
|
|
if (std.mem.eql(u8, "Prepend", step_type)) {
|
|||
|
|
normalizer.flags.add_dummy_prefix = true;
|
|||
|
|
} else if (std.mem.eql(u8, "Append", step_type)) {
|
|||
|
|
normalizer.flags.add_dummy_suffix = true;
|
|||
|
|
} else if (std.mem.eql(u8, "Replace", step_type)) {
|
|||
|
|
const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
|||
|
|
const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
|||
|
|
|
|||
|
|
if (std.mem.eql(u8, str_pattern, " ")) {
|
|||
|
|
normalizer._whitespace.appendSliceAssumeCapacity(
|
|||
|
|
objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
|||
|
|
);
|
|||
|
|
} else {
|
|||
|
|
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" });
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
log.warn("Unknown normalizer type: {s}", .{step_type});
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return normalizer;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
test "Normalizer.fromHfJson" {
|
|||
|
|
const config_json =
|
|||
|
|
\\{
|
|||
|
|
\\ "type": "Sequence",
|
|||
|
|
\\ "normalizers": [
|
|||
|
|
\\ {
|
|||
|
|
\\ "type": "Prepend",
|
|||
|
|
\\ "prepend": "▁"
|
|||
|
|
\\ },
|
|||
|
|
\\ {
|
|||
|
|
\\ "type": "Replace",
|
|||
|
|
\\ "pattern": {
|
|||
|
|
\\ "String": " "
|
|||
|
|
\\ },
|
|||
|
|
\\ "content": "▁"
|
|||
|
|
\\ }
|
|||
|
|
\\ ]
|
|||
|
|
\\}
|
|||
|
|
;
|
|||
|
|
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
|||
|
|
defer arena.deinit();
|
|||
|
|
const config = try std.json.parseFromSliceLeaky(std.json.Value, arena.allocator(), config_json, .{});
|
|||
|
|
const normalizer = try Normalizer.fromHfJson(config.object);
|
|||
|
|
|
|||
|
|
const expected = Normalizer{
|
|||
|
|
._whitespace = .{ .buffer = [_]u8{ 0xe2, 0x96, 0x81 } ++ [_]u8{0} ** 5, .len = 3 },
|
|||
|
|
.flags = .{
|
|||
|
|
.remove_extra_whitespaces = false,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
},
|
|||
|
|
};
|
|||
|
|
try std.testing.expectEqual(expected.flags, normalizer.flags);
|
|||
|
|
try std.testing.expectEqualStrings(expected.escapedSpace().?, normalizer.escapedSpace().?);
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
pub const KnownImplementation = enum(u8) {
|
|||
|
|
sentencepiece,
|
|||
|
|
gpt2,
|
|||
|
|
llama3,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
fn isPunct(unicode_char: []const u8) bool {
|
|||
|
|
// TODO use unicode categories
|
|||
|
|
if (unicode_char.len > 1) return false;
|
|||
|
|
|
|||
|
|
return switch (unicode_char[0]) {
|
|||
|
|
' ', '\t' => false,
|
|||
|
|
0...8 => true,
|
|||
|
|
10...31 => true,
|
|||
|
|
'!'...'/' => true,
|
|||
|
|
':'...'@' => true,
|
|||
|
|
'['...'`' => true,
|
|||
|
|
'{'...'~' => true,
|
|||
|
|
else => false,
|
|||
|
|
};
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
test Normalizer {
|
|||
|
|
{
|
|||
|
|
const n: Normalizer = .{ .flags = .{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
} };
|
|||
|
|
const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!");
|
|||
|
|
defer res.deinit(testing.allocator);
|
|||
|
|
|
|||
|
|
try testing.expectEqualSlices(u8, " Hellŏ world!", res.normalized);
|
|||
|
|
try testing.expectEqualSlices(
|
|||
|
|
usize,
|
|||
|
|
// H e l l ŏ ␣ w o r l d !
|
|||
|
|
&.{ 0, 0, 1, 2, 3, 4, 4, 6, 8, 9, 10, 11, 12, 13, 14 },
|
|||
|
|
res.normalized_to_origin,
|
|||
|
|
);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
const n: Normalizer = .{ .flags = .{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = true,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
} };
|
|||
|
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
|||
|
|
defer testing.allocator.free(res);
|
|||
|
|
|
|||
|
|
try testing.expectEqualSlices(u8, " Hello world! ", res);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
const n = Normalizer.init(
|
|||
|
|
.{
|
|||
|
|
.remove_extra_whitespaces = false,
|
|||
|
|
.add_dummy_prefix = true,
|
|||
|
|
.add_dummy_suffix = false,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
},
|
|||
|
|
Normalizer.sentencepiece_space,
|
|||
|
|
);
|
|||
|
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
|||
|
|
defer testing.allocator.free(res);
|
|||
|
|
|
|||
|
|
try testing.expectEqualSlices(u8, "▁Hello▁▁world!", res);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
const n: Normalizer = .{ .flags = .{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = false,
|
|||
|
|
.add_dummy_suffix = true,
|
|||
|
|
.lower_case_ascii = true,
|
|||
|
|
.split_on_punct_ascii = false,
|
|||
|
|
} };
|
|||
|
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
|||
|
|
defer testing.allocator.free(res);
|
|||
|
|
|
|||
|
|
try testing.expectEqualSlices(u8, "hello world! ", res);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
const n: Normalizer = .{ .flags = .{
|
|||
|
|
.remove_extra_whitespaces = true,
|
|||
|
|
.add_dummy_prefix = false,
|
|||
|
|
.add_dummy_suffix = true,
|
|||
|
|
.lower_case_ascii = false,
|
|||
|
|
.split_on_punct_ascii = true,
|
|||
|
|
} };
|
|||
|
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
|||
|
|
defer testing.allocator.free(res);
|
|||
|
|
|
|||
|
|
try testing.expectEqualSlices(u8, "Hello world ! ", res);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// gpt2 had their own way of storing text.
|
|||
|
|
/// Unfortunately this has contaminated other models.
|
|||
|
|
/// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm,
|
|||
|
|
/// into utf8 bytes, and do lookups at runtime.
|
|||
|
|
pub const Gpt2TextDecoder = struct {
|
|||
|
|
const Code = std.BoundedArray(u8, 2);
|
|||
|
|
|
|||
|
|
// TODO: benchmark this is more efficient than doing the conversion at runtime.
|
|||
|
|
code_to_byte: std.AutoArrayHashMap(Code, u8),
|
|||
|
|
|
|||
|
|
pub fn init(allocator: std.mem.Allocator) !Gpt2TextDecoder {
|
|||
|
|
var self = Gpt2TextDecoder{
|
|||
|
|
.code_to_byte = std.AutoArrayHashMap(Code, u8).init(allocator),
|
|||
|
|
};
|
|||
|
|
try self.code_to_byte.ensureTotalCapacity(256);
|
|||
|
|
errdefer unreachable;
|
|||
|
|
|
|||
|
|
var n: usize = 0;
|
|||
|
|
for (0..256) |index| {
|
|||
|
|
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
|
|||
|
|
const i: u8 = @intCast(index);
|
|||
|
|
if (isPrintableByte(i)) {
|
|||
|
|
if (std.ascii.isASCII(i)) {
|
|||
|
|
code.appendAssumeCapacity(i);
|
|||
|
|
} else {
|
|||
|
|
const codepoint: u21 = @as(u21, @intCast(i));
|
|||
|
|
code.len = @intCast(std.unicode.utf8Encode(codepoint, &code.buffer) catch unreachable);
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
const codepoint: u21 = 256 + @as(u21, @intCast(n));
|
|||
|
|
code.len = @intCast(std.unicode.utf8Encode(codepoint, &code.buffer) catch unreachable);
|
|||
|
|
n += 1;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.code_to_byte.putAssumeCapacityNoClobber(code, i);
|
|||
|
|
}
|
|||
|
|
return self;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn deinit(self: *Gpt2TextDecoder) void {
|
|||
|
|
self.code_to_byte.deinit();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Transform bytes representing text under the gpt2 encoding,
|
|||
|
|
/// and write to the `unicode` buffer utf-8 bytes.
|
|||
|
|
pub fn decode(self: Gpt2TextDecoder, unicode: *std.ArrayList(u8), bytes: []const u8) ![]const u8 {
|
|||
|
|
const start = unicode.items.len;
|
|||
|
|
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
|
|||
|
|
while (it.nextCodepointSlice()) |codepoint| {
|
|||
|
|
const code: Code = switch (codepoint.len) {
|
|||
|
|
1 => .{ .buffer = .{ codepoint[0], 0 }, .len = 1 }, // 0-init
|
|||
|
|
2 => .{ .buffer = .{ codepoint[0], codepoint[1] }, .len = 2 },
|
|||
|
|
else => return error.InvalidInput,
|
|||
|
|
};
|
|||
|
|
const byte = self.code_to_byte.get(code) orelse return error.InvalidInput;
|
|||
|
|
try unicode.append(byte);
|
|||
|
|
}
|
|||
|
|
return unicode.items[start..];
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
inline fn isPrintableByte(c: u8) bool {
|
|||
|
|
return ('!' <= c and c <= '~') or (0xa1 <= c and c <= 0xac) or (0xae <= c and c <= 0xff);
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
test Gpt2TextDecoder {
|
|||
|
|
var decoder = try Gpt2TextDecoder.init(testing.allocator);
|
|||
|
|
defer decoder.deinit();
|
|||
|
|
|
|||
|
|
var out = std.ArrayList(u8).init(testing.allocator);
|
|||
|
|
defer out.deinit();
|
|||
|
|
|
|||
|
|
// Ascii is not changed.
|
|||
|
|
try testing.expectEqualStrings("getTitle", try decoder.decode(&out, "getTitle"));
|
|||
|
|
// Leading space are represented with 'Ġ'
|
|||
|
|
try testing.expectEqualStrings(" UINavigationController", try decoder.decode(&out, "ĠUINavigationController"));
|
|||
|
|
// Russian is wild
|
|||
|
|
try testing.expectEqualStrings(" работ", try decoder.decode(&out, "ĠÑĢабоÑĤ"));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Open a json file in HF format and load the vocab from it.
|
|||
|
|
pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tokenizer {
|
|||
|
|
const file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
|||
|
|
defer file.close();
|
|||
|
|
|
|||
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|||
|
|
defer arena_state.deinit();
|
|||
|
|
const arena = arena_state.allocator();
|
|||
|
|
const file_content = try file.readToEndAlloc(arena, 32 * 1024 * 1024);
|
|||
|
|
|
|||
|
|
const info = try std.json.parseFromSliceLeaky(std.json.Value, arena, file_content, .{
|
|||
|
|
.duplicate_field_behavior = .use_last,
|
|||
|
|
});
|
|||
|
|
const main_object = switch (info) {
|
|||
|
|
.object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) {
|
|||
|
|
return error.InvalidFormat;
|
|||
|
|
} else obj,
|
|||
|
|
else => return error.InvalidFormat,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat;
|
|||
|
|
const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat;
|
|||
|
|
const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{};
|
|||
|
|
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len);
|
|||
|
|
|
|||
|
|
const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config|
|
|||
|
|
try Normalizer.fromHfJson(normalizer_config)
|
|||
|
|
else
|
|||
|
|
Normalizer.wellKnown(.llama3);
|
|||
|
|
|
|||
|
|
// delay init of special tokens.
|
|||
|
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
|||
|
|
errdefer tokenizer.deinit();
|
|||
|
|
|
|||
|
|
// Buffer containing all concatenated tokens.
|
|||
|
|
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
|||
|
|
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
|||
|
|
const original_alloc = all_tokens.items.ptr;
|
|||
|
|
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
|||
|
|
// If this is too annoying we could switch to a custom type instead of slices.
|
|||
|
|
defer {
|
|||
|
|
std.debug.assert(all_tokens.items.ptr == original_alloc);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// gpt2 based tokenizer got a special way of encoding unicode.
|
|||
|
|
// we don't know in advance if this will be used by this tokenizer or not.
|
|||
|
|
// so we assume it is the case, but if we find some unicode character,
|
|||
|
|
// outside of the range used by gpt2 we know it was wrong, and start over.
|
|||
|
|
var is_gpt2_vocab: bool = true;
|
|||
|
|
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
|
|||
|
|
defer gpt2_decoder.deinit();
|
|||
|
|
var it = vocab.iterator();
|
|||
|
|
while (it.next()) |kv| {
|
|||
|
|
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
|
|||
|
|
switch (err) {
|
|||
|
|
error.InvalidInput => {
|
|||
|
|
is_gpt2_vocab = false;
|
|||
|
|
break;
|
|||
|
|
},
|
|||
|
|
else => return err,
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
|||
|
|
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if (!is_gpt2_vocab) {
|
|||
|
|
// We where wrong, this is not a gpt2 vocab, start over,
|
|||
|
|
// and reset the tokenizer state.
|
|||
|
|
tokenizer.next_token_id = 0;
|
|||
|
|
tokenizer.token_lookup.clearRetainingCapacity();
|
|||
|
|
all_tokens.clearRetainingCapacity();
|
|||
|
|
it = vocab.iterator();
|
|||
|
|
while (it.next()) |kv| {
|
|||
|
|
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
|||
|
|
const token = try dup(&all_tokens, kv.key_ptr.*);
|
|||
|
|
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// More tokens, typically added during fine tuning of the model.
|
|||
|
|
for (added_tokens) |token_obj| {
|
|||
|
|
if (token_obj != .object) return error.InvalidFormat;
|
|||
|
|
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
|
|||
|
|
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
|
|||
|
|
const token = try if (is_gpt2_vocab)
|
|||
|
|
gpt2_decoder.decode(&all_tokens, v)
|
|||
|
|
else
|
|||
|
|
dup(&all_tokens, v);
|
|||
|
|
|
|||
|
|
tokenizer.addOwnedTokenByIndex(id, 0, token);
|
|||
|
|
}
|
|||
|
|
// We won't add more tokens here, let release.
|
|||
|
|
all_tokens.shrinkAndFree(all_tokens.items.len);
|
|||
|
|
|
|||
|
|
var unk = tokenizer.lookup("<unk>");
|
|||
|
|
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
|
|||
|
|
unk = @intCast(unk_tok);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tokenizer.special_tokens = .{
|
|||
|
|
// TODO allow users to specify special tokens or read them from a tokenizer_config.json file
|
|||
|
|
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
|||
|
|
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
|
|||
|
|
.unk = unk orelse std.math.maxInt(u32),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false;
|
|||
|
|
if (!byte_fallback and unk == null) {
|
|||
|
|
// GPT2 tokenizer have byte fallback already encoded in the model,
|
|||
|
|
// but the json generally don't have the field set.
|
|||
|
|
// We can detect it though because they don't specify an unknown token.
|
|||
|
|
if (is_gpt2_vocab) {
|
|||
|
|
tokenizer.byte_fallback = true;
|
|||
|
|
} else {
|
|||
|
|
log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{});
|
|||
|
|
}
|
|||
|
|
} else if (byte_fallback) {
|
|||
|
|
try tokenizer.rewriteByteFallbackTokens();
|
|||
|
|
}
|
|||
|
|
return tokenizer;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Returns a copy of the given string, stored inside the given ArrayList.
|
|||
|
|
fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
|||
|
|
const n = buffer.items.len;
|
|||
|
|
try buffer.appendSlice(str);
|
|||
|
|
return buffer.items[n..];
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Returns the given entry in a json object only if it has the right type.
|
|||
|
|
fn objectGet(
|
|||
|
|
object: std.json.ObjectMap,
|
|||
|
|
comptime kind: std.meta.FieldEnum(std.json.Value),
|
|||
|
|
key: []const u8,
|
|||
|
|
) ?std.meta.FieldType(std.json.Value, kind) {
|
|||
|
|
const val = object.get(key) orelse return null;
|
|||
|
|
if (val != kind) return null;
|
|||
|
|
return @field(val, @tagName(kind));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer {
|
|||
|
|
const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
|||
|
|
defer tokenizer_file.close();
|
|||
|
|
var tok_reader = std.io.bufferedReader(tokenizer_file.reader());
|
|||
|
|
const r = tok_reader.reader();
|
|||
|
|
|
|||
|
|
const max_token_len = try r.readInt(u32, .little);
|
|||
|
|
const special_tokens: Tokenizer.SpecialTokens = .{
|
|||
|
|
.unk = 0,
|
|||
|
|
.bos = 1,
|
|||
|
|
.eos = 2,
|
|||
|
|
};
|
|||
|
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true);
|
|||
|
|
var i: u32 = 0;
|
|||
|
|
while (readToken(&tokenizer, &r)) : (i += 1) {
|
|||
|
|
// Pass
|
|||
|
|
} else |_| {
|
|||
|
|
if (i < vocab_size) {
|
|||
|
|
log.info("Read {d} words out of {?d}", .{ i, vocab_size });
|
|||
|
|
}
|
|||
|
|
tokenizer.vocab_size = i;
|
|||
|
|
}
|
|||
|
|
try tokenizer.rewriteByteFallbackTokens();
|
|||
|
|
return tokenizer;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn readToken(tokenizer: *Tokenizer, tok_reader: anytype) !void {
|
|||
|
|
const score: f32 = @bitCast(try tok_reader.readInt(u32, .little));
|
|||
|
|
const len: usize = @intCast(try tok_reader.readInt(u32, .little));
|
|||
|
|
try tokenizer.readTokenInto(score, len, tok_reader);
|
|||
|
|
}
|