Update tokenizer to handle byte_fallback for Llama3 GPT2 vocab and add a Llama3‑specific normalizer; adjust tinyllama.zig and hostbuffer.zig to use the new tokenization logic.
This commit is contained in:
parent
b643f7bc53
commit
b8a0aaee5a
@ -112,7 +112,7 @@ fn newBuff(store: *zml.aio.BufferStore, name: []const u8, sh: anytype, offset: u
|
|||||||
const n = shape.byteSize();
|
const n = shape.byteSize();
|
||||||
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]);
|
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[offset..][0..n]);
|
||||||
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
||||||
zml.log.info("Found {s}: {}", .{ name, shape });
|
zml.log.debug("Found {s}: {}", .{ name, shape });
|
||||||
return offset + n;
|
return offset + n;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype,
|
|||||||
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]);
|
const buff = zml.HostBuffer.fromBytes(shape, store.files[0].data[off..][0..n]);
|
||||||
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
store.buffers.putAssumeCapacityNoClobber(name, buff);
|
||||||
off += n;
|
off += n;
|
||||||
if (i == 0) zml.log.info("Found {s}: {}", .{ name, shape });
|
if (i == 0) zml.log.debug("Found {s}: {}", .{ name, shape });
|
||||||
}
|
}
|
||||||
return off;
|
return off;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -285,14 +285,20 @@ pub const HostBuffer = struct {
|
|||||||
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
|
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
|
||||||
if (self.rank() == 1) {
|
if (self.rank() == 1) {
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
try writer.writeByteNTimes(' ', indent_level);
|
||||||
switch (self.dtype()) {
|
return switch (self.dtype()) {
|
||||||
inline else => |dt| {
|
inline else => |dt| {
|
||||||
const values = self.items(dt.toZigType());
|
const values = self.items(dt.toZigType());
|
||||||
const n = @min(values.len, 1024);
|
// Write first rows
|
||||||
try writer.print("{any},\n", .{values[0..n]});
|
const num_cols: u32 = 12;
|
||||||
|
const n: u64 = @intCast(self.dim(0));
|
||||||
|
if (n <= num_cols) {
|
||||||
|
try writer.print("{any},\n", .{values[0..n]});
|
||||||
|
} else {
|
||||||
|
const half = @divExact(num_cols, 2);
|
||||||
|
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
};
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
try writer.writeByteNTimes(' ', indent_level);
|
||||||
_ = try writer.write("{\n");
|
_ = try writer.write("{\n");
|
||||||
|
|||||||
@ -164,7 +164,8 @@ pub const Tokenizer = struct {
|
|||||||
// Step by step visualization of the progress.
|
// Step by step visualization of the progress.
|
||||||
if (options.debug) {
|
if (options.debug) {
|
||||||
var _debug_buf: [256]u8 = undefined;
|
var _debug_buf: [256]u8 = undefined;
|
||||||
var debug_progress: std.ArrayList(u8) = .{ .items = _debug_buf[0..0], .capacity = _debug_buf.len, .allocator = undefined };
|
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
||||||
|
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
|
||||||
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
||||||
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
||||||
}
|
}
|
||||||
@ -305,7 +306,7 @@ pub const Tokenizer = struct {
|
|||||||
// Convert `▁` to a regular space.
|
// Convert `▁` to a regular space.
|
||||||
if (escaped) |escspc| {
|
if (escaped) |escspc| {
|
||||||
// we modify piece inside the loop, so we can use it in the condition
|
// we modify piece inside the loop, so we can use it in the condition
|
||||||
while (std.mem.startsWith(u8, piece, escaped.?)) {
|
while (std.mem.startsWith(u8, piece, escspc)) {
|
||||||
piece = piece[escspc.len..];
|
piece = piece[escspc.len..];
|
||||||
// don't output a space at beginning of text.
|
// don't output a space at beginning of text.
|
||||||
if (output.items.len > 0) try output.append(' ');
|
if (output.items.len > 0) try output.append(' ');
|
||||||
@ -625,6 +626,13 @@ pub const Normalizer = struct {
|
|||||||
.lower_case_ascii = false,
|
.lower_case_ascii = false,
|
||||||
.split_on_punct_ascii = false,
|
.split_on_punct_ascii = false,
|
||||||
}, sentencepiece_space),
|
}, sentencepiece_space),
|
||||||
|
.llama3 => init(.{
|
||||||
|
.remove_extra_whitespaces = true,
|
||||||
|
.add_dummy_prefix = false,
|
||||||
|
.add_dummy_suffix = false,
|
||||||
|
.lower_case_ascii = false,
|
||||||
|
.split_on_punct_ascii = false,
|
||||||
|
}, null),
|
||||||
.gpt2 => init(.{
|
.gpt2 => init(.{
|
||||||
.remove_extra_whitespaces = true,
|
.remove_extra_whitespaces = true,
|
||||||
.add_dummy_prefix = true,
|
.add_dummy_prefix = true,
|
||||||
@ -645,7 +653,7 @@ pub const Normalizer = struct {
|
|||||||
} };
|
} };
|
||||||
|
|
||||||
// Normalizer config can be a single normalizer, or a sequence of normalizers.
|
// Normalizer config can be a single normalizer, or a sequence of normalizers.
|
||||||
const maybe_steps = object_get(config, .array, "normalizers");
|
const maybe_steps = objectGet(config, .array, "normalizers");
|
||||||
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
|
const steps = if (maybe_steps) |st| st.items else &.{std.json.Value{ .object = config }};
|
||||||
|
|
||||||
for (steps) |step_val| {
|
for (steps) |step_val| {
|
||||||
@ -654,7 +662,7 @@ pub const Normalizer = struct {
|
|||||||
}
|
}
|
||||||
const step = step_val.object;
|
const step = step_val.object;
|
||||||
|
|
||||||
const step_type = object_get(step, .string, "type") orelse {
|
const step_type = objectGet(step, .string, "type") orelse {
|
||||||
return error.InvalidNormalizerJson;
|
return error.InvalidNormalizerJson;
|
||||||
};
|
};
|
||||||
if (std.mem.eql(u8, "Prepend", step_type)) {
|
if (std.mem.eql(u8, "Prepend", step_type)) {
|
||||||
@ -662,15 +670,15 @@ pub const Normalizer = struct {
|
|||||||
} else if (std.mem.eql(u8, "Append", step_type)) {
|
} else if (std.mem.eql(u8, "Append", step_type)) {
|
||||||
normalizer.flags.add_dummy_suffix = true;
|
normalizer.flags.add_dummy_suffix = true;
|
||||||
} else if (std.mem.eql(u8, "Replace", step_type)) {
|
} else if (std.mem.eql(u8, "Replace", step_type)) {
|
||||||
const pattern = object_get(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
const pattern = objectGet(step, .object, "pattern") orelse return error.InvalidNormalizerJson;
|
||||||
const str_pattern = object_get(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
const str_pattern = objectGet(pattern, .string, "String") orelse return error.InvalidNormalizerJson;
|
||||||
|
|
||||||
if (std.mem.eql(u8, str_pattern, " ")) {
|
if (std.mem.eql(u8, str_pattern, " ")) {
|
||||||
normalizer._whitespace.appendSliceAssumeCapacity(
|
normalizer._whitespace.appendSliceAssumeCapacity(
|
||||||
object_get(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
objectGet(step, .string, "content") orelse return error.InvalidNormalizerJson,
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, object_get(pattern, .string, "content") orelse "" });
|
log.warn("Normalizer Replace pattern not supported: '{s}' -> '{s}'", .{ str_pattern, objectGet(pattern, .string, "content") orelse "" });
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.warn("Unknown normalizer type: {s}", .{step_type});
|
log.warn("Unknown normalizer type: {s}", .{step_type});
|
||||||
@ -721,6 +729,7 @@ pub const Normalizer = struct {
|
|||||||
pub const KnownImplementation = enum(u8) {
|
pub const KnownImplementation = enum(u8) {
|
||||||
sentencepiece,
|
sentencepiece,
|
||||||
gpt2,
|
gpt2,
|
||||||
|
llama3,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn isPunct(unicode_char: []const u8) bool {
|
fn isPunct(unicode_char: []const u8) bool {
|
||||||
@ -920,15 +929,15 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
else => return error.InvalidFormat,
|
else => return error.InvalidFormat,
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: remove all panics
|
const model = objectGet(main_object, .object, "model") orelse return error.InvalidFormat;
|
||||||
const added_tokens = main_object.get("added_tokens").?.array;
|
const vocab = objectGet(model, .object, "vocab") orelse return error.InvalidFormat;
|
||||||
const vocab = main_object.get("model").?.object.get("vocab").?.object;
|
const added_tokens = if (objectGet(main_object, .array, "added_tokens")) |added| added.items else &.{};
|
||||||
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.items.len);
|
const vocab_size: u32 = @intCast(vocab.count() + added_tokens.len);
|
||||||
|
|
||||||
const normalizer = if (object_get(main_object, .object, "normalizer")) |normalizer_config|
|
const normalizer = if (objectGet(main_object, .object, "normalizer")) |normalizer_config|
|
||||||
try Normalizer.fromHfJson(normalizer_config)
|
try Normalizer.fromHfJson(normalizer_config)
|
||||||
else
|
else
|
||||||
Normalizer.wellKnown(.gpt2);
|
Normalizer.wellKnown(.llama3);
|
||||||
|
|
||||||
// delay init of special tokens.
|
// delay init of special tokens.
|
||||||
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, 256, normalizer, undefined, true);
|
||||||
@ -981,9 +990,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
// More tokens, typically added during fine tuning of the model.
|
// More tokens, typically added during fine tuning of the model.
|
||||||
for (added_tokens.items) |token_obj| {
|
for (added_tokens) |token_obj| {
|
||||||
const v = token_obj.object.get("content").?.string;
|
if (token_obj != .object) return error.InvalidFormat;
|
||||||
const id: u32 = @intCast(token_obj.object.get("id").?.integer);
|
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
|
||||||
|
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
|
||||||
const token = try if (is_gpt2_vocab)
|
const token = try if (is_gpt2_vocab)
|
||||||
gpt2_decoder.decode(&all_tokens, v)
|
gpt2_decoder.decode(&all_tokens, v)
|
||||||
else
|
else
|
||||||
@ -994,16 +1004,30 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
// We won't add more tokens here, let release.
|
// We won't add more tokens here, let release.
|
||||||
all_tokens.shrinkAndFree(all_tokens.items.len);
|
all_tokens.shrinkAndFree(all_tokens.items.len);
|
||||||
|
|
||||||
|
var unk = tokenizer.lookup("<unk>");
|
||||||
|
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
|
||||||
|
unk = @intCast(unk_tok);
|
||||||
|
}
|
||||||
|
|
||||||
tokenizer.special_tokens = .{
|
tokenizer.special_tokens = .{
|
||||||
|
// TODO allow users to specify special tokens or read them from a tokenizer_config.json file
|
||||||
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
.bos = tokenizer.lookup("<s>") orelse tokenizer.lookup("<|begin_of_text|>") orelse @panic("bos token not found !"),
|
||||||
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
|
.eos = tokenizer.lookup("</s>") orelse tokenizer.lookup("<|end_of_text|>") orelse @panic("eos token not found !"),
|
||||||
.unk = tokenizer.lookup("<unk>") orelse std.math.maxInt(u32),
|
.unk = unk orelse std.math.maxInt(u32),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (main_object.get("model").?.object.get("byte_fallback")) |byte_fallback| {
|
const byte_fallback = objectGet(model, .bool, "byte_fallback") orelse false;
|
||||||
if (byte_fallback == .bool and byte_fallback.bool) {
|
if (!byte_fallback and unk == null) {
|
||||||
try tokenizer.rewriteByteFallbackTokens();
|
// GPT2 tokenizer have byte fallback already encoded in the model,
|
||||||
|
// but the json generally don't have the field set.
|
||||||
|
// We can detect it though because they don't specify an unknown token.
|
||||||
|
if (is_gpt2_vocab) {
|
||||||
|
tokenizer.byte_fallback = true;
|
||||||
|
} else {
|
||||||
|
log.warn("The given tokenizer can't handle unknown token: no unknown token was set, and byte_fallback is disabled too ! The tokenizer will panic when facing unknown tokens.", .{});
|
||||||
}
|
}
|
||||||
|
} else if (byte_fallback) {
|
||||||
|
try tokenizer.rewriteByteFallbackTokens();
|
||||||
}
|
}
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
@ -1016,7 +1040,7 @@ fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the given entry in a json object only if it has the right type.
|
/// Returns the given entry in a json object only if it has the right type.
|
||||||
fn object_get(
|
fn objectGet(
|
||||||
object: std.json.ObjectMap,
|
object: std.json.ObjectMap,
|
||||||
comptime kind: std.meta.FieldEnum(std.json.Value),
|
comptime kind: std.meta.FieldEnum(std.json.Value),
|
||||||
key: []const u8,
|
key: []const u8,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user