Add Normalizer.fromHfJson to read HuggingFace tokenizer JSON and map to internal options, including a configurable magic space token and a debug flag for token merges. Adjust default handling of extra whitespaces to align with HF defaults.

This commit is contained in:
Tarry Singh 2023-03-29 16:10:29 +00:00
parent ef922e3aea
commit 05d23beb23
3 changed files with 190 additions and 69 deletions

View File

@ -56,15 +56,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
.pad = pad orelse NOT_FOUND, .pad = pad orelse NOT_FOUND,
}; };
// default options const gguf_normalizer = if (tokenizer_impl == .gpt2)
const gguf_normalizer: zml.tokenizer.Normalizer = .{ .flags = .{ zml.tokenizer.Normalizer.wellKnown(.gpt2)
.escape_whitespaces = true, else
.remove_extra_whitespaces = true, zml.tokenizer.Normalizer.wellKnown(.sentencepiece);
.add_dummy_prefix = true,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
} };
const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0; const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0;
const n_tokens: u32 = @intCast(tokens.len + extra_tokens); const n_tokens: u32 = @intCast(tokens.len + extra_tokens);
@ -99,8 +94,6 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
// Gpt2 tokenizer always splits on spaces. // Gpt2 tokenizer always splits on spaces.
if (tokenizer_impl == .gpt2) { if (tokenizer_impl == .gpt2) {
tokenizer.normalizer.?.flags.add_dummy_prefix = true;
tokenizer.normalizer.?.flags.escape_whitespaces = false;
tokenizer.special_tokens.hard_space = tokenizer.next_token_id; tokenizer.special_tokens.hard_space = tokenizer.next_token_id;
tokenizer.addOwnedToken(0, " "); tokenizer.addOwnedToken(0, " ");
} }

View File

@ -75,9 +75,15 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()}); if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()});
if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{}); if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{});
if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{}); if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{});
if (spec.add_dummy_prefix) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"add_dummy_prefix\" flag set", .{});
return .{ .flags = .{ return Normalizer.init(
.remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false, .{
.add_dummy_prefix = spec.add_dummy_prefix orelse true, .remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false,
} }; .add_dummy_prefix = spec.add_dummy_prefix orelse false,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
},
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
);
} }

View File

