66 lines
2.0 KiB
Zig
66 lines
2.0 KiB
Zig
const std = @import("std");
|
|
const log = std.log.scoped(.@"//zml/tokenizer");
|
|
|
|
const asynk = @import("async");
|
|
const stdx = @import("stdx");
|
|
const zml_tokenizer = @import("zml/tokenizer");
|
|
|
|
const Flags = struct {
|
|
tokenizer: []const u8,
|
|
prompt: []const u8,
|
|
expected: []const u8 = "",
|
|
verbose: bool = false,
|
|
};
|
|
|
|
pub fn main() !void {
|
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
|
}
|
|
|
|
pub fn asyncMain() !void {
|
|
var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{};
|
|
const allocator = gpa.allocator();
|
|
|
|
const args = stdx.flags.parseProcessArgs(Flags);
|
|
|
|
log.info("\tLoading tokenizer from {s}", .{args.tokenizer});
|
|
var tokenizer = try zml_tokenizer.Tokenizer.fromFile(allocator, args.tokenizer);
|
|
log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer});
|
|
defer tokenizer.deinit();
|
|
|
|
var encoder = try tokenizer.encoder();
|
|
defer encoder.deinit();
|
|
|
|
var decoder = try tokenizer.decoder();
|
|
defer decoder.deinit();
|
|
|
|
const prompt_tok = try encoder.encode(args.prompt);
|
|
|
|
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
|
|
|
|
var errors: u8 = 0;
|
|
{
|
|
const reconstructed = try decoder.decode(prompt_tok);
|
|
if (!std.mem.eql(u8, args.prompt, reconstructed)) {
|
|
log.err("Reconstructed string from tokens doesn't match source: {s}", .{reconstructed});
|
|
errors += 1;
|
|
}
|
|
}
|
|
|
|
if (args.expected.len > 0) {
|
|
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
|
|
var it = std.mem.splitSequence(u8, args.expected, ",");
|
|
while (it.next()) |int_token| {
|
|
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
|
try expected.append(tok);
|
|
}
|
|
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
|
log.err("Doesn't match expected: {d}", .{expected.items});
|
|
errors += 1;
|
|
}
|
|
}
|
|
|
|
if (errors == 0) log.info("All good !", .{});
|
|
|
|
std.process.exit(errors);
|
|
}
|