Update tokenizer to handle byte_fallback for Llama3 GPT2 vocab and add a Llama3‑specific normalizer; adjust tinyllama.zig and hostbuffer.zig to use the new tokenization logic.
This commit is contained in:
parent
b643f7bc53
commit
b8a0aaee5a
@ -112,7 +112,7 @@ fn newBuff(store: *zml.aio.BufferStore, name: []const u8, sh: anytype, offset: u
|
||||
const n = shape.byteSize();
|
||||
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]);
|
||||
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
||||
zml.log.info("Found {s}: {}", .{ name, shape });
|
||||
zml.log.debug("Found {s}: {}", .{ name, shape });
|
||||
return offset + n;
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype,
|
||||
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]);
|
||||
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
||||
off += n;
|
||||
if (i == 0) zml.log.info("Found {s}: {}", .{ name, shape });
|
||||
if (i == 0) zml.log.debug("Found {s}: {}", .{ name, shape });
|
||||
}
|
||||
return off;
|
||||
}
|
||||
|
||||
@ -285,14 +285,20 @@ pub const HostBuffer = struct {
|
||||
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
|
||||
if (self.rank() == 1) {
|
||||
try writer.writeByteNTimes(' ', indent_level);
|
||||
switch (self.dtype()) {
|
||||
return switch (self.dtype()) {
|
||||
inline else => |dt| {
|
||||
const values = self.items(dt.toZigType());
|
||||
const n = @min(values.len, 1024);
|
||||
// Write first rows
|
||||
const num_cols: u32 = 12;
|
||||
const n: u64 = @intCast(self.dim(0));
|
||||
if (n <= num_cols) {
|
||||
try writer.print("{any},\n", .{values[0..n]});
|
||||
},
|
||||
} else {
|
||||
const half = @divExact(num_cols, 2);
|
||||
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
|
||||
}
|
||||
return;
|
||||
},
|
||||
};
|
||||
}
|
||||
try writer.writeByteNTimes(' ', indent_level);
|
||||
_ = try writer.write("{\n");
|
||||
|
||||
@ -164,7 +164,8 @@ pub const Tokenizer = struct {
|
||||
// Step by step visualization of the progress.
|
||||
if (options.debug) {
|
||||
var _debug_buf: [256]u8 = undefined;
|
||||
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined };
|
||||
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
||||
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
|
||||
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
||||
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
||||
}
|
||||
@ -305,7 +306,7 @@ pub const Tokenizer = struct {
|
||||
// Convert `▁` to a regular space.
|
||||
if (escaped) |escspc| {
|
||||
// we modify piece inside the loop, so we can use it in the condition
|
||||
while (std.mem.startsWith(u8, piece, escaped.?)) {
|
||||
while (std.mem.startsWith(u8, piece, escspc)) {
|
||||
piece = piece[escspc.len..];
|
||||
// don't output a space at beginning of text.
|
||||
if (output.items.len > 0) try output.append(' ');
|
||||
@ -625,6 +626,13 @@ pub const Normalizer = struct {
|
||||
.lower_case_ascii = false,
|
||||
.split_on_punct_ascii = false,
|
||||
}, sentencepiece_space),
|
||||
.llama3 => init(.{
|
||||
.remove_extra_whitespaces = true,
|
||||
.add_dummy_prefix = false,
|
||||
.add_dummy_suffix = false,
|
||||
.lower_case_ascii = false,
|
||||
.split_on_punct_ascii = false,
|
||||
}, null),
|
||||
.gpt2 => init(.{
|
||||
.remove_extra_whitespaces = true,
|
||||
.add_dummy_prefix = true,
|
||||
@ -645,7 +653,7 @@ pub const Normalizer = struct {
|
||||
} };
|
||||
|
||||
// Normalizer config can be a single normalizer, or a sequence of normalizers.
|
||||
const maybe_steps = object_get(config, .array, "normalizers");
|
||||
const maybe_steps = objectGet(config, .array, "normalizers");
|
||||
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
|
||||
|
||||
for (steps) |step_val| {
|
||||
@ -654,7 +662,7 @@ pub const Normalizer = struct {
|
||||
}
|
||||
const step = step_val.object;
|
||||
|
||||
const step_type = object_get(step, .string, "type") orelse {
|
||||
const step_type = objectGet(step, .string, "type") orelse {
|
||||
return error.InvalidNormalizerJson;
|
||||
};
|
||||
if (std.mem.eql(u8, "Prepend", step_type)) {
|
||||
@ -662,15 +670,15 @@ pub const Normalizer = struct {
|
||||
} else if (std.mem.eql(u8, "Append", step_type)) {
|
||||
normalizer.flags.add_dummy_suffix = true;
|
||||
} else if (std.mem.eql(u8, "Replace", step_type)) {
|
||||
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
||||
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
||||
const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
||||
const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
||||
|
||||
if (std.mem.eql(u8, str_pattern, " ")) {
|
||||
normalizer._whitespace.appendSliceAssumeCapacity(
|
||||
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
||||
objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
||||
);
|
||||
} else {
|
||||
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" });
|
||||
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" });
|
||||
}
|
||||
} else {
|
||||
log.warn("Unknown normalizer type: {s}", .{step_type});
|
||||
@ -721,6 +729,7 @@ pub const Normalizer = struct {
|
||||
pub const KnownImplementation = enum(u8) {
|
||||
sentencepiece,
|
||||
gpt2,
|
||||
llama3,
|
||||
};
|
||||
|
||||
fn isPunct(unicode_char: []const u8) bool {
|
||||
@ -920,15 +929,15 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
||||
else => return error.InvalidFormat,
|
||||
};
|
||||
|
||||
// TODO: remove all panics
|
||||
const added_tokens = main_object.get("added_tokens").?.array;
|
||||
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
||||
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
||||
const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat;
|
||||
const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat;
|
||||
const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{};
|
||||
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len);
|
||||
|
||||
const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config|
|
||||
const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config|
|
||||
try Normalizer.fromHfJson(normalizer_config)
|
||||
else
|
||||
Normalizer.wellKnown(.gpt2);
|
||||
Normalizer.wellKnown(.llama3);
|
||||
|
||||
// delay init of special tokens.
|
||||
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
||||
@ -981,9 +990,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
||||
}
|
||||
|
||||
// More tokens, typically added during fine tuning of the model.
|
||||
for (added_tokens.items) |token_obj| {
|
||||
const v = token_obj.object.get("content").?.string;
|
||||
const id: u32 = @intCast(token_obj.object.get("id").?.integer);
|
||||
for (added_tokens) |token_obj| {
|
||||
if (token_obj != .object) return error.InvalidFormat;
|
||||
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
|
||||
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
|
||||
const token = try if (is_gpt2_vocab)
|
||||
gpt2_decoder.decode(&all_tokens, v)
|
||||
else
|
||||
@ -994,16 +1004,30 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
||||
// We won't add more tokens here, let release.
|
||||
all_tokens.shrinkAndFree(all_tokens.items.len);
|
||||
|
||||
var unk = tokenizer.lookup("<unk>");
|
||||
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
|
||||
unk = @intCast(unk_tok);
|
||||
}
|
||||
|
||||
tokenizer.special_tokens = .{
|
||||
// TODO allow users to specify special tokens or read them from a tokenizer_config.json file
|
||||
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
||||
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
|
||||
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
|
||||
.unk = unk orelse std.math.maxInt(u32),
|
||||
};
|
||||
|
||||
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| {
|
||||
if (byte_fallback == .bool and byte_fallback.bool) {
|
||||
try tokenizer.rewriteByteFallbackTokens();
|
||||
const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false;
|
||||
if (!byte_fallback and unk == null) {
|
||||
// GPT2 tokenizer have byte fallback already encoded in the model,
|
||||
// but the json generally don't have the field set.
|
||||
// We can detect it though because they don't specify an unknown token.
|
||||
if (is_gpt2_vocab) {
|
||||
tokenizer.byte_fallback = true;
|
||||
} else {
|
||||
log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{});
|
||||
}
|
||||
} else if (byte_fallback) {
|
||||
try tokenizer.rewriteByteFallbackTokens();
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
@ -1016,7 +1040,7 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
||||
}
|
||||
|
||||
/// Returns the given entry in a json object only if it has the right type.
|
||||
fn object_get(
|
||||
fn objectGet(
|
||||
object: std.json.ObjectMap,
|
||||
comptime kind: std.meta.FieldEnum(std.json.Value),
|
||||
key: []const u8,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user