Radix/zml/tokenizer/sentencepiece/main.zig

290 lines
10 KiB
Zig

const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
pub const SentencePieceError = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
pub const DecoderStream = struct {
const TokensSize = 4;
const StringSize = 128;
decoder: SentencePieceProcessor.Decoder,
buffer: [StringSize]u8 = undefined,
last_tokens: []u8 = &.{},
pub fn init(decoder: SentencePieceProcessor.Decoder) DecoderStream {
var ret: DecoderStream = .{
.decoder = decoder,
};
ret.decoder.reserve_tokens(TokensSize);
ret.decoder.reserve_string(StringSize);
return ret;
}
pub fn next(self: *DecoderStream, next_token: u32) !?[]const u8 {
if (self.decoder.tokens().len >= TokensSize) {
const tokens = self.decoder.tokens();
inline for (0..TokensSize - 1) |i| {
tokens[i] = tokens[i + 1];
}
tokens[TokensSize - 1] = next_token;
} else {
self.decoder.append(next_token);
}
const new_tokens = try self.decoder.decode();
if (self.last_tokens.len == 0) {
self.last_tokens = self.buffer[0..new_tokens.len];
@memcpy(self.last_tokens, new_tokens);
return new_tokens;
}
for (1..self.last_tokens.len) |i| {
if (std.mem.startsWith(u8, new_tokens, self.last_tokens[i..])) {
const toks = new_tokens[self.last_tokens.len - i ..];
self.last_tokens = self.buffer[0..new_tokens.len];
@memcpy(self.last_tokens, new_tokens);
return toks;
}
}
return null;
}
};
pub const SentencePieceProcessor = opaque {
pub const Encoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
fn init(inner: *SentencePieceProcessor) Encoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
};
}
pub fn deinit(self: *Encoder) void {
c.std_vector_int_delete(self.vec);
}
pub fn reserve(self: *Encoder, size: usize) void {
c.std_vector_int_reserve(self.vec, size);
}
pub fn reset(self: *Encoder) void {
c.std_vector_int_clear(self.vec);
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
return ffi.ZigSlice.to(u32, .{
.ptr = c.std_vector_int_data(self.vec),
.len = c.std_vector_int_size(self.vec),
});
}
};
pub const Decoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
str: *c.std_string,
fn init(inner: *SentencePieceProcessor) Decoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
.str = c.std_string_new() orelse unreachable,
};
}
pub fn append(self: *Decoder, token: u32) void {
c.std_vector_int_push_back(self.vec, @intCast(token));
}
pub fn deinit(self: *Decoder) void {
c.std_vector_int_delete(self.vec);
c.std_string_delete(self.str);
}
pub fn reserve_tokens(self: *Decoder, size: usize) void {
c.std_vector_int_reserve(self.vec, size);
}
pub fn reserve_string(self: *Decoder, size: usize) void {
c.std_string_reserve(self.str, size);
}
pub fn reset(self: *Decoder) void {
c.std_vector_int_clear(self.vec);
c.std_string_clear(self.str);
}
pub fn decode(self: *Decoder) ![]const u8 {
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
const res = c.std_string_data(self.str);
return ffi.ZigSlice.to(u8, res);
}
pub fn tokens(self: *const Decoder) []u32 {
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
return ptr[0..c.std_vector_int_size(self.vec)];
}
};
fn assertOk(code: c.sentencepiece_util_StatusCode) SentencePieceError!void {
return switch (code) {
c.sentencepiece_util_StatusCode_kOk => {},
c.sentencepiece_util_StatusCode_kCancelled => error.Cancelled,
c.sentencepiece_util_StatusCode_kUnknown => error.Unknown,
c.sentencepiece_util_StatusCode_kInvalidArgument => error.InvalidArgument,
c.sentencepiece_util_StatusCode_kDeadlineExceeded => error.DeadlineExceeded,
c.sentencepiece_util_StatusCode_kNotFound => error.NotFound,
c.sentencepiece_util_StatusCode_kAlreadyExists => error.AlreadyExists,
c.sentencepiece_util_StatusCode_kPermissionDenied => error.PermissionDenied,
c.sentencepiece_util_StatusCode_kResourceExhausted => error.ResourceExhausted,
c.sentencepiece_util_StatusCode_kFailedPrecondition => error.FailedPrecondition,
c.sentencepiece_util_StatusCode_kAborted => error.Aborted,
c.sentencepiece_util_StatusCode_kOutOfRange => error.OutOfRange,
c.sentencepiece_util_StatusCode_kUnimplemented => error.Unimplemented,
c.sentencepiece_util_StatusCode_kInternal => error.Internal,
c.sentencepiece_util_StatusCode_kUnavailable => error.Unavailable,
c.sentencepiece_util_StatusCode_kDataLoss => error.DataLoss,
c.sentencepiece_util_StatusCode_kUnauthenticated => error.Unauthenticated,
else => unreachable,
};
}
pub fn load(model: []const u8) !*SentencePieceProcessor {
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
errdefer sp.deinit();
try sp.load_from(model);
return sp;
}
pub fn deinit(self: *SentencePieceProcessor) void {
c.SentencePieceProcessor_delete(@ptrCast(self));
}
fn load_from(self: *SentencePieceProcessor, model: []const u8) !void {
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(self), ffi.ZigSlice.from(model)));
}
pub fn encoder(self: *SentencePieceProcessor) Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *SentencePieceProcessor) Decoder {
return Decoder.init(self);
}
};
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
var result: [std.fs.max_path_bytes:0]u8 = undefined;
@memcpy(result[0..path.len], path);
result[path.len] = 0;
return result;
}
pub fn main() !void {
const sp = try SentencePieceProcessor.load("/Users/steeve/Downloads/poolside.sp.pb");
defer sp.deinit();
std.debug.print("Loaded model\n", .{});
var encoder = sp.encoder();
defer encoder.deinit();
var decoder = sp.decoder();
defer decoder.deinit();
const ss = @embedFile("main.zig");
// \\String class
// \\Strings are objects that represent sequences of characters.
// \\
// \\The standard string class provides support for such objects with an interface similar to that of a standard container of bytes, but adding features specifically designed to operate with strings of single-byte characters.
// \\
// \\The string class is an instantiation of the basic_string class template that uses char (i.e., bytes) as its character type, with its default char_traits and allocator types (see basic_string for more info on the template).
// \\
// \\Note that this class handles bytes independently of the encoding used: If used to handle sequences of multi-byte or variable-length characters (such as UTF-8), all members of this class (such as length or size), as well as its iterators, will still operate in terms of bytes (not actual encoded characters).
// \\
// ;
const tokens = try encoder.encode(ss);
// const ss2 = 128;
// var buf = [_]u8{0} ** ss2;
// // _ = buf; // autofix
// var last_tokens: []u8 = &.{};
// // _ = last_tokens; // autofix
// decoder.reserve_tokens(4);
// decoder.reserve_string(128);
var stream = DecoderStream.init(decoder);
var start = try std.time.Timer.start();
for (tokens) |token| {
if (try stream.next(token)) |chunk| {
// std.debug.print("{s}", .{chunk});
std.debug.print("{d}us - {s}\n", .{ start.lap() / std.time.ns_per_us, chunk });
}
}
// var start = try std.time.Timer.start();
// var it = std.mem.window(u32, tokens, 3, 1);
// while (it.next()) |slice| {
// if (decoder.tokens().len >= 4) {
// const kept_tokens = decoder.tokens()[1..];
// std.mem.copyForwards(u32, decoder.tokens()[0..kept_tokens.len], kept_tokens);
// kept_tokens[kept_tokens.len - 1] = slice[2];
// } else {
// for (slice) |token| {
// decoder.append(token);
// }
// }
// const new_tokens = try decoder.decode();
// for (0..ss2) |i| {
// if (std.mem.startsWith(u8, new_tokens, last_tokens[i..])) {
// const toks = new_tokens[last_tokens.len - i..];
// // std.debug.print("{s}", .{toks});
// if (toks.len == 0) {
// // std.debug.print("WESH\n", .{});
// }
// break;
// }
// }
// last_tokens = buf[0..new_tokens.len];
// @memcpy(last_tokens, new_tokens);
// std.debug.print("{d}us\n", .{start.lap() / std.time.ns_per_us});
// }
// for (tokens) |token| {
// decoder.append(token);
// }
// const decoded = try decoder.decode();
// std.debug.print("Decoded: {s}\n", .{decoded});
// const model = "/Users/steeve/Downloads/poolside.sp.pb";
// c.SentencePieceProcessor_LoadOrDie(sp, c.zig_slice{ .ptr = model.ptr, .len = model.len });
// const piece = c.SentencePieceProcessor_IdToPiece(sp, 10999);
// std.debug.print("{s}\n", .{piece.ptr[0..piece.len]});
}