Radix/zml/aio/sentencepiece.zig

90 lines
3.5 KiB
Zig
Raw Normal View History

const std = @import("std");
const asynk = @import("async");
const zml = @import("../zml.zig");
const sentencepiece_proto = @import("//sentencepiece:model_proto");
const Normalizer = zml.tokenizer.Normalizer;
const Tokenizer = zml.tokenizer.Tokenizer;
pub fn loadTokenizerFromPath(allocator: std.mem.Allocator, path: []const u8) !Tokenizer {
const file = try asynk.File.open(path, .{});
defer file.close() catch unreachable;
return loadTokenizerFromFile(allocator, file);
}
pub fn loadTokenizerFromFile(allocator: std.mem.Allocator, file: asynk.File) !Tokenizer {
const reader = file.reader();
const input = try reader.readAllAlloc(allocator, 16 * 1024 * 1024);
defer allocator.free(input);
var proto_arena = std.heap.ArenaAllocator.init(allocator);
defer proto_arena.deinit();
const model = try sentencepiece_proto.ModelProto.decode(input, proto_arena.allocator());
// no deinit, memory will be freed by the proto_arena
return loadTokenizerFromModelProto(allocator, model);
}
pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentencepiece_proto.ModelProto) !Tokenizer {
std.debug.assert(model.trainer_spec.?.model_type.? == .BPE);
const special_tokens: Tokenizer.SpecialTokens = .{
.unk = @intCast(model.trainer_spec.?.unk_id.?),
.bos = @intCast(model.trainer_spec.?.bos_id.?),
.eos = @intCast(model.trainer_spec.?.eos_id.?),
.pad = parseTokenId(model.trainer_spec.?.pad_id),
};
var tokenizer = try Tokenizer.init(
allocator,
@intCast(model.pieces.items.len),
@intCast(model.trainer_spec.?.max_sentencepiece_length.?),
normalizerFromSpec(model.normalizer_spec.?),
special_tokens,
true,
);
errdefer tokenizer.deinit();
for (model.pieces.items) |*piece| {
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
}
const byte_fallback = model.trainer_spec.?.byte_fallback orelse false;
if (byte_fallback) {
try tokenizer.rewriteByteFallbackTokens();
}
return tokenizer;
}
fn parseTokenId(id: ?i32) u32 {
if (id) |idx| {
if (idx > 0) return @intCast(idx);
}
return std.math.maxInt(u32);
}
pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
std.log.info("NormalizerSpec: {}", .{spec});
if (spec.normalization_rule_tsv) |rule_tsv| {
if (!rule_tsv.isEmpty()) {
std.debug.panic("SentencePiece model with normalization rules not supported: model.normalizer_spec.normalization_rule_tsv: {s}", .{spec.normalization_rule_tsv.?.getSlice()});
}
}
if (!std.mem.eql(u8, spec.name.?.getSlice(), "identity")) std.debug.panic("Normalizer only supports NormalizerSpec with name \"identity\", got \"{s}\"", .{spec.name.?.getSlice()});
if (!spec.escape_whitespaces.?) std.debug.panic("Normalizer only supports NormalizerSpec with \"escape_whitespaces\" flag set", .{});
if (spec.remove_extra_whitespaces) |_| {} else std.debug.panic("Normalizer only supports NormalizerSpec with \"remove_extra_whitespaces\" flag set", .{});
return Normalizer.init(
.{
.remove_extra_whitespaces = spec.remove_extra_whitespaces orelse false,
.add_dummy_prefix = spec.add_dummy_prefix orelse false,
.add_dummy_suffix = false,
.lower_case_ascii = false,
.split_on_punct_ascii = false,
},
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
);
}