Radix/zml/tokenizer/hftokenizers/hftokenizers.zig

117 lines
3.6 KiB
Zig
Raw Normal View History

const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
const stdx = @import("stdx");
pub const Encoder = struct {
inner: *HFTokenizer,
current_ids: ?[]const u32 = null,
fn init(inner: *HFTokenizer) Encoder {
return .{ .inner = inner };
}
pub fn reset(self: *Encoder) void {
if (self.current_ids) |current_ids_| {
c.hftokenizers_tokens_drop(ffi.ZigSlice.from(current_ids_));
self.current_ids = null;
}
}
pub fn deinit(self: *Encoder) void {
self.reset();
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
self.reset();
self.current_ids = ffi.ZigSlice.to(u32, c.hftokenizers_encode(@ptrCast(self.inner), ffi.ZigSlice.from(input)));
return self.ids();
}
pub fn ids(self: *const Encoder) []const u32 {
return self.current_ids orelse &.{};
}
};
pub const Decoder = struct {
const StringBuffer = stdx.BoundedArray(u8, 128);
const TokensIdsBuffer = stdx.BoundedArray(u32, 4);
inner: *HFTokenizer,
current_string: ?[]const u8 = null,
last_string: StringBuffer = .{ .len = 0 },
last_token_ids: TokensIdsBuffer = .{ .len = 0 },
fn init(inner: *HFTokenizer) Decoder {
return .{ .inner = inner };
}
pub fn deinit(self: *Decoder) void {
self.reset();
}
pub fn reset(self: *Decoder) void {
if (self.current_string) |current_string_| {
c.hftokenizers_str_drop(ffi.ZigSlice.from(current_string_));
self.current_string = null;
}
}
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
self.reset();
self.current_string = ffi.ZigSlice.to(u8, c.hftokenizers_decode(@ptrCast(self.inner), ffi.ZigSlice.from(ids)));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
return self.current_string orelse &.{};
}
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
if (self.last_token_ids.len >= self.last_token_ids.capacity()) {
_ = self.last_token_ids.orderedRemove(0);
}
self.last_token_ids.appendAssumeCapacity(token_id);
const new_string = try self.decode(self.last_token_ids.constSlice());
if (self.last_string.len == 0) {
self.last_string = try StringBuffer.fromSlice(new_string);
return new_string;
}
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
var it = view.iterator();
while (it.nextCodepointSlice()) |cp| {
const start = it.i - cp.len;
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
const chunk = new_string[self.last_string.len - start ..];
self.last_string = try StringBuffer.fromSlice(new_string);
return chunk;
}
}
return null;
}
};
pub const HFTokenizer = opaque {
pub fn fromFile(model: []const u8) !*HFTokenizer {
return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model)));
}
pub fn deinit(self: *HFTokenizer) void {
return c.hftokenizers_drop(@ptrCast(self));
}
pub fn encoder(self: *HFTokenizer) !Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *HFTokenizer) !Decoder {
return Decoder.init(self);
}
pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 {
const id = c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
return if (id == std.math.maxInt(u32)) null else id;
}
};