Add HuggingFace tokenizer bindings and SentencePiece integration; update BUILD files, async utilities, and FFI modules to support the new tokenizers.

This commit is contained in:
Tarry Singh 2024-02-28 15:47:37 +00:00
parent 5048e7dc89
commit 959bc48c42
45 changed files with 3751 additions and 183 deletions

View File

@ -7,7 +7,10 @@ new_git_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:git.bzl"
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.1")
bazel_dep(name = "patchelf", version = "0.18.0")
bazel_dep(name = "pcre2", version = "10.43")
bazel_dep(name = "abseil-cpp", version = "20240722.0.bcr.2")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "protobuf", version = "29.2")
bazel_dep(name = "rules_cc", version = "0.0.17")
bazel_dep(name = "rules_pkg", version = "1.0.1")
bazel_dep(name = "rules_proto", version = "7.1.0")
@ -114,3 +117,33 @@ apt.install(
manifest = "//runtimes/neuron:packages.yaml",
)
use_repo(apt, "neuron_bookworm")
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
bazel_dep(name = "rules_rust", version = "0.57.0")
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
rust.toolchain(
edition = "2021",
versions = ["1.84.0"],
extra_target_triples = [
"aarch64-apple-darwin",
"aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-gnu",
],
)
use_repo(rust, "rust_toolchains")
register_toolchains("@rust_toolchains//:all")
crate = use_extension("@rules_rust//crate_universe:extensions.bzl", "crate")
crate.from_cargo(
name = "crates",
cargo_lockfile = "//zml/tokenizer/hftokenizers:Cargo.lock",
manifests = ["//zml/tokenizer/hftokenizers:Cargo.toml"],
supported_platform_triples = [
"aarch64-apple-darwin",
"aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-gnu",
],
)
use_repo(crate, "crates")

File diff suppressed because one or more lines are too long

View File

@ -485,7 +485,7 @@ pub fn Channel(comptime T: type, capacity: usize) type {
}
pub fn send(self: *Self, val: T) void {
self.inner.send(val) catch unreachable;
self.inner.send(val);
}
pub fn recv(self: *Self) ?T {

View File

@ -134,8 +134,8 @@ const Coro = struct {
return initFromStack(func, stack_, storage);
}
pub fn deinit(self: Coro) void {
_ = self; // autofix
pub fn deinit(_: Coro) void {
// empty
}
fn initFromStack(func: *const fn () void, stack_: stack.Stack, storage: ?*anyopaque) !Frame {
@ -423,8 +423,7 @@ const CoroId = struct {
const StackOverflow = struct {
const magic_number: usize = 0x5E574D6D;
fn check(coro: Frame) !void {
_ = coro; // autofix
fn check(_: Frame) !void {
// const stack = coro.stack.ptr;
// const sp = coro.impl.stack_pointer;
// const magic_number_ptr: *usize = @ptrCast(stack);
@ -435,8 +434,7 @@ const StackOverflow = struct {
// }
}
fn setMagicNumber(stack_: stack.Stack) !void {
_ = stack_; // autofix
fn setMagicNumber(_: stack.Stack) !void {
// if (stack.len <= @sizeOf(usize)) {
// return Error.StackTooSmall;
// }

View File

@ -69,8 +69,8 @@ pub const StackAllocator = struct {
return .{ .allocator = allocator };
}
pub fn deinit(self: *StackAllocator) void {
_ = self; // autofix
pub fn deinit(_: *StackAllocator) void {
// empty
}
pub fn create(self: *StackAllocator) !Stack {

149
bazel/swig.bzl Normal file
View File

@ -0,0 +1,149 @@
load("@rules_cc//cc:action_names.bzl", "C_COMPILE_ACTION_NAME")
load("@rules_cc//cc:find_cc_toolchain.bzl", "find_cc_toolchain", "use_cc_toolchain")
def _swig_cc_library_impl(ctx):
args = ctx.actions.args()
if ctx.attr.cpp:
args.add("-c++")
args.add("-std=c++17")
args.add("-c")
args.add("-O")
args.add("-module", ctx.attr.module)
args.add_joined("-features", ctx.attr.enabled_features, join_with = ",")
if ctx.attr.defines:
args.add_all(ctx.attr.defines, format_each = "-D%s")
cc_toolchain = find_cc_toolchain(ctx)
if (cc_toolchain):
args.add_all(cc_toolchain.built_in_include_directories, format_each = "-I%s")
feature_configuration = cc_common.configure_features(
ctx = ctx,
cc_toolchain = cc_toolchain,
requested_features = ctx.features,
unsupported_features = ctx.disabled_features,
)
c_compile_variables = cc_common.create_compile_variables(
feature_configuration = feature_configuration,
cc_toolchain = cc_toolchain,
user_compile_flags = ctx.fragments.cpp.copts + ctx.fragments.cpp.conlyopts,
)
cc_compile_command_line = cc_common.get_memory_inefficient_command_line(
feature_configuration = feature_configuration,
action_name = C_COMPILE_ACTION_NAME,
variables = c_compile_variables,
)
for arg in cc_compile_command_line:
if (arg.startswith("-I") or arg.startswith("-D")):
args.add(arg)
cc_info = cc_common.merge_cc_infos(direct_cc_infos = [dep[CcInfo] for dep in ctx.attr.deps])
args.add_all(cc_info.compilation_context.defines, format_each = "-D%s")
args.add_all(cc_info.compilation_context.local_defines, format_each = "-D%s")
args.add_all(cc_info.compilation_context.framework_includes, format_each = "-I%s")
args.add_all(cc_info.compilation_context.includes, format_each = "-I%s")
args.add_all(cc_info.compilation_context.quote_includes, format_each = "-I%s")
args.add_all(cc_info.compilation_context.system_includes, format_each = "-I%s")
output_cpp = ctx.actions.declare_file("%s.cpp" % ctx.attr.module)
output_h = ctx.actions.declare_file("%s.h" % ctx.attr.module)
args.add("-outdir", output_h.dirname)
outputs = [
output_cpp,
output_h,
]
args.add("-o", output_cpp)
args.add("-w-305")
args.add(ctx.file.interface)
inputs = depset(ctx.attr.srcs, transitive = [
ctx.attr.interface.files,
cc_info.compilation_context.headers,
ctx.attr._swig_lib.files,
])
ctx.actions.run(
inputs = inputs,
outputs = outputs,
executable = ctx.executable._swig,
arguments = [args],
env = {
"SWIG_LIB": ctx.files._swig_lib[0].dirname,
},
mnemonic = "SwigC",
)
return [
DefaultInfo(
files = depset(outputs),
),
OutputGroupInfo(
hdrs = depset([output_h]),
srcs = depset([output_cpp]),
),
]
_swig_cc_library = rule(
_swig_cc_library_impl,
attrs = {
"interface": attr.label(
mandatory = True,
allow_single_file = True,
),
"srcs": attr.label_list(
allow_files = True,
),
"deps": attr.label_list(
providers = [CcInfo],
),
"defines": attr.string_list(),
"enabled_features": attr.string_list(),
"module": attr.string(
mandatory = True,
),
"cpp": attr.bool(
default = True,
),
"intgosize": attr.int(
default = 64,
),
"_swig": attr.label(
default = "@org_swig_swig//:swig",
cfg = "exec",
executable = True,
),
"_swig_lib": attr.label(
default = "@org_swig_swig//:lib",
allow_files = True,
),
},
toolchains = use_cc_toolchain(),
fragments = ["cpp"],
)
def swig_cc_library(name, deps = [], **kwargs):
_swig_cc_library(
name = "{}.swig".format(name),
deps = deps,
**kwargs
)
native.filegroup(
name = "{}.hdrs".format(name),
srcs = [":{}.swig".format(name)],
output_group = "hdrs",
)
native.filegroup(
name = "{}.srcs".format(name),
srcs = [":{}.swig".format(name)],
output_group = "srcs",
)
native.cc_library(
name = name,
hdrs = [":{}.hdrs".format(name)],
srcs = [":{}.srcs".format(name)],
deps = deps,
)

22
ffi/BUILD.bazel Normal file
View File

@ -0,0 +1,22 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
name = "cc",
hdrs = [
"zig_allocator.h",
"zig_slice.h",
],
visibility = ["//visibility:public"],
)
zig_library(
name = "zig",
srcs = [
"zig_allocator.zig",
"zig_slice.zig",
],
import_name = "ffi",
main = "ffi.zig",
visibility = ["//visibility:public"],
deps = [":cc"],
)

12
ffi/ffi.zig Normal file
View File

@ -0,0 +1,12 @@
const std = @import("std");
const c = @import("c");
pub const ZigAllocator = @import("zig_allocator.zig").ZigAllocator;
pub const ZigSlice = @import("zig_slice.zig").ZigSlice;
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
var result: [std.fs.max_path_bytes:0]u8 = undefined;
@memcpy(result[0..path.len], path);
result[path.len] = 0;
return result;
}

22
ffi/zig_allocator.h Normal file
View File

@ -0,0 +1,22 @@
#pragma once
#include <stdlib.h>
typedef struct
{
const void *ctx;
void *(*alloc)(const void *ctx, size_t elem, size_t nelems, size_t alignment);
void (*free)(const void *ctx, void *ptr, size_t elem, size_t nelems, size_t alignment);
#ifdef __cplusplus
template <typename T> [[nodiscard]] T *allocate(size_t n)
{
return static_cast<T *>(this->alloc(this->ctx, sizeof(T), n, _Alignof(T)));
}
template <typename T> [[nodiscard]] void deallocate(T *p, size_t n)
{
this->free(this->ctx, static_cast<void *>(p), sizeof(T), n, _Alignof(T));
}
#endif
} zig_allocator;

25
ffi/zig_allocator.zig Normal file
View File

@ -0,0 +1,25 @@
const std = @import("std");
const c = @import("c");
pub const ZigAllocator = struct {
pub inline fn from(allocator: std.mem.Allocator) c.zig_allocator {
return .{
.ctx = @ptrCast(@alignCast(&allocator)),
.alloc = &alloc,
.free = &free,
};
}
pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) ?*anyopaque {
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null;
return @ptrCast(ret);
}
pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) void {
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
const memory: [*c]u8 = @ptrCast(ptr);
const size = elem * nelems;
self.rawFree(memory[0..size], std.math.log2_int(usize, alignment), @returnAddress());
}
};

9
ffi/zig_slice.h Normal file
View File

@ -0,0 +1,9 @@
#pragma once
#include <stdlib.h>
typedef struct
{
void *ptr;
size_t len;
} zig_slice;

15
ffi/zig_slice.zig Normal file
View File

@ -0,0 +1,15 @@
const std = @import("std");
const c = @import("c");
pub const ZigSlice = struct {
pub fn from(slice: anytype) c.zig_slice {
return .{
.ptr = @ptrCast(@constCast(slice.ptr)),
.len = slice.len,
};
}
pub fn to(comptime T: type, slice: c.zig_slice) []T {
return @as([*c]T, @ptrCast(@alignCast(slice.ptr)))[0..slice.len];
}
};

View File

@ -5,10 +5,12 @@ zig_library(
srcs = [
"debug.zig",
"io.zig",
"json.zig",
"math.zig",
"meta.zig",
"queue.zig",
"signature.zig",
"time.zig",
],
main = "stdx.zig",
visibility = ["//visibility:public"],

72
stdx/json.zig Normal file
View File

@ -0,0 +1,72 @@
pub const std = @import("std");
pub fn Union(comptime T: type) type {
return struct {
const Self = @This();
value: T,
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
return jsonParseFromValue(
allocator,
try std.json.innerParse(
std.json.Value,
allocator,
source,
options,
),
options,
);
}
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
inline for (std.meta.fields(T)) |field| {
switch (field.type) {
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
[]const u8 => switch (source) {
.string => |v| return .{ .value = @unionInit(T, field.name, v) },
.number_string => |v| return .{ .value = @unionInit(T, field.name, v) },
else => {},
},
else => switch (@typeInfo(field.type)) {
.Int => if (source == .integer) return .{ .value = @unionInit(T, field.name, @intCast(source.integer)) },
.Float => if (source == .float) return .{ .value = @unionInit(T, field.name, @floatCast(source.float)) },
.Struct => if (source == .object) return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source.object, options)) },
inline else => switch (source) {
.number_string, .array => return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source, options)) },
else => {},
},
},
}
}
return error.UnexpectedToken;
}
};
}
pub fn NeverNull(comptime T: type, comptime default_value: T) type {
return struct {
const Self = @This();
value: T = default_value,
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
return .{ .value = (try std.json.innerParse(?T, allocator, source, options)) orelse default_value };
}
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
return .{ .value = (try std.json.innerParseFromValue(?T, allocator, source, options)) orelse default_value };
}
};
}
pub fn fillDefaultStructValues(comptime T: type, r: *T) !void {
inline for (@typeInfo(T).Struct.fields) |field| {
if (field.default_value) |default_ptr| {
if (@field(r, field.name) == null) {
const default = @as(*align(1) const field.type, @ptrCast(default_ptr)).*;
@field(r, field.name) = default;
}
}
}
}

