From b8a0aaee5afbe9218d0f99207d6e6b2485427424 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 5 Feb 2024 15:22:44 +0000 Subject: [PATCH] =?UTF-8?q?Update=20tokenizer=20to=20handle=20byte=5Ffallb?= =?UTF-8?q?ack=20for=20Llama3=20GPT2=20vocab=20and=20add=20a=20Llama3?= =?UTF-8?q?=E2=80=91specific=20normalizer;=20adjust=20tinyllama.zig=20and?= =?UTF-8?q?=20hostbuffer.zig=20to=20use=20the=20new=20tokenization=20logic?= =?UTF-8?q?.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/aio/tinyllama.zig | 4 +-- zml/hostbuffer.zig | 16 ++++++---- zml/tokenizer.zig | 68 +++++++++++++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index f50c246..c39f3e2 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -112,7 +112,7 @@ fn newBuff(store: *zml.aio.BufferStore, name: []const u8, sh: anytype, offset: u const n = shape.byteSize(); const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]); store.buffers.putAssumeCapacityNoClobber(name, buff); - zml.log.info("Found {s}: {}", .{ name, shape }); + zml.log.debug("Found {s}: {}", .{ name, shape }); return offset + n; } @@ -125,7 +125,7 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype, const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]); store.buffers.putAssumeCapacityNoClobber(name, buff); off += n; - if (i == 0) zml.log.info("Found {s}: {}", .{ name, shape }); + if (i == 0) zml.log.debug("Found {s}: {}", .{ name, shape }); } return off; } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index a477de4..34531b9 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -285,14 +285,20 @@ pub const HostBuffer = struct { fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void { if (self.rank() == 1) { try writer.writeByteNTimes(' ', indent_level); - switch (self.dtype()) { + return switch (self.dtype()) { inline else => |dt| { const values = self.items(dt.toZigType()); - const n = @min(values.len, 1024); - try writer.print("{any},\n", .{values[0..n]}); + // Write first rows + const num_cols: u32 = 12; + const n: u64 = @intCast(self.dim(0)); + if (n <= num_cols) { + try writer.print("{any},\n", .{values[0..n]}); + } else { + const half = @divExact(num_cols, 2); + try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] }); + } }, - } - return; + }; } try writer.writeByteNTimes(' ', indent_level); _ = try writer.write("{\n"); diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index 07c684a..dbe2046 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -164,7 +164,8 @@ pub const Tokenizer = struct { // Step by step visualization of the progress. if (options.debug) { var _debug_buf: [256]u8 = undefined; - var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined }; + var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); + var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator()); self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); } @@ -305,7 +306,7 @@ pub const Tokenizer = struct { // Convert `▁` to a regular space. if (escaped) |escspc| { // we modify piece inside the loop, so we can use it in the condition - while (std.mem.startsWith(u8, piece, escaped.?)) { + while (std.mem.startsWith(u8, piece, escspc)) { piece = piece[escspc.len..]; // don't output a space at beginning of text. if (output.items.len > 0) try output.append(' '); @@ -625,6 +626,13 @@ pub const Normalizer = struct { .lower_case_ascii = false, .split_on_punct_ascii = false, }, sentencepiece_space), + .llama3 => init(.{ + .remove_extra_whitespaces = true, + .add_dummy_prefix = false, + .add_dummy_suffix = false, + .lower_case_ascii = false, + .split_on_punct_ascii = false, + }, null), .gpt2 => init(.{ .remove_extra_whitespaces = true, .add_dummy_prefix = true, @@ -645,7 +653,7 @@ pub const Normalizer = struct { } }; // Normalizer config can be a single normalizer, or a sequence of normalizers. - const maybe_steps = object_get(config, .array, "normalizers"); + const maybe_steps = objectGet(config, .array, "normalizers"); const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }}; for (steps) |step_val| { @@ -654,7 +662,7 @@ pub const Normalizer = struct { } const step = step_val.object; - const step_type = object_get(step, .string, "type") orelse { + const step_type = objectGet(step, .string, "type") orelse { return error.InvalidNormalizerJson; }; if (std.mem.eql(u8, "Prepend", step_type)) { @@ -662,15 +670,15 @@ pub const Normalizer = struct { } else if (std.mem.eql(u8, "Append", step_type)) { normalizer.flags.add_dummy_suffix = true; } else if (std.mem.eql(u8, "Replace", step_type)) { - const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson; - const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson; + const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson; + const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson; if (std.mem.eql(u8, str_pattern, " ")) { normalizer._whitespace.appendSliceAssumeCapacity( - object_get(step, .string, "content") orelse return error.InvalidNormalizerJson, + objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson, ); } else { - log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" }); + log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" }); } } else { log.warn("Unknown normalizer type: {s}", .{step_type}); @@ -721,6 +729,7 @@ pub const Normalizer = struct { pub const KnownImplementation = enum(u8) { sentencepiece, gpt2, + llama3, }; fn isPunct(unicode_char: []const u8) bool { @@ -920,15 +929,15 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok else => return error.InvalidFormat, }; - // TODO: remove all panics - const added_tokens = main_object.get("added_tokens").?.array; - const vocab = main_object.get("model").?.object.get("vocab").?.object; - const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len); + const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat; + const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat; + const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{}; + const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len); - const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config| + const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config| try Normalizer.fromHfJson(normalizer_config) else - Normalizer.wellKnown(.gpt2); + Normalizer.wellKnown(.llama3); // delay init of special tokens. var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true); @@ -981,9 +990,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok } // More tokens, typically added during fine tuning of the model. - for (added_tokens.items) |token_obj| { - const v = token_obj.object.get("content").?.string; - const id: u32 = @intCast(token_obj.object.get("id").?.integer); + for (added_tokens) |token_obj| { + if (token_obj != .object) return error.InvalidFormat; + const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat; + const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat); const token = try if (is_gpt2_vocab) gpt2_decoder.decode(&all_tokens, v) else @@ -994,16 +1004,30 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok // We won't add more tokens here, let release. all_tokens.shrinkAndFree(all_tokens.items.len); + var unk = tokenizer.lookup(""); + if (objectGet(model, .integer, "unk_token")) |unk_tok| { + unk = @intCast(unk_tok); + } + tokenizer.special_tokens = .{ + // TODO allow users to specify special tokens or read them from a tokenizer_config.json file .bos = tokenizer.lookup("") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"), .eos = tokenizer.lookup("") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"), - .unk = tokenizer.lookup("") orelse std.math.maxInt(u32), + .unk = unk orelse std.math.maxInt(u32), }; - if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| { - if (byte_fallback == .bool and byte_fallback.bool) { - try tokenizer.rewriteByteFallbackTokens(); + const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false; + if (!byte_fallback and unk == null) { + // GPT2 tokenizer have byte fallback already encoded in the model, + // but the json generally don't have the field set. + // We can detect it though because they don't specify an unknown token. + if (is_gpt2_vocab) { + tokenizer.byte_fallback = true; + } else { + log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{}); } + } else if (byte_fallback) { + try tokenizer.rewriteByteFallbackTokens(); } return tokenizer; } @@ -1016,7 +1040,7 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { } /// Returns the given entry in a json object only if it has the right type. -fn object_get( +fn objectGet( object: std.json.ObjectMap, comptime kind: std.meta.FieldEnum(std.json.Value), key: []const u8,