Update tokenizer to handle byte_fallback for Llama3 GPT2 vocab and add a Llama3‑specific normalizer; adjust tinyllama.zig and hostbuffer.zig to use the new tokenization logic.

This commit is contained in:
Tarry Singh 2024-02-05 15:22:44 +00:00
parent b643f7bc53
commit b8a0aaee5a
3 changed files with 59 additions and 29 deletions

View File

@ -112,7 +112,7 @@ fn newBuff(store: *zml.aio.BufferStore, name: []const u8, sh: anytype, offset: u
const n = shape.byteSize();
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]);
store.buffers.putAssumeCapacityNoClobber(name, buff);
zml.log.info("Found {s}: {}", .{ name, shape });
zml.log.debug("Found {s}: {}", .{ name, shape });
return offset + n;
}
@ -125,7 +125,7 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype,
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]);
store.buffers.putAssumeCapacityNoClobber(name, buff);
off += n;
if (i == 0) zml.log.info("Found {s}: {}", .{ name, shape });
if (i == 0) zml.log.debug("Found {s}: {}", .{ name, shape });
}
return off;
}

View File

@ -285,14 +285,20 @@ pub const HostBuffer = struct {
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
if (self.rank() == 1) {
try writer.writeByteNTimes(' ', indent_level);
switch (self.dtype()) {
return switch (self.dtype()) {
inline else => |dt| {
const values = self.items(dt.toZigType());
const n = @min(values.len, 1024);
try writer.print("{any},\n", .{values[0..n]});
// Write first rows
const num_cols: u32 = 12;
const n: u64 = @intCast(self.dim(0));
if (n <= num_cols) {
try writer.print("{any},\n", .{values[0..n]});
} else {
const half = @divExact(num_cols, 2);
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
}
},
}
return;
};
}
try writer.writeByteNTimes(' ', indent_level);
_ = try writer.write("{\n");

View File

@ -164,7 +164,8 @@ pub const Tokenizer = struct {
// Step by step visualization of the progress.
if (options.debug) {
var _debug_buf: [256]u8 = undefined;
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined };
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
}
@ -305,7 +306,7 @@ pub const Tokenizer = struct {
// Convert `` to a regular space.
if (escaped) |escspc| {
// we modify piece inside the loop, so we can use it in the condition
while (std.mem.startsWith(u8, piece, escaped.?)) {
while (std.mem.startsWith(u8, piece, escspc)) {
piece = piece[escspc.len..];
// don't output a space at beginning of text.
if (output.items.len > 0) try output.append(' ');
@ -625,6 +626,13 @@ pub const Normalizer = struct {
.lower_case_ascii = false,
.split_on_punct_ascii = false,
}, sentencepiece_space),
.llama3 => init(.{
.remove_extra_whitespaces = true,
.add_dummy_prefix = false,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
}, null),
.gpt2 => init(.{
.remove_extra_whitespaces = true,
.add_dummy_prefix = true,
@ -645,7 +653,7 @@ pub const Normalizer = struct {
} };
// Normalizer config can be a single normalizer, or a sequence of normalizers.
const maybe_steps = object_get(config, .array, "normalizers");
const maybe_steps = objectGet(config, .array, "normalizers");
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
for (steps) |step_val| {
@ -654,7 +662,7 @@ pub const Normalizer = struct {
}
const step = step_val.object;
const step_type = object_get(step, .string, "type") orelse {
const step_type = objectGet(step, .string, "type") orelse {
return error.InvalidNormalizerJson;
};
if (std.mem.eql(u8, "Prepend", step_type)) {
@ -662,15 +670,15 @@ pub const Normalizer = struct {
} else if (std.mem.eql(u8, "Append", step_type)) {
normalizer.flags.add_dummy_suffix = true;
} else if (std.mem.eql(u8, "Replace", step_type)) {
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
if (std.mem.eql(u8, str_pattern, " ")) {
normalizer._whitespace.appendSliceAssumeCapacity(
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson,
objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson,
);
} else {
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" });
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" });
}
} else {
log.warn("Unknown normalizer type: {s}", .{step_type});
@ -721,6 +729,7 @@ pub const Normalizer = struct {
pub const KnownImplementation = enum(u8) {
sentencepiece,
gpt2,
llama3,
};
fn isPunct(unicode_char: []const u8) bool {
@ -920,15 +929,15 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
else => return error.InvalidFormat,
};
// TODO: remove all panics
const added_tokens = main_object.get("added_tokens").?.array;
const vocab = main_object.get("model").?.object.get("vocab").?.object;
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat;
const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat;
const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{};
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len);
const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config|
const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config|
try Normalizer.fromHfJson(normalizer_config)
else
Normalizer.wellKnown(.gpt2);
Normalizer.wellKnown(.llama3);
// delay init of special tokens.
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
@ -981,9 +990,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
}
// More tokens, typically added during fine tuning of the model.
for (added_tokens.items) |token_obj| {
const v = token_obj.object.get("content").?.string;
const id: u32 = @intCast(token_obj.object.get("id").?.integer);
for (added_tokens) |token_obj| {
if (token_obj != .object) return error.InvalidFormat;
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
const token = try if (is_gpt2_vocab)
gpt2_decoder.decode(&all_tokens, v)
else
@ -994,16 +1004,30 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
// We won't add more tokens here, let release.
all_tokens.shrinkAndFree(all_tokens.items.len);
var unk = tokenizer.lookup("<unk>");
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
unk = @intCast(unk_tok);
}
tokenizer.special_tokens = .{
// TODO allow users to specify special tokens or read them from a tokenizer_config.json file
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
.unk = unk orelse std.math.maxInt(u32),
};
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| {
if (byte_fallback == .bool and byte_fallback.bool) {
try tokenizer.rewriteByteFallbackTokens();
const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false;
if (!byte_fallback and unk == null) {
// GPT2 tokenizer have byte fallback already encoded in the model,
// but the json generally don't have the field set.
// We can detect it though because they don't specify an unknown token.
if (is_gpt2_vocab) {
tokenizer.byte_fallback = true;
} else {
log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{});
}
} else if (byte_fallback) {
try tokenizer.rewriteByteFallbackTokens();
}
return tokenizer;
}
@ -1016,7 +1040,7 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
}
/// Returns the given entry in a json object only if it has the right type.
fn object_get(
fn objectGet(
object: std.json.ObjectMap,
comptime kind: std.meta.FieldEnum(std.json.Value),
key: []const u8,