View File

@ -1,5 +1,7 @@
pub const debug = @import("debug.zig");
pub const io = @import("io.zig");
pub const json = @import("json.zig");
pub const math = @import("math.zig");
pub const meta = @import("meta.zig");
pub const queue = @import("queue.zig");
pub const time = @import("time.zig");

34
stdx/time.zig Normal file
View File

@ -0,0 +1,34 @@
const std = @import("std");
pub const Duration = struct {
ns: u64,
pub fn format(
self: Duration,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
return try std.fmt.fmtDuration(self.ns).format(fmt, options, writer);
}
};
pub const Timer = struct {
inner: std.time.Timer,
pub fn start() !Timer {
return .{ .inner = try std.time.Timer.start() };
}
pub fn lap(self: *Timer) Duration {
return .{ .ns = try self.inner.lap() };
}
pub fn read(self: *Timer) Duration {
return .{ .ns = self.inner.read() };
}
pub fn reset(self: *Timer) void {
self.inner.reset();
}
};

View File

@ -0,0 +1,38 @@
load(":fwd.bzl", "include_fwd")
include_fwd(
name = "absl_hdrs",
includes = [
"absl/container/flat_hash_map.h",
"absl/container/flat_hash_set.h",
"absl/flags/flag.h",
"absl/flags/parse.h",
"absl/flags/usage.h",
"absl/strings/match.h",
"absl/strings/numbers.h",
"absl/strings/str_cat.h",
"absl/strings/str_format.h",
"absl/strings/str_join.h",
"absl/strings/str_replace.h",
"absl/strings/str_split.h",
"absl/strings/string_view.h",
"absl/strings/strip.h",
],
)
cc_library(
name = "absl",
hdrs = [":absl_hdrs"],
include_prefix = "third_party",
includes = ["."],
visibility = ["//visibility:public"],
deps = [
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/flags:parse",
"@abseil-cpp//absl/flags:usage",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:string_view",
],
)

View File

@ -0,0 +1,14 @@
def _include_fwd_impl(ctx):
files = []
for include in ctx.attr.includes:
f = ctx.actions.declare_file(include)
ctx.actions.write(f, '#include "{}"'.format(include))
files.append(f)
return [DefaultInfo(files = depset(files))]
include_fwd = rule(
implementation = _include_fwd_impl,
attrs = {
"includes": attr.string_list(),
},
)

View File

@ -0,0 +1,9 @@
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
def repo():
new_git_repository(
name = "com_google_sentencepiece",
remote = "https://github.com/google/sentencepiece.git",
commit = "d8f741853847553169444afc12c00f4bbff3e9ce",
build_file = "//third_party/com_google_sentencepiece:sentencepiece.bazel",
)

