zml.tokenizer: Implement proper byte fallback support by converting hex byte strings (e.g., “<0x40>”) to their characters and splitting unknown UTF‑8 codepoints into bytes, fixing tokenization.
This commit is contained in:
parent
2f129f76c9
commit
ecf52ad724
@ -49,6 +49,10 @@ pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentence
|
|||||||
for (model.pieces.items) |*piece| {
|
for (model.pieces.items) |*piece| {
|
||||||
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
|
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
|
||||||
}
|
}
|
||||||
|
const byte_fallback = model.trainer_spec.?.byte_fallback orelse false;
|
||||||
|
if (byte_fallback) {
|
||||||
|
try tokenizer.rewriteByteFallbackTokens();
|
||||||
|
}
|
||||||
|
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,8 @@ const meta = @import("meta.zig");
|
|||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
|
std.testing.refAllDecls(Normalizer);
|
||||||
|
std.testing.refAllDecls(Tokenizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Byte Pair Encoding tokenizer generally used for LLM.
|
/// Byte Pair Encoding tokenizer generally used for LLM.
|
||||||
@ -21,6 +23,8 @@ pub const Tokenizer = struct {
|
|||||||
scores: []f32,
|
scores: []f32,
|
||||||
max_token_len: u32,
|
max_token_len: u32,
|
||||||
normalizer: ?Normalizer,
|
normalizer: ?Normalizer,
|
||||||
|
// Allows to split unknown unicode characters into bytes.
|
||||||
|
byte_fallback: bool = false,
|
||||||
|
|
||||||
arena_state: std.heap.ArenaAllocator,
|
arena_state: std.heap.ArenaAllocator,
|
||||||
vocab_size: u32,
|
vocab_size: u32,
|
||||||
@ -81,14 +85,13 @@ pub const Tokenizer = struct {
|
|||||||
const n = try tok_reader.read(token);
|
const n = try tok_reader.read(token);
|
||||||
std.debug.assert(n == len);
|
std.debug.assert(n == len);
|
||||||
|
|
||||||
self.addOwnedToken(score, token);
|
return self.addOwnedToken(score, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a new token (and copy it)
|
/// Adds a new token (and copy it)
|
||||||
pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void {
|
pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void {
|
||||||
const arena = self.arena_state.allocator();
|
const arena = self.arena_state.allocator();
|
||||||
|
return self.addOwnedToken(score, try arena.dupe(u8, token));
|
||||||
self.addOwnedToken(score, try arena.dupe(u8, token));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a new token (without copying it)
|
/// Adds a new token (without copying it)
|
||||||
@ -142,24 +145,13 @@ pub const Tokenizer = struct {
|
|||||||
const mergeable = try allocator.alloc(MergeState, tok_buff.len);
|
const mergeable = try allocator.alloc(MergeState, tok_buff.len);
|
||||||
|
|
||||||
var num_tokens: usize = 0;
|
var num_tokens: usize = 0;
|
||||||
var off: usize = 0;
|
var it: CharTokenIterator = .{ .input = input };
|
||||||
while (off < input.len) {
|
while (try it.nextCodepointToken(self)) |token| : (num_tokens += 1) {
|
||||||
const utf_len = try std.unicode.utf8ByteSequenceLength(input[off]);
|
tok_buff[num_tokens] = token;
|
||||||
defer off += utf_len;
|
mergeable[num_tokens] = if (token == self.special_tokens.hard_space)
|
||||||
|
.hard_space
|
||||||
mergeable[num_tokens] = .idk;
|
else
|
||||||
defer num_tokens += 1;
|
.idk;
|
||||||
|
|
||||||
const char = input[off..][0..utf_len];
|
|
||||||
tok_buff[num_tokens] = self.lookup(char) orelse
|
|
||||||
// TODO: split unknown token into bytes if model supports it
|
|
||||||
self.special_tokens.unk;
|
|
||||||
if (tok_buff[num_tokens] == self.special_tokens.unk) {
|
|
||||||
log.debug("Token not found for char '{s}' (@{x})", .{ char, char });
|
|
||||||
}
|
|
||||||
if (tok_buff[num_tokens] == self.special_tokens.hard_space) {
|
|
||||||
mergeable[num_tokens] = .hard_space;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var stable_prefix: usize = 0;
|
var stable_prefix: usize = 0;
|
||||||
@ -197,7 +189,6 @@ pub const Tokenizer = struct {
|
|||||||
continue;
|
continue;
|
||||||
},
|
},
|
||||||
.idk => {
|
.idk => {
|
||||||
const next_tok = self.tokens[tok_buff[i + 1]];
|
|
||||||
|
|
||||||
// Special tokens can't be concatenated.
|
// Special tokens can't be concatenated.
|
||||||
if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) {
|
if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) {
|
||||||
@ -206,7 +197,10 @@ pub const Tokenizer = struct {
|
|||||||
|
|
||||||
meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
|
meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] });
|
||||||
}
|
}
|
||||||
const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok.len];
|
const next_tok = self.tokens[tok_buff[i + 1]];
|
||||||
|
// if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token.
|
||||||
|
const next_tok_len = if (tok_buff[i + 1] == self.special_tokens.unk) 1 else next_tok.len;
|
||||||
|
const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok_len];
|
||||||
// Save the result
|
// Save the result
|
||||||
mergeable[i] = if (self.lookup(concat_tokens)) |tok|
|
mergeable[i] = if (self.lookup(concat_tokens)) |tok|
|
||||||
.{ .ready = tok }
|
.{ .ready = tok }
|
||||||
@ -310,6 +304,41 @@ pub const Tokenizer = struct {
|
|||||||
if (opts.sep.len > 0) try output.appendSlice(opts.sep);
|
if (opts.sep.len > 0) try output.appendSlice(opts.sep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Some tokenizers have bytes encoded in hex like this: "<0x40>".
|
||||||
|
/// This break the tokenization algorithm because the input text
|
||||||
|
/// will contain "@" not "<0x40>",
|
||||||
|
/// and if the input contains "<0x40>" it needs to not be treated as a single byte.
|
||||||
|
/// So we replace byte fallbacks strings, by their corresponding character.
|
||||||
|
/// This enables the normal tokenization algorithm to work.
|
||||||
|
pub fn rewriteByteFallbackTokens(tokenizer: *Tokenizer) !void {
|
||||||
|
tokenizer.byte_fallback = true;
|
||||||
|
var single_bytes = try tokenizer.arena_state.allocator().alloc(u8, 256);
|
||||||
|
var byte_fallback_buf = "<0x00>".*;
|
||||||
|
|
||||||
|
for (0..256) |i| {
|
||||||
|
const c: u8 = @truncate(i);
|
||||||
|
single_bytes[i] = c;
|
||||||
|
|
||||||
|
// First lookup the byte fallback entry.
|
||||||
|
// Note: we assume upper case, but we could try both upper and lower case if needed.
|
||||||
|
_ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 });
|
||||||
|
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
|
||||||
|
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
|
||||||
|
return error.InvalidInput;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the character is already present in the vocab.
|
||||||
|
// In that case, nothing to do,
|
||||||
|
// but note that the fallback token will be "unreachable",
|
||||||
|
// ie there is no way the tokenizer can produce it.
|
||||||
|
if (tokenizer.token_lookup.get(&.{c})) |_| continue;
|
||||||
|
|
||||||
|
const idx: u32 = entry.value_ptr.*;
|
||||||
|
tokenizer.token_lookup.removeByPtr(entry.key_ptr);
|
||||||
|
tokenizer.addOwnedTokenByIndex(idx, tokenizer.scores[idx], single_bytes[i .. i + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
test Tokenizer {
|
test Tokenizer {
|
||||||
@ -332,6 +361,91 @@ test Tokenizer {
|
|||||||
// TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto
|
// TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Given a slice, split it in the most simple tokens using the given tokenizer tokens.
|
||||||
|
/// The output of this can be used to initialize the tokenization algorithm.
|
||||||
|
/// Normally we split the input text into utf8 codepoint,
|
||||||
|
/// but if we find an unknown codepoint we either split it in bytes, or use the special "unknown" token,
|
||||||
|
/// depending on the tokenizer configuration.
|
||||||
|
const CharTokenIterator = struct {
|
||||||
|
state: union(enum) { by_codepoint, by_byte: u8 } = .by_codepoint,
|
||||||
|
input: []const u8,
|
||||||
|
|
||||||
|
fn nextCodepointToken(self: *CharTokenIterator, tokenizer: *const Tokenizer) error{ TruncatedInput, Utf8InvalidStartByte }!?u32 {
|
||||||
|
if (self.input.len == 0) return null;
|
||||||
|
return switch (self.state) {
|
||||||
|
.by_byte => |*byte_left| {
|
||||||
|
const idx = tokenizer.lookup(self.input[0..1]) orelse {
|
||||||
|
// Normally this has been caught when calling `rewriteByteFallbackTokens`.
|
||||||
|
std.debug.panic("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback for token '<0x{X:02}>'", .{self.input[0]});
|
||||||
|
};
|
||||||
|
|
||||||
|
self.input = self.input[1..];
|
||||||
|
byte_left.* -|= 1;
|
||||||
|
if (byte_left.* == 0) self.state = .by_codepoint;
|
||||||
|
return idx;
|
||||||
|
},
|
||||||
|
.by_codepoint => {
|
||||||
|
// Try to lookup valid utf8 codepoint first.
|
||||||
|
const utf8_len = try std.unicode.utf8ByteSequenceLength(self.input[0]);
|
||||||
|
if (self.input.len < utf8_len) return error.TruncatedInput;
|
||||||
|
if (tokenizer.lookup(self.input[0..utf8_len])) |idx| {
|
||||||
|
self.input = self.input[utf8_len..];
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise split in bytes if it's allowed.
|
||||||
|
if (tokenizer.byte_fallback) {
|
||||||
|
// TODO: replace this by a continue statement next time we bump Zig.
|
||||||
|
self.state = .{ .by_byte = utf8_len };
|
||||||
|
return self.nextCodepointToken(tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or mark the full utf8 codepoint as unknown.
|
||||||
|
log.debug("Token not found for char '{s}'", .{self.input[0..utf8_len]});
|
||||||
|
self.input = self.input[utf8_len..];
|
||||||
|
return tokenizer.special_tokens.unk;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
test CharTokenIterator {
|
||||||
|
const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, .eos = 2 };
|
||||||
|
var tokenizer = try Tokenizer.init(std.testing.allocator, 16, 4, null, special_tokens, true);
|
||||||
|
defer tokenizer.deinit();
|
||||||
|
|
||||||
|
tokenizer.addOwnedToken(1.0, "<unk>"); // 0
|
||||||
|
tokenizer.addOwnedToken(1.0, "<s>"); // 1
|
||||||
|
tokenizer.addOwnedToken(1.0, "</s>"); // 2
|
||||||
|
tokenizer.addOwnedToken(1.0, "ζ"); // 3
|
||||||
|
tokenizer.addOwnedToken(1.0, &.{0xE2}); // 4: ℳ, first byte
|
||||||
|
tokenizer.addOwnedToken(1.0, &.{0x84}); // 5: ℳ, second byte
|
||||||
|
tokenizer.addOwnedToken(1.0, &.{0xB3}); // 6: ℳ, third byte
|
||||||
|
tokenizer.addOwnedToken(1.0, "L"); // 7
|
||||||
|
|
||||||
|
// No byte fallback
|
||||||
|
{
|
||||||
|
tokenizer.byte_fallback = false;
|
||||||
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
||||||
|
var res: std.BoundedArray(u32, 8) = .{};
|
||||||
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
||||||
|
res.appendAssumeCapacity(token);
|
||||||
|
}
|
||||||
|
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 0, 7 }, res.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
|
// with byte fallback
|
||||||
|
{
|
||||||
|
tokenizer.byte_fallback = true;
|
||||||
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
||||||
|
var res: std.BoundedArray(u32, 8) = .{};
|
||||||
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
||||||
|
res.appendAssumeCapacity(token);
|
||||||
|
}
|
||||||
|
try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 4, 5, 6, 7 }, res.constSlice());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Text normalizer. Most tokenizer assumes the input text have been prepocessed
|
/// Text normalizer. Most tokenizer assumes the input text have been prepocessed
|
||||||
/// with on of those.
|
/// with on of those.
|
||||||
pub const Normalizer = struct {
|
pub const Normalizer = struct {
|
||||||
@ -613,7 +727,6 @@ pub const Gpt2TextDecoder = struct {
|
|||||||
try self.code_to_byte.ensureTotalCapacity(256);
|
try self.code_to_byte.ensureTotalCapacity(256);
|
||||||
errdefer unreachable;
|
errdefer unreachable;
|
||||||
|
|
||||||
// The eon
|
|
||||||
var n: usize = 0;
|
var n: usize = 0;
|
||||||
for (0..256) |index| {
|
for (0..256) |index| {
|
||||||
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
|
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
|
||||||
@ -703,34 +816,70 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
||||||
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
||||||
|
|
||||||
// TODO not all tokenizer.json are Gpt2 encoded, detect when it's needed or not.
|
|
||||||
const normalizer = Normalizer.wellKnown(.gpt2);
|
const normalizer = Normalizer.wellKnown(.gpt2);
|
||||||
var decoder = try Gpt2TextDecoder.init(allocator);
|
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
|
||||||
defer decoder.deinit();
|
defer gpt2_decoder.deinit();
|
||||||
|
|
||||||
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
||||||
|
errdefer tokenizer.deinit();
|
||||||
|
|
||||||
// Buffer containing all concatenated tokens.
|
// Buffer containing all concatenated tokens.
|
||||||
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
||||||
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), 24 * vocab.count());
|
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
||||||
defer all_tokens.shrinkAndFree(all_tokens.items.len);
|
const original_alloc = all_tokens.items.ptr;
|
||||||
|
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
||||||
|
// If this is too annoying we could switch to a custom type instead of slices.
|
||||||
|
defer {
|
||||||
|
std.debug.assert(all_tokens.items.ptr == original_alloc);
|
||||||
|
}
|
||||||
|
|
||||||
var it = vocab.iterator();
|
var it = vocab.iterator();
|
||||||
|
// gpt2 based tokenizer got a special way of encoding unicode.
|
||||||
|
// we don't know in advance if this will be used by this tokenizer or not.
|
||||||
|
// so we assume it is the case, but if we find some unicode character,
|
||||||
|
// outside of the range used by gpt2 we know it was wrong, and start over.
|
||||||
|
var is_gpt2_vocab: bool = true;
|
||||||
while (it.next()) |kv| {
|
while (it.next()) |kv| {
|
||||||
const token = try decoder.decode(&all_tokens, kv.key_ptr.*);
|
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
|
||||||
|
switch (err) {
|
||||||
|
error.InvalidInput => {
|
||||||
|
is_gpt2_vocab = false;
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
else => return err,
|
||||||
|
}
|
||||||
|
};
|
||||||
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
||||||
// std.debug.assert(idx == tokenizer.next_token_id);
|
|
||||||
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (added_tokens.items) |token_obj| {
|
if (!is_gpt2_vocab) {
|
||||||
const token = try decoder.decode(&all_tokens, token_obj.object.get("content").?.string);
|
// We where wrong, this is not a gpt2 vocab, start over,
|
||||||
tokenizer.addOwnedTokenByIndex(
|
// and reset the tokenizer state.
|
||||||
@intCast(token_obj.object.get("id").?.integer),
|
tokenizer.next_token_id = 0;
|
||||||
0,
|
tokenizer.token_lookup.clearRetainingCapacity();
|
||||||
token,
|
all_tokens.clearRetainingCapacity();
|
||||||
);
|
it = vocab.iterator();
|
||||||
|
while (it.next()) |kv| {
|
||||||
|
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
||||||
|
const token = try dup(&all_tokens, kv.key_ptr.*);
|
||||||
|
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// More tokens, typically added during fine tuning of the model.
|
||||||
|
for (added_tokens.items) |token_obj| {
|
||||||
|
const v = token_obj.object.get("content").?.string;
|
||||||
|
const id: u32 = @intCast(token_obj.object.get("id").?.integer);
|
||||||
|
const token = try if (is_gpt2_vocab)
|
||||||
|
gpt2_decoder.decode(&all_tokens, v)
|
||||||
|
else
|
||||||
|
dup(&all_tokens, v);
|
||||||
|
|
||||||
|
tokenizer.addOwnedTokenByIndex(id, 0, token);
|
||||||
|
}
|
||||||
|
// We won't add more tokens here, let release.
|
||||||
|
all_tokens.shrinkAndFree(all_tokens.items.len);
|
||||||
|
|
||||||
tokenizer.special_tokens = .{
|
tokenizer.special_tokens = .{
|
||||||
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
||||||
@ -738,5 +887,17 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
|
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| {
|
||||||
|
if (byte_fallback == .bool and byte_fallback.bool) {
|
||||||
|
try tokenizer.rewriteByteFallbackTokens();
|
||||||
|
}
|
||||||
|
}
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a copy of the given string, stored inside the given ArrayList.
|
||||||
|
fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
||||||
|
const n = buffer.items.len;
|
||||||
|
try buffer.appendSlice(str);
|
||||||
|
return buffer.items[n..];
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user