zml.tokenizer: Implement proper byte fallback support by converting hex byte strings (e.g., “<0x40>”) to their characters and splitting unknown UTF‑8 codepoints into bytes, fixing tokenization.

This commit is contained in:
Tarry Singh 2023-02-28 14:40:25 +00:00
parent 2f129f76c9
commit ecf52ad724
2 changed files with 203 additions and 38 deletions

View File

@ -49,6 +49,10 @@ pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentence
for (model.pieces.items) |*piece| {
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
}
const byte_fallback = model.trainer_spec.?.byte_fallback orelse false;
if (byte_fallback) {
try tokenizer.rewriteByteFallbackTokens();
}
return tokenizer;
}

View File

@ -10,6 +10,8 @@ const meta = @import("meta.zig");
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(Normalizer);
std.testing.refAllDecls(Tokenizer);
}
/// Byte Pair Encoding tokenizer generally used for LLM.
@ -21,6 +23,8 @@ pub const Tokenizer = struct {
scores: []f32,
max_token_len: u32,
normalizer: ?Normalizer,
// Allows to split unknown unicode characters into bytes.
byte_fallback: bool = false,
arena_state: std.heap.ArenaAllocator,
vocab_size: u32,
@ -81,14 +85,13 @@ pub const Tokenizer = struct {
const n = try tok_reader.read(token);
std.debug.assert(n == len);
self.addOwnedToken(score, token);
return self.addOwnedToken(score, token);
}
/// Adds a new token (and copy it)
pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void {
const arena = self.arena_state.allocator();
self.addOwnedToken(score, try arena.dupe(u8, token));
return self.addOwnedToken(score, try arena.dupe(u8, token));
}
/// Adds a new token (without copying it)
@ -142,24 +145,13 @@ pub const Tokenizer = struct {
const mergeable = try allocator.alloc(MergeState, tok_buff.len);
var num_tokens: usize = 0;
var off: usize = 0;
while (off < input.len) {
const utf_len = try std.unicode.utf8ByteSequenceLength(input[off]);
defer off += utf_len;
mergeable[num_tokens] = .idk;
defer num_tokens += 1;
const char = input[off..][0..utf_len];
tok_buff[num_tokens] = self.lookup(char) orelse
// TODO: split unknown token into bytes if model supports it
self.special_tokens.unk;
if (tok_buff[num_tokens] == self.special_tokens.unk) {
log.debug("Token not found for char '{s}' (@{x})", .{ char, char });
}
if (tok_buff[num_tokens] == self.special_tokens.hard_space) {
mergeable[num_tokens] = .hard_space;
}
var it: CharTokenIterator = .{ .input = input };
while (try it.nextCodepointToken(self)) |token| : (num_tokens += 1) {
tok_buff[num_tokens] = token;
mergeable[num_tokens] = if (token == self.special_tokens.hard_space)
.hard_space
else
.idk;
}
var stable_prefix: usize = 0;
@ -197,7 +189,6 @@ pub const Tokenizer = struct {
continue;
},
.idk => {
const next_tok = self.tokens[tok_buff[i + 1]];
// Special tokens can't be concatenated.
if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) {
@ -206,7 +197,10 @@ pub const Tokenizer = struct {
meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
}
const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok.len];
const next_tok = self.tokens[tok_buff[i + 1]];
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.
const next_tok_len = if (tok_buff[i + 1] == self.special_tokens.unk) 1 else next_tok.len;
const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok_len];
// Save the result
mergeable[i] = if (self.lookup(concat_tokens)) |tok|
.{ .ready = tok }
@ -310,6 +304,41 @@ pub const Tokenizer = struct {
if (opts.sep.len > 0) try output.appendSlice(opts.sep);
}
}
/// Some tokenizers have bytes encoded in hex like this: "<0x40>".
/// This break the tokenization algorithm because the input text
/// will contain "@" not "<0x40>",
/// and if the input contains "<0x40>" it needs to not be treated as a single byte.
/// So we replace byte fallbacks strings, by their corresponding character.
/// This enables the normal tokenization algorithm to work.
pub fn rewriteByteFallbackTokens(tokenizer: *Tokenizer) !void {
tokenizer.byte_fallback = true;
var single_bytes = try tokenizer.arena_state.allocator().alloc(u8, 256);
var byte_fallback_buf = "<0x00>".*;
for (0..256) |i| {
const c: u8 = @truncate(i);
single_bytes[i] = c;
// First lookup the byte fallback entry.
// Note: we assume upper case, but we could try both upper and lower case if needed.
_ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 });
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
return error.InvalidInput;
};
// Check if the character is already present in the vocab.
// In that case, nothing to do,
// but note that the fallback token will be "unreachable",
// ie there is no way the tokenizer can produce it.
if (tokenizer.token_lookup.get(&.{c})) |_| continue;
const idx: u32 = entry.value_ptr.*;
tokenizer.token_lookup.removeByPtr(entry.key_ptr);
tokenizer.addOwnedTokenByIndex(idx, tokenizer.scores[idx], single_bytes[i .. i + 1]);
}
}
};
test Tokenizer {
@ -332,6 +361,91 @@ test Tokenizer {
// TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto
}
/// Given a slice, split it in the most simple tokens using the given tokenizer tokens.
/// The output of this can be used to initialize the tokenization algorithm.
/// Normally we split the input text into utf8 codepoint,
/// but if we find an unknown codepoint we either split it in bytes, or use the special "unknown" token,
/// depending on the tokenizer configuration.
const CharTokenIterator = struct {
state: union(enum) { by_codepoint, by_byte: u8 } = .by_codepoint,
input: []const u8,
fn nextCodepointToken(self: *CharTokenIterator, tokenizer: *const Tokenizer) error{ TruncatedInput, Utf8InvalidStartByte }!?u32 {
if (self.input.len == 0) return null;
return switch (self.state) {
.by_byte => |*byte_left| {
const idx = tokenizer.lookup(self.input[0..1]) orelse {
// Normally this has been caught when calling `rewriteByteFallbackTokens`.
std.debug.panic("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback for token '<0x{X:02}>'", .{self.input[0]});
};
self.input = self.input[1..];
byte_left.* -|= 1;
if (byte_left.* == 0) self.state = .by_codepoint;
return idx;
},
.by_codepoint => {
// Try to lookup valid utf8 codepoint first.
const utf8_len = try std.unicode.utf8ByteSequenceLength(self.input[0]);
if (self.input.len < utf8_len) return error.TruncatedInput;
if (tokenizer.lookup(self.input[0..utf8_len])) |idx| {
self.input = self.input[utf8_len..];
return idx;
}
// Otherwise split in bytes if it's allowed.
if (tokenizer.byte_fallback) {
// TODO: replace this by a continue statement next time we bump Zig.
self.state = .{ .by_byte = utf8_len };
return self.nextCodepointToken(tokenizer);
}
// Or mark the full utf8 codepoint as unknown.
log.debug("Token not found for char '{s}'", .{self.input[0..utf8_len]});
self.input = self.input[utf8_len..];
return tokenizer.special_tokens.unk;
},
};
}
};
test CharTokenIterator {
const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, .eos = 2 };
var tokenizer = try Tokenizer.init(std.testing.allocator, 16, 4, null, special_tokens, true);
defer tokenizer.deinit();
tokenizer.addOwnedToken(1.0, "<unk>"); // 0
tokenizer.addOwnedToken(1.0, "<s>"); // 1
tokenizer.addOwnedToken(1.0, "</s>"); // 2
tokenizer.addOwnedToken(1.0, "ζ"); // 3
tokenizer.addOwnedToken(1.0, &.{0xE2}); // 4: , first byte
tokenizer.addOwnedToken(1.0, &.{0x84}); // 5: , second byte
tokenizer.addOwnedToken(1.0, &.{0xB3}); // 6: , third byte
tokenizer.addOwnedToken(1.0, "L"); // 7
// No byte fallback
{
tokenizer.byte_fallback = false;
var it: CharTokenIterator = .{ .input = "ζL" };
var res: std.BoundedArray(u32, 8) = .{};
while (try it.nextCodepointToken(&tokenizer)) |token| {
res.appendAssumeCapacity(token);
}
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 0, 7 }, res.constSlice());
}
// with byte fallback
{
tokenizer.byte_fallback = true;
var it: CharTokenIterator = .{ .input = "ζL" };
var res: std.BoundedArray(u32, 8) = .{};
while (try it.nextCodepointToken(&tokenizer)) |token| {
res.appendAssumeCapacity(token);
}
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 4, 5, 6, 7 }, res.constSlice());
}
}
/// Text normalizer. Most tokenizer assumes the input text have been prepocessed
/// with on of those.
pub const Normalizer = struct {
@ -613,7 +727,6 @@ pub const Gpt2TextDecoder = struct {
try self.code_to_byte.ensureTotalCapacity(256);
errdefer unreachable;
// The eon
var n: usize = 0;
for (0..256) |index| {
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
@ -703,34 +816,70 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
const vocab = main_object.get("model").?.object.get("vocab").?.object;
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
// TODO not all tokenizer.json are Gpt2 encoded, detect when it's needed or not.
const normalizer = Normalizer.wellKnown(.gpt2);
var decoder = try Gpt2TextDecoder.init(allocator);
defer decoder.deinit();
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
defer gpt2_decoder.deinit();
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
errdefer tokenizer.deinit();
// Buffer containing all concatenated tokens.
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), 24 * vocab.count());
defer all_tokens.shrinkAndFree(all_tokens.items.len);
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
const original_alloc = all_tokens.items.ptr;
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
// If this is too annoying we could switch to a custom type instead of slices.
defer {
std.debug.assert(all_tokens.items.ptr == original_alloc);
}
var it = vocab.iterator();
// gpt2 based tokenizer got a special way of encoding unicode.
// we don't know in advance if this will be used by this tokenizer or not.
// so we assume it is the case, but if we find some unicode character,
// outside of the range used by gpt2 we know it was wrong, and start over.
var is_gpt2_vocab: bool = true;
while (it.next()) |kv| {
const token = try decoder.decode(&all_tokens, kv.key_ptr.*);
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
switch (err) {
error.InvalidInput => {
is_gpt2_vocab = false;
break;
},
else => return err,
}
};
const idx: u32 = @intCast(kv.value_ptr.*.integer);
// std.debug.assert(idx == tokenizer.next_token_id);
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
}
for (added_tokens.items) |token_obj| {
const token = try decoder.decode(&all_tokens, token_obj.object.get("content").?.string);
tokenizer.addOwnedTokenByIndex(
@intCast(token_obj.object.get("id").?.integer),
0,
token,
);
if (!is_gpt2_vocab) {
// We where wrong, this is not a gpt2 vocab, start over,
// and reset the tokenizer state.
tokenizer.next_token_id = 0;
tokenizer.token_lookup.clearRetainingCapacity();
all_tokens.clearRetainingCapacity();
it = vocab.iterator();
while (it.next()) |kv| {
const idx: u32 = @intCast(kv.value_ptr.*.integer);
const token = try dup(&all_tokens, kv.key_ptr.*);
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
}
}
// More tokens, typically added during fine tuning of the model.
for (added_tokens.items) |token_obj| {
const v = token_obj.object.get("content").?.string;
const id: u32 = @intCast(token_obj.object.get("id").?.integer);
const token = try if (is_gpt2_vocab)
gpt2_decoder.decode(&all_tokens, v)
else
dup(&all_tokens, v);
tokenizer.addOwnedTokenByIndex(id, 0, token);
}
// We won't add more tokens here, let release.
all_tokens.shrinkAndFree(all_tokens.items.len);
tokenizer.special_tokens = .{
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
@ -738,5 +887,17 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
};
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| {
if (byte_fallback == .bool and byte_fallback.bool) {
try tokenizer.rewriteByteFallbackTokens();
}
}
return tokenizer;
}
/// Returns a copy of the given string, stored inside the given ArrayList.
fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
const n = buffer.items.len;
try buffer.appendSlice(str);
return buffer.items[n..];
}