View File

@ -0,0 +1,95 @@
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@protobuf//bazel:proto_library.bzl", "proto_library")
package(
default_visibility = ["//visibility:public"],
features = [
"layering_check",
"parse_headers",
],
)
licenses(["notice"]) # Apache 2, BSD, MIT
proto_library(
name = "sentencepiece_proto",
srcs = ["src/sentencepiece.proto"],
)
cc_proto_library(
name = "sentencepiece_cc_proto",
deps = [":sentencepiece_proto"],
)
proto_library(
name = "sentencepiece_model_proto",
srcs = ["src/sentencepiece_model.proto"],
)
cc_proto_library(
name = "sentencepiece_model_cc_proto",
deps = [":sentencepiece_model_proto"],
)
copy_file(
name = "config_h",
src = "config.h.in",
out = "config.h",
allow_symlink = True,
)
cc_library(
name = "darts_clone",
hdrs = glob([
"third_party/darts_clone/*.h",
]),
includes = ["."],
)
cc_library(
name = "sentencepiece_processor",
srcs = [
"src/bpe_model.cc",
"src/char_model.cc",
"src/error.cc",
"src/filesystem.cc",
"src/model_factory.cc",
"src/model_interface.cc",
"src/normalizer.cc",
"src/sentencepiece_processor.cc",
"src/unigram_model.cc",
"src/util.cc",
"src/word_model.cc",
],
hdrs = [
":config_h",
"src/common.h",
"src/bpe_model.h",
"src/char_model.h",
"src/filesystem.h",
"src/freelist.h",
"src/init.h",
"src/model_factory.h",
"src/model_interface.h",
"src/normalizer.h",
"src/sentencepiece_processor.h",
"src/trainer_interface.h",
"src/unigram_model.h",
"src/util.h",
"src/word_model.h",
],
defines = [
"_USE_EXTERNAL_PROTOBUF",
"_USE_EXTERNAL_ABSL",
],
includes = [
"src",
],
deps = [
":darts_clone",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@zml//third_party/com_google_sentencepiece:absl",
],
)

16
third_party/non_module_deps.bzl vendored Normal file
View File

@ -0,0 +1,16 @@
load("//:third_party/org_swig_swig/repo.bzl", org_swig_swig = "repo")
load("//third_party/com_google_sentencepiece:repo.bzl", com_google_sentencepiece = "repo")
def _non_module_deps_impl(mctx):
com_google_sentencepiece()
org_swig_swig()
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
non_module_deps = module_extension(
implementation = _non_module_deps_impl,
)

10
third_party/org_swig_swig/repo.bzl vendored Normal file
View File

@ -0,0 +1,10 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def repo():
http_archive(
name = "org_swig_swig",
url = "http://prdownloads.sourceforge.net/swig/swig-4.3.0.tar.gz",
sha256 = "f7203ef796f61af986c70c05816236cbd0d31b7aa9631e5ab53020ab7804aa9e",
strip_prefix = "swig-4.3.0",
build_file = "//:third_party/org_swig_swig/swig.bazel",
)

109
third_party/org_swig_swig/swig.bazel vendored Normal file
View File

@ -0,0 +1,109 @@
licenses(["restricted"]) # GPLv3
exports_files(["LICENSE"])
filegroup(
name = "lib",
srcs = glob([
"Lib/*.*",
"Lib/c/*.*",
"Lib/std/*.*",
"Lib/typemaps/*.*",
]),
visibility = ["//visibility:public"],
)
cc_binary(
name = "swig",
srcs = glob([
"Source/CParse/*.h",
"Source/CParse/*.c",
"Source/DOH/*.h",
"Source/DOH/*.c",
"Source/Include/*.h",
"Source/Preprocessor/*.h",
"Source/Preprocessor/*.c",
"Source/Swig/*.h",
"Source/Swig/*.c",
]) + [
"Source/Include/swigconfig.h",
"Source/Modules/allocate.cxx",
"Source/Modules/c.cxx",
"Source/Modules/contract.cxx",
"Source/Modules/directors.cxx",
"Source/Modules/emit.cxx",
"Source/Modules/interface.cxx",
"Source/Modules/lang.cxx",
"Source/Modules/main.cxx",
"Source/Modules/nested.cxx",
"Source/Modules/overload.cxx",
"Source/Modules/swigmain-lite.cxx",
"Source/Modules/swigmod.h",
"Source/Modules/typepass.cxx",
"Source/Modules/utils.cxx",
"Source/Modules/xml.cxx",
],
includes = [
"Source/CParse",
"Source/DOH",
"Source/Include",
"Source/Modules",
"Source/Preprocessor",
"Source/Swig",
],
data = [":lib"],
output_licenses = ["unencumbered"],
visibility = ["//visibility:public"],
deps = ["@pcre2"],
)
genrule(
name = "swigconfig",
outs = ["Source/Include/swigconfig.h"],
cmd = """\
cat <<EOF >$@
#define HAVE_BOOL
#define HAVE_PCRE
#define HAVE_POPEN
#define PACKAGE_BUGREPORT \"http://www.swig.org\"
#define PACKAGE_VERSION \"4.3.0\"
#define STDC_HEADERS
#define SWIG_CXX \"bazel4lyfe\"
#define SWIG_LIB \"external/org_swig_swig/Lib\"
#define SWIG_LIB_WIN_UNIX \"\"
#define SWIG_PLATFORM \"bazel4lyfe\"
EOF
""",
)
genrule(
name = "get_rid_of_stuff_we_dont_need_yet",
srcs = ["Source/Modules/swigmain.cxx"],
outs = ["Source/Modules/swigmain-lite.cxx"],
cmd = """\
sed -e '/swig_allegrocl/d' \
-e '/swig_chicken/d' \
-e '/swig_clisp/d' \
-e '/swig_csharp/d' \
-e '/swig_d/d' \
-e '/swig_guile/d' \
-e '/swig_go/d' \
-e '/swig_java/d' \
-e '/swig_lua/d' \
-e '/swig_modula3/d' \
-e '/swig_mzscheme/d' \
-e '/swig_ocaml/d' \
-e '/swig_octave/d' \
-e '/swig_perl/d' \
-e '/swig_php/d' \
-e '/swig_pike/d' \
-e '/swig_python/d' \
-e '/swig_r/d' \
-e '/swig_ruby/d' \
-e '/swig_scilab/d' \
-e '/swig_sexp/d' \
-e '/swig_tcl/d' \
-e '/swig_uffi/d' \
$< >$@
""",
)

View File

