Add Normalizer.fromHfJson to read HuggingFace tokenizer JSON and map to internal options, including a configurable magic space token and a debug flag for token merges. Adjust default handling of extra whitespaces to align with HF defaults.
This commit is contained in:
parent
ef922e3aea
commit
05d23beb23
@ -56,15 +56,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
|
|||||||
.pad = pad orelse NOT_FOUND,
|
.pad = pad orelse NOT_FOUND,
|
||||||
};
|
};
|
||||||
|
|
||||||
// default options
|
const gguf_normalizer = if (tokenizer_impl == .gpt2)
|
||||||
const gguf_normalizer: zml.tokenizer.Normalizer = .{ .flags = .{
|
zml.tokenizer.Normalizer.wellKnown(.gpt2)
|
||||||
.escape_whitespaces = true,
|
else
|
||||||
.remove_extra_whitespaces = true,
|
zml.tokenizer.Normalizer.wellKnown(.sentencepiece);
|
||||||
.add_dummy_prefix = true,
|
|
||||||
.add_dummy_suffix = false,
|
|
||||||
.lower_case_ascii = false,
|
|
||||||
.split_on_punct_ascii = false,
|
|
||||||
} };
|
|
||||||
|
|
||||||
const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0;
|
const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0;
|
||||||
const n_tokens: u32 = @intCast(tokens.len + extra_tokens);
|
const n_tokens: u32 = @intCast(tokens.len + extra_tokens);
|
||||||
@ -99,8 +94,6 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
|
|||||||
|
|
||||||
// Gpt2 tokenizer always splits on spaces.
|
// Gpt2 tokenizer always splits on spaces.
|
||||||
if (tokenizer_impl == .gpt2) {
|
if (tokenizer_impl == .gpt2) {
|
||||||
tokenizer.normalizer.?.flags.add_dummy_prefix = true;
|
|
||||||
tokenizer.normalizer.?.flags.escape_whitespaces = false;
|
|
||||||
tokenizer.special_tokens.hard_space = tokenizer.next_token_id;
|
tokenizer.special_tokens.hard_space = tokenizer.next_token_id;
|
||||||
tokenizer.addOwnedToken(0, " ");
|
tokenizer.addOwnedToken(0, " ");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,9 +75,15 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
|
|||||||
if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()});
|
if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()});
|
||||||
if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{});
|
if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{});
|
||||||
if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{});
|
if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{});
|
||||||
if (spec.add_dummy_prefix) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"add_dummy_prefix\" flag set", .{});
|
|
||||||
return .{ .flags = .{
|
return Normalizer.init(
|
||||||
.remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false,
|
.{
|
||||||
.add_dummy_prefix = spec.add_dummy_prefix orelse true,
|
.remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false,
|
||||||
} };
|
.add_dummy_prefix = spec.add_dummy_prefix orelse false,
|
||||||
|
.add_dummy_suffix = false,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
|
},
|
||||||
|
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -128,13 +128,15 @@ pub const Tokenizer = struct {
|
|||||||
add_bos: bool = true,
|
add_bos: bool = true,
|
||||||
add_eos: bool = false,
|
add_eos: bool = false,
|
||||||
pad_to: u32 = 0,
|
pad_to: u32 = 0,
|
||||||
|
// Print tokenization intermediary steps.
|
||||||
|
debug: bool = false,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 {
|
pub fn encode(self: *const Tokenizer, allocator: std.mem.Allocator, raw: []const u8, options: EncodeOptions) ![]u32 {
|
||||||
// log.debug("Tokenizer.encode('{s}')", .{raw});
|
if (options.debug) log.debug("Tokenizer.encode('{s}')", .{raw});
|
||||||
const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw;
|
const input = if (self.normalizer) |n| try n.normalize(allocator, raw) else raw;
|
||||||
defer if (self.normalizer) |_| allocator.free(input);
|
defer if (self.normalizer) |_| allocator.free(input);
|
||||||
// log.debug("Tokenizer.encode.normalize -> '{s}'", .{input});
|
if (options.debug) log.debug("Tokenizer.encode.normalize -> '{s}'", .{input});
|
||||||
|
|
||||||
// Allocate a buffer that can fit all indices as well as extra character if requested.
|
// Allocate a buffer that can fit all indices as well as extra character if requested.
|
||||||
// We then slice it so that the token merging code doesn't see the bos token.
|
// We then slice it so that the token merging code doesn't see the bos token.
|
||||||
@ -158,7 +160,12 @@ pub const Tokenizer = struct {
|
|||||||
var stable_off: usize = 0;
|
var stable_off: usize = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
// Step by step visualization of the progress.
|
// Step by step visualization of the progress.
|
||||||
// log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], try self.decodeWithOpts(allocator, tok_buff[0..num_tokens], .{ .sep = "|" }) });
|
if (options.debug) {
|
||||||
|
var _debug_buf: [256]u8 = undefined;
|
||||||
|
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined };
|
||||||
|
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
||||||
|
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
||||||
|
}
|
||||||
var best_score: f32 = -1e10;
|
var best_score: f32 = -1e10;
|
||||||
var best_token: u32 = 0;
|
var best_token: u32 = 0;
|
||||||
var best_idx: ?usize = null;
|
var best_idx: ?usize = null;
|
||||||
@ -273,7 +280,7 @@ pub const Tokenizer = struct {
|
|||||||
/// Converts the given slice of tokens back into bytes.
|
/// Converts the given slice of tokens back into bytes.
|
||||||
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
||||||
/// the output is not valid utf8.
|
/// the output is not valid utf8.
|
||||||
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) ![]u8 {
|
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
|
||||||
var output = std.ArrayList(u8).init(allocator);
|
var output = std.ArrayList(u8).init(allocator);
|
||||||
errdefer output.deinit();
|
errdefer output.deinit();
|
||||||
|
|
||||||
@ -286,18 +293,21 @@ pub const Tokenizer = struct {
|
|||||||
output: *std.ArrayList(u8),
|
output: *std.ArrayList(u8),
|
||||||
input: []const u32,
|
input: []const u32,
|
||||||
opts: struct { sep: []const u8 = "" },
|
opts: struct { sep: []const u8 = "" },
|
||||||
) !void {
|
) error{OutOfMemory}!void {
|
||||||
|
const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
|
||||||
// Flag used to indicate if the first dummy whitespace has been consumed.
|
// Flag used to indicate if the first dummy whitespace has been consumed.
|
||||||
for (input) |id| {
|
for (input) |id| {
|
||||||
// Retrieve the slice corresponding to the id.
|
// Retrieve the slice corresponding to the id.
|
||||||
var piece = self.lookupPiece(id);
|
var piece = self.lookupPiece(id);
|
||||||
|
|
||||||
// Convert `▁` to a regular space.
|
// Convert `▁` to a regular space.
|
||||||
if (std.mem.startsWith(u8, piece, Normalizer.space_symbol)) {
|
if (escaped) |escspc| {
|
||||||
piece = piece[Normalizer.space_symbol.len..];
|
// we modify piece inside the loop, so we can use it in the condition
|
||||||
|
while (std.mem.startsWith(u8, piece, escaped.?)) {
|
||||||
// don't output a space at beginning of text.
|
piece = piece[escspc.len..];
|
||||||
if (output.items.len > 0) try output.append(' ');
|
// don't output a space at beginning of text.
|
||||||
|
if (output.items.len > 0) try output.append(' ');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try output.appendSlice(piece);
|
try output.appendSlice(piece);
|
||||||
@ -349,7 +359,7 @@ test Tokenizer {
|
|||||||
.eos = 2,
|
.eos = 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
var tokenizer = try Tokenizer.init(allocator, 10, 5, .{}, special_tokens, true);
|
var tokenizer = try Tokenizer.init(allocator, 10, 5, null, special_tokens, true);
|
||||||
defer tokenizer.deinit();
|
defer tokenizer.deinit();
|
||||||
|
|
||||||
try tokenizer.addToken(10, "hello");
|
try tokenizer.addToken(10, "hello");
|
||||||
@ -446,23 +456,37 @@ test CharTokenIterator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Text normalizer. Most tokenizer assumes the input text have been prepocessed
|
/// Text normalizer.
|
||||||
/// with on of those.
|
/// Most tokenizer assumes the input text have been prepocessed with on of those.
|
||||||
pub const Normalizer = struct {
|
pub const Normalizer = struct {
|
||||||
pub const space_symbol = "▁"; // \xe2\x96\x81
|
/// Space token used by sentencepiece derived tokenizer.
|
||||||
|
pub const sentencepiece_space = "▁"; // \xe2\x96\x81
|
||||||
|
|
||||||
|
_whitespace: std.BoundedArray(u8, 8) = .{},
|
||||||
|
|
||||||
flags: packed struct {
|
flags: packed struct {
|
||||||
escape_whitespaces: bool = true,
|
remove_extra_whitespaces: bool,
|
||||||
remove_extra_whitespaces: bool = true,
|
add_dummy_prefix: bool,
|
||||||
add_dummy_prefix: bool = true,
|
add_dummy_suffix: bool,
|
||||||
add_dummy_suffix: bool = false,
|
|
||||||
/// Cheap lower casing.
|
/// Cheap lower casing.
|
||||||
/// TODO: try to match Python "lower"
|
/// TODO: try to match Python "lower"
|
||||||
lower_case_ascii: bool = false,
|
lower_case_ascii: bool,
|
||||||
/// cheap ascii punct splitting.
|
/// cheap ascii punct splitting.
|
||||||
// doing this processing ahead of time simplifies the logic
|
// doing this processing ahead of time simplifies the logic
|
||||||
split_on_punct_ascii: bool = false,
|
split_on_punct_ascii: bool,
|
||||||
} = .{},
|
},
|
||||||
|
|
||||||
|
pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer {
|
||||||
|
var res: Normalizer = .{ .flags = flags };
|
||||||
|
if (escaped_whitespace) |escaped| {
|
||||||
|
res._whitespace.appendSliceAssumeCapacity(escaped);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub inline fn escapedSpace(self: Normalizer) ?[]const u8 {
|
||||||
|
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
||||||
|
}
|
||||||
|
|
||||||
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
|
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
|
||||||
try normalized.appendSlice(data);
|
try normalized.appendSlice(data);
|
||||||
@ -512,7 +536,7 @@ pub const Normalizer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pre-allocate outputs
|
// Pre-allocate outputs
|
||||||
const space = if (self.flags.escape_whitespaces) Normalizer.space_symbol else " ";
|
const space = self.escapedSpace() orelse " ";
|
||||||
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
||||||
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
||||||
errdefer normalized.deinit();
|
errdefer normalized.deinit();
|
||||||
@ -545,7 +569,7 @@ pub const Normalizer = struct {
|
|||||||
if (slice.len == 1) ascii: {
|
if (slice.len == 1) ascii: {
|
||||||
// The more advanced logic only works with ascii atm
|
// The more advanced logic only works with ascii atm
|
||||||
var byte = slice[0];
|
var byte = slice[0];
|
||||||
if (self.flags.escape_whitespaces and byte == ' ') {
|
if (self.escapedSpace() != null and byte == ' ') {
|
||||||
// replace the space token by the special token
|
// replace the space token by the special token
|
||||||
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
||||||
is_prev_word = false;
|
is_prev_word = false;
|
||||||
@ -592,26 +616,106 @@ pub const Normalizer = struct {
|
|||||||
|
|
||||||
pub fn wellKnown(impl: KnownImplementation) Normalizer {
|
pub fn wellKnown(impl: KnownImplementation) Normalizer {
|
||||||
return switch (impl) {
|
return switch (impl) {
|
||||||
.sentencepiece => .{ .flags = .{
|
.sentencepiece => init(.{
|
||||||
.escape_whitespaces = true,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
.add_dummy_suffix = false,
|
.add_dummy_suffix = false,
|
||||||
.lower_case_ascii = false,
|
.lower_case_ascii = false,
|
||||||
.split_on_punct_ascii = false,
|
.split_on_punct_ascii = false,
|
||||||
} },
|
}, sentencepiece_space),
|
||||||
.gpt2 => .{ .flags = .{
|
.gpt2 => init(.{
|
||||||
.escape_whitespaces = false,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
.add_dummy_suffix = false,
|
.add_dummy_suffix = false,
|
||||||
.lower_case_ascii = false,
|
.lower_case_ascii = false,
|
||||||
.split_on_punct_ascii = false,
|
.split_on_punct_ascii = false,
|
||||||
} },
|
}, null),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
|
pub fn fromHfJson(config: std.json.ObjectMap) error{InvalidNormalizerJson}!Normalizer {
|
||||||
|
var normalizer: Normalizer = .{ .flags = .{
|
||||||
|
.remove_extra_whitespaces = false,
|
||||||
|
.add_dummy_suffix = false,
|
||||||
|
.add_dummy_prefix = false,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
|
} };
|
||||||
|
|
||||||
|
// Normalizer config can be a single normalizer, or a sequence of normalizers.
|
||||||
|
const maybe_steps = object_get(config, .array, "normalizers");
|
||||||
|
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
|
||||||
|
|
||||||
|
for (steps) |step_val| {
|
||||||
|
if (step_val != .object) {
|
||||||
|
return error.InvalidNormalizerJson;
|
||||||
|
}
|
||||||
|
const step = step_val.object;
|
||||||
|
|
||||||
|
const step_type = object_get(step, .string, "type") orelse {
|
||||||
|
return error.InvalidNormalizerJson;
|
||||||
|
};
|
||||||
|
if (std.mem.eql(u8, "Prepend", step_type)) {
|
||||||
|
normalizer.flags.add_dummy_prefix = true;
|
||||||
|
} else if (std.mem.eql(u8, "Append", step_type)) {
|
||||||
|
normalizer.flags.add_dummy_suffix = true;
|
||||||
|
} else if (std.mem.eql(u8, "Replace", step_type)) {
|
||||||
|
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
||||||
|
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
||||||
|
|
||||||
|
if (std.mem.eql(u8, str_pattern, " ")) {
|
||||||
|
normalizer._whitespace.appendSliceAssumeCapacity(
|
||||||
|
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.warn("Unknown normalizer type: {s}", .{step_type});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Normalizer.fromHfJson" {
|
||||||
|
const config_json =
|
||||||
|
\\{
|
||||||
|
\\ "type": "Sequence",
|
||||||
|
\\ "normalizers": [
|
||||||
|
\\ {
|
||||||
|
\\ "type": "Prepend",
|
||||||
|
\\ "prepend": "▁"
|
||||||
|
\\ },
|
||||||
|
\\ {
|
||||||
|
\\ "type": "Replace",
|
||||||
|
\\ "pattern": {
|
||||||
|
\\ "String": " "
|
||||||
|
\\ },
|
||||||
|
\\ "content": "▁"
|
||||||
|
\\ }
|
||||||
|
\\ ]
|
||||||
|
\\}
|
||||||
|
;
|
||||||
|
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
const config = try std.json.parseFromSliceLeaky(std.json.Value, arena.allocator(), config_json, .{});
|
||||||
|
const normalizer = try Normalizer.fromHfJson(config.object);
|
||||||
|
|
||||||
|
const expected = Normalizer{
|
||||||
|
._whitespace = .{ .buffer = [_]u8{ 0xe2, 0x96, 0x81 } ++ [_]u8{0} ** 5, .len = 3 },
|
||||||
|
.flags = .{
|
||||||
|
.remove_extra_whitespaces = false,
|
||||||
|
.add_dummy_prefix = true,
|
||||||
|
.add_dummy_suffix = false,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
try std.testing.expectEqual(expected.flags, normalizer.flags);
|
||||||
|
try std.testing.expectEqualStrings(expected.escapedSpace().?, normalizer.escapedSpace().?);
|
||||||
|
}
|
||||||
|
};
|
||||||
pub const KnownImplementation = enum(u8) {
|
pub const KnownImplementation = enum(u8) {
|
||||||
sentencepiece,
|
sentencepiece,
|
||||||
gpt2,
|
gpt2,
|
||||||
@ -634,14 +738,13 @@ fn isPunct(unicode_char: []const u8) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test Normalizer {
|
test Normalizer {
|
||||||
try testing.expectEqualSlices(u8, "▁", Normalizer.space_symbol);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
const n: Normalizer = .{ .flags = .{
|
const n: Normalizer = .{ .flags = .{
|
||||||
.escape_whitespaces = false,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
.add_dummy_suffix = false,
|
.add_dummy_suffix = false,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
} };
|
} };
|
||||||
const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!");
|
const res = try n.normalizeWithMapping(testing.allocator, "Hellŏ world!");
|
||||||
defer res.deinit(testing.allocator);
|
defer res.deinit(testing.allocator);
|
||||||
@ -657,10 +760,11 @@ test Normalizer {
|
|||||||
|
|
||||||
{
|
{
|
||||||
const n: Normalizer = .{ .flags = .{
|
const n: Normalizer = .{ .flags = .{
|
||||||
.escape_whitespaces = false,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
.add_dummy_suffix = true,
|
.add_dummy_suffix = true,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
} };
|
} };
|
||||||
const res = try n.normalize(testing.allocator, "Hello world!");
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
||||||
defer testing.allocator.free(res);
|
defer testing.allocator.free(res);
|
||||||
@ -669,12 +773,16 @@ test Normalizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const n: Normalizer = .{ .flags = .{
|
const n = Normalizer.init(
|
||||||
.escape_whitespaces = true,
|
.{
|
||||||
.remove_extra_whitespaces = false,
|
.remove_extra_whitespaces = false,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
.add_dummy_suffix = false,
|
.add_dummy_suffix = false,
|
||||||
} };
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
|
},
|
||||||
|
Normalizer.sentencepiece_space,
|
||||||
|
);
|
||||||
const res = try n.normalize(testing.allocator, "Hello world!");
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
||||||
defer testing.allocator.free(res);
|
defer testing.allocator.free(res);
|
||||||
|
|
||||||
@ -683,11 +791,11 @@ test Normalizer {
|
|||||||
|
|
||||||
{
|
{
|
||||||
const n: Normalizer = .{ .flags = .{
|
const n: Normalizer = .{ .flags = .{
|
||||||
.escape_whitespaces = false,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = false,
|
.add_dummy_prefix = false,
|
||||||
.add_dummy_suffix = true,
|
.add_dummy_suffix = true,
|
||||||
.lower_case_ascii = true,
|
.lower_case_ascii = true,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
} };
|
} };
|
||||||
const res = try n.normalize(testing.allocator, "Hello world!");
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
||||||
defer testing.allocator.free(res);
|
defer testing.allocator.free(res);
|
||||||
@ -697,10 +805,10 @@ test Normalizer {
|
|||||||
|
|
||||||
{
|
{
|
||||||
const n: Normalizer = .{ .flags = .{
|
const n: Normalizer = .{ .flags = .{
|
||||||
.escape_whitespaces = false,
|
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = false,
|
.add_dummy_prefix = false,
|
||||||
.add_dummy_suffix = true,
|
.add_dummy_suffix = true,
|
||||||
|
.lower_case_ascii = false,
|
||||||
.split_on_punct_ascii = true,
|
.split_on_punct_ascii = true,
|
||||||
} };
|
} };
|
||||||
const res = try n.normalize(testing.allocator, "Hello world!");
|
const res = try n.normalize(testing.allocator, "Hello world!");
|
||||||
@ -795,15 +903,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
const file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
const file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
||||||
defer file.close();
|
defer file.close();
|
||||||
|
|
||||||
const file_content = try file.readToEndAlloc(allocator, 32 * 1024 * 1024);
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||||
defer allocator.free(file_content);
|
defer arena_state.deinit();
|
||||||
// TODO create local arena and use parseFromSliceLeaky.
|
const arena = arena_state.allocator();
|
||||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, file_content, .{
|
const file_content = try file.readToEndAlloc(arena, 32 * 1024 * 1024);
|
||||||
|
|
||||||
|
const info = try std.json.parseFromSliceLeaky(std.json.Value, arena, file_content, .{
|
||||||
.duplicate_field_behavior = .use_last,
|
.duplicate_field_behavior = .use_last,
|
||||||
});
|
});
|
||||||
defer parsed.deinit();
|
|
||||||
const info = parsed.value;
|
|
||||||
|
|
||||||
const main_object = switch (info) {
|
const main_object = switch (info) {
|
||||||
.object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) {
|
.object => |obj| if (obj.get("added_tokens") == null or obj.get("model") == null) {
|
||||||
return error.InvalidFormat;
|
return error.InvalidFormat;
|
||||||
@ -816,10 +923,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
||||||
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
||||||
|
|
||||||
const normalizer = Normalizer.wellKnown(.gpt2);
|
const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config|
|
||||||
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
|
try Normalizer.fromHfJson(normalizer_config)
|
||||||
defer gpt2_decoder.deinit();
|
else
|
||||||
|
Normalizer.wellKnown(.gpt2);
|
||||||
|
|
||||||
|
// delay init of special tokens.
|
||||||
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
||||||
errdefer tokenizer.deinit();
|
errdefer tokenizer.deinit();
|
||||||
|
|
||||||
@ -833,12 +942,14 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
std.debug.assert(all_tokens.items.ptr == original_alloc);
|
std.debug.assert(all_tokens.items.ptr == original_alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
var it = vocab.iterator();
|
|
||||||
// gpt2 based tokenizer got a special way of encoding unicode.
|
// gpt2 based tokenizer got a special way of encoding unicode.
|
||||||
// we don't know in advance if this will be used by this tokenizer or not.
|
// we don't know in advance if this will be used by this tokenizer or not.
|
||||||
// so we assume it is the case, but if we find some unicode character,
|
// so we assume it is the case, but if we find some unicode character,
|
||||||
// outside of the range used by gpt2 we know it was wrong, and start over.
|
// outside of the range used by gpt2 we know it was wrong, and start over.
|
||||||
var is_gpt2_vocab: bool = true;
|
var is_gpt2_vocab: bool = true;
|
||||||
|
var gpt2_decoder = try Gpt2TextDecoder.init(allocator);
|
||||||
|
defer gpt2_decoder.deinit();
|
||||||
|
var it = vocab.iterator();
|
||||||
while (it.next()) |kv| {
|
while (it.next()) |kv| {
|
||||||
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
|
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
|
||||||
switch (err) {
|
switch (err) {
|
||||||
@ -901,3 +1012,14 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
|||||||
try buffer.appendSlice(str);
|
try buffer.appendSlice(str);
|
||||||
return buffer.items[n..];
|
return buffer.items[n..];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the given entry in a json object only if it has the right type.
|
||||||
|
fn object_get(
|
||||||
|
object: std.json.ObjectMap,
|
||||||
|
comptime kind: std.meta.FieldEnum(std.json.Value),
|
||||||
|
key: []const u8,
|
||||||
|
) ?std.meta.FieldType(std.json.Value, kind) {
|
||||||
|
const val = object.get(key) orelse return null;
|
||||||
|
if (val != kind) return null;
|
||||||
|
return @field(val, @tagName(kind));
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user