Update tokenizer to handle byte_fallback for Llama3 GPT2 vocab and add a Llama3‑specific normalizer; adjust tinyllama.zig and hostbuffer.zig to use the new tokenization logic.

This commit is contained in:
Tarry Singh 2024-02-05 15:22:44 +00:00
parent b643f7bc53
commit b8a0aaee5a
3 changed files with 59 additions and 29 deletions

View File

@ -112,7 +112,7 @@ fn newBuff(store: *zml.aio.BufferStore, name: []const u8, sh: anytype, offset: u
const n = shape.byteSize(); const n = shape.byteSize();
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]); const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]);
store.buffers.putAssumeCapacityNoClobber(name, buff); store.buffers.putAssumeCapacityNoClobber(name, buff);
zml.log.info("Found {s}: {}", .{ name, shape }); zml.log.debug("Found {s}: {}", .{ name, shape });
return offset + n; return offset + n;
} }
@ -125,7 +125,7 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype,
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]); const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]);
store.buffers.putAssumeCapacityNoClobber(name, buff); store.buffers.putAssumeCapacityNoClobber(name, buff);
off += n; off += n;
if (i == 0) zml.log.info("Found {s}: {}", .{ name, shape }); if (i == 0) zml.log.debug("Found {s}: {}", .{ name, shape });
} }
return off; return off;
} }

View File