@ -32,6 +32,7 @@ zig_library(
"//pjrt",
"//runtimes",
"//stdx",
"//zml/tokenizer",
"//zml/tools",
"@rules_zig//zig/lib:libc",
"@rules_zig//zig/runfiles",

View File

@ -10,8 +10,6 @@ const posix = @import("posix.zig");
pub const gguf = @import("aio/gguf.zig");
pub const nemo = @import("aio/nemo.zig");
pub const safetensors = @import("aio/safetensors.zig");
pub const sentencepiece = @import("aio/sentencepiece.zig");
pub const tinyllama = @import("aio/tinyllama.zig");
pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig");
@ -23,8 +21,6 @@ test {
std.testing.refAllDecls(gguf);
std.testing.refAllDecls(nemo);
std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(sentencepiece);
std.testing.refAllDecls(tinyllama);
std.testing.refAllDecls(torch);
std.testing.refAllDecls(yaml);
}
@ -39,29 +35,11 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
try gguf.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".pt"))
try torch.open(allocator, model_path)
else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
try tinyllama.open(allocator, model_path)
else {
std.debug.panic("File extension not recognized: {s}", .{model_path});
};
}
pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8) !zml.tokenizer.Tokenizer {
return if (std.mem.endsWith(u8, tokenizer_path, ".json"))
try zml.tokenizer.fromHfJson(allocator, tokenizer_path)
else if (std.mem.endsWith(u8, tokenizer_path, ".gguf")) {
const store = try gguf.open(allocator, tokenizer_path);
return gguf.getGgufTokenizer(store, allocator);
} else if (std.mem.endsWith(u8, tokenizer_path, ".pb") or std.mem.endsWith(u8, tokenizer_path, ".model"))
try sentencepiece.loadTokenizerFromPath(allocator, tokenizer_path)
else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama"))
try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000)
else {
log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
return error.FormatNotRecognized;
};
}
/// Creates a Model struct with tensor shapes read from the given BufferStore.
/// The result can be used to pass to `compileModel`.
///
@ -445,6 +423,7 @@ fn _populateStruct(
return true;
},
.Void => true,
.Union => true,
else => if (required) {
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
return error.UnsupportedMetadataType;

View File

@ -31,76 +31,6 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
return res;
}
pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) !zml.tokenizer.Tokenizer {
const tokens = self.metadataSlice("tokenizer.ggml.tokens", .string) orelse {
log.err("GGUF File: Tokens not found", .{});
return error.TokensNotFound;
};
const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse {
log.err("GGUF File: Scores not found", .{});
return error.ScoresNotFound;
};
assert(tokens.len == scores.len);
const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama";
const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece;
const bos = self.metadata("tokenizer.ggml.bos_token_id", .int);
const eos = self.metadata("tokenizer.ggml.eos_token_id", .int);
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int);
const pad = self.metadata("tokenizer.ggml.padding_token_id", .int);
const NOT_FOUND = std.math.maxInt(u32);
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
.bos = @intCast(bos.?),
.eos = @intCast(eos.?),
.unk = @intCast(unk orelse NOT_FOUND),
.pad = @intCast(pad orelse NOT_FOUND),
};
const gguf_normalizer = if (tokenizer_impl == .gpt2)
zml.tokenizer.Normalizer.wellKnown(.gpt2)
else
zml.tokenizer.Normalizer.wellKnown(.sentencepiece);
const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0;
const n_tokens: u32 = @intCast(tokens.len + extra_tokens);
var tokenizer = try zml.tokenizer.Tokenizer.init(
allocator,
n_tokens,
32,
gguf_normalizer,
special_tokens,
true,
);
var gpt2_unicode = if (tokenizer_impl == .gpt2)
try zml.tokenizer.Gpt2TextDecoder.init(allocator)
else
null;
defer if (gpt2_unicode) |*gpt2| gpt2.deinit();
var decoded = std.ArrayList(u8).init(allocator);
defer decoded.deinit();
// copy the tokens to the tokenizer arena.
for (tokens, 0..tokens.len) |t, i| {
if (tokenizer_impl == .gpt2) {
decoded.clearRetainingCapacity();
try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t));
// log.debug("token: {s} -> {s}", .{t, decoded.items});
} else {
try tokenizer.addToken(@floatCast(scores[i]), t);
}
}
// Gpt2 tokenizer always splits on spaces.
if (tokenizer_impl == .gpt2) {
tokenizer.special_tokens.hard_space = tokenizer.next_token_id;
tokenizer.addOwnedToken(0, " ");
}
return tokenizer;
}
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));

View File