@ -128,13 +128,15 @@ pub const Tokenizer = struct {
add_bos: bool = true, add_bos: bool = true,
add_eos: bool = false, add_eos: bool = false,
pad_to: u32 = 0, pad_to: u32 = 0,
// Print tokenization intermediary steps.
debug: bool = false,
}; };
pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 { pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 {
// log.debug("Tokenizer.encode('{s}')", .{raw}); if (options.debug) log.debug("Tokenizer.encode('{s}')", .{raw});
const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw; const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw;
defer if (self.normalizer) |_| allocator.free(input); defer if (self.normalizer) |_| allocator.free(input);
// log.debug("Tokenizer.encode.normalize -> '{s}'", .{input}); if (options.debug) log.debug("Tokenizer.encode.normalize -> '{s}'", .{input});
// Allocate a buffer that can fit all indices as well as extra character if requested. // Allocate a buffer that can fit all indices as well as extra character if requested.
// We then slice it so that the token merging code doesn't see the bos token. // We then slice it so that the token merging code doesn't see the bos token.
@ -158,7 +160,12 @@ pub const Tokenizer = struct {
var stable_off: usize = 0; var stable_off: usize = 0;
while (true) { while (true) {
// Step by step visualization of the progress. // Step by step visualization of the progress.
// log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], try self.decodeWithOpts(allocator, tok_buff[0..num_tokens], .{ .sep = "|" }) }); if (options.debug) {
var _debug_buf: [256]u8 = undefined;
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined };
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
}
var best_score: f32 = -1e10; var best_score: f32 = -1e10;
var best_token: u32 = 0; var best_token: u32 = 0;
var best_idx: ?usize = null; var best_idx: ?usize = null;
@ -273,7 +280,7 @@ pub const Tokenizer = struct {
/// Converts the given slice of tokens back into bytes. /// Converts the given slice of tokens back into bytes.
/// Note that if the tokenizer allows sub-unicode bytes, it's possible /// Note that if the tokenizer allows sub-unicode bytes, it's possible
/// the output is not valid utf8. /// the output is not valid utf8.
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) ![]u8 { pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
var output = std.ArrayList(u8).init(allocator); var output = std.ArrayList(u8).init(allocator);
errdefer output.deinit(); errdefer output.deinit();
@ -286,18 +293,21 @@ pub const Tokenizer = struct {
output: *std.ArrayList(u8), output: *std.ArrayList(u8),
input: []const u32, input: []const u32,
opts: struct { sep: []const u8 = "" }, opts: struct { sep: []const u8 = "" },
) !void { ) error{OutOfMemory}!void {
const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
// Flag used to indicate if the first dummy whitespace has been consumed. // Flag used to indicate if the first dummy whitespace has been consumed.
for (input) |id| { for (input) |id| {
// Retrieve the slice corresponding to the id. // Retrieve the slice corresponding to the id.
var piece = self.lookupPiece(id); var piece = self.lookupPiece(id);
// Convert `` to a regular space. // Convert `` to a regular space.
if (std.mem.startsWith(u8, piece, Normalizer.space_symbol)) { if (escaped) |escspc| {
piece = piece[Normalizer.space_symbol.len..]; // we modify piece inside the loop, so we can use it in the condition
while (std.mem.startsWith(u8, piece, escaped.?)) {
// don't output a space at beginning of text. piece = piece[escspc.len..];
if (output.items.len > 0) try output.append(' '); // don't output a space at beginning of text.
if (output.items.len > 0) try output.append(' ');
}
} }
try output.appendSlice(piece); try output.appendSlice(piece);
@ -349,7 +359,7 @@ test Tokenizer {
.eos = 2, .eos = 2,
}; };
var tokenizer = try Tokenizer.init(allocator, 10, 5, .{}, special_tokens, true); var tokenizer = try Tokenizer.init(allocator, 10, 5, null, special_tokens, true);
defer tokenizer.deinit(); defer tokenizer.deinit();
try tokenizer.addToken(10, "hello"); try tokenizer.addToken(10, "hello");
@ -446,23 +456,37 @@ test CharTokenIterator {
} }
} }
/// Text normalizer. Most tokenizer assumes the input text have been prepocessed /// Text normalizer.
/// with on of those. /// Most tokenizer assumes the input text have been prepocessed with on of those.
pub const Normalizer = struct { pub const Normalizer = struct {
pub const space_symbol = ""; // \xe2\x96\x81 /// Space token used by sentencepiece derived tokenizer.
pub const sentencepiece_space = ""; // \xe2\x96\x81
_whitespace: std.BoundedArray(u8, 8) = .{},
flags: packed struct { flags: packed struct {
escape_whitespaces: bool = true, remove_extra_whitespaces: bool,
remove_extra_whitespaces: bool = true, add_dummy_prefix: bool,
add_dummy_prefix: bool = true, add_dummy_suffix: bool,
add_dummy_suffix: bool = false,
/// Cheap lower casing. /// Cheap lower casing.
/// TODO: try to match Python "lower" /// TODO: try to match Python "lower"
lower_case_ascii: bool = false, lower_case_ascii: bool,
/// cheap ascii punct splitting. /// cheap ascii punct splitting.
// doing this processing ahead of time simplifies the logic // doing this processing ahead of time simplifies the logic
split_on_punct_ascii: bool = false, split_on_punct_ascii: bool,
} = .{}, },
pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer {
var res: Normalizer = .{ .flags = flags };
if (escaped_whitespace) |escaped| {
res._whitespace.appendSliceAssumeCapacity(escaped);
}
return res;
}
pub inline fn escapedSpace(self: Normalizer) ?[]const u8 {
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
}
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
try normalized.appendSlice(data); try normalized.appendSlice(data);
@ -512,7 +536,7 @@ pub const Normalizer = struct {
} }
// Pre-allocate outputs // Pre-allocate outputs
const space = if (self.flags.escape_whitespaces) Normalizer.space_symbol else " "; const space = self.escapedSpace() orelse " ";
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
errdefer normalized.deinit(); errdefer normalized.deinit();
@ -545,7 +569,7 @@ pub const Normalizer = struct {
if (slice.len == 1) ascii: { if (slice.len == 1) ascii: {
// The more advanced logic only works with ascii atm // The more advanced logic only works with ascii atm
var byte = slice[0]; var byte = slice[0];
if (self.flags.escape_whitespaces and byte == ' ') { if (self.escapedSpace() != null and byte == ' ') {
// replace the space token by the special token // replace the space token by the special token
try addSlice(space, origin, &normalized, &normalized_to_origin); try addSlice(space, origin, &normalized, &normalized_to_origin);
is_prev_word = false; is_prev_word = false;
@ -592,26 +616,106 @@ pub const Normalizer = struct {
pub fn wellKnown(impl: KnownImplementation) Normalizer { pub fn wellKnown(impl: KnownImplementation) Normalizer {
return switch (impl) { return switch (impl) {
.sentencepiece => .{ .flags = .{ .sentencepiece => init(.{
.escape_whitespaces = true,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = true, .add_dummy_prefix = true,
.add_dummy_suffix = false, .add_dummy_suffix = false,
.lower_case_ascii = false, .lower_case_ascii = false,
.split_on_punct_ascii = false, .split_on_punct_ascii = false,
} }, }, sentencepiece_space),
.gpt2 => .{ .flags = .{ .gpt2 => init(.{
.escape_whitespaces = false,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = true, .add_dummy_prefix = true,
.add_dummy_suffix = false, .add_dummy_suffix = false,
.lower_case_ascii = false, .lower_case_ascii = false,
.split_on_punct_ascii = false, .split_on_punct_ascii = false,
} }, }, null),
}; };
} }
};
pub fn fromHfJson(config: std.json.ObjectMap) error{InvalidNormalizerJson}!Normalizer {
var normalizer: Normalizer = .{ .flags = .{
.remove_extra_whitespaces = false,
.add_dummy_suffix = false,
.add_dummy_prefix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
} };
// Normalizer config can be a single normalizer, or a sequence of normalizers.
const maybe_steps = object_get(config, .array, "normalizers");
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
for (steps) |step_val| {
if (step_val != .object) {
return error.InvalidNormalizerJson;
}
const step = step_val.object;
const step_type = object_get(step, .string, "type") orelse {
return error.InvalidNormalizerJson;
};
if (std.mem.eql(u8, "Prepend", step_type)) {
normalizer.flags.add_dummy_prefix = true;
} else if (std.mem.eql(u8, "Append", step_type)) {
normalizer.flags.add_dummy_suffix = true;
} else if (std.mem.eql(u8, "Replace", step_type)) {
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
if (std.mem.eql(u8, str_pattern, " ")) {
normalizer._whitespace.appendSliceAssumeCapacity(
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson,
);
} else {
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" });
}
} else {
log.warn("Unknown normalizer type: {s}", .{step_type});
}
}
return normalizer;
}
test "Normalizer.fromHfJson" {
const config_json =
\\{
\\ "type": "Sequence",
\\ "normalizers": [
\\ {
\\ "type": "Prepend",
\\ "prepend": "▁"
\\ },
\\ {
\\ "type": "Replace",
\\ "pattern": {
\\ "String": " "
\\ },
\\ "content": "▁"
\\ }
\\ ]
\\}
;
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const config = try std.json.parseFromSliceLeaky(std.json.Value, arena.allocator(), config_json, .{});
const normalizer = try Normalizer.fromHfJson(config.object);
const expected = Normalizer{
._whitespace = .{ .buffer = [_]u8{ 0xe2, 0x96, 0x81 } ++ [_]u8{0} ** 5, .len = 3 },
.flags = .{
.remove_extra_whitespaces = false,
.add_dummy_prefix = true,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
},
};
try std.testing.expectEqual(expected.flags, normalizer.flags);
try std.testing.expectEqualStrings(expected.escapedSpace().?, normalizer.escapedSpace().?);
}
};
pub const KnownImplementation = enum(u8) { pub const KnownImplementation = enum(u8) {
sentencepiece, sentencepiece,
gpt2, gpt2,
@ -634,14 +738,13 @@ fn isPunct(unicode_char: []const u8) bool {
} }
test Normalizer { test Normalizer {
try testing.expectEqualSlices(u8, "", Normalizer.space_symbol);
{ {
const n: Normalizer = .{ .flags = .{ const n: Normalizer = .{ .flags = .{
.escape_whitespaces = false,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = true, .add_dummy_prefix = true,
.add_dummy_suffix = false, .add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
} }; } };
const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!"); const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!");
defer res.deinit(testing.allocator); defer res.deinit(testing.allocator);
@ -657,10 +760,11 @@ test Normalizer {
{ {
const n: Normalizer = .{ .flags = .{ const n: Normalizer = .{ .flags = .{
.escape_whitespaces = false,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = true, .add_dummy_prefix = true,
.add_dummy_suffix = true, .add_dummy_suffix = true,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
} }; } };
const res = try n.normalize(testing.allocator, "Hello world!"); const res = try n.normalize(testing.allocator, "Hello world!");
defer testing.allocator.free(res); defer testing.allocator.free(res);
@ -669,12 +773,16 @@ test Normalizer {
} }
{ {
const n: Normalizer = .{ .flags = .{ const n = Normalizer.init(
.escape_whitespaces = true, .{
.remove_extra_whitespaces = false, .remove_extra_whitespaces = false,
.add_dummy_prefix = true, .add_dummy_prefix = true,
.add_dummy_suffix = false, .add_dummy_suffix = false,
} }; .lower_case_ascii = false,
.split_on_punct_ascii = false,
},
Normalizer.sentencepiece_space,
);
const res = try n.normalize(testing.allocator, "Hello world!"); const res = try n.normalize(testing.allocator, "Hello world!");
defer testing.allocator.free(res); defer testing.allocator.free(res);
@ -683,11 +791,11 @@ test Normalizer {
{ {
const n: Normalizer = .{ .flags = .{ const n: Normalizer = .{ .flags = .{
.escape_whitespaces = false,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = false, .add_dummy_prefix = false,
.add_dummy_suffix = true, .add_dummy_suffix = true,
.lower_case_ascii = true, .lower_case_ascii = true,
.split_on_punct_ascii = false,
} }; } };
const res = try n.normalize(testing.allocator, "Hello world!"); const res = try n.normalize(testing.allocator, "Hello world!");
defer testing.allocator.free(res); defer testing.allocator.free(res);
@ -697,10 +805,10 @@ test Normalizer {
{ {
const n: Normalizer = .{ .flags = .{ const n: Normalizer = .{ .flags = .{
.escape_whitespaces = false,
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = false, .add_dummy_prefix = false,
.add_dummy_suffix = true, .add_dummy_suffix = true,
.lower_case_ascii = false,
.split_on_punct_ascii = true, .split_on_punct_ascii = true,
} }; } };
const res = try n.normalize(testing.allocator, "Hello world!"); const res = try n.normalize(testing.allocator, "Hello world!");
@ -795,15 +903,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
const file = try std.fs.cwd().openFile(tokenizer_path, .{}); const file = try std.fs.cwd().openFile(tokenizer_path, .{});
defer file.close(); defer file.close();
const file_content = try file.readToEndAlloc(allocator, 32 * 1024 * 1024); var arena_state = std.heap.ArenaAllocator.init(allocator);
defer allocator.free(file_content); defer arena_state.deinit();
// TODO create local arena and use parseFromSliceLeaky. const arena = arena_state.allocator();
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, file_content, .{ const file_content = try file.readToEndAlloc(arena, 32 * 1024 * 1024);
const info = try std.json.parseFromSliceLeaky(std.json.Value, arena, file_content, .{
.duplicate_field_behavior = .use_last, .duplicate_field_behavior = .use_last,
}); });
defer parsed.deinit();
const info = parsed.value;
const main_object = switch (info) { const main_object = switch (info) {
.object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) { .object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) {
return error.InvalidFormat; return error.InvalidFormat;
@ -816,10 +923,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
const vocab = main_object.get("model").?.object.get("vocab").?.object; const vocab = main_object.get("model").?.object.get("vocab").?.object;
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len); const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
const normalizer = Normalizer.wellKnown(.gpt2); const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config|
var gpt2_decoder = try Gpt2TextDecoder.init(allocator); try Normalizer.fromHfJson(normalizer_config)
defer gpt2_decoder.deinit(); else
Normalizer.wellKnown(.gpt2);
// delay init of special tokens.
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
errdefer tokenizer.deinit(); errdefer tokenizer.deinit();
@ -833,12 +942,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
std.debug.assert(all_tokens.items.ptr == original_alloc); std.debug.assert(all_tokens.items.ptr == original_alloc);
} }
var it = vocab.iterator();
// gpt2 based tokenizer got a special way of encoding unicode. // gpt2 based tokenizer got a special way of encoding unicode.
// we don't know in advance if this will be used by this tokenizer or not. // we don't know in advance if this will be used by this tokenizer or not.
// so we assume it is the case, but if we find some unicode character, // so we assume it is the case, but if we find some unicode character,
// outside of the range used by gpt2 we know it was wrong, and start over. // outside of the range used by gpt2 we know it was wrong, and start over.
var is_gpt2_vocab: bool = true; var is_gpt2_vocab: bool = true;
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
defer gpt2_decoder.deinit();
var it = vocab.iterator();
while (it.next()) |kv| { while (it.next()) |kv| {
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
switch (err) { switch (err) {
@ -901,3 +1012,14 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
try buffer.appendSlice(str); try buffer.appendSlice(str);
return buffer.items[n..]; return buffer.items[n..];
} }
/// Returns the given entry in a json object only if it has the right type.
fn object_get(
object: std.json.ObjectMap,
comptime kind: std.meta.FieldEnum(std.json.Value),
key: []const u8,
) ?std.meta.FieldType(std.json.Value, kind) {
const val = object.get(key) orelse return null;
if (val != kind) return null;
return @field(val, @tagName(kind));
}