@ -285,14 +285,20 @@ pub const HostBuffer = struct {
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void { fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
if (self.rank() == 1) { if (self.rank() == 1) {
try writer.writeByteNTimes(' ', indent_level); try writer.writeByteNTimes(' ', indent_level);
switch (self.dtype()) { return switch (self.dtype()) {
inline else => |dt| { inline else => |dt| {
const values = self.items(dt.toZigType()); const values = self.items(dt.toZigType());
const n = @min(values.len, 1024); // Write first rows
const num_cols: u32 = 12;
const n: u64 = @intCast(self.dim(0));
if (n <= num_cols) {
try writer.print("{any},\n", .{values[0..n]}); try writer.print("{any},\n", .{values[0..n]});
}, } else {
const half = @divExact(num_cols, 2);
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
} }
return; },
};
} }
try writer.writeByteNTimes(' ', indent_level); try writer.writeByteNTimes(' ', indent_level);
_ = try writer.write("{\n"); _ = try writer.write("{\n");

View File

@ -164,7 +164,8 @@ pub const Tokenizer = struct {
// Step by step visualization of the progress. // Step by step visualization of the progress.
if (options.debug) { if (options.debug) {
var _debug_buf: [256]u8 = undefined; var _debug_buf: [256]u8 = undefined;
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined }; var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
} }
@ -305,7 +306,7 @@ pub const Tokenizer = struct {
// Convert `` to a regular space. // Convert `` to a regular space.
if (escaped) |escspc| { if (escaped) |escspc| {
// we modify piece inside the loop, so we can use it in the condition // we modify piece inside the loop, so we can use it in the condition
while (std.mem.startsWith(u8, piece, escaped.?)) { while (std.mem.startsWith(u8, piece, escspc)) {
piece = piece[escspc.len..]; piece = piece[escspc.len..];
// don't output a space at beginning of text. // don't output a space at beginning of text.
if (output.items.len > 0) try output.append(' '); if (output.items.len > 0) try output.append(' ');
@ -625,6 +626,13 @@ pub const Normalizer = struct {
.lower_case_ascii = false, .lower_case_ascii = false,
.split_on_punct_ascii = false, .split_on_punct_ascii = false,
}, sentencepiece_space), }, sentencepiece_space),
.llama3 => init(.{
.remove_extra_whitespaces = true,
.add_dummy_prefix = false,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
}, null),
.gpt2 => init(.{ .gpt2 => init(.{
.remove_extra_whitespaces = true, .remove_extra_whitespaces = true,
.add_dummy_prefix = true, .add_dummy_prefix = true,
@ -645,7 +653,7 @@ pub const Normalizer = struct {
} }; } };
// Normalizer config can be a single normalizer, or a sequence of normalizers. // Normalizer config can be a single normalizer, or a sequence of normalizers.
const maybe_steps = object_get(config, .array, "normalizers"); const maybe_steps = objectGet(config, .array, "normalizers");
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }}; const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
for (steps) |step_val| { for (steps) |step_val| {
@ -654,7 +662,7 @@ pub const Normalizer = struct {
} }
const step = step_val.object; const step = step_val.object;
const step_type = object_get(step, .string, "type") orelse { const step_type = objectGet(step, .string, "type") orelse {
return error.InvalidNormalizerJson; return error.InvalidNormalizerJson;
}; };
if (std.mem.eql(u8, "Prepend", step_type)) { if (std.mem.eql(u8, "Prepend", step_type)) {
@ -662,15 +670,15 @@ pub const Normalizer = struct {
} else if (std.mem.eql(u8, "Append", step_type)) { } else if (std.mem.eql(u8, "Append", step_type)) {
normalizer.flags.add_dummy_suffix = true; normalizer.flags.add_dummy_suffix = true;
} else if (std.mem.eql(u8, "Replace", step_type)) { } else if (std.mem.eql(u8, "Replace", step_type)) {
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson; const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson; const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
if (std.mem.eql(u8, str_pattern, " ")) { if (std.mem.eql(u8, str_pattern, " ")) {
normalizer._whitespace.appendSliceAssumeCapacity( normalizer._whitespace.appendSliceAssumeCapacity(
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson, objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson,
); );
} else { } else {
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" }); log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" });
} }
} else { } else {
log.warn("Unknown normalizer type: {s}", .{step_type}); log.warn("Unknown normalizer type: {s}", .{step_type});
@ -721,6 +729,7 @@ pub const Normalizer = struct {
pub const KnownImplementation = enum(u8) { pub const KnownImplementation = enum(u8) {
sentencepiece, sentencepiece,
gpt2, gpt2,
llama3,
}; };
fn isPunct(unicode_char: []const u8) bool { fn isPunct(unicode_char: []const u8) bool {
@ -920,15 +929,15 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
else => return error.InvalidFormat, else => return error.InvalidFormat,
}; };
// TODO: remove all panics const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat;
const added_tokens = main_object.get("added_tokens").?.array; const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat;
const vocab = main_object.get("model").?.object.get("vocab").?.object; const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{};
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len); const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len);
const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config| const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config|
try Normalizer.fromHfJson(normalizer_config) try Normalizer.fromHfJson(normalizer_config)
else else
Normalizer.wellKnown(.gpt2); Normalizer.wellKnown(.llama3);
// delay init of special tokens. // delay init of special tokens.
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
@ -981,9 +990,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
} }
// More tokens, typically added during fine tuning of the model. // More tokens, typically added during fine tuning of the model.
for (added_tokens.items) |token_obj| { for (added_tokens) |token_obj| {
const v = token_obj.object.get("content").?.string; if (token_obj != .object) return error.InvalidFormat;
const id: u32 = @intCast(token_obj.object.get("id").?.integer); const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
const token = try if (is_gpt2_vocab) const token = try if (is_gpt2_vocab)
gpt2_decoder.decode(&all_tokens, v) gpt2_decoder.decode(&all_tokens, v)
else else
@ -994,16 +1004,30 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
// We won't add more tokens here, let release. // We won't add more tokens here, let release.
all_tokens.shrinkAndFree(all_tokens.items.len); all_tokens.shrinkAndFree(all_tokens.items.len);
var unk = tokenizer.lookup("<unk>");
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
unk = @intCast(unk_tok);
}
tokenizer.special_tokens = .{ tokenizer.special_tokens = .{
// TODO allow users to specify special tokens or read them from a tokenizer_config.json file
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"), .bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"), .eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32), .unk = unk orelse std.math.maxInt(u32),
}; };
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| { const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false;
if (byte_fallback == .bool and byte_fallback.bool) { if (!byte_fallback and unk == null) {
try tokenizer.rewriteByteFallbackTokens(); // GPT2 tokenizer have byte fallback already encoded in the model,
// but the json generally don't have the field set.
// We can detect it though because they don't specify an unknown token.
if (is_gpt2_vocab) {
tokenizer.byte_fallback = true;
} else {
log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{});
} }
} else if (byte_fallback) {
try tokenizer.rewriteByteFallbackTokens();
} }
return tokenizer; return tokenizer;
} }
@ -1016,7 +1040,7 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
} }
/// Returns the given entry in a json object only if it has the right type. /// Returns the given entry in a json object only if it has the right type.
fn object_get( fn objectGet(
object: std.json.ObjectMap, object: std.json.ObjectMap,
comptime kind: std.meta.FieldEnum(std.json.Value), comptime kind: std.meta.FieldEnum(std.json.Value),
key: []const u8, key: []const u8,