@ -4,58 +4,6 @@ const zml = @import("../zml.zig");
const sentencepiece_proto = @import("//sentencepiece:model_proto");
const Normalizer = zml.tokenizer.Normalizer;
const Tokenizer = zml.tokenizer.Tokenizer;
pub fn loadTokenizerFromPath(allocator: std.mem.Allocator, path: []const u8) !Tokenizer {
const file = try asynk.File.open(path, .{});
defer file.close() catch unreachable;
return loadTokenizerFromFile(allocator, file);
}
pub fn loadTokenizerFromFile(allocator: std.mem.Allocator, file: asynk.File) !Tokenizer {
const reader = file.reader();
const input = try reader.readAllAlloc(allocator, 16 * 1024 * 1024);
defer allocator.free(input);
var proto_arena = std.heap.ArenaAllocator.init(allocator);
defer proto_arena.deinit();
const model = try sentencepiece_proto.ModelProto.decode(input, proto_arena.allocator());
// no deinit, memory will be freed by the proto_arena
return loadTokenizerFromModelProto(allocator, model);
}
pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentencepiece_proto.ModelProto) !Tokenizer {
std.debug.assert(model.trainer_spec.?.model_type.? == .BPE);
const special_tokens: Tokenizer.SpecialTokens = .{
.unk = @intCast(model.trainer_spec.?.unk_id.?),
.bos = @intCast(model.trainer_spec.?.bos_id.?),
.eos = @intCast(model.trainer_spec.?.eos_id.?),
.pad = parseTokenId(model.trainer_spec.?.pad_id),
};
var tokenizer = try Tokenizer.init(
allocator,
@intCast(model.pieces.items.len),
@intCast(model.trainer_spec.?.max_sentencepiece_length.?),
normalizerFromSpec(model.normalizer_spec.?),
special_tokens,
true,
);
errdefer tokenizer.deinit();
for (model.pieces.items) |*piece| {
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
}
const byte_fallback = model.trainer_spec.?.byte_fallback orelse false;
if (byte_fallback) {
try tokenizer.rewriteByteFallbackTokens();
}
return tokenizer;
}
fn parseTokenId(id: ?i32) u32 {
if (id) |idx| {

View File

@ -38,7 +38,7 @@ pub const HostBuffer = struct {
/// The returned HostBuffer doesn't take ownership of the slice
/// that will still need to be deallocated.
pub fn fromBytes(shape_: Shape, data_: []const u8) HostBuffer {
std.debug.assert(shape_.byteSize() == data_.len);
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
return .{
._shape = shape_,
.data = data_,

View File

@ -175,7 +175,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
} else {
to.* = null;
},
.Int, .Float, .Enum => to.* = from,
.Int, .Float, .Enum, .Union => to.* = from,
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
}
}

View File

@ -2027,7 +2027,7 @@ pub const Tensor = struct {
/// Appends a 1-dim axis, with the given tag.
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
// stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
return self.insertAxes(.last, t);
}

35
zml/tokenizer/BUILD.bazel Normal file
View File

@ -0,0 +1,35 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
load("//bazel:swig.bzl", "swig_cc_library")
swig_cc_library(
name = "sentencepiece_swig",
interface = "sentencepiece.i",
module = "sentencepiece",
deps = [
"//ffi:cc",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
zig_library(
name = "tokenizer",
import_name = "zml/tokenizer",
main = "tokenizer.zig",
visibility = ["//visibility:public"],
deps = [
"//async",
"//ffi:zig",
"//zml/tokenizer/hftokenizers",
"//zml/tokenizer/sentencepiece",
],
)
zig_cc_binary(
name = "main",
main = "main.zig",
visibility = ["//visibility:public"],
deps = [
":tokenizer",
],
)

View File

@ -0,0 +1,38 @@
load("@rules_rust//rust:defs.bzl", "rust_static_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
rust_static_library(
name = "hftokenizers_rs",
srcs = ["hftokenizers.rs"],
crate_name = "zml_tokenizer_hftokenizers",
edition = "2021",
deps = ["@crates//:tokenizers"],
)
cc_library(
name = "hftokenizers_cc",
hdrs = ["hftokenizers.h"],
visibility = ["//visibility:public"],
deps = [
":hftokenizers_rs",
"//ffi:cc",
],
)
zig_library(
name = "hftokenizers",
main = "hftokenizers.zig",
visibility = ["//visibility:public"],
deps = [
":hftokenizers_cc",
"//ffi:zig",
],
)
zig_cc_binary(
name = "main",
main = "main.zig",
visibility = ["//visibility:public"],
deps = [":hftokenizers"],
)

669
zml/tokenizer/hftokenizers/Cargo.lock generated Normal file
View File

@ -0,0 +1,669 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "darling"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn",
]
[[package]]
name = "either"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "esaxx-rs"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
[[package]]
name = "fancy-regex"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
dependencies = [
"bit-set",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "js-sys"
version = "0.3.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7"
dependencies = [
"once_cell",
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "log"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "macro_rules_attribute"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13"
dependencies = [
"macro_rules_attribute-proc_macro",
"paste",
]
[[package]]
name = "macro_rules_attribute-proc_macro"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "monostate"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e"
dependencies = [
"monostate-impl",
"serde",
]
[[package]]
name = "monostate-impl"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "once_cell"
version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-cond"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
dependencies = [
"either",
"itertools 0.11.0",
"rayon",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "regex"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "serde"
version = "1.0.216"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.216"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.134"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "smallvec"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "spm_precompiled"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326"
dependencies = [
"base64",
"nom",
"serde",
"unicode-segmentation",
]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "syn"
version = "2.0.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53cbcb5a243bd33b7858b1d7f4aca2153490815872d86d955d6ea29f743c035"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tokenizers"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ecededfed68a69bc657e486510089e255e53c3d38cc7d4d59c8742668ca2cae"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"fancy-regex",
"getrandom",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "unicode-ident"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
[[package]]
name = "unicode-normalization-alignments"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
dependencies = [
"smallvec",
]
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396"
dependencies = [
"cfg-if",
"once_cell",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79"
dependencies = [
"bumpalo",
"log",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6"
[[package]]
name = "zerocopy"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
"byteorder",
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zml_tokenizer_hftokenizers"
version = "0.1.0"
dependencies = [
"tokenizers",
]

View File

@ -0,0 +1,13 @@
[package]
name = "zml_tokenizer_hftokenizers"
version = "0.1.0"
edition = "2021"
[dependencies]
# onig-sys fails to build with zig cc, disable it via the unstable_wasm feature, which switches
# the regex library to using fancy.
tokenizers = { version = "0.21.0", default-features = false, features = ["unstable_wasm"] }
[lib]
name = "zml_tokenizer_hftokenizers"
path = "hftokenizers.rs"

View File

@ -0,0 +1,16 @@
#pragma once
#include <stdint.h>
#include <stdlib.h>
#include "ffi/zig_slice.h"
typedef struct hftokenizers hftokenizers;
hftokenizers *hftokenizers_new(zig_slice);
void hftokenizers_drop(hftokenizers *tokenizer);
zig_slice hftokenizers_encode(hftokenizers *tokenizer, zig_slice text);
void hftokenizers_tokens_drop(zig_slice tokens);
zig_slice hftokenizers_decode(hftokenizers *tokenizer, zig_slice tokens);
void hftokenizers_str_drop(zig_slice text);
uint32_t hftokenizers_token_to_id(hftokenizers *tokenizer, zig_slice token);

View File

@ -0,0 +1,101 @@
#[repr(C)]
struct ZigSlice<T> {
ptr: *mut T,
len: usize,
}
impl<T> ZigSlice<T> {
fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
fn as_slice_mut(&self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
#[no_mangle]
extern "C" fn hftokenizers_new(path: ZigSlice<u8>) -> *mut tokenizers::Tokenizer {
return Box::into_raw(Box::new(
tokenizers::Tokenizer::from_file(std::path::Path::new(
std::str::from_utf8(path.as_slice()).unwrap(),
))
.unwrap(),
));
}
#[no_mangle]
extern "C" fn hftokenizers_drop(t: *mut tokenizers::Tokenizer) {
drop(unsafe { Box::from_raw(t) });
}
#[no_mangle]
extern "C" fn hftokenizers_encode(
t: *mut tokenizers::Tokenizer,
string: ZigSlice<u8>,
) -> ZigSlice<u32> {
let input_str = std::str::from_utf8(string.as_slice()).unwrap();
let encoded = unsafe { t.as_ref() }
.unwrap()
.encode_fast(input_str, false)
.unwrap();
// Convert the result to a boxed slice
let mut ids: Box<[u32]> = encoded.get_ids().to_owned().into_boxed_slice();
// Retrieve the zig slice associated to the boxed slice.
let slice = ZigSlice {
ptr: ids.as_mut_ptr(),
len: ids.len(),
};
// Leak the box so that it's not deallocated.
Box::leak(ids);
return slice;
}
#[no_mangle]
extern "C" fn hftokenizers_tokens_drop(tokens: ZigSlice<u32>) {
// Reconstruct the Box from the zig slice so that it's dropped.
drop(unsafe { Box::from_raw(tokens.as_slice_mut()) });
}
#[no_mangle]
extern "C" fn hftokenizers_decode(
t: *mut tokenizers::Tokenizer,
ids: ZigSlice<u32>,
) -> ZigSlice<u8> {
let decoded = unsafe { t.as_ref() }
.unwrap()
.decode(ids.as_slice(), false)
.unwrap();
// Convert the result to a boxed slice
let mut string: Box<[u8]> = decoded.into_bytes().into_boxed_slice();
// Retrieve the zig slice associated to the boxed slice.
let slice = ZigSlice {
ptr: string.as_mut_ptr(),
len: string.len(),
};
// Leak the box so that it's not deallocated.
Box::leak(string);
return slice;
}
#[no_mangle]
extern "C" fn hftokenizers_str_drop(tokens: ZigSlice<u8>) {
drop(unsafe { Box::from_raw(tokens.as_slice_mut()) });
}
#[no_mangle]
extern "C" fn hftokenizers_token_to_id(t: *mut tokenizers::Tokenizer, token: ZigSlice<u8>) -> u32 {
let id = unsafe { t.as_ref() }
.unwrap()
.token_to_id(std::str::from_utf8(token.as_slice()).unwrap())
.unwrap_or(u32::MAX);
return id;
}

View File

@ -0,0 +1,113 @@
const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
pub const Encoder = struct {
inner: *HFTokenizer,
current_ids: ?[]const u32 = null,
fn init(inner: *HFTokenizer) Encoder {
return .{ .inner = inner };
}
pub fn reset(self: *Encoder) void {
if (self.current_ids) |current_ids_| {
c.hftokenizers_tokens_drop(ffi.ZigSlice.from(current_ids_));
self.current_ids = null;
}
}
pub fn deinit(self: *Encoder) void {
self.reset();
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
self.reset();
self.current_ids = ffi.ZigSlice.to(u32, c.hftokenizers_encode(@ptrCast(self.inner), ffi.ZigSlice.from(input)));
return self.ids();
}
pub fn ids(self: *const Encoder) []const u32 {
return self.current_ids orelse &.{};
}
};
pub const Decoder = struct {
const StringBuffer = std.BoundedArray(u8, 128);
const TokensIdsBuffer = std.BoundedArray(u32, 4);
inner: *HFTokenizer,
current_string: ?[]const u8 = null,
last_string: StringBuffer = .{ .len = 0 },
last_token_ids: TokensIdsBuffer = .{ .len = 0 },
fn init(inner: *HFTokenizer) Decoder {
return .{ .inner = inner };
}
pub fn deinit(self: *Decoder) void {
self.reset();
}
pub fn reset(self: *Decoder) void {
if (self.current_string) |current_string_| {
c.hftokenizers_str_drop(ffi.ZigSlice.from(current_string_));
self.current_string = null;
}
}
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
self.reset();
self.current_string = ffi.ZigSlice.to(u8, c.hftokenizers_decode(@ptrCast(self.inner), ffi.ZigSlice.from(ids)));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
return self.current_string orelse &.{};
}
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
if (self.last_token_ids.len >= self.last_token_ids.capacity()) {
_ = self.last_token_ids.orderedRemove(0);
}
self.last_token_ids.appendAssumeCapacity(token_id);
const new_string = try self.decode(self.last_token_ids.constSlice());
if (self.last_string.len == 0) {
self.last_string = try StringBuffer.fromSlice(new_string);
return new_string;
}
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
var it = view.iterator();
while (it.nextCodepointSlice()) |cp| {
const start = it.i - cp.len;
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
const chunk = new_string[self.last_string.len - start ..];
self.last_string = try StringBuffer.fromSlice(new_string);
return chunk;
}
}
return null;
}
};
pub const HFTokenizer = opaque {
pub fn from_file(model: []const u8) !*HFTokenizer {
return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model)));
}
pub fn deinit(self: *HFTokenizer) void {
return c.hftokenizers_drop(@ptrCast(self));
}
pub fn encoder(self: *HFTokenizer) !Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *HFTokenizer) !Decoder {
return Decoder.init(self);
}
pub fn token_to_id(self: *HFTokenizer, token: []const u8) ?u32 {
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
}
};

View File

@ -0,0 +1,27 @@
const std = @import("std");
const c = @import("c");
const HFTokenizers = @import("hftokenizers").HFTokenizers;
pub fn main() !void {
const tokenizer = HFTokenizers.init("/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json");
defer HFTokenizers.deinit(tokenizer);
const input = "Hello, world! plane pouet plane";
var encoded = HFTokenizers.encode(tokenizer, input);
defer encoded.deinit();
var pouet = std.ArrayList(u32).init(std.heap.c_allocator);
defer pouet.deinit();
// try pouet.appendSlice(encoded.ids);
var t = try std.time.Timer.start();
for (0..100) |_| {
try pouet.appendSlice(encoded.ids);
t.reset();
var decoded = HFTokenizers.decode(tokenizer, pouet.items);
defer decoded.deinit();
const elapsed = t.lap();
// std.debug.print("{any} {any} {d}us\n", .{tokenizer, encoded, elapsed / std.time.ns_per_us});
std.debug.print("{any} {any} {s} {d}ns {d}us\n", .{ tokenizer, encoded, decoded.str, elapsed, elapsed / std.time.ns_per_us });
}
}

22
zml/tokenizer/main.zig Normal file
View File

@ -0,0 +1,22 @@
const std = @import("std");
const tokenizer = @import("zml/tokenizer");
pub fn main() !void {
const model2 = "/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json";
var sp = try tokenizer.Tokenizer.from_file(std.heap.c_allocator, model2);
defer sp.deinit();
std.debug.print("Loaded model\n", .{});
var encoder = try sp.encoder();
defer encoder.deinit();
var decoder = try sp.decoder();
defer decoder.deinit();
const ids = try encoder.encode("Hello, world! plane pouet plane");
const decoded = try decoder.decode(ids);
std.debug.print("{d}\n{s}\n", .{ ids, decoded });
}

View File

@ -0,0 +1,35 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@zml//bazel:zig.bzl", "zig_cc_binary")
load("//bazel:swig.bzl", "swig_cc_library")
swig_cc_library(
name = "sentencepiece_swig",
interface = "sentencepiece.i",
module = "sentencepiece",
deps = [
"//ffi:cc",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
zig_library(
name = "sentencepiece",
import_name = "sentencepiece",
main = "sentencepiece.zig",
visibility = ["//visibility:public"],
deps = [
":sentencepiece_swig",
"//ffi:zig",
],
)
zig_cc_binary(
name = "main",
srcs = ["sentencepiece.zig"],
main = "main.zig",
visibility = ["//visibility:public"],
deps = [
":sentencepiece_swig",
"//ffi:zig",
],
)

View File

@ -0,0 +1,289 @@
const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
pub const SentencePieceError = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
pub const DecoderStream = struct {
const TokensSize = 4;
const StringSize = 128;
decoder: SentencePieceProcessor.Decoder,
buffer: [StringSize]u8 = undefined,
last_tokens: []u8 = &.{},
pub fn init(decoder: SentencePieceProcessor.Decoder) DecoderStream {
var ret: DecoderStream = .{
.decoder = decoder,
};
ret.decoder.reserve_tokens(TokensSize);
ret.decoder.reserve_string(StringSize);
return ret;
}
pub fn next(self: *DecoderStream, next_token: u32) !?[]const u8 {
if (self.decoder.tokens().len >= TokensSize) {
const tokens = self.decoder.tokens();
inline for (0..TokensSize - 1) |i| {
tokens[i] = tokens[i + 1];
}
tokens[TokensSize - 1] = next_token;
} else {
self.decoder.append(next_token);
}
const new_tokens = try self.decoder.decode();
if (self.last_tokens.len == 0) {
self.last_tokens = self.buffer[0..new_tokens.len];
@memcpy(self.last_tokens, new_tokens);
return new_tokens;
}
for (1..self.last_tokens.len) |i| {
if (std.mem.startsWith(u8, new_tokens, self.last_tokens[i..])) {
const toks = new_tokens[self.last_tokens.len - i ..];
self.last_tokens = self.buffer[0..new_tokens.len];
@memcpy(self.last_tokens, new_tokens);
return toks;
}
}
return null;
}
};
pub const SentencePieceProcessor = opaque {
pub const Encoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
fn init(inner: *SentencePieceProcessor) Encoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
};
}
pub fn deinit(self: *Encoder) void {
c.std_vector_int_delete(self.vec);
}
pub fn reserve(self: *Encoder, size: usize) void {
c.std_vector_int_reserve(self.vec, size);
}
pub fn reset(self: *Encoder) void {
c.std_vector_int_clear(self.vec);
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
return ffi.ZigSlice.to(u32, .{
.ptr = c.std_vector_int_data(self.vec),
.len = c.std_vector_int_size(self.vec),
});
}
};
pub const Decoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
str: *c.std_string,
fn init(inner: *SentencePieceProcessor) Decoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
.str = c.std_string_new() orelse unreachable,
};
}
pub fn append(self: *Decoder, token: u32) void {
c.std_vector_int_push_back(self.vec, @intCast(token));
}
pub fn deinit(self: *Decoder) void {
c.std_vector_int_delete(self.vec);
c.std_string_delete(self.str);
}
pub fn reserve_tokens(self: *Decoder, size: usize) void {
c.std_vector_int_reserve(self.vec, size);
}
pub fn reserve_string(self: *Decoder, size: usize) void {
c.std_string_reserve(self.str, size);
}
pub fn reset(self: *Decoder) void {
c.std_vector_int_clear(self.vec);
c.std_string_clear(self.str);
}
pub fn decode(self: *Decoder) ![]const u8 {
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
const res = c.std_string_data(self.str);
return ffi.ZigSlice.to(u8, res);
}
pub fn tokens(self: *const Decoder) []u32 {
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
return ptr[0..c.std_vector_int_size(self.vec)];
}
};
fn assertOk(code: c.sentencepiece_util_StatusCode) SentencePieceError!void {
return switch (code) {
c.sentencepiece_util_StatusCode_kOk => {},
c.sentencepiece_util_StatusCode_kCancelled => error.Cancelled,
c.sentencepiece_util_StatusCode_kUnknown => error.Unknown,
c.sentencepiece_util_StatusCode_kInvalidArgument => error.InvalidArgument,
c.sentencepiece_util_StatusCode_kDeadlineExceeded => error.DeadlineExceeded,
c.sentencepiece_util_StatusCode_kNotFound => error.NotFound,
c.sentencepiece_util_StatusCode_kAlreadyExists => error.AlreadyExists,
c.sentencepiece_util_StatusCode_kPermissionDenied => error.PermissionDenied,
c.sentencepiece_util_StatusCode_kResourceExhausted => error.ResourceExhausted,
c.sentencepiece_util_StatusCode_kFailedPrecondition => error.FailedPrecondition,
c.sentencepiece_util_StatusCode_kAborted => error.Aborted,
c.sentencepiece_util_StatusCode_kOutOfRange => error.OutOfRange,
c.sentencepiece_util_StatusCode_kUnimplemented => error.Unimplemented,
c.sentencepiece_util_StatusCode_kInternal => error.Internal,
c.sentencepiece_util_StatusCode_kUnavailable => error.Unavailable,
c.sentencepiece_util_StatusCode_kDataLoss => error.DataLoss,
c.sentencepiece_util_StatusCode_kUnauthenticated => error.Unauthenticated,
else => unreachable,
};
}
pub fn load(model: []const u8) !*SentencePieceProcessor {
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
errdefer sp.deinit();
try sp.load_from(model);
return sp;
}
pub fn deinit(self: *SentencePieceProcessor) void {
c.SentencePieceProcessor_delete(@ptrCast(self));
}
fn load_from(self: *SentencePieceProcessor, model: []const u8) !void {
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(self), ffi.ZigSlice.from(model)));
}
pub fn encoder(self: *SentencePieceProcessor) Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *SentencePieceProcessor) Decoder {
return Decoder.init(self);
}
};
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
var result: [std.fs.max_path_bytes:0]u8 = undefined;
@memcpy(result[0..path.len], path);
result[path.len] = 0;
return result;
}
pub fn main() !void {
const sp = try SentencePieceProcessor.load("/Users/steeve/Downloads/poolside.sp.pb");
defer sp.deinit();
std.debug.print("Loaded model\n", .{});
var encoder = sp.encoder();
defer encoder.deinit();
var decoder = sp.decoder();
defer decoder.deinit();
const ss = @embedFile("main.zig");
// \\String class
// \\Strings are objects that represent sequences of characters.
// \\
// \\The standard string class provides support for such objects with an interface similar to that of a standard container of bytes, but adding features specifically designed to operate with strings of single-byte characters.
// \\
// \\The string class is an instantiation of the basic_string class template that uses char (i.e., bytes) as its character type, with its default char_traits and allocator types (see basic_string for more info on the template).
// \\
// \\Note that this class handles bytes independently of the encoding used: If used to handle sequences of multi-byte or variable-length characters (such as UTF-8), all members of this class (such as length or size), as well as its iterators, will still operate in terms of bytes (not actual encoded characters).
// \\
// ;
const tokens = try encoder.encode(ss);
// const ss2 = 128;
// var buf = [_]u8{0} ** ss2;
// // _ = buf; // autofix
// var last_tokens: []u8 = &.{};
// // _ = last_tokens; // autofix
// decoder.reserve_tokens(4);
// decoder.reserve_string(128);
var stream = DecoderStream.init(decoder);
var start = try std.time.Timer.start();
for (tokens) |token| {
if (try stream.next(token)) |chunk| {
// std.debug.print("{s}", .{chunk});
std.debug.print("{d}us - {s}\n", .{ start.lap() / std.time.ns_per_us, chunk });
}
}
// var start = try std.time.Timer.start();
// var it = std.mem.window(u32, tokens, 3, 1);
// while (it.next()) |slice| {
// if (decoder.tokens().len >= 4) {
// const kept_tokens = decoder.tokens()[1..];
// std.mem.copyForwards(u32, decoder.tokens()[0..kept_tokens.len], kept_tokens);
// kept_tokens[kept_tokens.len - 1] = slice[2];
// } else {
// for (slice) |token| {
// decoder.append(token);
// }
// }
// const new_tokens = try decoder.decode();
// for (0..ss2) |i| {
// if (std.mem.startsWith(u8, new_tokens, last_tokens[i..])) {
// const toks = new_tokens[last_tokens.len - i..];
// // std.debug.print("{s}", .{toks});
// if (toks.len == 0) {
// // std.debug.print("WESH\n", .{});
// }
// break;
// }
// }
// last_tokens = buf[0..new_tokens.len];
// @memcpy(last_tokens, new_tokens);
// std.debug.print("{d}us\n", .{start.lap() / std.time.ns_per_us});
// }
// for (tokens) |token| {
// decoder.append(token);
// }
// const decoded = try decoder.decode();
// std.debug.print("Decoded: {s}\n", .{decoded});
// const model = "/Users/steeve/Downloads/poolside.sp.pb";
// c.SentencePieceProcessor_LoadOrDie(sp, c.zig_slice{ .ptr = model.ptr, .len = model.len });
// const piece = c.SentencePieceProcessor_IdToPiece(sp, 10999);
// std.debug.print("{s}\n", .{piece.ptr[0..piece.len]});
}

