Integrate TinyLlama support, restore the homemade tokenizer, and align Zig API naming across stdx and zml tokenizer modules.
This commit is contained in:
parent
b67685b941
commit
d4db5ccc6b
@ -4,6 +4,7 @@ zig_library(
|
|||||||
name = "stdx",
|
name = "stdx",
|
||||||
srcs = [
|
srcs = [
|
||||||
"debug.zig",
|
"debug.zig",
|
||||||
|
"flags.zig",
|
||||||
"io.zig",
|
"io.zig",
|
||||||
"json.zig",
|
"json.zig",
|
||||||
"math.zig",
|
"math.zig",
|
||||||
|
|||||||
582
stdx/flags.zig
Normal file
582
stdx/flags.zig
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
//! From TigerBeetle, under Apache 2.0 attribution license.
|
||||||
|
//! https://github.com/tigerbeetle/tigerbeetle/blob/main/src/flags.zig TigerBeetle/
|
||||||
|
//!
|
||||||
|
//! The purpose of `flags` is to define standard behavior for parsing CLI arguments and provide
|
||||||
|
//! a specific parsing library, implementing this behavior.
|
||||||
|
//!
|
||||||
|
//! These are TigerBeetle CLI guidelines:
|
||||||
|
//!
|
||||||
|
//! - The main principle is robustness --- make operator errors harder to make.
|
||||||
|
//! - For production usage, avoid defaults.
|
||||||
|
//! - Thoroughly validate options.
|
||||||
|
//! - In particular, check that no options are repeated.
|
||||||
|
//! - Use only long options (`--addresses`).
|
||||||
|
//! - Exception: `-h/--help` is allowed.
|
||||||
|
//! - Use `--key=value` syntax for an option with an argument.
|
||||||
|
//! Don't use `--key value`, as that can be ambiguous (e.g., `--key --verbose`).
|
||||||
|
//! - Use subcommand syntax when appropriate.
|
||||||
|
//! - Use positional arguments when appropriate.
|
||||||
|
//!
|
||||||
|
//! Design choices for this particular `flags` library:
|
||||||
|
//!
|
||||||
|
//! - Be a 80% solution. Parsing arguments is a surprisingly vast topic: auto-generated help,
|
||||||
|
//! bash completions, typo correction. Rather than providing a definitive solution, `flags`
|
||||||
|
//! is just one possible option. It is ok to re-implement arg parsing in a different way, as long
|
||||||
|
//! as the CLI guidelines are observed.
|
||||||
|
//!
|
||||||
|
//! - No auto-generated help. Zig doesn't expose doc comments through `@typeInfo`, so its hard to
|
||||||
|
//! implement auto-help nicely. Additionally, fully hand-crafted `--help` message can be of
|
||||||
|
//! higher quality.
|
||||||
|
//!
|
||||||
|
//! - Fatal errors. It might be "cleaner" to use `try` to propagate the error to the caller, but
|
||||||
|
//! during early CLI parsing, it is much simpler to terminate the process directly and save the
|
||||||
|
//! caller the hassle of propagating errors. The `fatal` function is public, to allow the caller
|
||||||
|
//! to run additional validation or parsing using the same error reporting mechanism.
|
||||||
|
//!
|
||||||
|
//! - Concise DSL. Most cli parsing is done for ad-hoc tools like benchmarking, where the ability to
|
||||||
|
//! quickly add a new argument is valuable. As this is a 80% solution, production code may use
|
||||||
|
//! more verbose approach if it gives better UX.
|
||||||
|
//!
|
||||||
|
//! - Caller manages ArgsIterator. ArgsIterator owns the backing memory of the args, so we let the
|
||||||
|
//! caller to manage the lifetime. The caller should be skipping program name.
|
||||||
|
|
||||||
|
const std = @import("std");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
const assert = std.debug.assert;
|
||||||
|
|
||||||
|
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
||||||
|
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
|
||||||
|
const stderr = std.io.getStdErr().writer();
|
||||||
|
stderr.print("error: " ++ fmt_string ++ "\n", args) catch {};
|
||||||
|
std.posix.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse CLI arguments for subcommands specified as Zig `struct` or `union(enum)`:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// const CliArgs = union(enum) {
|
||||||
|
/// start: struct { addresses: []const u8, replica: u32 },
|
||||||
|
/// format: struct {
|
||||||
|
/// verbose: bool = false,
|
||||||
|
/// positional: struct {
|
||||||
|
/// path: []const u8,
|
||||||
|
/// }
|
||||||
|
/// },
|
||||||
|
///
|
||||||
|
/// pub const help =
|
||||||
|
/// \\ tigerbeetle start --addresses=<addresses> --replica=<replica>
|
||||||
|
/// \\ tigerbeetle format [--verbose] <path>
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// const cli_args = parse_commands(&args, CliArgs);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// `positional` field is treated specially, it designates positional arguments.
|
||||||
|
///
|
||||||
|
/// If `pub const help` declaration is present, it is used to implement `-h/--help` argument.
|
||||||
|
pub fn parse(args: *std.process.ArgIterator, comptime CliArgs: type) CliArgs {
|
||||||
|
assert(args.skip()); // Discard executable name.
|
||||||
|
|
||||||
|
return switch (@typeInfo(CliArgs)) {
|
||||||
|
.Union => parse_commands(args, CliArgs),
|
||||||
|
.Struct => parse_flags(args, CliArgs),
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse CLI arguments for current process.
|
||||||
|
/// See `stdx.flags.parse` documentation for more.
|
||||||
|
pub fn parseProcessArgs(comptime CliArgs: type) CliArgs {
|
||||||
|
var args = std.process.args();
|
||||||
|
return parse(&args, CliArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_commands(args: *std.process.ArgIterator, comptime Commands: type) Commands {
|
||||||
|
comptime assert(@typeInfo(Commands) == .Union);
|
||||||
|
comptime assert(std.meta.fields(Commands).len >= 2);
|
||||||
|
|
||||||
|
const first_arg = args.next() orelse fatal(
|
||||||
|
"subcommand required, expected {s}",
|
||||||
|
.{comptime fields_to_comma_list(Commands)},
|
||||||
|
);
|
||||||
|
|
||||||
|
// NB: help must be declared as *pub* const to be visible here.
|
||||||
|
if (@hasDecl(Commands, "help")) {
|
||||||
|
if (std.mem.eql(u8, first_arg, "-h") or std.mem.eql(u8, first_arg, "--help")) {
|
||||||
|
std.io.getStdOut().writeAll(Commands.help) catch std.posix.exit(1);
|
||||||
|
std.posix.exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline for (comptime std.meta.fields(Commands)) |field| {
|
||||||
|
comptime assert(std.mem.indexOf(u8, field.name, "_") == null);
|
||||||
|
if (std.mem.eql(u8, first_arg, field.name)) {
|
||||||
|
return @unionInit(Commands, field.name, parse_flags(args, field.type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fatal("unknown subcommand: '{s}'", .{first_arg});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
||||||
|
@setEvalBranchQuota(5_000);
|
||||||
|
|
||||||
|
if (Flags == void) {
|
||||||
|
if (args.next()) |arg| {
|
||||||
|
fatal("unexpected argument: '{s}'", .{arg});
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(@typeInfo(Flags) == .Struct);
|
||||||
|
|
||||||
|
comptime var fields: [std.meta.fields(Flags).len]std.builtin.Type.StructField = undefined;
|
||||||
|
comptime var field_count = 0;
|
||||||
|
|
||||||
|
comptime var positional_fields: []const std.builtin.Type.StructField = &.{};
|
||||||
|
|
||||||
|
comptime for (std.meta.fields(Flags)) |field| {
|
||||||
|
if (std.mem.eql(u8, field.name, "positional")) {
|
||||||
|
assert(@typeInfo(field.type) == .Struct);
|
||||||
|
positional_fields = std.meta.fields(field.type);
|
||||||
|
var optional_tail = false;
|
||||||
|
for (positional_fields) |positional_field| {
|
||||||
|
if (default_value(positional_field) == null) {
|
||||||
|
if (optional_tail) @panic("optional positional arguments must be last");
|
||||||
|
} else {
|
||||||
|
optional_tail = true;
|
||||||
|
}
|
||||||
|
switch (@typeInfo(positional_field.type)) {
|
||||||
|
.Optional => |optional| {
|
||||||
|
// optional flags should have a default
|
||||||
|
assert(default_value(positional_field) != null);
|
||||||
|
assert(default_value(positional_field).? == null);
|
||||||
|
assert_valid_value_type(optional.child);
|
||||||
|
},
|
||||||
|
else => {
|
||||||
|
assert_valid_value_type(positional_field.type);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fields[field_count] = field;
|
||||||
|
field_count += 1;
|
||||||
|
|
||||||
|
switch (@typeInfo(field.type)) {
|
||||||
|
.Bool => {
|
||||||
|
// boolean flags should have a default
|
||||||
|
assert(default_value(field) != null);
|
||||||
|
assert(default_value(field).? == false);
|
||||||
|
},
|
||||||
|
.Optional => |optional| {
|
||||||
|
// optional flags should have a default
|
||||||
|
assert(default_value(field) != null);
|
||||||
|
assert(default_value(field).? == null);
|
||||||
|
|
||||||
|
assert_valid_value_type(optional.child);
|
||||||
|
},
|
||||||
|
else => {
|
||||||
|
assert_valid_value_type(field.type);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var result: Flags = undefined;
|
||||||
|
// Would use std.enums.EnumFieldStruct(Flags, u32, 0) here but Flags is a Struct not an Enum.
|
||||||
|
var counts = comptime blk: {
|
||||||
|
var count_fields = std.meta.fields(Flags)[0..std.meta.fields(Flags).len].*;
|
||||||
|
for (&count_fields) |*field| {
|
||||||
|
field.type = u32;
|
||||||
|
field.alignment = @alignOf(u32);
|
||||||
|
field.default_value = @ptrCast(&@as(u32, 0));
|
||||||
|
}
|
||||||
|
break :blk @Type(.{ .Struct = .{
|
||||||
|
.layout = .auto,
|
||||||
|
.fields = &count_fields,
|
||||||
|
.decls = &.{},
|
||||||
|
.is_tuple = false,
|
||||||
|
} }){};
|
||||||
|
};
|
||||||
|
|
||||||
|
// When parsing arguments, we must consider longer arguments first, such that `--foo-bar=92` is
|
||||||
|
// not confused for a misspelled `--foo=92`. Using `std.sort` for comptime-only values does not
|
||||||
|
// work, so open-code insertion sort, and comptime assert order during the actual parsing.
|
||||||
|
comptime {
|
||||||
|
for (fields[0..field_count], 0..) |*field_right, i| {
|
||||||
|
for (fields[0..i]) |*field_left| {
|
||||||
|
if (field_left.name.len < field_right.name.len) {
|
||||||
|
std.mem.swap(std.builtin.Type.StructField, field_left, field_right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed_positional = false;
|
||||||
|
next_arg: while (args.next()) |arg| {
|
||||||
|
comptime var field_len_prev = std.math.maxInt(usize);
|
||||||
|
inline for (fields[0..field_count]) |field| {
|
||||||
|
const flag = comptime flag_name(field);
|
||||||
|
|
||||||
|
comptime assert(field_len_prev >= field.name.len);
|
||||||
|
field_len_prev = field.name.len;
|
||||||
|
if (std.mem.startsWith(u8, arg, flag)) {
|
||||||
|
if (parsed_positional) {
|
||||||
|
fatal("unexpected trailing option: '{s}'", .{arg});
|
||||||
|
}
|
||||||
|
|
||||||
|
@field(counts, field.name) += 1;
|
||||||
|
const flag_value = parse_flag(field.type, flag, arg);
|
||||||
|
@field(result, field.name) = flag_value;
|
||||||
|
continue :next_arg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (@hasField(Flags, "positional")) {
|
||||||
|
counts.positional += 1;
|
||||||
|
switch (counts.positional - 1) {
|
||||||
|
inline 0...positional_fields.len - 1 => |positional_index| {
|
||||||
|
const positional_field = positional_fields[positional_index];
|
||||||
|
const flag = comptime flag_name_positional(positional_field);
|
||||||
|
|
||||||
|
if (arg.len == 0) fatal("{s}: empty argument", .{flag});
|
||||||
|
// Prevent ambiguity between a flag and positional argument value. We could add
|
||||||
|
// support for bare ` -- ` as a disambiguation mechanism once we have a real
|
||||||
|
// use-case.
|
||||||
|
if (arg[0] == '-') fatal("unexpected argument: '{s}'", .{arg});
|
||||||
|
parsed_positional = true;
|
||||||
|
|
||||||
|
@field(result.positional, positional_field.name) =
|
||||||
|
parse_value(positional_field.type, flag, arg);
|
||||||
|
continue :next_arg;
|
||||||
|
},
|
||||||
|
else => {}, // Fall-through to the unexpected argument error.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fatal("unexpected argument: '{s}'", .{arg});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline for (fields[0..field_count]) |field| {
|
||||||
|
const flag = flag_name(field);
|
||||||
|
switch (@field(counts, field.name)) {
|
||||||
|
0 => if (default_value(field)) |default| {
|
||||||
|
@field(result, field.name) = default;
|
||||||
|
} else {
|
||||||
|
fatal("{s}: argument is required", .{flag});
|
||||||
|
},
|
||||||
|
1 => {},
|
||||||
|
else => fatal("{s}: duplicate argument", .{flag}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (@hasField(Flags, "positional")) {
|
||||||
|
assert(counts.positional <= positional_fields.len);
|
||||||
|
inline for (positional_fields, 0..) |positional_field, positional_index| {
|
||||||
|
if (positional_index >= counts.positional) {
|
||||||
|
const flag = comptime flag_name_positional(positional_field);
|
||||||
|
if (default_value(positional_field)) |default| {
|
||||||
|
@field(result.positional, positional_field.name) = default;
|
||||||
|
} else {
|
||||||
|
fatal("{s}: argument is required", .{flag});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_valid_value_type(comptime T: type) void {
|
||||||
|
comptime {
|
||||||
|
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .Int) return;
|
||||||
|
|
||||||
|
if (@typeInfo(T) == .Enum) {
|
||||||
|
const info = @typeInfo(T).Enum;
|
||||||
|
assert(info.is_exhaustive);
|
||||||
|
assert(info.fields.len >= 2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
@compileLog("unsupported type", T);
|
||||||
|
unreachable;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse, e.g., `--cluster=123` into `123` integer
|
||||||
|
fn parse_flag(comptime T: type, flag: []const u8, arg: [:0]const u8) T {
|
||||||
|
assert(flag[0] == '-' and flag[1] == '-');
|
||||||
|
|
||||||
|
if (T == bool) {
|
||||||
|
if (!std.mem.eql(u8, arg, flag)) {
|
||||||
|
fatal("{s}: argument does not require a value in '{s}'", .{ flag, arg });
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const value = parse_flag_split_value(flag, arg);
|
||||||
|
assert(value.len > 0);
|
||||||
|
return parse_value(T, flag, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Splits the value part from a `--arg=value` syntax.
|
||||||
|
fn parse_flag_split_value(flag: []const u8, arg: [:0]const u8) [:0]const u8 {
|
||||||
|
assert(flag[0] == '-' and flag[1] == '-');
|
||||||
|
assert(std.mem.startsWith(u8, arg, flag));
|
||||||
|
|
||||||
|
const value = arg[flag.len..];
|
||||||
|
if (value.len == 0) {
|
||||||
|
fatal("{s}: expected value separator '='", .{flag});
|
||||||
|
}
|
||||||
|
if (value[0] != '=') {
|
||||||
|
fatal(
|
||||||
|
"{s}: expected value separator '=', but found '{c}' in '{s}'",
|
||||||
|
.{ flag, value[0], arg },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (value.len == 1) fatal("{s}: argument requires a value", .{flag});
|
||||||
|
return value[1..];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||||
|
comptime assert(T != bool);
|
||||||
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
|
assert(value.len > 0);
|
||||||
|
|
||||||
|
const V = switch (@typeInfo(T)) {
|
||||||
|
.Optional => |optional| optional.child,
|
||||||
|
else => T,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (V == []const u8 or V == [:0]const u8) return value;
|
||||||
|
if (V == ByteSize) return parse_value_size(flag, value);
|
||||||
|
if (@typeInfo(V) == .Int) return parse_value_int(V, flag, value);
|
||||||
|
if (@typeInfo(V) == .Enum) return parse_value_enum(V, flag, value);
|
||||||
|
comptime unreachable;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_value_size(flag: []const u8, value: []const u8) ByteSize {
|
||||||
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
|
|
||||||
|
return ByteSize.parse(value) catch |err| {
|
||||||
|
switch (err) {
|
||||||
|
error.ParseOverflow => fatal(
|
||||||
|
"{s}: value exceeds 64-bit unsigned integer: '{s}'",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
error.InvalidSize => fatal(
|
||||||
|
"{s}: expected a size, but found '{s}'",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
error.InvalidUnit => fatal(
|
||||||
|
"{s}: invalid unit in size '{s}', (needed KiB, MiB, GiB or TiB)",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
error.BytesOverflow => fatal(
|
||||||
|
"{s}: size in bytes exceeds 64-bit unsigned integer: '{s}'",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ByteUnit = enum(u64) {
|
||||||
|
bytes = 1,
|
||||||
|
kib = 1024,
|
||||||
|
mib = 1024 * 1024,
|
||||||
|
gib = 1024 * 1024 * 1024,
|
||||||
|
tib = 1024 * 1024 * 1024 * 1024,
|
||||||
|
};
|
||||||
|
|
||||||
|
const ByteSizeParseError = error{
|
||||||
|
ParseOverflow,
|
||||||
|
InvalidSize,
|
||||||
|
InvalidUnit,
|
||||||
|
BytesOverflow,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ByteSize = struct {
|
||||||
|
value: u64,
|
||||||
|
unit: ByteUnit = .bytes,
|
||||||
|
|
||||||
|
fn parse(value: []const u8) ByteSizeParseError!ByteSize {
|
||||||
|
assert(value.len != 0);
|
||||||
|
|
||||||
|
const split: struct {
|
||||||
|
value_input: []const u8,
|
||||||
|
unit_input: []const u8,
|
||||||
|
} = split: for (0..value.len) |i| {
|
||||||
|
if (!std.ascii.isDigit(value[i])) {
|
||||||
|
break :split .{
|
||||||
|
.value_input = value[0..i],
|
||||||
|
.unit_input = value[i..],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break :split .{
|
||||||
|
.value_input = value,
|
||||||
|
.unit_input = "",
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const amount = std.fmt.parseUnsigned(u64, split.value_input, 10) catch |err| {
|
||||||
|
switch (err) {
|
||||||
|
error.Overflow => {
|
||||||
|
return ByteSizeParseError.ParseOverflow;
|
||||||
|
},
|
||||||
|
error.InvalidCharacter => {
|
||||||
|
// The only case this can happen is for the empty string
|
||||||
|
return ByteSizeParseError.InvalidSize;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const unit = if (split.unit_input.len > 0)
|
||||||
|
unit: inline for (comptime std.enums.values(ByteUnit)) |tag| {
|
||||||
|
if (std.ascii.eqlIgnoreCase(split.unit_input, @tagName(tag))) {
|
||||||
|
break :unit tag;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return ByteSizeParseError.InvalidUnit;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
ByteUnit.bytes;
|
||||||
|
|
||||||
|
_ = std.math.mul(u64, amount, @intFromEnum(unit)) catch {
|
||||||
|
return ByteSizeParseError.BytesOverflow;
|
||||||
|
};
|
||||||
|
|
||||||
|
return ByteSize{ .value = amount, .unit = unit };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bytes(size: *const ByteSize) u64 {
|
||||||
|
return std.math.mul(
|
||||||
|
u64,
|
||||||
|
size.value,
|
||||||
|
@intFromEnum(size.unit),
|
||||||
|
) catch unreachable;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn suffix(size: *const ByteSize) []const u8 {
|
||||||
|
return switch (size.unit) {
|
||||||
|
.bytes => "",
|
||||||
|
.kib => "KiB",
|
||||||
|
.mib => "MiB",
|
||||||
|
.gib => "GiB",
|
||||||
|
.tib => "TiB",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
test parse_value_size {
|
||||||
|
const kib = 1024;
|
||||||
|
const mib = kib * 1024;
|
||||||
|
const gib = mib * 1024;
|
||||||
|
const tib = gib * 1024;
|
||||||
|
|
||||||
|
const cases = .{
|
||||||
|
.{ 0, "0", 0, ByteUnit.bytes },
|
||||||
|
.{ 1, "1", 1, ByteUnit.bytes },
|
||||||
|
.{ 140737488355328, "140737488355328", 140737488355328, ByteUnit.bytes },
|
||||||
|
.{ 140737488355328, "128TiB", 128, ByteUnit.tib },
|
||||||
|
.{ 1 * tib, "1TiB", 1, ByteUnit.tib },
|
||||||
|
.{ 10 * tib, "10tib", 10, ByteUnit.tib },
|
||||||
|
.{ 1 * gib, "1GiB", 1, ByteUnit.gib },
|
||||||
|
.{ 10 * gib, "10gib", 10, ByteUnit.gib },
|
||||||
|
.{ 1 * mib, "1MiB", 1, ByteUnit.mib },
|
||||||
|
.{ 10 * mib, "10mib", 10, ByteUnit.mib },
|
||||||
|
.{ 1 * kib, "1KiB", 1, ByteUnit.kib },
|
||||||
|
.{ 10 * kib, "10kib", 10, ByteUnit.kib },
|
||||||
|
};
|
||||||
|
|
||||||
|
inline for (cases) |case| {
|
||||||
|
const bytes = case[0];
|
||||||
|
const input = case[1];
|
||||||
|
const unit_val = case[2];
|
||||||
|
const unit = case[3];
|
||||||
|
const got = parse_value_size("--size", input);
|
||||||
|
assert(bytes == got.bytes());
|
||||||
|
assert(unit_val == got.value);
|
||||||
|
assert(unit == got.unit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse string value into an integer, providing a nice error message for the user.
|
||||||
|
fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||||
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
|
|
||||||
|
return std.fmt.parseInt(T, value, 10) catch |err| {
|
||||||
|
switch (err) {
|
||||||
|
error.Overflow => fatal(
|
||||||
|
"{s}: value exceeds {d}-bit {s} integer: '{s}'",
|
||||||
|
.{ flag, @typeInfo(T).Int.bits, @tagName(@typeInfo(T).Int.signedness), value },
|
||||||
|
),
|
||||||
|
error.InvalidCharacter => fatal(
|
||||||
|
"{s}: expected an integer value, but found '{s}' (invalid digit)",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
||||||
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
|
comptime assert(@typeInfo(E).Enum.is_exhaustive);
|
||||||
|
|
||||||
|
return std.meta.stringToEnum(E, value) orelse fatal(
|
||||||
|
"{s}: expected one of {s}, but found '{s}'",
|
||||||
|
.{ flag, comptime fields_to_comma_list(E), value },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fields_to_comma_list(comptime E: type) []const u8 {
|
||||||
|
comptime {
|
||||||
|
const field_count = std.meta.fields(E).len;
|
||||||
|
assert(field_count >= 2);
|
||||||
|
|
||||||
|
var result: []const u8 = "";
|
||||||
|
for (std.meta.fields(E), 0..) |field, field_index| {
|
||||||
|
const separator = switch (field_index) {
|
||||||
|
0 => "",
|
||||||
|
else => ", ",
|
||||||
|
field_count - 1 => if (field_count == 2) " or " else ", or ",
|
||||||
|
};
|
||||||
|
result = result ++ separator ++ "'" ++ field.name ++ "'";
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flag_name(comptime field: std.builtin.Type.StructField) []const u8 {
|
||||||
|
// TODO(Zig): Cleanup when this is fixed after Zig 0.11.
|
||||||
|
// Without comptime blk, the compiler thinks the result is a runtime slice returning a UAF.
|
||||||
|
return comptime blk: {
|
||||||
|
assert(!std.mem.eql(u8, field.name, "positional"));
|
||||||
|
|
||||||
|
var result: []const u8 = "--";
|
||||||
|
var index = 0;
|
||||||
|
while (std.mem.indexOf(u8, field.name[index..], "_")) |i| {
|
||||||
|
result = result ++ field.name[index..][0..i] ++ "-";
|
||||||
|
index = index + i + 1;
|
||||||
|
}
|
||||||
|
result = result ++ field.name[index..];
|
||||||
|
break :blk result;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test flag_name {
|
||||||
|
const field = @typeInfo(struct { statsd: bool }).Struct.fields[0];
|
||||||
|
try std.testing.expectEqualStrings(flag_name(field), "--statsd");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flag_name_positional(comptime field: std.builtin.Type.StructField) []const u8 {
|
||||||
|
comptime assert(std.mem.indexOf(u8, field.name, "_") == null);
|
||||||
|
return "<" ++ field.name ++ ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is essentially `field.default_value`, but with a useful type instead of `?*anyopaque`.
|
||||||
|
pub fn default_value(comptime field: std.builtin.Type.StructField) ?field.type {
|
||||||
|
return if (field.default_value) |default_opaque|
|
||||||
|
@as(*const field.type, @ptrCast(@alignCast(default_opaque))).*
|
||||||
|
else
|
||||||
|
null;
|
||||||
|
}
|
||||||
@ -1,4 +1,5 @@
|
|||||||
pub const debug = @import("debug.zig");
|
pub const debug = @import("debug.zig");
|
||||||
|
pub const flags = @import("flags.zig");
|
||||||
pub const io = @import("io.zig");
|
pub const io = @import("io.zig");
|
||||||
pub const json = @import("json.zig");
|
pub const json = @import("json.zig");
|
||||||
pub const math = @import("math.zig");
|
pub const math = @import("math.zig");
|
||||||
|
|||||||
@ -10,6 +10,7 @@ const posix = @import("posix.zig");
|
|||||||
pub const gguf = @import("aio/gguf.zig");
|
pub const gguf = @import("aio/gguf.zig");
|
||||||
pub const nemo = @import("aio/nemo.zig");
|
pub const nemo = @import("aio/nemo.zig");
|
||||||
pub const safetensors = @import("aio/safetensors.zig");
|
pub const safetensors = @import("aio/safetensors.zig");
|
||||||
|
pub const tinyllama = @import("aio/tinyllama.zig");
|
||||||
pub const torch = @import("aio/torch.zig");
|
pub const torch = @import("aio/torch.zig");
|
||||||
pub const yaml = @import("aio/yaml.zig");
|
pub const yaml = @import("aio/yaml.zig");
|
||||||
|
|
||||||
@ -35,6 +36,8 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
|
|||||||
try gguf.open(allocator, model_path)
|
try gguf.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
||||||
try torch.open(allocator, model_path)
|
try torch.open(allocator, model_path)
|
||||||
|
else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
|
||||||
|
try tinyllama.open(allocator, model_path)
|
||||||
else {
|
else {
|
||||||
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
||||||
};
|
};
|
||||||
|
|||||||
@ -129,35 +129,3 @@ fn splitBuff(store: *zml.aio.BufferStore, comptime fmt: []const u8, sh: anytype,
|
|||||||
}
|
}
|
||||||
return off;
|
return off;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn loadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !zml.tokenizer.Tokenizer {
|
|
||||||
const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
|
||||||
defer tokenizer_file.close();
|
|
||||||
var tok_reader = std.io.bufferedReader(tokenizer_file.reader());
|
|
||||||
const r = tok_reader.reader();
|
|
||||||
|
|
||||||
const max_token_len = try r.readInt(u32, .little);
|
|
||||||
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
|
|
||||||
.unk = 0,
|
|
||||||
.bos = 1,
|
|
||||||
.eos = 2,
|
|
||||||
};
|
|
||||||
var tokenizer = try zml.tokenizer.Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true);
|
|
||||||
var i: u32 = 0;
|
|
||||||
while (readToken(&tokenizer, &r)) : (i += 1) {
|
|
||||||
// Pass
|
|
||||||
} else |_| {
|
|
||||||
if (i < vocab_size) {
|
|
||||||
zml.log.info("Read {d} words out of {?d}", .{ i, vocab_size });
|
|
||||||
}
|
|
||||||
tokenizer.vocab_size = i;
|
|
||||||
}
|
|
||||||
try tokenizer.rewriteByteFallbackTokens();
|
|
||||||
return tokenizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn readToken(tokenizer: *zml.tokenizer.Tokenizer, tok_reader: anytype) !void {
|
|
||||||
const score: f32 = @bitCast(try tok_reader.readInt(u32, .little));
|
|
||||||
const len: usize = @intCast(try tok_reader.readInt(u32, .little));
|
|
||||||
try tokenizer.readTokenInto(score, len, tok_reader);
|
|
||||||
}
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ zig_library(
|
|||||||
name = "tokenizer",
|
name = "tokenizer",
|
||||||
import_name = "zml/tokenizer",
|
import_name = "zml/tokenizer",
|
||||||
main = "tokenizer.zig",
|
main = "tokenizer.zig",
|
||||||
|
srcs = ["homemade.zig"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//async",
|
"//async",
|
||||||
@ -30,6 +31,8 @@ zig_cc_binary(
|
|||||||
main = "main.zig",
|
main = "main.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//stdx",
|
||||||
|
"//async",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -91,7 +91,7 @@ pub const Decoder = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const HFTokenizer = opaque {
|
pub const HFTokenizer = opaque {
|
||||||
pub fn from_file(model: []const u8) !*HFTokenizer {
|
pub fn fromFile(model: []const u8) !*HFTokenizer {
|
||||||
return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model)));
|
return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ pub const HFTokenizer = opaque {
|
|||||||
return Decoder.init(self);
|
return Decoder.init(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn token_to_id(self: *HFTokenizer, token: []const u8) ?u32 {
|
pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 {
|
||||||
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
|
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
1215
zml/tokenizer/homemade.zig
Normal file
1215
zml/tokenizer/homemade.zig
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,22 +1,65 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const tokenizer = @import("zml/tokenizer");
|
const log = std.log.scoped(.@"//zml/tokenizer");
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
const zml_tokenizer = @import("zml/tokenizer");
|
||||||
|
|
||||||
|
const Flags = struct {
|
||||||
|
tokenizer: []const u8,
|
||||||
|
prompt: []const u8,
|
||||||
|
expected: []const u8 = "",
|
||||||
|
verbose: bool = false,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn main() !void {
|
pub fn main() !void {
|
||||||
const model2 = "/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json";
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||||
|
}
|
||||||
|
|
||||||
var sp = try tokenizer.Tokenizer.from_file(std.heap.c_allocator, model2);
|
pub fn asyncMain() !void {
|
||||||
defer sp.deinit();
|
var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{};
|
||||||
|
const allocator = gpa.allocator();
|
||||||
|
|
||||||
std.debug.print("Loaded model\n", .{});
|
const args = stdx.flags.parseProcessArgs(Flags);
|
||||||
|
|
||||||
var encoder = try sp.encoder();
|
log.info("\tLoading tokenizer from {s}", .{args.tokenizer});
|
||||||
|
var tokenizer = try zml_tokenizer.Tokenizer.fromFile(allocator, args.tokenizer);
|
||||||
|
log.info("✅\tLoaded tokenizer from {s}", .{args.tokenizer});
|
||||||
|
defer tokenizer.deinit();
|
||||||
|
|
||||||
|
var encoder = try tokenizer.encoder();
|
||||||
defer encoder.deinit();
|
defer encoder.deinit();
|
||||||
|
|
||||||
var decoder = try sp.decoder();
|
var decoder = try tokenizer.decoder();
|
||||||
defer decoder.deinit();
|
defer decoder.deinit();
|
||||||
|
|
||||||
const ids = try encoder.encode("Hello, world! plane pouet plane");
|
const prompt_tok = try encoder.encode(args.prompt);
|
||||||
const decoded = try decoder.decode(ids);
|
|
||||||
|
|
||||||
std.debug.print("{d}\n{s}\n", .{ ids, decoded });
|
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
|
||||||
|
|
||||||
|
var errors: u8 = 0;
|
||||||
|
{
|
||||||
|
const reconstructed = try decoder.decode(prompt_tok);
|
||||||
|
if (!std.mem.eql(u8, args.prompt, reconstructed)) {
|
||||||
|
log.err("Reconstructed string from tokens doesn't match source: {s}", .{reconstructed});
|
||||||
|
errors += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.expected.len > 0) {
|
||||||
|
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
|
||||||
|
var it = std.mem.splitSequence(u8, args.expected, ",");
|
||||||
|
while (it.next()) |int_token| {
|
||||||
|
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
||||||
|
try expected.append(tok);
|
||||||
|
}
|
||||||
|
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
||||||
|
log.err("Doesn't match expected: {d}", .{expected.items});
|
||||||
|
errors += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errors == 0) log.info("All good !", .{});
|
||||||
|
|
||||||
|
std.process.exit(errors);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -164,7 +164,7 @@ pub const Decoder = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const SentencePieceProcessor = opaque {
|
pub const SentencePieceProcessor = opaque {
|
||||||
pub fn from_file(model: []const u8) !*SentencePieceProcessor {
|
pub fn fromFile(model: []const u8) !*SentencePieceProcessor {
|
||||||
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
|
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
|
||||||
errdefer sp.deinit();
|
errdefer sp.deinit();
|
||||||
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model)));
|
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model)));
|
||||||
@ -183,7 +183,7 @@ pub const SentencePieceProcessor = opaque {
|
|||||||
return try Decoder.init(self);
|
return try Decoder.init(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn token_to_id(self: *SentencePieceProcessor, token: []const u8) u32 {
|
pub fn tokenToId(self: *SentencePieceProcessor, token: []const u8) u32 {
|
||||||
return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token)));
|
return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -3,15 +3,19 @@ const hftokenizers = @import("hftokenizers");
|
|||||||
const sentencepiece = @import("sentencepiece");
|
const sentencepiece = @import("sentencepiece");
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
|
||||||
|
const homemade = @import("homemade.zig");
|
||||||
|
|
||||||
const Tokenizers = enum {
|
const Tokenizers = enum {
|
||||||
hftokenizers,
|
hftokenizers,
|
||||||
sentencepiece,
|
sentencepiece,
|
||||||
|
homemade,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Tokenizer = union(Tokenizers) {
|
pub const Tokenizer = union(Tokenizers) {
|
||||||
pub const Encoder = union(Tokenizers) {
|
pub const Encoder = union(Tokenizers) {
|
||||||
hftokenizers: hftokenizers.Encoder,
|
hftokenizers: hftokenizers.Encoder,
|
||||||
sentencepiece: sentencepiece.Encoder,
|
sentencepiece: sentencepiece.Encoder,
|
||||||
|
homemade: homemade.Encoder,
|
||||||
|
|
||||||
pub fn deinit(self: *Encoder) void {
|
pub fn deinit(self: *Encoder) void {
|
||||||
switch (self.*) {
|
switch (self.*) {
|
||||||
@ -41,6 +45,7 @@ pub const Tokenizer = union(Tokenizers) {
|
|||||||
pub const Decoder = union(Tokenizers) {
|
pub const Decoder = union(Tokenizers) {
|
||||||
hftokenizers: hftokenizers.Decoder,
|
hftokenizers: hftokenizers.Decoder,
|
||||||
sentencepiece: sentencepiece.Decoder,
|
sentencepiece: sentencepiece.Decoder,
|
||||||
|
homemade: homemade.Decoder,
|
||||||
|
|
||||||
pub fn deinit(self: *Decoder) void {
|
pub fn deinit(self: *Decoder) void {
|
||||||
switch (self.*) {
|
switch (self.*) {
|
||||||
@ -81,14 +86,22 @@ pub const Tokenizer = union(Tokenizers) {
|
|||||||
|
|
||||||
hftokenizers: *hftokenizers.HFTokenizer,
|
hftokenizers: *hftokenizers.HFTokenizer,
|
||||||
sentencepiece: *sentencepiece.SentencePieceProcessor,
|
sentencepiece: *sentencepiece.SentencePieceProcessor,
|
||||||
|
homemade: *homemade.Tokenizer,
|
||||||
|
|
||||||
pub fn from_file(_: std.mem.Allocator, model: []const u8) !Tokenizer {
|
pub fn fromFile(allocator: std.mem.Allocator, model: []const u8) !Tokenizer {
|
||||||
if (std.mem.endsWith(u8, model, ".pb")) {
|
if (std.mem.endsWith(u8, model, ".pb")) {
|
||||||
return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.from_file, .{model}) };
|
return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.fromFile, .{model}) };
|
||||||
}
|
}
|
||||||
if (std.mem.endsWith(u8, model, ".json")) {
|
if (std.mem.endsWith(u8, model, ".json")) {
|
||||||
return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.from_file, .{model}) };
|
return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.fromFile, .{model}) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (std.mem.endsWith(u8, model, ".tinyllama")) {
|
||||||
|
const tokenizer = try allocator.create(homemade.Tokenizer);
|
||||||
|
tokenizer.* = try asynk.callBlocking(homemade.fromTinyLlamaFile, .{ allocator, model, 32000 });
|
||||||
|
return .{ .homemade = tokenizer };
|
||||||
|
}
|
||||||
|
|
||||||
return error.InvalidArgument;
|
return error.InvalidArgument;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,9 +123,9 @@ pub const Tokenizer = union(Tokenizers) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn token_to_id(self: Tokenizer, token: []const u8) ?u32 {
|
pub fn tokenToId(self: Tokenizer, token: []const u8) ?u32 {
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
inline else => |v| v.token_to_id(token),
|
inline else => |v| v.tokenToId(token),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user