Radix/zml/tokenizer/sentencepiece/sentencepiece.zig

192 lines
6.6 KiB
Zig
Raw Permalink Normal View History

const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
const stdx = @import("stdx");
const StringToTokenRatio = 3;
pub const Error = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
fn assertOk(code: c.sentencepiece_util_StatusCode) Error!void {
return switch (code) {
c.sentencepiece_util_StatusCode_kOk => {},
c.sentencepiece_util_StatusCode_kCancelled => Error.Cancelled,
c.sentencepiece_util_StatusCode_kUnknown => Error.Unknown,
c.sentencepiece_util_StatusCode_kInvalidArgument => Error.InvalidArgument,
c.sentencepiece_util_StatusCode_kDeadlineExceeded => Error.DeadlineExceeded,
c.sentencepiece_util_StatusCode_kNotFound => Error.NotFound,
c.sentencepiece_util_StatusCode_kAlreadyExists => Error.AlreadyExists,
c.sentencepiece_util_StatusCode_kPermissionDenied => Error.PermissionDenied,
c.sentencepiece_util_StatusCode_kResourceExhausted => Error.ResourceExhausted,
c.sentencepiece_util_StatusCode_kFailedPrecondition => Error.FailedPrecondition,
c.sentencepiece_util_StatusCode_kAborted => Error.Aborted,
c.sentencepiece_util_StatusCode_kOutOfRange => Error.OutOfRange,
c.sentencepiece_util_StatusCode_kUnimplemented => Error.Unimplemented,
c.sentencepiece_util_StatusCode_kInternal => Error.Internal,
c.sentencepiece_util_StatusCode_kUnavailable => Error.Unavailable,
c.sentencepiece_util_StatusCode_kDataLoss => Error.DataLoss,
c.sentencepiece_util_StatusCode_kUnauthenticated => Error.Unauthenticated,
else => unreachable,
};
}
pub const Encoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
fn init(inner: *SentencePieceProcessor) Encoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
};
}
pub fn deinit(self: *Encoder) void {
c.std_vector_int_delete(self.vec);
}
pub fn reset(self: *Encoder) void {
c.std_vector_int_clear(self.vec);
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
c.std_vector_int_reserve(self.vec, input.len / StringToTokenRatio);
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
return self.ids();
}
pub fn ids(self: *const Encoder) []const u32 {
return ffi.ZigSlice.to(u32, .{
.ptr = c.std_vector_int_data(self.vec),
.len = c.std_vector_int_size(self.vec),
});
}
};
pub const Decoder = struct {
const StringBufferSize = 64;
const StringBuffer = stdx.BoundedArray(u8, StringBufferSize);
const TokenIdsBufferSize = 4;
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
str: *c.std_string,
last_string: StringBuffer = .{ .len = 0 },
fn init(inner: *SentencePieceProcessor) !Decoder {
const vec = try (c.std_vector_int_new() orelse std.mem.Allocator.Error.OutOfMemory);
c.std_vector_int_reserve(vec, TokenIdsBufferSize);
errdefer c.std_vector_int_delete(vec);
const str = try (c.std_string_new() orelse std.mem.Allocator.Error.OutOfMemory);
c.std_string_reserve(str, StringBufferSize);
errdefer c.std_string_delete(str);
return .{
.inner = inner,
.vec = vec,
.str = str,
};
}
pub fn deinit(self: *Decoder) void {
c.std_vector_int_delete(self.vec);
c.std_string_delete(self.str);
}
pub fn reset(self: *Decoder) void {
c.std_vector_int_clear(self.vec);
c.std_string_clear(self.str);
}
pub fn decode(self: *Decoder, ids_: []const u32) ![]const u8 {
c.std_vector_int_reserve(self.vec, ids_.len);
c.std_string_reserve(self.str, ids_.len * StringToTokenRatio);
for (ids_) |id| {
c.std_vector_int_push_back(self.vec, @intCast(id));
}
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
const res = c.std_string_data(self.str);
return ffi.ZigSlice.to(u8, res);
}
fn ids(self: *const Decoder) []u32 {
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
return ptr[0..c.std_vector_int_size(self.vec)];
}
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
const current_ids = self.ids();
if (current_ids.len >= c.std_vector_int_capacity(self.vec)) {
std.mem.copyForwards(u32, current_ids[0 .. current_ids.len - 1], current_ids[1..]);
current_ids[current_ids.len - 1] = token_id;
} else {
c.std_vector_int_push_back(self.vec, @intCast(token_id));
}
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
const new_string = self.string();
if (self.last_string.len == 0) {
self.last_string = try StringBuffer.fromSlice(new_string);
return new_string;
}
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
var it = view.iterator();
while (it.nextCodepointSlice()) |cp| {
const start = it.i - cp.len;
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
const chunk = new_string[self.last_string.len - start ..];
self.last_string = try StringBuffer.fromSlice(new_string);
return chunk;
}
}
return &.{};
}
};
pub const SentencePieceProcessor = opaque {
pub fn fromFile(model: []const u8) !*SentencePieceProcessor {
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
errdefer sp.deinit();
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model)));
return sp;
}
pub fn deinit(self: *SentencePieceProcessor) void {
c.SentencePieceProcessor_delete(@ptrCast(self));
}
pub fn encoder(self: *SentencePieceProcessor) !Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *SentencePieceProcessor) !Decoder {
return try Decoder.init(self);
}
pub fn tokenToId(self: *SentencePieceProcessor, token: []const u8) u32 {
return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token)));
}
};