View File

@ -0,0 +1,92 @@
/* File : example.i */
%module sentencepiece
%include <typemaps.i>
%include <std_vector.i>
%{
#include <stdint.h>
#include <string>
#include <vector>
#include <sentencepiece_processor.h>
#include "ffi/zig_slice.h"
%}
%insert("cheader") %{
#include "ffi/zig_slice.h"
%}
%typemap(in) absl::string_view {
$1 = absl::string_view((char *)$input.ptr, $input.len);
}
%typemap(out, optimal="1") const std::string& %{
$result.ptr = (void *)($1->data());
$result.len = (size_t)($1->length());
%}
%typemap(ctype) absl::string_view, const std::string& "zig_slice"
%rename(std_string) std::string;
namespace std {
class string {
public:
void reserve(size_t n);
void clear();
size_t capacity() const;
};
}
%extend std::string {
const std::string& data() const {
return *$self;
};
}
%extend std::vector {
T* data() {
return $self->data();
};
}
%template(std_vector_int) std::vector<int>;
%typemap(out) sentencepiece::util::Status %{
$result = $1.code();
%}
%typemap(ctype) sentencepiece::util::Status "unsigned int"
%rename(sentencepiece_util_StatusCode, fullname=1) sentencepiece::util::StatusCode;
namespace sentencepiece {
namespace util {
enum class StatusCode : int {
kOk = 0,
kCancelled = 1,
kUnknown = 2,
kInvalidArgument = 3,
kDeadlineExceeded = 4,
kNotFound = 5,
kAlreadyExists = 6,
kPermissionDenied = 7,
kResourceExhausted = 8,
kFailedPrecondition = 9,
kAborted = 10,
kOutOfRange = 11,
kUnimplemented = 12,
kInternal = 13,
kUnavailable = 14,
kDataLoss = 15,
kUnauthenticated = 16,
};
}
class SentencePieceProcessor {
public:
virtual sentencepiece::util::Status Load(absl::string_view filename);
virtual sentencepiece::util::Status Encode(absl::string_view input, std::vector<int> *ids) const;
virtual sentencepiece::util::Status Decode(const std::vector<int> &ids, std::string *detokenized) const;
virtual int PieceToId(absl::string_view piece) const;
virtual int unk_id() const;
virtual int bos_id() const;
virtual int eos_id() const;
virtual int pad_id() const;
};
}

