From ecf52ad724cf2ad75a9ea9b96287ba8ae56635f6 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 28 Feb 2023 14:40:25 +0000 Subject: [PATCH] =?UTF-8?q?zml.tokenizer:=20Implement=20proper=20byte=20fa?= =?UTF-8?q?llback=20support=20by=20converting=20hex=20byte=20strings=20(e.?= =?UTF-8?q?g.,=20=E2=80=9C<0x40>=E2=80=9D)=20to=20their=20characters=20and?= =?UTF-8?q?=20splitting=20unknown=20UTF=E2=80=918=20codepoints=20into=20by?= =?UTF-8?q?tes,=20fixing=20tokenization.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/aio/sentencepiece.zig | 4 + zml/tokenizer.zig | 237 ++++++++++++++++++++++++++++++++------ 2 files changed, 203 insertions(+), 38 deletions(-) diff --git a/zml/aio/sentencepiece.zig b/zml/aio/sentencepiece.zig index 8b4b04a..3596752 100644 --- a/zml/aio/sentencepiece.zig +++ b/zml/aio/sentencepiece.zig @@ -49,6 +49,10 @@ pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentence for (model.pieces.items) |*piece| { try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice()); } + const byte_fallback = model.trainer_spec.?.byte_fallback orelse false; + if (byte_fallback) { + try tokenizer.rewriteByteFallbackTokens(); + } return tokenizer; } diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index 6a1f394..a034ae6 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -10,6 +10,8 @@ const meta = @import("meta.zig"); test { std.testing.refAllDecls(@This()); + std.testing.refAllDecls(Normalizer); + std.testing.refAllDecls(Tokenizer); } /// Byte Pair Encoding tokenizer generally used for LLM. @@ -21,6 +23,8 @@ pub const Tokenizer = struct { scores: []f32, max_token_len: u32, normalizer: ?Normalizer, + // Allows to split unknown unicode characters into bytes. + byte_fallback: bool = false, arena_state: std.heap.ArenaAllocator, vocab_size: u32, @@ -81,14 +85,13 @@ pub const Tokenizer = struct { const n = try tok_reader.read(token); std.debug.assert(n == len); - self.addOwnedToken(score, token); + return self.addOwnedToken(score, token); } /// Adds a new token (and copy it) pub fn addToken(self: *Tokenizer, score: f32, token: []const u8) !void { const arena = self.arena_state.allocator(); - - self.addOwnedToken(score, try arena.dupe(u8, token)); + return self.addOwnedToken(score, try arena.dupe(u8, token)); } /// Adds a new token (without copying it) @@ -142,24 +145,13 @@ pub const Tokenizer = struct { const mergeable = try allocator.alloc(MergeState, tok_buff.len); var num_tokens: usize = 0; - var off: usize = 0; - while (off < input.len) { - const utf_len = try std.unicode.utf8ByteSequenceLength(input[off]); - defer off += utf_len; - - mergeable[num_tokens] = .idk; - defer num_tokens += 1; - - const char = input[off..][0..utf_len]; - tok_buff[num_tokens] = self.lookup(char) orelse - // TODO: split unknown token into bytes if model supports it - self.special_tokens.unk; - if (tok_buff[num_tokens] == self.special_tokens.unk) { - log.debug("Token not found for char '{s}' (@{x})", .{ char, char }); - } - if (tok_buff[num_tokens] == self.special_tokens.hard_space) { - mergeable[num_tokens] = .hard_space; - } + var it: CharTokenIterator = .{ .input = input }; + while (try it.nextCodepointToken(self)) |token| : (num_tokens += 1) { + tok_buff[num_tokens] = token; + mergeable[num_tokens] = if (token == self.special_tokens.hard_space) + .hard_space + else + .idk; } var stable_prefix: usize = 0; @@ -197,7 +189,6 @@ pub const Tokenizer = struct { continue; }, .idk => { - const next_tok = self.tokens[tok_buff[i + 1]]; // Special tokens can't be concatenated. if (builtin.mode == .Debug and tok_buff[i] != self.special_tokens.unk) { @@ -206,7 +197,10 @@ pub const Tokenizer = struct { meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] }); } - const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok.len]; + const next_tok = self.tokens[tok_buff[i + 1]]; + // if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token. + const next_tok_len = if (tok_buff[i + 1] == self.special_tokens.unk) 1 else next_tok.len; + const concat_tokens = input[input_off..][0 .. cur_tok.len + next_tok_len]; // Save the result mergeable[i] = if (self.lookup(concat_tokens)) |tok| .{ .ready = tok } @@ -310,6 +304,41 @@ pub const Tokenizer = struct { if (opts.sep.len > 0) try output.appendSlice(opts.sep); } } + + /// Some tokenizers have bytes encoded in hex like this: "<0x40>". + /// This break the tokenization algorithm because the input text + /// will contain "@" not "<0x40>", + /// and if the input contains "<0x40>" it needs to not be treated as a single byte. + /// So we replace byte fallbacks strings, by their corresponding character. + /// This enables the normal tokenization algorithm to work. + pub fn rewriteByteFallbackTokens(tokenizer: *Tokenizer) !void { + tokenizer.byte_fallback = true; + var single_bytes = try tokenizer.arena_state.allocator().alloc(u8, 256); + var byte_fallback_buf = "<0x00>".*; + + for (0..256) |i| { + const c: u8 = @truncate(i); + single_bytes[i] = c; + + // First lookup the byte fallback entry. + // Note: we assume upper case, but we could try both upper and lower case if needed. + _ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 }); + const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse { + log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf}); + return error.InvalidInput; + }; + + // Check if the character is already present in the vocab. + // In that case, nothing to do, + // but note that the fallback token will be "unreachable", + // ie there is no way the tokenizer can produce it. + if (tokenizer.token_lookup.get(&.{c})) |_| continue; + + const idx: u32 = entry.value_ptr.*; + tokenizer.token_lookup.removeByPtr(entry.key_ptr); + tokenizer.addOwnedTokenByIndex(idx, tokenizer.scores[idx], single_bytes[i .. i + 1]); + } + } }; test Tokenizer { @@ -332,6 +361,91 @@ test Tokenizer { // TODO: test Tokenizer.decode, Tokenizer.encode, Tokenizer.readTokenInto } +/// Given a slice, split it in the most simple tokens using the given tokenizer tokens. +/// The output of this can be used to initialize the tokenization algorithm. +/// Normally we split the input text into utf8 codepoint, +/// but if we find an unknown codepoint we either split it in bytes, or use the special "unknown" token, +/// depending on the tokenizer configuration. +const CharTokenIterator = struct { + state: union(enum) { by_codepoint, by_byte: u8 } = .by_codepoint, + input: []const u8, + + fn nextCodepointToken(self: *CharTokenIterator, tokenizer: *const Tokenizer) error{ TruncatedInput, Utf8InvalidStartByte }!?u32 { + if (self.input.len == 0) return null; + return switch (self.state) { + .by_byte => |*byte_left| { + const idx = tokenizer.lookup(self.input[0..1]) orelse { + // Normally this has been caught when calling `rewriteByteFallbackTokens`. + std.debug.panic("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback for token '<0x{X:02}>'", .{self.input[0]}); + }; + + self.input = self.input[1..]; + byte_left.* -|= 1; + if (byte_left.* == 0) self.state = .by_codepoint; + return idx; + }, + .by_codepoint => { + // Try to lookup valid utf8 codepoint first. + const utf8_len = try std.unicode.utf8ByteSequenceLength(self.input[0]); + if (self.input.len < utf8_len) return error.TruncatedInput; + if (tokenizer.lookup(self.input[0..utf8_len])) |idx| { + self.input = self.input[utf8_len..]; + return idx; + } + + // Otherwise split in bytes if it's allowed. + if (tokenizer.byte_fallback) { + // TODO: replace this by a continue statement next time we bump Zig. + self.state = .{ .by_byte = utf8_len }; + return self.nextCodepointToken(tokenizer); + } + + // Or mark the full utf8 codepoint as unknown. + log.debug("Token not found for char '{s}'", .{self.input[0..utf8_len]}); + self.input = self.input[utf8_len..]; + return tokenizer.special_tokens.unk; + }, + }; + } +}; + +test CharTokenIterator { + const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, .eos = 2 }; + var tokenizer = try Tokenizer.init(std.testing.allocator, 16, 4, null, special_tokens, true); + defer tokenizer.deinit(); + + tokenizer.addOwnedToken(1.0, ""); // 0 + tokenizer.addOwnedToken(1.0, ""); // 1 + tokenizer.addOwnedToken(1.0, ""); // 2 + tokenizer.addOwnedToken(1.0, "ζ"); // 3 + tokenizer.addOwnedToken(1.0, &.{0xE2}); // 4: ℳ, first byte + tokenizer.addOwnedToken(1.0, &.{0x84}); // 5: ℳ, second byte + tokenizer.addOwnedToken(1.0, &.{0xB3}); // 6: ℳ, third byte + tokenizer.addOwnedToken(1.0, "L"); // 7 + + // No byte fallback + { + tokenizer.byte_fallback = false; + var it: CharTokenIterator = .{ .input = "ζℳL" }; + var res: std.BoundedArray(u32, 8) = .{}; + while (try it.nextCodepointToken(&tokenizer)) |token| { + res.appendAssumeCapacity(token); + } + try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 0, 7 }, res.constSlice()); + } + + // with byte fallback + { + tokenizer.byte_fallback = true; + var it: CharTokenIterator = .{ .input = "ζℳL" }; + var res: std.BoundedArray(u32, 8) = .{}; + while (try it.nextCodepointToken(&tokenizer)) |token| { + res.appendAssumeCapacity(token); + } + try std.testing.expectEqualSlices(u32, &[_]u32{ 3, 4, 5, 6, 7 }, res.constSlice()); + } +} + /// Text normalizer. Most tokenizer assumes the input text have been prepocessed /// with on of those. pub const Normalizer = struct { @@ -613,7 +727,6 @@ pub const Gpt2TextDecoder = struct { try self.code_to_byte.ensureTotalCapacity(256); errdefer unreachable; - // The eon var n: usize = 0; for (0..256) |index| { var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init @@ -703,40 +816,88 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok const vocab = main_object.get("model").?.object.get("vocab").?.object; const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len); - // TODO not all tokenizer.json are Gpt2 encoded, detect when it's needed or not. const normalizer = Normalizer.wellKnown(.gpt2); - var decoder = try Gpt2TextDecoder.init(allocator); - defer decoder.deinit(); + var gpt2_decoder = try Gpt2TextDecoder.init(allocator); + defer gpt2_decoder.deinit(); var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); + errdefer tokenizer.deinit(); // Buffer containing all concatenated tokens. // Reserve a big chunk, to avoid grow event, but release over-allocated memory. - var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), 24 * vocab.count()); - defer all_tokens.shrinkAndFree(all_tokens.items.len); + var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); + const original_alloc = all_tokens.items.ptr; + // A re-alloc event here means we have invalidated all slices inside the tokenizer. + // If this is too annoying we could switch to a custom type instead of slices. + defer { + std.debug.assert(all_tokens.items.ptr == original_alloc); + } var it = vocab.iterator(); + // gpt2 based tokenizer got a special way of encoding unicode. + // we don't know in advance if this will be used by this tokenizer or not. + // so we assume it is the case, but if we find some unicode character, + // outside of the range used by gpt2 we know it was wrong, and start over. + var is_gpt2_vocab: bool = true; while (it.next()) |kv| { - const token = try decoder.decode(&all_tokens, kv.key_ptr.*); + const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { + switch (err) { + error.InvalidInput => { + is_gpt2_vocab = false; + break; + }, + else => return err, + } + }; const idx: u32 = @intCast(kv.value_ptr.*.integer); - // std.debug.assert(idx == tokenizer.next_token_id); tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); } - for (added_tokens.items) |token_obj| { - const token = try decoder.decode(&all_tokens, token_obj.object.get("content").?.string); - tokenizer.addOwnedTokenByIndex( - @intCast(token_obj.object.get("id").?.integer), - 0, - token, - ); + if (!is_gpt2_vocab) { + // We where wrong, this is not a gpt2 vocab, start over, + // and reset the tokenizer state. + tokenizer.next_token_id = 0; + tokenizer.token_lookup.clearRetainingCapacity(); + all_tokens.clearRetainingCapacity(); + it = vocab.iterator(); + while (it.next()) |kv| { + const idx: u32 = @intCast(kv.value_ptr.*.integer); + const token = try dup(&all_tokens, kv.key_ptr.*); + tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); + } } + // More tokens, typically added during fine tuning of the model. + for (added_tokens.items) |token_obj| { + const v = token_obj.object.get("content").?.string; + const id: u32 = @intCast(token_obj.object.get("id").?.integer); + const token = try if (is_gpt2_vocab) + gpt2_decoder.decode(&all_tokens, v) + else + dup(&all_tokens, v); + + tokenizer.addOwnedTokenByIndex(id, 0, token); + } + // We won't add more tokens here, let release. + all_tokens.shrinkAndFree(all_tokens.items.len); + tokenizer.special_tokens = .{ .bos = tokenizer.lookup("") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"), .eos = tokenizer.lookup("") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"), .unk = tokenizer.lookup("") orelse std.math.maxInt(u32), }; + if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| { + if (byte_fallback == .bool and byte_fallback.bool) { + try tokenizer.rewriteByteFallbackTokens(); + } + } return tokenizer; } + +/// Returns a copy of the given string, stored inside the given ArrayList. +fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { + const n = buffer.items.len; + try buffer.appendSlice(str); + return buffer.items[n..]; +}