diff --git a/zml/aio/gguf.zig b/zml/aio/gguf.zig index efe5b37..05a6c87 100644 --- a/zml/aio/gguf.zig +++ b/zml/aio/gguf.zig @@ -56,15 +56,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) .pad = pad orelse NOT_FOUND, }; - // default options - const gguf_normalizer: zml.tokenizer.Normalizer = .{ .flags = .{ - .escape_whitespaces = true, - .remove_extra_whitespaces = true, - .add_dummy_prefix = true, - .add_dummy_suffix = false, - .lower_case_ascii = false, - .split_on_punct_ascii = false, - } }; + const gguf_normalizer = if (tokenizer_impl == .gpt2) + zml.tokenizer.Normalizer.wellKnown(.gpt2) + else + zml.tokenizer.Normalizer.wellKnown(.sentencepiece); const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0; const n_tokens: u32 = @intCast(tokens.len + extra_tokens); @@ -99,8 +94,6 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) // Gpt2 tokenizer always splits on spaces. if (tokenizer_impl == .gpt2) { - tokenizer.normalizer.?.flags.add_dummy_prefix = true; - tokenizer.normalizer.?.flags.escape_whitespaces = false; tokenizer.special_tokens.hard_space = tokenizer.next_token_id; tokenizer.addOwnedToken(0, " "); } diff --git a/zml/aio/sentencepiece.zig b/zml/aio/sentencepiece.zig index 3596752..90ee128 100644 --- a/zml/aio/sentencepiece.zig +++ b/zml/aio/sentencepiece.zig @@ -75,9 +75,15 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer { if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()}); if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{}); if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{}); - if (spec.add_dummy_prefix) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"add_dummy_prefix\" flag set", .{}); - return .{ .flags = .{ - .remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false, - .add_dummy_prefix = spec.add_dummy_prefix orelse true, - } }; + + return Normalizer.init( + .{ + .remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false, + .add_dummy_prefix = spec.add_dummy_prefix orelse false, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, + if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null, + ); } diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index a034ae6..2e211ea 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -128,13 +128,15 @@ pub const Tokenizer = struct { add_bos: bool = true, add_eos: bool = false, pad_to: u32 = 0, + // Print tokenization intermediary steps. + debug: bool = false, }; pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 { - // log.debug("Tokenizer.encode('{s}')", .{raw}); + if (options.debug) log.debug("Tokenizer.encode('{s}')", .{raw}); const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw; defer if (self.normalizer) |_| allocator.free(input); - // log.debug("Tokenizer.encode.normalize -> '{s}'", .{input}); + if (options.debug) log.debug("Tokenizer.encode.normalize -> '{s}'", .{input}); // Allocate a buffer that can fit all indices as well as extra character if requested. // We then slice it so that the token merging code doesn't see the bos token. @@ -158,7 +160,12 @@ pub const Tokenizer = struct { var stable_off: usize = 0; while (true) { // Step by step visualization of the progress. - // log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], try self.decodeWithOpts(allocator, tok_buff[0..num_tokens], .{ .sep = "|" }) }); + if (options.debug) { + var _debug_buf: [256]u8 = undefined; + var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined }; + self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; + log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); + } var best_score: f32 = -1e10; var best_token: u32 = 0; var best_idx: ?usize = null; @@ -273,7 +280,7 @@ pub const Tokenizer = struct { /// Converts the given slice of tokens back into bytes. /// Note that if the tokenizer allows sub-unicode bytes, it's possible /// the output is not valid utf8. - pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) ![]u8 { + pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { var output = std.ArrayList(u8).init(allocator); errdefer output.deinit(); @@ -286,18 +293,21 @@ pub const Tokenizer = struct { output: *std.ArrayList(u8), input: []const u32, opts: struct { sep: []const u8 = "" }, - ) !void { + ) error{OutOfMemory}!void { + const escaped = if (self.normalizer) |n| n.escapedSpace() else null; // Flag used to indicate if the first dummy whitespace has been consumed. for (input) |id| { // Retrieve the slice corresponding to the id. var piece = self.lookupPiece(id); // Convert `▁` to a regular space. - if (std.mem.startsWith(u8, piece, Normalizer.space_symbol)) { - piece = piece[Normalizer.space_symbol.len..]; - - // don't output a space at beginning of text. - if (output.items.len > 0) try output.append(' '); + if (escaped) |escspc| { + // we modify piece inside the loop, so we can use it in the condition + while (std.mem.startsWith(u8, piece, escaped.?)) { + piece = piece[escspc.len..]; + // don't output a space at beginning of text. + if (output.items.len > 0) try output.append(' '); + } } try output.appendSlice(piece); @@ -349,7 +359,7 @@ test Tokenizer { .eos = 2, }; - var tokenizer = try Tokenizer.init(allocator, 10, 5, .{}, special_tokens, true); + var tokenizer = try Tokenizer.init(allocator, 10, 5, null, special_tokens, true); defer tokenizer.deinit(); try tokenizer.addToken(10, "hello"); @@ -446,23 +456,37 @@ test CharTokenIterator { } } -/// Text normalizer. Most tokenizer assumes the input text have been prepocessed -/// with on of those. +/// Text normalizer. +/// Most tokenizer assumes the input text have been prepocessed with on of those. pub const Normalizer = struct { - pub const space_symbol = "▁"; // \xe2\x96\x81 + /// Space token used by sentencepiece derived tokenizer. + pub const sentencepiece_space = "▁"; // \xe2\x96\x81 + + _whitespace: std.BoundedArray(u8, 8) = .{}, flags: packed struct { - escape_whitespaces: bool = true, - remove_extra_whitespaces: bool = true, - add_dummy_prefix: bool = true, - add_dummy_suffix: bool = false, + remove_extra_whitespaces: bool, + add_dummy_prefix: bool, + add_dummy_suffix: bool, /// Cheap lower casing. /// TODO: try to match Python "lower" - lower_case_ascii: bool = false, + lower_case_ascii: bool, /// cheap ascii punct splitting. // doing this processing ahead of time simplifies the logic - split_on_punct_ascii: bool = false, - } = .{}, + split_on_punct_ascii: bool, + }, + + pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer { + var res: Normalizer = .{ .flags = flags }; + if (escaped_whitespace) |escaped| { + res._whitespace.appendSliceAssumeCapacity(escaped); + } + return res; + } + + pub inline fn escapedSpace(self: Normalizer) ?[]const u8 { + return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; + } fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { try normalized.appendSlice(data); @@ -512,7 +536,7 @@ pub const Normalizer = struct { } // Pre-allocate outputs - const space = if (self.flags.escape_whitespaces) Normalizer.space_symbol else " "; + const space = self.escapedSpace() orelse " "; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); errdefer normalized.deinit(); @@ -545,7 +569,7 @@ pub const Normalizer = struct { if (slice.len == 1) ascii: { // The more advanced logic only works with ascii atm var byte = slice[0]; - if (self.flags.escape_whitespaces and byte == ' ') { + if (self.escapedSpace() != null and byte == ' ') { // replace the space token by the special token try addSlice(space, origin, &normalized, &normalized_to_origin); is_prev_word = false; @@ -592,26 +616,106 @@ pub const Normalizer = struct { pub fn wellKnown(impl: KnownImplementation) Normalizer { return switch (impl) { - .sentencepiece => .{ .flags = .{ - .escape_whitespaces = true, + .sentencepiece => init(.{ .remove_extra_whitespaces = true, .add_dummy_prefix = true, .add_dummy_suffix = false, .lower_case_ascii = false, .split_on_punct_ascii = false, - } }, - .gpt2 => .{ .flags = .{ - .escape_whitespaces = false, + }, sentencepiece_space), + .gpt2 => init(.{ .remove_extra_whitespaces = true, .add_dummy_prefix = true, .add_dummy_suffix = false, .lower_case_ascii = false, .split_on_punct_ascii = false, - } }, + }, null), }; } -}; + pub fn fromHfJson(config: std.json.ObjectMap) error{InvalidNormalizerJson}!Normalizer { + var normalizer: Normalizer = .{ .flags = .{ + .remove_extra_whitespaces = false, + .add_dummy_suffix = false, + .add_dummy_prefix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + } }; + + // Normalizer config can be a single normalizer, or a sequence of normalizers. + const maybe_steps = object_get(config, .array, "normalizers"); + const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }}; + + for (steps) |step_val| { + if (step_val != .object) { + return error.InvalidNormalizerJson; + } + const step = step_val.object; + + const step_type = object_get(step, .string, "type") orelse { + return error.InvalidNormalizerJson; + }; + if (std.mem.eql(u8, "Prepend", step_type)) { + normalizer.flags.add_dummy_prefix = true; + } else if (std.mem.eql(u8, "Append", step_type)) { + normalizer.flags.add_dummy_suffix = true; + } else if (std.mem.eql(u8, "Replace", step_type)) { + const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson; + const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson; + + if (std.mem.eql(u8, str_pattern, " ")) { + normalizer._whitespace.appendSliceAssumeCapacity( + object_get(step, .string, "content") orelse return error.InvalidNormalizerJson, + ); + } else { + log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" }); + } + } else { + log.warn("Unknown normalizer type: {s}", .{step_type}); + } + } + + return normalizer; + } + + test "Normalizer.fromHfJson" { + const config_json = + \\{ + \\ "type": "Sequence", + \\ "normalizers": [ + \\ { + \\ "type": "Prepend", + \\ "prepend": "▁" + \\ }, + \\ { + \\ "type": "Replace", + \\ "pattern": { + \\ "String": " " + \\ }, + \\ "content": "▁" + \\ } + \\ ] + \\} + ; + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const config = try std.json.parseFromSliceLeaky(std.json.Value, arena.allocator(), config_json, .{}); + const normalizer = try Normalizer.fromHfJson(config.object); + + const expected = Normalizer{ + ._whitespace = .{ .buffer = [_]u8{ 0xe2, 0x96, 0x81 } ++ [_]u8{0} ** 5, .len = 3 }, + .flags = .{ + .remove_extra_whitespaces = false, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, + }; + try std.testing.expectEqual(expected.flags, normalizer.flags); + try std.testing.expectEqualStrings(expected.escapedSpace().?, normalizer.escapedSpace().?); + } +}; pub const KnownImplementation = enum(u8) { sentencepiece, gpt2, @@ -634,14 +738,13 @@ fn isPunct(unicode_char: []const u8) bool { } test Normalizer { - try testing.expectEqualSlices(u8, "▁", Normalizer.space_symbol); - { const n: Normalizer = .{ .flags = .{ - .escape_whitespaces = false, .remove_extra_whitespaces = true, .add_dummy_prefix = true, .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, } }; const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!"); defer res.deinit(testing.allocator); @@ -657,10 +760,11 @@ test Normalizer { { const n: Normalizer = .{ .flags = .{ - .escape_whitespaces = false, .remove_extra_whitespaces = true, .add_dummy_prefix = true, .add_dummy_suffix = true, + .lower_case_ascii = false, + .split_on_punct_ascii = false, } }; const res = try n.normalize(testing.allocator, "Hello world!"); defer testing.allocator.free(res); @@ -669,12 +773,16 @@ test Normalizer { } { - const n: Normalizer = .{ .flags = .{ - .escape_whitespaces = true, - .remove_extra_whitespaces = false, - .add_dummy_prefix = true, - .add_dummy_suffix = false, - } }; + const n = Normalizer.init( + .{ + .remove_extra_whitespaces = false, + .add_dummy_prefix = true, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, + Normalizer.sentencepiece_space, + ); const res = try n.normalize(testing.allocator, "Hello world!"); defer testing.allocator.free(res); @@ -683,11 +791,11 @@ test Normalizer { { const n: Normalizer = .{ .flags = .{ - .escape_whitespaces = false, .remove_extra_whitespaces = true, .add_dummy_prefix = false, .add_dummy_suffix = true, .lower_case_ascii = true, + .split_on_punct_ascii = false, } }; const res = try n.normalize(testing.allocator, "Hello world!"); defer testing.allocator.free(res); @@ -697,10 +805,10 @@ test Normalizer { { const n: Normalizer = .{ .flags = .{ - .escape_whitespaces = false, .remove_extra_whitespaces = true, .add_dummy_prefix = false, .add_dummy_suffix = true, + .lower_case_ascii = false, .split_on_punct_ascii = true, } }; const res = try n.normalize(testing.allocator, "Hello world!"); @@ -795,15 +903,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok const file = try std.fs.cwd().openFile(tokenizer_path, .{}); defer file.close(); - const file_content = try file.readToEndAlloc(allocator, 32 * 1024 * 1024); - defer allocator.free(file_content); - // TODO create local arena and use parseFromSliceLeaky. - const parsed = try std.json.parseFromSlice(std.json.Value, allocator, file_content, .{ + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + const file_content = try file.readToEndAlloc(arena, 32 * 1024 * 1024); + + const info = try std.json.parseFromSliceLeaky(std.json.Value, arena, file_content, .{ .duplicate_field_behavior = .use_last, }); - defer parsed.deinit(); - const info = parsed.value; - const main_object = switch (info) { .object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) { return error.InvalidFormat; @@ -816,10 +923,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok const vocab = main_object.get("model").?.object.get("vocab").?.object; const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len); - const normalizer = Normalizer.wellKnown(.gpt2); - var gpt2_decoder = try Gpt2TextDecoder.init(allocator); - defer gpt2_decoder.deinit(); + const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config| + try Normalizer.fromHfJson(normalizer_config) + else + Normalizer.wellKnown(.gpt2); + // delay init of special tokens. var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); errdefer tokenizer.deinit(); @@ -833,12 +942,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok std.debug.assert(all_tokens.items.ptr == original_alloc); } - var it = vocab.iterator(); // gpt2 based tokenizer got a special way of encoding unicode. // we don't know in advance if this will be used by this tokenizer or not. // so we assume it is the case, but if we find some unicode character, // outside of the range used by gpt2 we know it was wrong, and start over. var is_gpt2_vocab: bool = true; + var gpt2_decoder = try Gpt2TextDecoder.init(allocator); + defer gpt2_decoder.deinit(); + var it = vocab.iterator(); while (it.next()) |kv| { const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { switch (err) { @@ -901,3 +1012,14 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { try buffer.appendSlice(str); return buffer.items[n..]; } + +/// Returns the given entry in a json object only if it has the right type. +fn object_get( + object: std.json.ObjectMap, + comptime kind: std.meta.FieldEnum(std.json.Value), + key: []const u8, +) ?std.meta.FieldType(std.json.Value, kind) { + const val = object.get(key) orelse return null; + if (val != kind) return null; + return @field(val, @tagName(kind)); +}