View File

@ -0,0 +1,189 @@
const std = @import("std");
const c = @import("c");
const ffi = @import("ffi");
const StringToTokenRatio = 3;
pub const Error = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
fn assertOk(code: c.sentencepiece_util_StatusCode) Error!void {
return switch (code) {
c.sentencepiece_util_StatusCode_kOk => {},
c.sentencepiece_util_StatusCode_kCancelled => Error.Cancelled,
c.sentencepiece_util_StatusCode_kUnknown => Error.Unknown,
c.sentencepiece_util_StatusCode_kInvalidArgument => Error.InvalidArgument,
c.sentencepiece_util_StatusCode_kDeadlineExceeded => Error.DeadlineExceeded,
c.sentencepiece_util_StatusCode_kNotFound => Error.NotFound,
c.sentencepiece_util_StatusCode_kAlreadyExists => Error.AlreadyExists,
c.sentencepiece_util_StatusCode_kPermissionDenied => Error.PermissionDenied,
c.sentencepiece_util_StatusCode_kResourceExhausted => Error.ResourceExhausted,
c.sentencepiece_util_StatusCode_kFailedPrecondition => Error.FailedPrecondition,
c.sentencepiece_util_StatusCode_kAborted => Error.Aborted,
c.sentencepiece_util_StatusCode_kOutOfRange => Error.OutOfRange,
c.sentencepiece_util_StatusCode_kUnimplemented => Error.Unimplemented,
c.sentencepiece_util_StatusCode_kInternal => Error.Internal,
c.sentencepiece_util_StatusCode_kUnavailable => Error.Unavailable,
c.sentencepiece_util_StatusCode_kDataLoss => Error.DataLoss,
c.sentencepiece_util_StatusCode_kUnauthenticated => Error.Unauthenticated,
else => unreachable,
};
}
pub const Encoder = struct {
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
fn init(inner: *SentencePieceProcessor) Encoder {
return .{
.inner = inner,
.vec = c.std_vector_int_new() orelse unreachable,
};
}
pub fn deinit(self: *Encoder) void {
c.std_vector_int_delete(self.vec);
}
pub fn reset(self: *Encoder) void {
c.std_vector_int_clear(self.vec);
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
c.std_vector_int_reserve(self.vec, input.len / StringToTokenRatio);
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
return self.ids();
}
pub fn ids(self: *const Encoder) []const u32 {
return ffi.ZigSlice.to(u32, .{
.ptr = c.std_vector_int_data(self.vec),
.len = c.std_vector_int_size(self.vec),
});
}
};
pub const Decoder = struct {
const StringBufferSize = 64;
const StringBuffer = std.BoundedArray(u8, StringBufferSize);
const TokenIdsBufferSize = 4;
inner: *SentencePieceProcessor,
vec: *c.std_vector_int,
str: *c.std_string,
last_string: StringBuffer = .{ .len = 0 },
fn init(inner: *SentencePieceProcessor) !Decoder {
const vec = try (c.std_vector_int_new() orelse std.mem.Allocator.Error.OutOfMemory);
c.std_vector_int_reserve(vec, TokenIdsBufferSize);
errdefer c.std_vector_int_delete(vec);
const str = try (c.std_string_new() orelse std.mem.Allocator.Error.OutOfMemory);
c.std_string_reserve(str, StringBufferSize);
errdefer c.std_string_delete(str);
return .{
.inner = inner,
.vec = vec,
.str = str,
};
}
pub fn deinit(self: *Decoder) void {
c.std_vector_int_delete(self.vec);
c.std_string_delete(self.str);
}
pub fn reset(self: *Decoder) void {
c.std_vector_int_clear(self.vec);
c.std_string_clear(self.str);
}
pub fn decode(self: *Decoder, ids_: []const u32) ![]const u8 {
c.std_vector_int_reserve(self.vec, ids_.len);
c.std_string_reserve(self.str, ids_.len * StringToTokenRatio);
for (ids_) |id| {
c.std_vector_int_push_back(self.vec, @intCast(id));
}
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
return self.string();
}
pub fn string(self: *const Decoder) []const u8 {
const res = c.std_string_data(self.str);
return ffi.ZigSlice.to(u8, res);
}
fn ids(self: *const Decoder) []u32 {
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
return ptr[0..c.std_vector_int_size(self.vec)];
}
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
const current_ids = self.ids();
if (current_ids.len >= c.std_vector_int_capacity(self.vec)) {
std.mem.copyForwards(u32, current_ids[0 .. current_ids.len - 1], current_ids[1..]);
current_ids[current_ids.len - 1] = token_id;
} else {
c.std_vector_int_push_back(self.vec, @intCast(token_id));
}
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
const new_string = self.string();
if (self.last_string.len == 0) {
self.last_string = try StringBuffer.fromSlice(new_string);
return new_string;
}
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
var it = view.iterator();
while (it.nextCodepointSlice()) |cp| {
const start = it.i - cp.len;
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
const chunk = new_string[self.last_string.len - start ..];
self.last_string = try StringBuffer.fromSlice(new_string);
return chunk;
}
}
return &.{};
}
};
pub const SentencePieceProcessor = opaque {
pub fn from_file(model: []const u8) !*SentencePieceProcessor {
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
errdefer sp.deinit();
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model)));
return sp;
}
pub fn deinit(self: *SentencePieceProcessor) void {
c.SentencePieceProcessor_delete(@ptrCast(self));
}
pub fn encoder(self: *SentencePieceProcessor) !Encoder {
return Encoder.init(self);
}
pub fn decoder(self: *SentencePieceProcessor) !Decoder {
return try Decoder.init(self);
}
pub fn token_to_id(self: *SentencePieceProcessor, token: []const u8) u32 {
return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token)));
}
};

118
zml/tokenizer/tokenizer.zig Normal file
View File

@ -0,0 +1,118 @@
const std = @import("std");
const hftokenizers = @import("hftokenizers");
const sentencepiece = @import("sentencepiece");
const asynk = @import("async");
const Tokenizers = enum {
hftokenizers,
sentencepiece,
};
pub const Tokenizer = union(Tokenizers) {
pub const Encoder = union(Tokenizers) {
hftokenizers: hftokenizers.Encoder,
sentencepiece: sentencepiece.Encoder,
pub fn deinit(self: *Encoder) void {
switch (self.*) {
inline else => |*v| v.deinit(),
}
}
pub fn reset(self: *Encoder) void {
switch (self.*) {
inline else => |*v| v.reset(),
}
}
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
return switch (self.*) {
inline else => |*v| v.encode(input),
};
}
pub fn ids(self: Encoder) []const u32 {
return switch (self) {
inline else => |v| v.ids(),
};
}
};
pub const Decoder = union(Tokenizers) {
hftokenizers: hftokenizers.Decoder,
sentencepiece: sentencepiece.Decoder,
pub fn deinit(self: *Decoder) void {
switch (self.*) {
inline else => |*v| v.deinit(),
}
}
pub fn reset(self: *Decoder) void {
switch (self.*) {
inline else => |*v| v.reset(),
}
}
pub fn decode(self: *Decoder, ids_: []const u32) ![]const u8 {
return switch (self.*) {
inline else => |*v| v.decode(ids_),
};
}
pub fn string(self: Decoder) []const u8 {
return switch (self.*) {
inline else => |v| v.string(),
};
}
pub fn ids(self: Decoder) []u32 {
return switch (self.*) {
inline else => |v| v.ids(),
};
}
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
return switch (self.*) {
inline else => |*v| v.next(token_id),
};
}
};
hftokenizers: *hftokenizers.HFTokenizer,
sentencepiece: *sentencepiece.SentencePieceProcessor,
pub fn from_file(_: std.mem.Allocator, model: []const u8) !Tokenizer {
if (std.mem.endsWith(u8, model, ".pb")) {
return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.from_file, .{model}) };
}
if (std.mem.endsWith(u8, model, ".json")) {
return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.from_file, .{model}) };
}
return error.InvalidArgument;
}
pub fn deinit(self: *Tokenizer) void {
switch (self.*) {
inline else => |t| t.deinit(),
}
}
pub fn encoder(self: Tokenizer) !Encoder {
return switch (self) {
inline else => |v, tag| @unionInit(Encoder, @tagName(tag), try v.encoder()),
};
}
pub fn decoder(self: Tokenizer) !Decoder {
return switch (self) {
inline else => |v, tag| @unionInit(Decoder, @tagName(tag), try v.decoder()),
};
}
pub fn token_to_id(self: Tokenizer, token: []const u8) ?u32 {
return switch (self) {
inline else => |v| v.token_to_id(token),
};
}
};

View File

@ -29,7 +29,8 @@ pub const pjrt = @import("pjrtx.zig");
pub const testing = @import("testing.zig");
pub const torch = @import("torch.zig");
pub const tokenizer = @import("tokenizer.zig");
// pub const tokenizer = @import("tokenizer.zig");
pub const tokenizer = @import("zml/tokenizer");
pub const call = ops.call;
pub const compile = exe.compile;