Add HuggingFace tokenizer bindings and SentencePiece integration; update BUILD files, async utilities, and FFI modules to support the new tokenizers.
This commit is contained in:
parent
5048e7dc89
commit
959bc48c42
33
MODULE.bazel
33
MODULE.bazel
@ -7,7 +7,10 @@ new_git_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:git.bzl"
|
|||||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.1")
|
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.1")
|
||||||
bazel_dep(name = "patchelf", version = "0.18.0")
|
bazel_dep(name = "patchelf", version = "0.18.0")
|
||||||
|
bazel_dep(name = "pcre2", version = "10.43")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240722.0.bcr.2")
|
||||||
bazel_dep(name = "platforms", version = "0.0.10")
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "protobuf", version = "29.2")
|
||||||
bazel_dep(name = "rules_cc", version = "0.0.17")
|
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||||
bazel_dep(name = "rules_pkg", version = "1.0.1")
|
bazel_dep(name = "rules_pkg", version = "1.0.1")
|
||||||
bazel_dep(name = "rules_proto", version = "7.1.0")
|
bazel_dep(name = "rules_proto", version = "7.1.0")
|
||||||
@ -114,3 +117,33 @@ apt.install(
|
|||||||
manifest = "//runtimes/neuron:packages.yaml",
|
manifest = "//runtimes/neuron:packages.yaml",
|
||||||
)
|
)
|
||||||
use_repo(apt, "neuron_bookworm")
|
use_repo(apt, "neuron_bookworm")
|
||||||
|
|
||||||
|
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
||||||
|
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_rust", version = "0.57.0")
|
||||||
|
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
||||||
|
rust.toolchain(
|
||||||
|
edition = "2021",
|
||||||
|
versions = ["1.84.0"],
|
||||||
|
extra_target_triples = [
|
||||||
|
"aarch64-apple-darwin",
|
||||||
|
"aarch64-unknown-linux-gnu",
|
||||||
|
"x86_64-unknown-linux-gnu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
use_repo(rust, "rust_toolchains")
|
||||||
|
register_toolchains("@rust_toolchains//:all")
|
||||||
|
|
||||||
|
crate = use_extension("@rules_rust//crate_universe:extensions.bzl", "crate")
|
||||||
|
crate.from_cargo(
|
||||||
|
name = "crates",
|
||||||
|
cargo_lockfile = "//zml/tokenizer/hftokenizers:Cargo.lock",
|
||||||
|
manifests = ["//zml/tokenizer/hftokenizers:Cargo.toml"],
|
||||||
|
supported_platform_triples = [
|
||||||
|
"aarch64-apple-darwin",
|
||||||
|
"aarch64-unknown-linux-gnu",
|
||||||
|
"x86_64-unknown-linux-gnu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
use_repo(crate, "crates")
|
||||||
|
|||||||
1318
MODULE.bazel.lock
1318
MODULE.bazel.lock
File diff suppressed because one or more lines are too long
@ -485,7 +485,7 @@ pub fn Channel(comptime T: type, capacity: usize) type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn send(self: *Self, val: T) void {
|
pub fn send(self: *Self, val: T) void {
|
||||||
self.inner.send(val) catch unreachable;
|
self.inner.send(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn recv(self: *Self) ?T {
|
pub fn recv(self: *Self) ?T {
|
||||||
|
|||||||
@ -134,8 +134,8 @@ const Coro = struct {
|
|||||||
return initFromStack(func, stack_, storage);
|
return initFromStack(func, stack_, storage);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: Coro) void {
|
pub fn deinit(_: Coro) void {
|
||||||
_ = self; // autofix
|
// empty
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initFromStack(func: *const fn () void, stack_: stack.Stack, storage: ?*anyopaque) !Frame {
|
fn initFromStack(func: *const fn () void, stack_: stack.Stack, storage: ?*anyopaque) !Frame {
|
||||||
@ -423,8 +423,7 @@ const CoroId = struct {
|
|||||||
const StackOverflow = struct {
|
const StackOverflow = struct {
|
||||||
const magic_number: usize = 0x5E574D6D;
|
const magic_number: usize = 0x5E574D6D;
|
||||||
|
|
||||||
fn check(coro: Frame) !void {
|
fn check(_: Frame) !void {
|
||||||
_ = coro; // autofix
|
|
||||||
// const stack = coro.stack.ptr;
|
// const stack = coro.stack.ptr;
|
||||||
// const sp = coro.impl.stack_pointer;
|
// const sp = coro.impl.stack_pointer;
|
||||||
// const magic_number_ptr: *usize = @ptrCast(stack);
|
// const magic_number_ptr: *usize = @ptrCast(stack);
|
||||||
@ -435,8 +434,7 @@ const StackOverflow = struct {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setMagicNumber(stack_: stack.Stack) !void {
|
fn setMagicNumber(_: stack.Stack) !void {
|
||||||
_ = stack_; // autofix
|
|
||||||
// if (stack.len <= @sizeOf(usize)) {
|
// if (stack.len <= @sizeOf(usize)) {
|
||||||
// return Error.StackTooSmall;
|
// return Error.StackTooSmall;
|
||||||
// }
|
// }
|
||||||
|
|||||||
@ -69,8 +69,8 @@ pub const StackAllocator = struct {
|
|||||||
return .{ .allocator = allocator };
|
return .{ .allocator = allocator };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: *StackAllocator) void {
|
pub fn deinit(_: *StackAllocator) void {
|
||||||
_ = self; // autofix
|
// empty
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create(self: *StackAllocator) !Stack {
|
pub fn create(self: *StackAllocator) !Stack {
|
||||||
|
|||||||
149
bazel/swig.bzl
Normal file
149
bazel/swig.bzl
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
load("@rules_cc//cc:action_names.bzl", "C_COMPILE_ACTION_NAME")
|
||||||
|
load("@rules_cc//cc:find_cc_toolchain.bzl", "find_cc_toolchain", "use_cc_toolchain")
|
||||||
|
|
||||||
|
def _swig_cc_library_impl(ctx):
|
||||||
|
args = ctx.actions.args()
|
||||||
|
|
||||||
|
if ctx.attr.cpp:
|
||||||
|
args.add("-c++")
|
||||||
|
|
||||||
|
args.add("-std=c++17")
|
||||||
|
args.add("-c")
|
||||||
|
args.add("-O")
|
||||||
|
args.add("-module", ctx.attr.module)
|
||||||
|
args.add_joined("-features", ctx.attr.enabled_features, join_with = ",")
|
||||||
|
|
||||||
|
if ctx.attr.defines:
|
||||||
|
args.add_all(ctx.attr.defines, format_each = "-D%s")
|
||||||
|
|
||||||
|
cc_toolchain = find_cc_toolchain(ctx)
|
||||||
|
if (cc_toolchain):
|
||||||
|
args.add_all(cc_toolchain.built_in_include_directories, format_each = "-I%s")
|
||||||
|
|
||||||
|
feature_configuration = cc_common.configure_features(
|
||||||
|
ctx = ctx,
|
||||||
|
cc_toolchain = cc_toolchain,
|
||||||
|
requested_features = ctx.features,
|
||||||
|
unsupported_features = ctx.disabled_features,
|
||||||
|
)
|
||||||
|
c_compile_variables = cc_common.create_compile_variables(
|
||||||
|
feature_configuration = feature_configuration,
|
||||||
|
cc_toolchain = cc_toolchain,
|
||||||
|
user_compile_flags = ctx.fragments.cpp.copts + ctx.fragments.cpp.conlyopts,
|
||||||
|
)
|
||||||
|
cc_compile_command_line = cc_common.get_memory_inefficient_command_line(
|
||||||
|
feature_configuration = feature_configuration,
|
||||||
|
action_name = C_COMPILE_ACTION_NAME,
|
||||||
|
variables = c_compile_variables,
|
||||||
|
)
|
||||||
|
for arg in cc_compile_command_line:
|
||||||
|
if (arg.startswith("-I") or arg.startswith("-D")):
|
||||||
|
args.add(arg)
|
||||||
|
|
||||||
|
cc_info = cc_common.merge_cc_infos(direct_cc_infos = [dep[CcInfo] for dep in ctx.attr.deps])
|
||||||
|
args.add_all(cc_info.compilation_context.defines, format_each = "-D%s")
|
||||||
|
args.add_all(cc_info.compilation_context.local_defines, format_each = "-D%s")
|
||||||
|
args.add_all(cc_info.compilation_context.framework_includes, format_each = "-I%s")
|
||||||
|
args.add_all(cc_info.compilation_context.includes, format_each = "-I%s")
|
||||||
|
args.add_all(cc_info.compilation_context.quote_includes, format_each = "-I%s")
|
||||||
|
args.add_all(cc_info.compilation_context.system_includes, format_each = "-I%s")
|
||||||
|
|
||||||
|
output_cpp = ctx.actions.declare_file("%s.cpp" % ctx.attr.module)
|
||||||
|
output_h = ctx.actions.declare_file("%s.h" % ctx.attr.module)
|
||||||
|
args.add("-outdir", output_h.dirname)
|
||||||
|
|
||||||
|
outputs = [
|
||||||
|
output_cpp,
|
||||||
|
output_h,
|
||||||
|
]
|
||||||
|
args.add("-o", output_cpp)
|
||||||
|
args.add("-w-305")
|
||||||
|
args.add(ctx.file.interface)
|
||||||
|
|
||||||
|
inputs = depset(ctx.attr.srcs, transitive = [
|
||||||
|
ctx.attr.interface.files,
|
||||||
|
cc_info.compilation_context.headers,
|
||||||
|
ctx.attr._swig_lib.files,
|
||||||
|
])
|
||||||
|
|
||||||
|
ctx.actions.run(
|
||||||
|
inputs = inputs,
|
||||||
|
outputs = outputs,
|
||||||
|
executable = ctx.executable._swig,
|
||||||
|
arguments = [args],
|
||||||
|
env = {
|
||||||
|
"SWIG_LIB": ctx.files._swig_lib[0].dirname,
|
||||||
|
},
|
||||||
|
mnemonic = "SwigC",
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
DefaultInfo(
|
||||||
|
files = depset(outputs),
|
||||||
|
),
|
||||||
|
OutputGroupInfo(
|
||||||
|
hdrs = depset([output_h]),
|
||||||
|
srcs = depset([output_cpp]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
_swig_cc_library = rule(
|
||||||
|
_swig_cc_library_impl,
|
||||||
|
attrs = {
|
||||||
|
"interface": attr.label(
|
||||||
|
mandatory = True,
|
||||||
|
allow_single_file = True,
|
||||||
|
),
|
||||||
|
"srcs": attr.label_list(
|
||||||
|
allow_files = True,
|
||||||
|
),
|
||||||
|
"deps": attr.label_list(
|
||||||
|
providers = [CcInfo],
|
||||||
|
),
|
||||||
|
"defines": attr.string_list(),
|
||||||
|
"enabled_features": attr.string_list(),
|
||||||
|
"module": attr.string(
|
||||||
|
mandatory = True,
|
||||||
|
),
|
||||||
|
"cpp": attr.bool(
|
||||||
|
default = True,
|
||||||
|
),
|
||||||
|
"intgosize": attr.int(
|
||||||
|
default = 64,
|
||||||
|
),
|
||||||
|
"_swig": attr.label(
|
||||||
|
default = "@org_swig_swig//:swig",
|
||||||
|
cfg = "exec",
|
||||||
|
executable = True,
|
||||||
|
),
|
||||||
|
"_swig_lib": attr.label(
|
||||||
|
default = "@org_swig_swig//:lib",
|
||||||
|
allow_files = True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
toolchains = use_cc_toolchain(),
|
||||||
|
fragments = ["cpp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def swig_cc_library(name, deps = [], **kwargs):
|
||||||
|
_swig_cc_library(
|
||||||
|
name = "{}.swig".format(name),
|
||||||
|
deps = deps,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
native.filegroup(
|
||||||
|
name = "{}.hdrs".format(name),
|
||||||
|
srcs = [":{}.swig".format(name)],
|
||||||
|
output_group = "hdrs",
|
||||||
|
)
|
||||||
|
native.filegroup(
|
||||||
|
name = "{}.srcs".format(name),
|
||||||
|
srcs = [":{}.swig".format(name)],
|
||||||
|
output_group = "srcs",
|
||||||
|
)
|
||||||
|
native.cc_library(
|
||||||
|
name = name,
|
||||||
|
hdrs = [":{}.hdrs".format(name)],
|
||||||
|
srcs = [":{}.srcs".format(name)],
|
||||||
|
deps = deps,
|
||||||
|
)
|
||||||
22
ffi/BUILD.bazel
Normal file
22
ffi/BUILD.bazel
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cc",
|
||||||
|
hdrs = [
|
||||||
|
"zig_allocator.h",
|
||||||
|
"zig_slice.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "zig",
|
||||||
|
srcs = [
|
||||||
|
"zig_allocator.zig",
|
||||||
|
"zig_slice.zig",
|
||||||
|
],
|
||||||
|
import_name = "ffi",
|
||||||
|
main = "ffi.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":cc"],
|
||||||
|
)
|
||||||
12
ffi/ffi.zig
Normal file
12
ffi/ffi.zig
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub const ZigAllocator = @import("zig_allocator.zig").ZigAllocator;
|
||||||
|
pub const ZigSlice = @import("zig_slice.zig").ZigSlice;
|
||||||
|
|
||||||
|
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
|
||||||
|
var result: [std.fs.max_path_bytes:0]u8 = undefined;
|
||||||
|
@memcpy(result[0..path.len], path);
|
||||||
|
result[path.len] = 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
22
ffi/zig_allocator.h
Normal file
22
ffi/zig_allocator.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
typedef struct
|
||||||
|
{
|
||||||
|
const void *ctx;
|
||||||
|
void *(*alloc)(const void *ctx, size_t elem, size_t nelems, size_t alignment);
|
||||||
|
void (*free)(const void *ctx, void *ptr, size_t elem, size_t nelems, size_t alignment);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
template <typename T> [[nodiscard]] T *allocate(size_t n)
|
||||||
|
{
|
||||||
|
return static_cast<T *>(this->alloc(this->ctx, sizeof(T), n, _Alignof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> [[nodiscard]] void deallocate(T *p, size_t n)
|
||||||
|
{
|
||||||
|
this->free(this->ctx, static_cast<void *>(p), sizeof(T), n, _Alignof(T));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} zig_allocator;
|
||||||
25
ffi/zig_allocator.zig
Normal file
25
ffi/zig_allocator.zig
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub const ZigAllocator = struct {
|
||||||
|
pub inline fn from(allocator: std.mem.Allocator) c.zig_allocator {
|
||||||
|
return .{
|
||||||
|
.ctx = @ptrCast(@alignCast(&allocator)),
|
||||||
|
.alloc = &alloc,
|
||||||
|
.free = &free,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) ?*anyopaque {
|
||||||
|
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
||||||
|
const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null;
|
||||||
|
return @ptrCast(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) void {
|
||||||
|
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
||||||
|
const memory: [*c]u8 = @ptrCast(ptr);
|
||||||
|
const size = elem * nelems;
|
||||||
|
self.rawFree(memory[0..size], std.math.log2_int(usize, alignment), @returnAddress());
|
||||||
|
}
|
||||||
|
};
|
||||||
9
ffi/zig_slice.h
Normal file
9
ffi/zig_slice.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
typedef struct
|
||||||
|
{
|
||||||
|
void *ptr;
|
||||||
|
size_t len;
|
||||||
|
} zig_slice;
|
||||||
15
ffi/zig_slice.zig
Normal file
15
ffi/zig_slice.zig
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
|
||||||
|
pub const ZigSlice = struct {
|
||||||
|
pub fn from(slice: anytype) c.zig_slice {
|
||||||
|
return .{
|
||||||
|
.ptr = @ptrCast(@constCast(slice.ptr)),
|
||||||
|
.len = slice.len,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to(comptime T: type, slice: c.zig_slice) []T {
|
||||||
|
return @as([*c]T, @ptrCast(@alignCast(slice.ptr)))[0..slice.len];
|
||||||
|
}
|
||||||
|
};
|
||||||
@ -5,10 +5,12 @@ zig_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"debug.zig",
|
"debug.zig",
|
||||||
"io.zig",
|
"io.zig",
|
||||||
|
"json.zig",
|
||||||
"math.zig",
|
"math.zig",
|
||||||
"meta.zig",
|
"meta.zig",
|
||||||
"queue.zig",
|
"queue.zig",
|
||||||
"signature.zig",
|
"signature.zig",
|
||||||
|
"time.zig",
|
||||||
],
|
],
|
||||||
main = "stdx.zig",
|
main = "stdx.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
|||||||
72
stdx/json.zig
Normal file
72
stdx/json.zig
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
pub const std = @import("std");
|
||||||
|
|
||||||
|
pub fn Union(comptime T: type) type {
|
||||||
|
return struct {
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
value: T,
|
||||||
|
|
||||||
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
|
||||||
|
return jsonParseFromValue(
|
||||||
|
allocator,
|
||||||
|
try std.json.innerParse(
|
||||||
|
std.json.Value,
|
||||||
|
allocator,
|
||||||
|
source,
|
||||||
|
options,
|
||||||
|
),
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
|
||||||
|
inline for (std.meta.fields(T)) |field| {
|
||||||
|
switch (field.type) {
|
||||||
|
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
|
||||||
|
[]const u8 => switch (source) {
|
||||||
|
.string => |v| return .{ .value = @unionInit(T, field.name, v) },
|
||||||
|
.number_string => |v| return .{ .value = @unionInit(T, field.name, v) },
|
||||||
|
else => {},
|
||||||
|
},
|
||||||
|
else => switch (@typeInfo(field.type)) {
|
||||||
|
.Int => if (source == .integer) return .{ .value = @unionInit(T, field.name, @intCast(source.integer)) },
|
||||||
|
.Float => if (source == .float) return .{ .value = @unionInit(T, field.name, @floatCast(source.float)) },
|
||||||
|
.Struct => if (source == .object) return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source.object, options)) },
|
||||||
|
inline else => switch (source) {
|
||||||
|
.number_string, .array => return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source, options)) },
|
||||||
|
else => {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return error.UnexpectedToken;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn NeverNull(comptime T: type, comptime default_value: T) type {
|
||||||
|
return struct {
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
value: T = default_value,
|
||||||
|
|
||||||
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
|
||||||
|
return .{ .value = (try std.json.innerParse(?T, allocator, source, options)) orelse default_value };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
|
||||||
|
return .{ .value = (try std.json.innerParseFromValue(?T, allocator, source, options)) orelse default_value };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fillDefaultStructValues(comptime T: type, r: *T) !void {
|
||||||
|
inline for (@typeInfo(T).Struct.fields) |field| {
|
||||||
|
if (field.default_value) |default_ptr| {
|
||||||
|
if (@field(r, field.name) == null) {
|
||||||
|
const default = @as(*align(1) const field.type, @ptrCast(default_ptr)).*;
|
||||||
|
@field(r, field.name) = default;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,7 @@
|
|||||||
pub const debug = @import("debug.zig");
|
pub const debug = @import("debug.zig");
|
||||||
pub const io = @import("io.zig");
|
pub const io = @import("io.zig");
|
||||||
|
pub const json = @import("json.zig");
|
||||||
pub const math = @import("math.zig");
|
pub const math = @import("math.zig");
|
||||||
pub const meta = @import("meta.zig");
|
pub const meta = @import("meta.zig");
|
||||||
pub const queue = @import("queue.zig");
|
pub const queue = @import("queue.zig");
|
||||||
|
pub const time = @import("time.zig");
|
||||||
|
|||||||
34
stdx/time.zig
Normal file
34
stdx/time.zig
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub const Duration = struct {
|
||||||
|
ns: u64,
|
||||||
|
|
||||||
|
pub fn format(
|
||||||
|
self: Duration,
|
||||||
|
comptime fmt: []const u8,
|
||||||
|
options: std.fmt.FormatOptions,
|
||||||
|
writer: anytype,
|
||||||
|
) @TypeOf(writer).Error!void {
|
||||||
|
return try std.fmt.fmtDuration(self.ns).format(fmt, options, writer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Timer = struct {
|
||||||
|
inner: std.time.Timer,
|
||||||
|
|
||||||
|
pub fn start() !Timer {
|
||||||
|
return .{ .inner = try std.time.Timer.start() };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lap(self: *Timer) Duration {
|
||||||
|
return .{ .ns = try self.inner.lap() };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read(self: *Timer) Duration {
|
||||||
|
return .{ .ns = self.inner.read() };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Timer) void {
|
||||||
|
self.inner.reset();
|
||||||
|
}
|
||||||
|
};
|
||||||
38
third_party/com_google_sentencepiece/BUILD.bazel
vendored
Normal file
38
third_party/com_google_sentencepiece/BUILD.bazel
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
load(":fwd.bzl", "include_fwd")
|
||||||
|
|
||||||
|
include_fwd(
|
||||||
|
name = "absl_hdrs",
|
||||||
|
includes = [
|
||||||
|
"absl/container/flat_hash_map.h",
|
||||||
|
"absl/container/flat_hash_set.h",
|
||||||
|
"absl/flags/flag.h",
|
||||||
|
"absl/flags/parse.h",
|
||||||
|
"absl/flags/usage.h",
|
||||||
|
"absl/strings/match.h",
|
||||||
|
"absl/strings/numbers.h",
|
||||||
|
"absl/strings/str_cat.h",
|
||||||
|
"absl/strings/str_format.h",
|
||||||
|
"absl/strings/str_join.h",
|
||||||
|
"absl/strings/str_replace.h",
|
||||||
|
"absl/strings/str_split.h",
|
||||||
|
"absl/strings/string_view.h",
|
||||||
|
"absl/strings/strip.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "absl",
|
||||||
|
hdrs = [":absl_hdrs"],
|
||||||
|
include_prefix = "third_party",
|
||||||
|
includes = ["."],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"@abseil-cpp//absl/container:flat_hash_map",
|
||||||
|
"@abseil-cpp//absl/container:flat_hash_set",
|
||||||
|
"@abseil-cpp//absl/flags:flag",
|
||||||
|
"@abseil-cpp//absl/flags:parse",
|
||||||
|
"@abseil-cpp//absl/flags:usage",
|
||||||
|
"@abseil-cpp//absl/strings",
|
||||||
|
"@abseil-cpp//absl/strings:string_view",
|
||||||
|
],
|
||||||
|
)
|
||||||
14
third_party/com_google_sentencepiece/fwd.bzl
vendored
Normal file
14
third_party/com_google_sentencepiece/fwd.bzl
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
def _include_fwd_impl(ctx):
|
||||||
|
files = []
|
||||||
|
for include in ctx.attr.includes:
|
||||||
|
f = ctx.actions.declare_file(include)
|
||||||
|
ctx.actions.write(f, '#include "{}"'.format(include))
|
||||||
|
files.append(f)
|
||||||
|
return [DefaultInfo(files = depset(files))]
|
||||||
|
|
||||||
|
include_fwd = rule(
|
||||||
|
implementation = _include_fwd_impl,
|
||||||
|
attrs = {
|
||||||
|
"includes": attr.string_list(),
|
||||||
|
},
|
||||||
|
)
|
||||||
9
third_party/com_google_sentencepiece/repo.bzl
vendored
Normal file
9
third_party/com_google_sentencepiece/repo.bzl
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
||||||
|
|
||||||
|
def repo():
|
||||||
|
new_git_repository(
|
||||||
|
name = "com_google_sentencepiece",
|
||||||
|
remote = "https://github.com/google/sentencepiece.git",
|
||||||
|
commit = "d8f741853847553169444afc12c00f4bbff3e9ce",
|
||||||
|
build_file = "//third_party/com_google_sentencepiece:sentencepiece.bazel",
|
||||||
|
)
|
||||||
95
third_party/com_google_sentencepiece/sentencepiece.bazel
vendored
Normal file
95
third_party/com_google_sentencepiece/sentencepiece.bazel
vendored
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
||||||
|
load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
|
||||||
|
load("@protobuf//bazel:proto_library.bzl", "proto_library")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
features = [
|
||||||
|
"layering_check",
|
||||||
|
"parse_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2, BSD, MIT
|
||||||
|
|
||||||
|
proto_library(
|
||||||
|
name = "sentencepiece_proto",
|
||||||
|
srcs = ["src/sentencepiece.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_proto_library(
|
||||||
|
name = "sentencepiece_cc_proto",
|
||||||
|
deps = [":sentencepiece_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
proto_library(
|
||||||
|
name = "sentencepiece_model_proto",
|
||||||
|
srcs = ["src/sentencepiece_model.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_proto_library(
|
||||||
|
name = "sentencepiece_model_cc_proto",
|
||||||
|
deps = [":sentencepiece_model_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
copy_file(
|
||||||
|
name = "config_h",
|
||||||
|
src = "config.h.in",
|
||||||
|
out = "config.h",
|
||||||
|
allow_symlink = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "darts_clone",
|
||||||
|
hdrs = glob([
|
||||||
|
"third_party/darts_clone/*.h",
|
||||||
|
]),
|
||||||
|
includes = ["."],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sentencepiece_processor",
|
||||||
|
srcs = [
|
||||||
|
"src/bpe_model.cc",
|
||||||
|
"src/char_model.cc",
|
||||||
|
"src/error.cc",
|
||||||
|
"src/filesystem.cc",
|
||||||
|
"src/model_factory.cc",
|
||||||
|
"src/model_interface.cc",
|
||||||
|
"src/normalizer.cc",
|
||||||
|
"src/sentencepiece_processor.cc",
|
||||||
|
"src/unigram_model.cc",
|
||||||
|
"src/util.cc",
|
||||||
|
"src/word_model.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
":config_h",
|
||||||
|
"src/common.h",
|
||||||
|
"src/bpe_model.h",
|
||||||
|
"src/char_model.h",
|
||||||
|
"src/filesystem.h",
|
||||||
|
"src/freelist.h",
|
||||||
|
"src/init.h",
|
||||||
|
"src/model_factory.h",
|
||||||
|
"src/model_interface.h",
|
||||||
|
"src/normalizer.h",
|
||||||
|
"src/sentencepiece_processor.h",
|
||||||
|
"src/trainer_interface.h",
|
||||||
|
"src/unigram_model.h",
|
||||||
|
"src/util.h",
|
||||||
|
"src/word_model.h",
|
||||||
|
],
|
||||||
|
defines = [
|
||||||
|
"_USE_EXTERNAL_PROTOBUF",
|
||||||
|
"_USE_EXTERNAL_ABSL",
|
||||||
|
],
|
||||||
|
includes = [
|
||||||
|
"src",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":darts_clone",
|
||||||
|
":sentencepiece_cc_proto",
|
||||||
|
":sentencepiece_model_cc_proto",
|
||||||
|
"@zml//third_party/com_google_sentencepiece:absl",
|
||||||
|
],
|
||||||
|
)
|
||||||
16
third_party/non_module_deps.bzl
vendored
Normal file
16
third_party/non_module_deps.bzl
vendored
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
load("//:third_party/org_swig_swig/repo.bzl", org_swig_swig = "repo")
|
||||||
|
load("//third_party/com_google_sentencepiece:repo.bzl", com_google_sentencepiece = "repo")
|
||||||
|
|
||||||
|
def _non_module_deps_impl(mctx):
|
||||||
|
com_google_sentencepiece()
|
||||||
|
org_swig_swig()
|
||||||
|
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
non_module_deps = module_extension(
|
||||||
|
implementation = _non_module_deps_impl,
|
||||||
|
)
|
||||||
10
third_party/org_swig_swig/repo.bzl
vendored
Normal file
10
third_party/org_swig_swig/repo.bzl
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
|
||||||
|
def repo():
|
||||||
|
http_archive(
|
||||||
|
name = "org_swig_swig",
|
||||||
|
url = "http://prdownloads.sourceforge.net/swig/swig-4.3.0.tar.gz",
|
||||||
|
sha256 = "f7203ef796f61af986c70c05816236cbd0d31b7aa9631e5ab53020ab7804aa9e",
|
||||||
|
strip_prefix = "swig-4.3.0",
|
||||||
|
build_file = "//:third_party/org_swig_swig/swig.bazel",
|
||||||
|
)
|
||||||
109
third_party/org_swig_swig/swig.bazel
vendored
Normal file
109
third_party/org_swig_swig/swig.bazel
vendored
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
licenses(["restricted"]) # GPLv3
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "lib",
|
||||||
|
srcs = glob([
|
||||||
|
"Lib/*.*",
|
||||||
|
"Lib/c/*.*",
|
||||||
|
"Lib/std/*.*",
|
||||||
|
"Lib/typemaps/*.*",
|
||||||
|
]),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "swig",
|
||||||
|
srcs = glob([
|
||||||
|
"Source/CParse/*.h",
|
||||||
|
"Source/CParse/*.c",
|
||||||
|
"Source/DOH/*.h",
|
||||||
|
"Source/DOH/*.c",
|
||||||
|
"Source/Include/*.h",
|
||||||
|
"Source/Preprocessor/*.h",
|
||||||
|
"Source/Preprocessor/*.c",
|
||||||
|
"Source/Swig/*.h",
|
||||||
|
"Source/Swig/*.c",
|
||||||
|
]) + [
|
||||||
|
"Source/Include/swigconfig.h",
|
||||||
|
"Source/Modules/allocate.cxx",
|
||||||
|
"Source/Modules/c.cxx",
|
||||||
|
"Source/Modules/contract.cxx",
|
||||||
|
"Source/Modules/directors.cxx",
|
||||||
|
"Source/Modules/emit.cxx",
|
||||||
|
"Source/Modules/interface.cxx",
|
||||||
|
"Source/Modules/lang.cxx",
|
||||||
|
"Source/Modules/main.cxx",
|
||||||
|
"Source/Modules/nested.cxx",
|
||||||
|
"Source/Modules/overload.cxx",
|
||||||
|
"Source/Modules/swigmain-lite.cxx",
|
||||||
|
"Source/Modules/swigmod.h",
|
||||||
|
"Source/Modules/typepass.cxx",
|
||||||
|
"Source/Modules/utils.cxx",
|
||||||
|
"Source/Modules/xml.cxx",
|
||||||
|
],
|
||||||
|
includes = [
|
||||||
|
"Source/CParse",
|
||||||
|
"Source/DOH",
|
||||||
|
"Source/Include",
|
||||||
|
"Source/Modules",
|
||||||
|
"Source/Preprocessor",
|
||||||
|
"Source/Swig",
|
||||||
|
],
|
||||||
|
data = [":lib"],
|
||||||
|
output_licenses = ["unencumbered"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = ["@pcre2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "swigconfig",
|
||||||
|
outs = ["Source/Include/swigconfig.h"],
|
||||||
|
cmd = """\
|
||||||
|
cat <<EOF >$@
|
||||||
|
#define HAVE_BOOL
|
||||||
|
#define HAVE_PCRE
|
||||||
|
#define HAVE_POPEN
|
||||||
|
#define PACKAGE_BUGREPORT \"http://www.swig.org\"
|
||||||
|
#define PACKAGE_VERSION \"4.3.0\"
|
||||||
|
#define STDC_HEADERS
|
||||||
|
#define SWIG_CXX \"bazel4lyfe\"
|
||||||
|
#define SWIG_LIB \"external/org_swig_swig/Lib\"
|
||||||
|
#define SWIG_LIB_WIN_UNIX \"\"
|
||||||
|
#define SWIG_PLATFORM \"bazel4lyfe\"
|
||||||
|
EOF
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "get_rid_of_stuff_we_dont_need_yet",
|
||||||
|
srcs = ["Source/Modules/swigmain.cxx"],
|
||||||
|
outs = ["Source/Modules/swigmain-lite.cxx"],
|
||||||
|
cmd = """\
|
||||||
|
sed -e '/swig_allegrocl/d' \
|
||||||
|
-e '/swig_chicken/d' \
|
||||||
|
-e '/swig_clisp/d' \
|
||||||
|
-e '/swig_csharp/d' \
|
||||||
|
-e '/swig_d/d' \
|
||||||
|
-e '/swig_guile/d' \
|
||||||
|
-e '/swig_go/d' \
|
||||||
|
-e '/swig_java/d' \
|
||||||
|
-e '/swig_lua/d' \
|
||||||
|
-e '/swig_modula3/d' \
|
||||||
|
-e '/swig_mzscheme/d' \
|
||||||
|
-e '/swig_ocaml/d' \
|
||||||
|
-e '/swig_octave/d' \
|
||||||
|
-e '/swig_perl/d' \
|
||||||
|
-e '/swig_php/d' \
|
||||||
|
-e '/swig_pike/d' \
|
||||||
|
-e '/swig_python/d' \
|
||||||
|
-e '/swig_r/d' \
|
||||||
|
-e '/swig_ruby/d' \
|
||||||
|
-e '/swig_scilab/d' \
|
||||||
|
-e '/swig_sexp/d' \
|
||||||
|
-e '/swig_tcl/d' \
|
||||||
|
-e '/swig_uffi/d' \
|
||||||
|
$< >$@
|
||||||
|
""",
|
||||||
|
)
|
||||||
@ -32,6 +32,7 @@ zig_library(
|
|||||||
"//pjrt",
|
"//pjrt",
|
||||||
"//runtimes",
|
"//runtimes",
|
||||||
"//stdx",
|
"//stdx",
|
||||||
|
"//zml/tokenizer",
|
||||||
"//zml/tools",
|
"//zml/tools",
|
||||||
"@rules_zig//zig/lib:libc",
|
"@rules_zig//zig/lib:libc",
|
||||||
"@rules_zig//zig/runfiles",
|
"@rules_zig//zig/runfiles",
|
||||||
|
|||||||
23
zml/aio.zig
23
zml/aio.zig
@ -10,8 +10,6 @@ const posix = @import("posix.zig");
|
|||||||
pub const gguf = @import("aio/gguf.zig");
|
pub const gguf = @import("aio/gguf.zig");
|
||||||
pub const nemo = @import("aio/nemo.zig");
|
pub const nemo = @import("aio/nemo.zig");
|
||||||
pub const safetensors = @import("aio/safetensors.zig");
|
pub const safetensors = @import("aio/safetensors.zig");
|
||||||
pub const sentencepiece = @import("aio/sentencepiece.zig");
|
|
||||||
pub const tinyllama = @import("aio/tinyllama.zig");
|
|
||||||
pub const torch = @import("aio/torch.zig");
|
pub const torch = @import("aio/torch.zig");
|
||||||
pub const yaml = @import("aio/yaml.zig");
|
pub const yaml = @import("aio/yaml.zig");
|
||||||
|
|
||||||
@ -23,8 +21,6 @@ test {
|
|||||||
std.testing.refAllDecls(gguf);
|
std.testing.refAllDecls(gguf);
|
||||||
std.testing.refAllDecls(nemo);
|
std.testing.refAllDecls(nemo);
|
||||||
std.testing.refAllDecls(safetensors);
|
std.testing.refAllDecls(safetensors);
|
||||||
std.testing.refAllDecls(sentencepiece);
|
|
||||||
std.testing.refAllDecls(tinyllama);
|
|
||||||
std.testing.refAllDecls(torch);
|
std.testing.refAllDecls(torch);
|
||||||
std.testing.refAllDecls(yaml);
|
std.testing.refAllDecls(yaml);
|
||||||
}
|
}
|
||||||
@ -39,29 +35,11 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
|
|||||||
try gguf.open(allocator, model_path)
|
try gguf.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
else if (std.mem.endsWith(u8, model_path, ".pt"))
|
||||||
try torch.open(allocator, model_path)
|
try torch.open(allocator, model_path)
|
||||||
else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
|
|
||||||
try tinyllama.open(allocator, model_path)
|
|
||||||
else {
|
else {
|
||||||
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
std.debug.panic("File extension not recognized: {s}", .{model_path});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8) !zml.tokenizer.Tokenizer {
|
|
||||||
return if (std.mem.endsWith(u8, tokenizer_path, ".json"))
|
|
||||||
try zml.tokenizer.fromHfJson(allocator, tokenizer_path)
|
|
||||||
else if (std.mem.endsWith(u8, tokenizer_path, ".gguf")) {
|
|
||||||
const store = try gguf.open(allocator, tokenizer_path);
|
|
||||||
return gguf.getGgufTokenizer(store, allocator);
|
|
||||||
} else if (std.mem.endsWith(u8, tokenizer_path, ".pb") or std.mem.endsWith(u8, tokenizer_path, ".model"))
|
|
||||||
try sentencepiece.loadTokenizerFromPath(allocator, tokenizer_path)
|
|
||||||
else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama"))
|
|
||||||
try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000)
|
|
||||||
else {
|
|
||||||
log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
|
|
||||||
return error.FormatNotRecognized;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a Model struct with tensor shapes read from the given BufferStore.
|
/// Creates a Model struct with tensor shapes read from the given BufferStore.
|
||||||
/// The result can be used to pass to `compileModel`.
|
/// The result can be used to pass to `compileModel`.
|
||||||
///
|
///
|
||||||
@ -445,6 +423,7 @@ fn _populateStruct(
|
|||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Void => true,
|
.Void => true,
|
||||||
|
.Union => true,
|
||||||
else => if (required) {
|
else => if (required) {
|
||||||
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
||||||
return error.UnsupportedMetadataType;
|
return error.UnsupportedMetadataType;
|
||||||
|
|||||||
@ -31,76 +31,6 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator) !zml.tokenizer.Tokenizer {
|
|
||||||
const tokens = self.metadataSlice("tokenizer.ggml.tokens", .string) orelse {
|
|
||||||
log.err("GGUF File: Tokens not found", .{});
|
|
||||||
return error.TokensNotFound;
|
|
||||||
};
|
|
||||||
const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse {
|
|
||||||
log.err("GGUF File: Scores not found", .{});
|
|
||||||
return error.ScoresNotFound;
|
|
||||||
};
|
|
||||||
assert(tokens.len == scores.len);
|
|
||||||
const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama";
|
|
||||||
const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece;
|
|
||||||
const bos = self.metadata("tokenizer.ggml.bos_token_id", .int);
|
|
||||||
const eos = self.metadata("tokenizer.ggml.eos_token_id", .int);
|
|
||||||
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int);
|
|
||||||
const pad = self.metadata("tokenizer.ggml.padding_token_id", .int);
|
|
||||||
|
|
||||||
const NOT_FOUND = std.math.maxInt(u32);
|
|
||||||
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
|
|
||||||
.bos = @intCast(bos.?),
|
|
||||||
.eos = @intCast(eos.?),
|
|
||||||
.unk = @intCast(unk orelse NOT_FOUND),
|
|
||||||
.pad = @intCast(pad orelse NOT_FOUND),
|
|
||||||
};
|
|
||||||
|
|
||||||
const gguf_normalizer = if (tokenizer_impl == .gpt2)
|
|
||||||
zml.tokenizer.Normalizer.wellKnown(.gpt2)
|
|
||||||
else
|
|
||||||
zml.tokenizer.Normalizer.wellKnown(.sentencepiece);
|
|
||||||
|
|
||||||
const extra_tokens: u8 = if (tokenizer_impl == .gpt2) 1 else 0;
|
|
||||||
const n_tokens: u32 = @intCast(tokens.len + extra_tokens);
|
|
||||||
|
|
||||||
var tokenizer = try zml.tokenizer.Tokenizer.init(
|
|
||||||
allocator,
|
|
||||||
n_tokens,
|
|
||||||
32,
|
|
||||||
gguf_normalizer,
|
|
||||||
special_tokens,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
var gpt2_unicode = if (tokenizer_impl == .gpt2)
|
|
||||||
try zml.tokenizer.Gpt2TextDecoder.init(allocator)
|
|
||||||
else
|
|
||||||
null;
|
|
||||||
defer if (gpt2_unicode) |*gpt2| gpt2.deinit();
|
|
||||||
var decoded = std.ArrayList(u8).init(allocator);
|
|
||||||
defer decoded.deinit();
|
|
||||||
|
|
||||||
// copy the tokens to the tokenizer arena.
|
|
||||||
for (tokens, 0..tokens.len) |t, i| {
|
|
||||||
if (tokenizer_impl == .gpt2) {
|
|
||||||
decoded.clearRetainingCapacity();
|
|
||||||
try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t));
|
|
||||||
// log.debug("token: {s} -> {s}", .{t, decoded.items});
|
|
||||||
} else {
|
|
||||||
try tokenizer.addToken(@floatCast(scores[i]), t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gpt2 tokenizer always splits on spaces.
|
|
||||||
if (tokenizer_impl == .gpt2) {
|
|
||||||
tokenizer.special_tokens.hard_space = tokenizer.next_token_id;
|
|
||||||
tokenizer.addOwnedToken(0, " ");
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
|
fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.GgufFile) !void {
|
||||||
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));
|
try store._metadata.ensureTotalCapacity(allocator, @intCast(file.header.metadata_kv_count));
|
||||||
|
|
||||||
|
|||||||
@ -4,58 +4,6 @@ const zml = @import("../zml.zig");
|
|||||||
|
|
||||||
const sentencepiece_proto = @import("//sentencepiece:model_proto");
|
const sentencepiece_proto = @import("//sentencepiece:model_proto");
|
||||||
const Normalizer = zml.tokenizer.Normalizer;
|
const Normalizer = zml.tokenizer.Normalizer;
|
||||||
const Tokenizer = zml.tokenizer.Tokenizer;
|
|
||||||
|
|
||||||
pub fn loadTokenizerFromPath(allocator: std.mem.Allocator, path: []const u8) !Tokenizer {
|
|
||||||
const file = try asynk.File.open(path, .{});
|
|
||||||
defer file.close() catch unreachable;
|
|
||||||
|
|
||||||
return loadTokenizerFromFile(allocator, file);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn loadTokenizerFromFile(allocator: std.mem.Allocator, file: asynk.File) !Tokenizer {
|
|
||||||
const reader = file.reader();
|
|
||||||
const input = try reader.readAllAlloc(allocator, 16 * 1024 * 1024);
|
|
||||||
defer allocator.free(input);
|
|
||||||
|
|
||||||
var proto_arena = std.heap.ArenaAllocator.init(allocator);
|
|
||||||
defer proto_arena.deinit();
|
|
||||||
|
|
||||||
const model = try sentencepiece_proto.ModelProto.decode(input, proto_arena.allocator());
|
|
||||||
// no deinit, memory will be freed by the proto_arena
|
|
||||||
|
|
||||||
return loadTokenizerFromModelProto(allocator, model);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn loadTokenizerFromModelProto(allocator: std.mem.Allocator, model: sentencepiece_proto.ModelProto) !Tokenizer {
|
|
||||||
std.debug.assert(model.trainer_spec.?.model_type.? == .BPE);
|
|
||||||
const special_tokens: Tokenizer.SpecialTokens = .{
|
|
||||||
.unk = @intCast(model.trainer_spec.?.unk_id.?),
|
|
||||||
.bos = @intCast(model.trainer_spec.?.bos_id.?),
|
|
||||||
.eos = @intCast(model.trainer_spec.?.eos_id.?),
|
|
||||||
.pad = parseTokenId(model.trainer_spec.?.pad_id),
|
|
||||||
};
|
|
||||||
|
|
||||||
var tokenizer = try Tokenizer.init(
|
|
||||||
allocator,
|
|
||||||
@intCast(model.pieces.items.len),
|
|
||||||
@intCast(model.trainer_spec.?.max_sentencepiece_length.?),
|
|
||||||
normalizerFromSpec(model.normalizer_spec.?),
|
|
||||||
special_tokens,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
errdefer tokenizer.deinit();
|
|
||||||
|
|
||||||
for (model.pieces.items) |*piece| {
|
|
||||||
try tokenizer.addToken(piece.score.?, piece.piece.?.getSlice());
|
|
||||||
}
|
|
||||||
const byte_fallback = model.trainer_spec.?.byte_fallback orelse false;
|
|
||||||
if (byte_fallback) {
|
|
||||||
try tokenizer.rewriteByteFallbackTokens();
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parseTokenId(id: ?i32) u32 {
|
fn parseTokenId(id: ?i32) u32 {
|
||||||
if (id) |idx| {
|
if (id) |idx| {
|
||||||
|
|||||||
@ -38,7 +38,7 @@ pub const HostBuffer = struct {
|
|||||||
/// The returned HostBuffer doesn't take ownership of the slice
|
/// The returned HostBuffer doesn't take ownership of the slice
|
||||||
/// that will still need to be deallocated.
|
/// that will still need to be deallocated.
|
||||||
pub fn fromBytes(shape_: Shape, data_: []const u8) HostBuffer {
|
pub fn fromBytes(shape_: Shape, data_: []const u8) HostBuffer {
|
||||||
std.debug.assert(shape_.byteSize() == data_.len);
|
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
|
||||||
return .{
|
return .{
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
.data = data_,
|
.data = data_,
|
||||||
|
|||||||
@ -175,7 +175,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
} else {
|
} else {
|
||||||
to.* = null;
|
to.* = null;
|
||||||
},
|
},
|
||||||
.Int, .Float, .Enum => to.* = from,
|
.Int, .Float, .Enum, .Union => to.* = from,
|
||||||
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2027,7 +2027,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Appends a 1-dim axis, with the given tag.
|
/// Appends a 1-dim axis, with the given tag.
|
||||||
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
|
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
|
||||||
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
|
// stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
|
||||||
|
|
||||||
return self.insertAxes(.last, t);
|
return self.insertAxes(.last, t);
|
||||||
}
|
}
|
||||||
|
|||||||
35
zml/tokenizer/BUILD.bazel
Normal file
35
zml/tokenizer/BUILD.bazel
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||||
|
load("//bazel:swig.bzl", "swig_cc_library")
|
||||||
|
|
||||||
|
swig_cc_library(
|
||||||
|
name = "sentencepiece_swig",
|
||||||
|
interface = "sentencepiece.i",
|
||||||
|
module = "sentencepiece",
|
||||||
|
deps = [
|
||||||
|
"//ffi:cc",
|
||||||
|
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "tokenizer",
|
||||||
|
import_name = "zml/tokenizer",
|
||||||
|
main = "tokenizer.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//async",
|
||||||
|
"//ffi:zig",
|
||||||
|
"//zml/tokenizer/hftokenizers",
|
||||||
|
"//zml/tokenizer/sentencepiece",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_cc_binary(
|
||||||
|
name = "main",
|
||||||
|
main = "main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":tokenizer",
|
||||||
|
],
|
||||||
|
)
|
||||||
38
zml/tokenizer/hftokenizers/BUILD.bazel
Normal file
38
zml/tokenizer/hftokenizers/BUILD.bazel
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
load("@rules_rust//rust:defs.bzl", "rust_static_library")
|
||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||||
|
|
||||||
|
rust_static_library(
|
||||||
|
name = "hftokenizers_rs",
|
||||||
|
srcs = ["hftokenizers.rs"],
|
||||||
|
crate_name = "zml_tokenizer_hftokenizers",
|
||||||
|
edition = "2021",
|
||||||
|
deps = ["@crates//:tokenizers"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "hftokenizers_cc",
|
||||||
|
hdrs = ["hftokenizers.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":hftokenizers_rs",
|
||||||
|
"//ffi:cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "hftokenizers",
|
||||||
|
main = "hftokenizers.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":hftokenizers_cc",
|
||||||
|
"//ffi:zig",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_cc_binary(
|
||||||
|
name = "main",
|
||||||
|
main = "main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":hftokenizers"],
|
||||||
|
)
|
||||||
669
zml/tokenizer/hftokenizers/Cargo.lock
generated
Normal file
669
zml/tokenizer/hftokenizers/Cargo.lock
generated
Normal file
@ -0,0 +1,669 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-set"
|
||||||
|
version = "0.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
||||||
|
dependencies = [
|
||||||
|
"bit-vec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-vec"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bumpalo"
|
||||||
|
version = "3.16.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "byteorder"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling"
|
||||||
|
version = "0.20.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"darling_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_core"
|
||||||
|
version = "0.20.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
|
||||||
|
dependencies = [
|
||||||
|
"fnv",
|
||||||
|
"ident_case",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"strsim",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_macro"
|
||||||
|
version = "0.20.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder"
|
||||||
|
version = "0.20.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
|
||||||
|
dependencies = [
|
||||||
|
"derive_builder_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder_core"
|
||||||
|
version = "0.20.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
|
||||||
|
dependencies = [
|
||||||
|
"darling",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder_macro"
|
||||||
|
version = "0.20.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
||||||
|
dependencies = [
|
||||||
|
"derive_builder_core",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "esaxx-rs"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fancy-regex"
|
||||||
|
version = "0.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
|
||||||
|
dependencies = [
|
||||||
|
"bit-set",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fnv"
|
||||||
|
version = "1.0.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"js-sys",
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ident_case"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itoa"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "js-sys"
|
||||||
|
version = "0.3.76"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy_static"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.169"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "macro_rules_attribute"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13"
|
||||||
|
dependencies = [
|
||||||
|
"macro_rules_attribute-proc_macro",
|
||||||
|
"paste",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "macro_rules_attribute-proc_macro"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memchr"
|
||||||
|
version = "2.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minimal-lexical"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monostate"
|
||||||
|
version = "0.1.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e"
|
||||||
|
dependencies = [
|
||||||
|
"monostate-impl",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monostate-impl"
|
||||||
|
version = "0.1.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "7.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"minimal-lexical",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "once_cell"
|
||||||
|
version = "1.20.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ppv-lite86"
|
||||||
|
version = "0.2.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
|
||||||
|
dependencies = [
|
||||||
|
"zerocopy",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.37"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"rayon-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-cond"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"itertools 0.11.0",
|
||||||
|
"rayon",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-core"
|
||||||
|
version = "1.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ryu"
|
||||||
|
version = "1.0.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde"
|
||||||
|
version = "1.0.216"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e"
|
||||||
|
dependencies = [
|
||||||
|
"serde_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_derive"
|
||||||
|
version = "1.0.216"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_json"
|
||||||
|
version = "1.0.134"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"memchr",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smallvec"
|
||||||
|
version = "1.13.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spm_precompiled"
|
||||||
|
version = "0.1.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326"
|
||||||
|
dependencies = [
|
||||||
|
"base64",
|
||||||
|
"nom",
|
||||||
|
"serde",
|
||||||
|
"unicode-segmentation",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strsim"
|
||||||
|
version = "0.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.91"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d53cbcb5a243bd33b7858b1d7f4aca2153490815872d86d955d6ea29f743c035"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ecededfed68a69bc657e486510089e255e53c3d38cc7d4d59c8742668ca2cae"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"derive_builder",
|
||||||
|
"esaxx-rs",
|
||||||
|
"fancy-regex",
|
||||||
|
"getrandom",
|
||||||
|
"itertools 0.12.1",
|
||||||
|
"lazy_static",
|
||||||
|
"log",
|
||||||
|
"macro_rules_attribute",
|
||||||
|
"monostate",
|
||||||
|
"paste",
|
||||||
|
"rand",
|
||||||
|
"rayon",
|
||||||
|
"rayon-cond",
|
||||||
|
"regex",
|
||||||
|
"regex-syntax",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"spm_precompiled",
|
||||||
|
"thiserror",
|
||||||
|
"unicode-normalization-alignments",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode_categories",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-normalization-alignments"
|
||||||
|
version = "0.1.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
|
||||||
|
dependencies = [
|
||||||
|
"smallvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-segmentation"
|
||||||
|
version = "1.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode_categories"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasi"
|
||||||
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen"
|
||||||
|
version = "0.2.99"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"once_cell",
|
||||||
|
"wasm-bindgen-macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-backend"
|
||||||
|
version = "0.2.99"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79"
|
||||||
|
dependencies = [
|
||||||
|
"bumpalo",
|
||||||
|
"log",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"wasm-bindgen-shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-macro"
|
||||||
|
version = "0.2.99"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe"
|
||||||
|
dependencies = [
|
||||||
|
"quote",
|
||||||
|
"wasm-bindgen-macro-support",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-macro-support"
|
||||||
|
version = "0.2.99"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"wasm-bindgen-backend",
|
||||||
|
"wasm-bindgen-shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-bindgen-shared"
|
||||||
|
version = "0.2.99"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy"
|
||||||
|
version = "0.7.35"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"zerocopy-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy-derive"
|
||||||
|
version = "0.7.35"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zml_tokenizer_hftokenizers"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"tokenizers",
|
||||||
|
]
|
||||||
13
zml/tokenizer/hftokenizers/Cargo.toml
Normal file
13
zml/tokenizer/hftokenizers/Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "zml_tokenizer_hftokenizers"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# onig-sys fails to build with zig cc, disable it via the unstable_wasm feature, which switches
|
||||||
|
# the regex library to using fancy.
|
||||||
|
tokenizers = { version = "0.21.0", default-features = false, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "zml_tokenizer_hftokenizers"
|
||||||
|
path = "hftokenizers.rs"
|
||||||
16
zml/tokenizer/hftokenizers/hftokenizers.h
Normal file
16
zml/tokenizer/hftokenizers/hftokenizers.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "ffi/zig_slice.h"
|
||||||
|
|
||||||
|
typedef struct hftokenizers hftokenizers;
|
||||||
|
|
||||||
|
hftokenizers *hftokenizers_new(zig_slice);
|
||||||
|
void hftokenizers_drop(hftokenizers *tokenizer);
|
||||||
|
zig_slice hftokenizers_encode(hftokenizers *tokenizer, zig_slice text);
|
||||||
|
void hftokenizers_tokens_drop(zig_slice tokens);
|
||||||
|
zig_slice hftokenizers_decode(hftokenizers *tokenizer, zig_slice tokens);
|
||||||
|
void hftokenizers_str_drop(zig_slice text);
|
||||||
|
uint32_t hftokenizers_token_to_id(hftokenizers *tokenizer, zig_slice token);
|
||||||
101
zml/tokenizer/hftokenizers/hftokenizers.rs
Normal file
101
zml/tokenizer/hftokenizers/hftokenizers.rs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
#[repr(C)]
|
||||||
|
struct ZigSlice<T> {
|
||||||
|
ptr: *mut T,
|
||||||
|
len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> ZigSlice<T> {
|
||||||
|
fn as_slice(&self) -> &[T] {
|
||||||
|
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_slice_mut(&self) -> &mut [T] {
|
||||||
|
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_new(path: ZigSlice<u8>) -> *mut tokenizers::Tokenizer {
|
||||||
|
return Box::into_raw(Box::new(
|
||||||
|
tokenizers::Tokenizer::from_file(std::path::Path::new(
|
||||||
|
std::str::from_utf8(path.as_slice()).unwrap(),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_drop(t: *mut tokenizers::Tokenizer) {
|
||||||
|
drop(unsafe { Box::from_raw(t) });
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_encode(
|
||||||
|
t: *mut tokenizers::Tokenizer,
|
||||||
|
string: ZigSlice<u8>,
|
||||||
|
) -> ZigSlice<u32> {
|
||||||
|
let input_str = std::str::from_utf8(string.as_slice()).unwrap();
|
||||||
|
let encoded = unsafe { t.as_ref() }
|
||||||
|
.unwrap()
|
||||||
|
.encode_fast(input_str, false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Convert the result to a boxed slice
|
||||||
|
let mut ids: Box<[u32]> = encoded.get_ids().to_owned().into_boxed_slice();
|
||||||
|
|
||||||
|
// Retrieve the zig slice associated to the boxed slice.
|
||||||
|
let slice = ZigSlice {
|
||||||
|
ptr: ids.as_mut_ptr(),
|
||||||
|
len: ids.len(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Leak the box so that it's not deallocated.
|
||||||
|
Box::leak(ids);
|
||||||
|
|
||||||
|
return slice;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_tokens_drop(tokens: ZigSlice<u32>) {
|
||||||
|
// Reconstruct the Box from the zig slice so that it's dropped.
|
||||||
|
drop(unsafe { Box::from_raw(tokens.as_slice_mut()) });
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_decode(
|
||||||
|
t: *mut tokenizers::Tokenizer,
|
||||||
|
ids: ZigSlice<u32>,
|
||||||
|
) -> ZigSlice<u8> {
|
||||||
|
let decoded = unsafe { t.as_ref() }
|
||||||
|
.unwrap()
|
||||||
|
.decode(ids.as_slice(), false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Convert the result to a boxed slice
|
||||||
|
let mut string: Box<[u8]> = decoded.into_bytes().into_boxed_slice();
|
||||||
|
|
||||||
|
// Retrieve the zig slice associated to the boxed slice.
|
||||||
|
let slice = ZigSlice {
|
||||||
|
ptr: string.as_mut_ptr(),
|
||||||
|
len: string.len(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Leak the box so that it's not deallocated.
|
||||||
|
Box::leak(string);
|
||||||
|
|
||||||
|
return slice;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_str_drop(tokens: ZigSlice<u8>) {
|
||||||
|
drop(unsafe { Box::from_raw(tokens.as_slice_mut()) });
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
extern "C" fn hftokenizers_token_to_id(t: *mut tokenizers::Tokenizer, token: ZigSlice<u8>) -> u32 {
|
||||||
|
let id = unsafe { t.as_ref() }
|
||||||
|
.unwrap()
|
||||||
|
.token_to_id(std::str::from_utf8(token.as_slice()).unwrap())
|
||||||
|
.unwrap_or(u32::MAX);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
113
zml/tokenizer/hftokenizers/hftokenizers.zig
Normal file
113
zml/tokenizer/hftokenizers/hftokenizers.zig
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
const ffi = @import("ffi");
|
||||||
|
|
||||||
|
pub const Encoder = struct {
|
||||||
|
inner: *HFTokenizer,
|
||||||
|
current_ids: ?[]const u32 = null,
|
||||||
|
|
||||||
|
fn init(inner: *HFTokenizer) Encoder {
|
||||||
|
return .{ .inner = inner };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Encoder) void {
|
||||||
|
if (self.current_ids) |current_ids_| {
|
||||||
|
c.hftokenizers_tokens_drop(ffi.ZigSlice.from(current_ids_));
|
||||||
|
self.current_ids = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Encoder) void {
|
||||||
|
self.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
||||||
|
self.reset();
|
||||||
|
self.current_ids = ffi.ZigSlice.to(u32, c.hftokenizers_encode(@ptrCast(self.inner), ffi.ZigSlice.from(input)));
|
||||||
|
return self.ids();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ids(self: *const Encoder) []const u32 {
|
||||||
|
return self.current_ids orelse &.{};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Decoder = struct {
|
||||||
|
const StringBuffer = std.BoundedArray(u8, 128);
|
||||||
|
const TokensIdsBuffer = std.BoundedArray(u32, 4);
|
||||||
|
|
||||||
|
inner: *HFTokenizer,
|
||||||
|
current_string: ?[]const u8 = null,
|
||||||
|
last_string: StringBuffer = .{ .len = 0 },
|
||||||
|
last_token_ids: TokensIdsBuffer = .{ .len = 0 },
|
||||||
|
|
||||||
|
fn init(inner: *HFTokenizer) Decoder {
|
||||||
|
return .{ .inner = inner };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Decoder) void {
|
||||||
|
self.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Decoder) void {
|
||||||
|
if (self.current_string) |current_string_| {
|
||||||
|
c.hftokenizers_str_drop(ffi.ZigSlice.from(current_string_));
|
||||||
|
self.current_string = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
|
||||||
|
self.reset();
|
||||||
|
self.current_string = ffi.ZigSlice.to(u8, c.hftokenizers_decode(@ptrCast(self.inner), ffi.ZigSlice.from(ids)));
|
||||||
|
return self.string();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn string(self: *const Decoder) []const u8 {
|
||||||
|
return self.current_string orelse &.{};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
|
||||||
|
if (self.last_token_ids.len >= self.last_token_ids.capacity()) {
|
||||||
|
_ = self.last_token_ids.orderedRemove(0);
|
||||||
|
}
|
||||||
|
self.last_token_ids.appendAssumeCapacity(token_id);
|
||||||
|
const new_string = try self.decode(self.last_token_ids.constSlice());
|
||||||
|
if (self.last_string.len == 0) {
|
||||||
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
||||||
|
return new_string;
|
||||||
|
}
|
||||||
|
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
|
||||||
|
var it = view.iterator();
|
||||||
|
while (it.nextCodepointSlice()) |cp| {
|
||||||
|
const start = it.i - cp.len;
|
||||||
|
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
|
||||||
|
const chunk = new_string[self.last_string.len - start ..];
|
||||||
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
||||||
|
return chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const HFTokenizer = opaque {
|
||||||
|
pub fn from_file(model: []const u8) !*HFTokenizer {
|
||||||
|
return @ptrCast(c.hftokenizers_new(ffi.ZigSlice.from(model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *HFTokenizer) void {
|
||||||
|
return c.hftokenizers_drop(@ptrCast(self));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder(self: *HFTokenizer) !Encoder {
|
||||||
|
return Encoder.init(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder(self: *HFTokenizer) !Decoder {
|
||||||
|
return Decoder.init(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn token_to_id(self: *HFTokenizer, token: []const u8) ?u32 {
|
||||||
|
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
|
||||||
|
}
|
||||||
|
};
|
||||||
27
zml/tokenizer/hftokenizers/main.zig
Normal file
27
zml/tokenizer/hftokenizers/main.zig
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
const HFTokenizers = @import("hftokenizers").HFTokenizers;
|
||||||
|
|
||||||
|
pub fn main() !void {
|
||||||
|
const tokenizer = HFTokenizers.init("/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json");
|
||||||
|
defer HFTokenizers.deinit(tokenizer);
|
||||||
|
|
||||||
|
const input = "Hello, world! plane pouet plane";
|
||||||
|
var encoded = HFTokenizers.encode(tokenizer, input);
|
||||||
|
defer encoded.deinit();
|
||||||
|
var pouet = std.ArrayList(u32).init(std.heap.c_allocator);
|
||||||
|
defer pouet.deinit();
|
||||||
|
|
||||||
|
// try pouet.appendSlice(encoded.ids);
|
||||||
|
|
||||||
|
var t = try std.time.Timer.start();
|
||||||
|
for (0..100) |_| {
|
||||||
|
try pouet.appendSlice(encoded.ids);
|
||||||
|
t.reset();
|
||||||
|
var decoded = HFTokenizers.decode(tokenizer, pouet.items);
|
||||||
|
defer decoded.deinit();
|
||||||
|
const elapsed = t.lap();
|
||||||
|
// std.debug.print("{any} {any} {d}us\n", .{tokenizer, encoded, elapsed / std.time.ns_per_us});
|
||||||
|
std.debug.print("{any} {any} {s} {d}ns {d}us\n", .{ tokenizer, encoded, decoded.str, elapsed, elapsed / std.time.ns_per_us });
|
||||||
|
}
|
||||||
|
}
|
||||||
22
zml/tokenizer/main.zig
Normal file
22
zml/tokenizer/main.zig
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const tokenizer = @import("zml/tokenizer");
|
||||||
|
|
||||||
|
pub fn main() !void {
|
||||||
|
const model2 = "/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json";
|
||||||
|
|
||||||
|
var sp = try tokenizer.Tokenizer.from_file(std.heap.c_allocator, model2);
|
||||||
|
defer sp.deinit();
|
||||||
|
|
||||||
|
std.debug.print("Loaded model\n", .{});
|
||||||
|
|
||||||
|
var encoder = try sp.encoder();
|
||||||
|
defer encoder.deinit();
|
||||||
|
|
||||||
|
var decoder = try sp.decoder();
|
||||||
|
defer decoder.deinit();
|
||||||
|
|
||||||
|
const ids = try encoder.encode("Hello, world! plane pouet plane");
|
||||||
|
const decoded = try decoder.decode(ids);
|
||||||
|
|
||||||
|
std.debug.print("{d}\n{s}\n", .{ ids, decoded });
|
||||||
|
}
|
||||||
35
zml/tokenizer/sentencepiece/BUILD.bazel
Normal file
35
zml/tokenizer/sentencepiece/BUILD.bazel
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||||
|
load("//bazel:swig.bzl", "swig_cc_library")
|
||||||
|
|
||||||
|
swig_cc_library(
|
||||||
|
name = "sentencepiece_swig",
|
||||||
|
interface = "sentencepiece.i",
|
||||||
|
module = "sentencepiece",
|
||||||
|
deps = [
|
||||||
|
"//ffi:cc",
|
||||||
|
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "sentencepiece",
|
||||||
|
import_name = "sentencepiece",
|
||||||
|
main = "sentencepiece.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":sentencepiece_swig",
|
||||||
|
"//ffi:zig",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_cc_binary(
|
||||||
|
name = "main",
|
||||||
|
srcs = ["sentencepiece.zig"],
|
||||||
|
main = "main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":sentencepiece_swig",
|
||||||
|
"//ffi:zig",
|
||||||
|
],
|
||||||
|
)
|
||||||
289
zml/tokenizer/sentencepiece/main.zig
Normal file
289
zml/tokenizer/sentencepiece/main.zig
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
const ffi = @import("ffi");
|
||||||
|
|
||||||
|
pub const SentencePieceError = error{
|
||||||
|
Cancelled,
|
||||||
|
Unknown,
|
||||||
|
InvalidArgument,
|
||||||
|
DeadlineExceeded,
|
||||||
|
NotFound,
|
||||||
|
AlreadyExists,
|
||||||
|
PermissionDenied,
|
||||||
|
ResourceExhausted,
|
||||||
|
FailedPrecondition,
|
||||||
|
Aborted,
|
||||||
|
OutOfRange,
|
||||||
|
Unimplemented,
|
||||||
|
Internal,
|
||||||
|
Unavailable,
|
||||||
|
DataLoss,
|
||||||
|
Unauthenticated,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const DecoderStream = struct {
|
||||||
|
const TokensSize = 4;
|
||||||
|
const StringSize = 128;
|
||||||
|
decoder: SentencePieceProcessor.Decoder,
|
||||||
|
buffer: [StringSize]u8 = undefined,
|
||||||
|
last_tokens: []u8 = &.{},
|
||||||
|
|
||||||
|
pub fn init(decoder: SentencePieceProcessor.Decoder) DecoderStream {
|
||||||
|
var ret: DecoderStream = .{
|
||||||
|
.decoder = decoder,
|
||||||
|
};
|
||||||
|
ret.decoder.reserve_tokens(TokensSize);
|
||||||
|
ret.decoder.reserve_string(StringSize);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self: *DecoderStream, next_token: u32) !?[]const u8 {
|
||||||
|
if (self.decoder.tokens().len >= TokensSize) {
|
||||||
|
const tokens = self.decoder.tokens();
|
||||||
|
inline for (0..TokensSize - 1) |i| {
|
||||||
|
tokens[i] = tokens[i + 1];
|
||||||
|
}
|
||||||
|
tokens[TokensSize - 1] = next_token;
|
||||||
|
} else {
|
||||||
|
self.decoder.append(next_token);
|
||||||
|
}
|
||||||
|
const new_tokens = try self.decoder.decode();
|
||||||
|
if (self.last_tokens.len == 0) {
|
||||||
|
self.last_tokens = self.buffer[0..new_tokens.len];
|
||||||
|
@memcpy(self.last_tokens, new_tokens);
|
||||||
|
return new_tokens;
|
||||||
|
}
|
||||||
|
for (1..self.last_tokens.len) |i| {
|
||||||
|
if (std.mem.startsWith(u8, new_tokens, self.last_tokens[i..])) {
|
||||||
|
const toks = new_tokens[self.last_tokens.len - i ..];
|
||||||
|
self.last_tokens = self.buffer[0..new_tokens.len];
|
||||||
|
@memcpy(self.last_tokens, new_tokens);
|
||||||
|
return toks;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const SentencePieceProcessor = opaque {
|
||||||
|
pub const Encoder = struct {
|
||||||
|
inner: *SentencePieceProcessor,
|
||||||
|
vec: *c.std_vector_int,
|
||||||
|
|
||||||
|
fn init(inner: *SentencePieceProcessor) Encoder {
|
||||||
|
return .{
|
||||||
|
.inner = inner,
|
||||||
|
.vec = c.std_vector_int_new() orelse unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Encoder) void {
|
||||||
|
c.std_vector_int_delete(self.vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reserve(self: *Encoder, size: usize) void {
|
||||||
|
c.std_vector_int_reserve(self.vec, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Encoder) void {
|
||||||
|
c.std_vector_int_clear(self.vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
||||||
|
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
|
||||||
|
return ffi.ZigSlice.to(u32, .{
|
||||||
|
.ptr = c.std_vector_int_data(self.vec),
|
||||||
|
.len = c.std_vector_int_size(self.vec),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Decoder = struct {
|
||||||
|
inner: *SentencePieceProcessor,
|
||||||
|
vec: *c.std_vector_int,
|
||||||
|
str: *c.std_string,
|
||||||
|
|
||||||
|
fn init(inner: *SentencePieceProcessor) Decoder {
|
||||||
|
return .{
|
||||||
|
.inner = inner,
|
||||||
|
.vec = c.std_vector_int_new() orelse unreachable,
|
||||||
|
.str = c.std_string_new() orelse unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(self: *Decoder, token: u32) void {
|
||||||
|
c.std_vector_int_push_back(self.vec, @intCast(token));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Decoder) void {
|
||||||
|
c.std_vector_int_delete(self.vec);
|
||||||
|
c.std_string_delete(self.str);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reserve_tokens(self: *Decoder, size: usize) void {
|
||||||
|
c.std_vector_int_reserve(self.vec, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reserve_string(self: *Decoder, size: usize) void {
|
||||||
|
c.std_string_reserve(self.str, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Decoder) void {
|
||||||
|
c.std_vector_int_clear(self.vec);
|
||||||
|
c.std_string_clear(self.str);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(self: *Decoder) ![]const u8 {
|
||||||
|
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
|
||||||
|
return self.string();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn string(self: *const Decoder) []const u8 {
|
||||||
|
const res = c.std_string_data(self.str);
|
||||||
|
return ffi.ZigSlice.to(u8, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokens(self: *const Decoder) []u32 {
|
||||||
|
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
|
||||||
|
return ptr[0..c.std_vector_int_size(self.vec)];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fn assertOk(code: c.sentencepiece_util_StatusCode) SentencePieceError!void {
|
||||||
|
return switch (code) {
|
||||||
|
c.sentencepiece_util_StatusCode_kOk => {},
|
||||||
|
c.sentencepiece_util_StatusCode_kCancelled => error.Cancelled,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnknown => error.Unknown,
|
||||||
|
c.sentencepiece_util_StatusCode_kInvalidArgument => error.InvalidArgument,
|
||||||
|
c.sentencepiece_util_StatusCode_kDeadlineExceeded => error.DeadlineExceeded,
|
||||||
|
c.sentencepiece_util_StatusCode_kNotFound => error.NotFound,
|
||||||
|
c.sentencepiece_util_StatusCode_kAlreadyExists => error.AlreadyExists,
|
||||||
|
c.sentencepiece_util_StatusCode_kPermissionDenied => error.PermissionDenied,
|
||||||
|
c.sentencepiece_util_StatusCode_kResourceExhausted => error.ResourceExhausted,
|
||||||
|
c.sentencepiece_util_StatusCode_kFailedPrecondition => error.FailedPrecondition,
|
||||||
|
c.sentencepiece_util_StatusCode_kAborted => error.Aborted,
|
||||||
|
c.sentencepiece_util_StatusCode_kOutOfRange => error.OutOfRange,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnimplemented => error.Unimplemented,
|
||||||
|
c.sentencepiece_util_StatusCode_kInternal => error.Internal,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnavailable => error.Unavailable,
|
||||||
|
c.sentencepiece_util_StatusCode_kDataLoss => error.DataLoss,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnauthenticated => error.Unauthenticated,
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(model: []const u8) !*SentencePieceProcessor {
|
||||||
|
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
|
||||||
|
errdefer sp.deinit();
|
||||||
|
try sp.load_from(model);
|
||||||
|
return sp;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *SentencePieceProcessor) void {
|
||||||
|
c.SentencePieceProcessor_delete(@ptrCast(self));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_from(self: *SentencePieceProcessor, model: []const u8) !void {
|
||||||
|
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(self), ffi.ZigSlice.from(model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder(self: *SentencePieceProcessor) Encoder {
|
||||||
|
return Encoder.init(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder(self: *SentencePieceProcessor) Decoder {
|
||||||
|
return Decoder.init(self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
|
||||||
|
var result: [std.fs.max_path_bytes:0]u8 = undefined;
|
||||||
|
@memcpy(result[0..path.len], path);
|
||||||
|
result[path.len] = 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() !void {
|
||||||
|
const sp = try SentencePieceProcessor.load("/Users/steeve/Downloads/poolside.sp.pb");
|
||||||
|
defer sp.deinit();
|
||||||
|
|
||||||
|
std.debug.print("Loaded model\n", .{});
|
||||||
|
|
||||||
|
var encoder = sp.encoder();
|
||||||
|
defer encoder.deinit();
|
||||||
|
|
||||||
|
var decoder = sp.decoder();
|
||||||
|
defer decoder.deinit();
|
||||||
|
|
||||||
|
const ss = @embedFile("main.zig");
|
||||||
|
// \\String class
|
||||||
|
// \\Strings are objects that represent sequences of characters.
|
||||||
|
// \\
|
||||||
|
// \\The standard string class provides support for such objects with an interface similar to that of a standard container of bytes, but adding features specifically designed to operate with strings of single-byte characters.
|
||||||
|
// \\
|
||||||
|
// \\The string class is an instantiation of the basic_string class template that uses char (i.e., bytes) as its character type, with its default char_traits and allocator types (see basic_string for more info on the template).
|
||||||
|
// \\
|
||||||
|
// \\Note that this class handles bytes independently of the encoding used: If used to handle sequences of multi-byte or variable-length characters (such as UTF-8), all members of this class (such as length or size), as well as its iterators, will still operate in terms of bytes (not actual encoded characters).
|
||||||
|
// \\
|
||||||
|
// ;
|
||||||
|
const tokens = try encoder.encode(ss);
|
||||||
|
|
||||||
|
// const ss2 = 128;
|
||||||
|
// var buf = [_]u8{0} ** ss2;
|
||||||
|
// // _ = buf; // autofix
|
||||||
|
// var last_tokens: []u8 = &.{};
|
||||||
|
// // _ = last_tokens; // autofix
|
||||||
|
// decoder.reserve_tokens(4);
|
||||||
|
// decoder.reserve_string(128);
|
||||||
|
|
||||||
|
var stream = DecoderStream.init(decoder);
|
||||||
|
|
||||||
|
var start = try std.time.Timer.start();
|
||||||
|
for (tokens) |token| {
|
||||||
|
if (try stream.next(token)) |chunk| {
|
||||||
|
// std.debug.print("{s}", .{chunk});
|
||||||
|
std.debug.print("{d}us - {s}\n", .{ start.lap() / std.time.ns_per_us, chunk });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// var start = try std.time.Timer.start();
|
||||||
|
// var it = std.mem.window(u32, tokens, 3, 1);
|
||||||
|
// while (it.next()) |slice| {
|
||||||
|
// if (decoder.tokens().len >= 4) {
|
||||||
|
// const kept_tokens = decoder.tokens()[1..];
|
||||||
|
// std.mem.copyForwards(u32, decoder.tokens()[0..kept_tokens.len], kept_tokens);
|
||||||
|
// kept_tokens[kept_tokens.len - 1] = slice[2];
|
||||||
|
// } else {
|
||||||
|
// for (slice) |token| {
|
||||||
|
// decoder.append(token);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// const new_tokens = try decoder.decode();
|
||||||
|
// for (0..ss2) |i| {
|
||||||
|
// if (std.mem.startsWith(u8, new_tokens, last_tokens[i..])) {
|
||||||
|
// const toks = new_tokens[last_tokens.len - i..];
|
||||||
|
// // std.debug.print("{s}", .{toks});
|
||||||
|
// if (toks.len == 0) {
|
||||||
|
// // std.debug.print("WESH\n", .{});
|
||||||
|
// }
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// last_tokens = buf[0..new_tokens.len];
|
||||||
|
// @memcpy(last_tokens, new_tokens);
|
||||||
|
// std.debug.print("{d}us\n", .{start.lap() / std.time.ns_per_us});
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for (tokens) |token| {
|
||||||
|
// decoder.append(token);
|
||||||
|
// }
|
||||||
|
// const decoded = try decoder.decode();
|
||||||
|
// std.debug.print("Decoded: {s}\n", .{decoded});
|
||||||
|
|
||||||
|
// const model = "/Users/steeve/Downloads/poolside.sp.pb";
|
||||||
|
|
||||||
|
// c.SentencePieceProcessor_LoadOrDie(sp, c.zig_slice{ .ptr = model.ptr, .len = model.len });
|
||||||
|
|
||||||
|
// const piece = c.SentencePieceProcessor_IdToPiece(sp, 10999);
|
||||||
|
// std.debug.print("{s}\n", .{piece.ptr[0..piece.len]});
|
||||||
|
}
|
||||||
92
zml/tokenizer/sentencepiece/sentencepiece.i
Normal file
92
zml/tokenizer/sentencepiece/sentencepiece.i
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
/* File : example.i */
|
||||||
|
%module sentencepiece
|
||||||
|
%include <typemaps.i>
|
||||||
|
%include <std_vector.i>
|
||||||
|
|
||||||
|
%{
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <sentencepiece_processor.h>
|
||||||
|
#include "ffi/zig_slice.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%insert("cheader") %{
|
||||||
|
#include "ffi/zig_slice.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%typemap(in) absl::string_view {
|
||||||
|
$1 = absl::string_view((char *)$input.ptr, $input.len);
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(out, optimal="1") const std::string& %{
|
||||||
|
$result.ptr = (void *)($1->data());
|
||||||
|
$result.len = (size_t)($1->length());
|
||||||
|
%}
|
||||||
|
%typemap(ctype) absl::string_view, const std::string& "zig_slice"
|
||||||
|
|
||||||
|
%rename(std_string) std::string;
|
||||||
|
namespace std {
|
||||||
|
class string {
|
||||||
|
public:
|
||||||
|
void reserve(size_t n);
|
||||||
|
void clear();
|
||||||
|
size_t capacity() const;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
%extend std::string {
|
||||||
|
const std::string& data() const {
|
||||||
|
return *$self;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
%extend std::vector {
|
||||||
|
T* data() {
|
||||||
|
return $self->data();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
%template(std_vector_int) std::vector<int>;
|
||||||
|
|
||||||
|
%typemap(out) sentencepiece::util::Status %{
|
||||||
|
$result = $1.code();
|
||||||
|
%}
|
||||||
|
%typemap(ctype) sentencepiece::util::Status "unsigned int"
|
||||||
|
%rename(sentencepiece_util_StatusCode, fullname=1) sentencepiece::util::StatusCode;
|
||||||
|
|
||||||
|
namespace sentencepiece {
|
||||||
|
namespace util {
|
||||||
|
enum class StatusCode : int {
|
||||||
|
kOk = 0,
|
||||||
|
kCancelled = 1,
|
||||||
|
kUnknown = 2,
|
||||||
|
kInvalidArgument = 3,
|
||||||
|
kDeadlineExceeded = 4,
|
||||||
|
kNotFound = 5,
|
||||||
|
kAlreadyExists = 6,
|
||||||
|
kPermissionDenied = 7,
|
||||||
|
kResourceExhausted = 8,
|
||||||
|
kFailedPrecondition = 9,
|
||||||
|
kAborted = 10,
|
||||||
|
kOutOfRange = 11,
|
||||||
|
kUnimplemented = 12,
|
||||||
|
kInternal = 13,
|
||||||
|
kUnavailable = 14,
|
||||||
|
kDataLoss = 15,
|
||||||
|
kUnauthenticated = 16,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
class SentencePieceProcessor {
|
||||||
|
public:
|
||||||
|
virtual sentencepiece::util::Status Load(absl::string_view filename);
|
||||||
|
virtual sentencepiece::util::Status Encode(absl::string_view input, std::vector<int> *ids) const;
|
||||||
|
virtual sentencepiece::util::Status Decode(const std::vector<int> &ids, std::string *detokenized) const;
|
||||||
|
virtual int PieceToId(absl::string_view piece) const;
|
||||||
|
virtual int unk_id() const;
|
||||||
|
virtual int bos_id() const;
|
||||||
|
virtual int eos_id() const;
|
||||||
|
virtual int pad_id() const;
|
||||||
|
};
|
||||||
|
}
|
||||||
189
zml/tokenizer/sentencepiece/sentencepiece.zig
Normal file
189
zml/tokenizer/sentencepiece/sentencepiece.zig
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const c = @import("c");
|
||||||
|
const ffi = @import("ffi");
|
||||||
|
|
||||||
|
const StringToTokenRatio = 3;
|
||||||
|
|
||||||
|
pub const Error = error{
|
||||||
|
Cancelled,
|
||||||
|
Unknown,
|
||||||
|
InvalidArgument,
|
||||||
|
DeadlineExceeded,
|
||||||
|
NotFound,
|
||||||
|
AlreadyExists,
|
||||||
|
PermissionDenied,
|
||||||
|
ResourceExhausted,
|
||||||
|
FailedPrecondition,
|
||||||
|
Aborted,
|
||||||
|
OutOfRange,
|
||||||
|
Unimplemented,
|
||||||
|
Internal,
|
||||||
|
Unavailable,
|
||||||
|
DataLoss,
|
||||||
|
Unauthenticated,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn assertOk(code: c.sentencepiece_util_StatusCode) Error!void {
|
||||||
|
return switch (code) {
|
||||||
|
c.sentencepiece_util_StatusCode_kOk => {},
|
||||||
|
c.sentencepiece_util_StatusCode_kCancelled => Error.Cancelled,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnknown => Error.Unknown,
|
||||||
|
c.sentencepiece_util_StatusCode_kInvalidArgument => Error.InvalidArgument,
|
||||||
|
c.sentencepiece_util_StatusCode_kDeadlineExceeded => Error.DeadlineExceeded,
|
||||||
|
c.sentencepiece_util_StatusCode_kNotFound => Error.NotFound,
|
||||||
|
c.sentencepiece_util_StatusCode_kAlreadyExists => Error.AlreadyExists,
|
||||||
|
c.sentencepiece_util_StatusCode_kPermissionDenied => Error.PermissionDenied,
|
||||||
|
c.sentencepiece_util_StatusCode_kResourceExhausted => Error.ResourceExhausted,
|
||||||
|
c.sentencepiece_util_StatusCode_kFailedPrecondition => Error.FailedPrecondition,
|
||||||
|
c.sentencepiece_util_StatusCode_kAborted => Error.Aborted,
|
||||||
|
c.sentencepiece_util_StatusCode_kOutOfRange => Error.OutOfRange,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnimplemented => Error.Unimplemented,
|
||||||
|
c.sentencepiece_util_StatusCode_kInternal => Error.Internal,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnavailable => Error.Unavailable,
|
||||||
|
c.sentencepiece_util_StatusCode_kDataLoss => Error.DataLoss,
|
||||||
|
c.sentencepiece_util_StatusCode_kUnauthenticated => Error.Unauthenticated,
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Encoder = struct {
|
||||||
|
inner: *SentencePieceProcessor,
|
||||||
|
vec: *c.std_vector_int,
|
||||||
|
|
||||||
|
fn init(inner: *SentencePieceProcessor) Encoder {
|
||||||
|
return .{
|
||||||
|
.inner = inner,
|
||||||
|
.vec = c.std_vector_int_new() orelse unreachable,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Encoder) void {
|
||||||
|
c.std_vector_int_delete(self.vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Encoder) void {
|
||||||
|
c.std_vector_int_clear(self.vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
||||||
|
c.std_vector_int_reserve(self.vec, input.len / StringToTokenRatio);
|
||||||
|
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
|
||||||
|
return self.ids();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ids(self: *const Encoder) []const u32 {
|
||||||
|
return ffi.ZigSlice.to(u32, .{
|
||||||
|
.ptr = c.std_vector_int_data(self.vec),
|
||||||
|
.len = c.std_vector_int_size(self.vec),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Decoder = struct {
|
||||||
|
const StringBufferSize = 64;
|
||||||
|
const StringBuffer = std.BoundedArray(u8, StringBufferSize);
|
||||||
|
const TokenIdsBufferSize = 4;
|
||||||
|
|
||||||
|
inner: *SentencePieceProcessor,
|
||||||
|
vec: *c.std_vector_int,
|
||||||
|
str: *c.std_string,
|
||||||
|
last_string: StringBuffer = .{ .len = 0 },
|
||||||
|
|
||||||
|
fn init(inner: *SentencePieceProcessor) !Decoder {
|
||||||
|
const vec = try (c.std_vector_int_new() orelse std.mem.Allocator.Error.OutOfMemory);
|
||||||
|
c.std_vector_int_reserve(vec, TokenIdsBufferSize);
|
||||||
|
errdefer c.std_vector_int_delete(vec);
|
||||||
|
|
||||||
|
const str = try (c.std_string_new() orelse std.mem.Allocator.Error.OutOfMemory);
|
||||||
|
c.std_string_reserve(str, StringBufferSize);
|
||||||
|
errdefer c.std_string_delete(str);
|
||||||
|
|
||||||
|
return .{
|
||||||
|
.inner = inner,
|
||||||
|
.vec = vec,
|
||||||
|
.str = str,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Decoder) void {
|
||||||
|
c.std_vector_int_delete(self.vec);
|
||||||
|
c.std_string_delete(self.str);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Decoder) void {
|
||||||
|
c.std_vector_int_clear(self.vec);
|
||||||
|
c.std_string_clear(self.str);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(self: *Decoder, ids_: []const u32) ![]const u8 {
|
||||||
|
c.std_vector_int_reserve(self.vec, ids_.len);
|
||||||
|
c.std_string_reserve(self.str, ids_.len * StringToTokenRatio);
|
||||||
|
for (ids_) |id| {
|
||||||
|
c.std_vector_int_push_back(self.vec, @intCast(id));
|
||||||
|
}
|
||||||
|
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
|
||||||
|
return self.string();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn string(self: *const Decoder) []const u8 {
|
||||||
|
const res = c.std_string_data(self.str);
|
||||||
|
return ffi.ZigSlice.to(u8, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ids(self: *const Decoder) []u32 {
|
||||||
|
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
|
||||||
|
return ptr[0..c.std_vector_int_size(self.vec)];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
|
||||||
|
const current_ids = self.ids();
|
||||||
|
if (current_ids.len >= c.std_vector_int_capacity(self.vec)) {
|
||||||
|
std.mem.copyForwards(u32, current_ids[0 .. current_ids.len - 1], current_ids[1..]);
|
||||||
|
current_ids[current_ids.len - 1] = token_id;
|
||||||
|
} else {
|
||||||
|
c.std_vector_int_push_back(self.vec, @intCast(token_id));
|
||||||
|
}
|
||||||
|
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
|
||||||
|
const new_string = self.string();
|
||||||
|
if (self.last_string.len == 0) {
|
||||||
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
||||||
|
return new_string;
|
||||||
|
}
|
||||||
|
var view = try std.unicode.Utf8View.init(self.last_string.constSlice());
|
||||||
|
var it = view.iterator();
|
||||||
|
while (it.nextCodepointSlice()) |cp| {
|
||||||
|
const start = it.i - cp.len;
|
||||||
|
if (std.mem.startsWith(u8, new_string, self.last_string.constSlice()[start..])) {
|
||||||
|
const chunk = new_string[self.last_string.len - start ..];
|
||||||
|
self.last_string = try StringBuffer.fromSlice(new_string);
|
||||||
|
return chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &.{};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const SentencePieceProcessor = opaque {
|
||||||
|
pub fn from_file(model: []const u8) !*SentencePieceProcessor {
|
||||||
|
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
|
||||||
|
errdefer sp.deinit();
|
||||||
|
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(sp), ffi.ZigSlice.from(model)));
|
||||||
|
return sp;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *SentencePieceProcessor) void {
|
||||||
|
c.SentencePieceProcessor_delete(@ptrCast(self));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder(self: *SentencePieceProcessor) !Encoder {
|
||||||
|
return Encoder.init(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder(self: *SentencePieceProcessor) !Decoder {
|
||||||
|
return try Decoder.init(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn token_to_id(self: *SentencePieceProcessor, token: []const u8) u32 {
|
||||||
|
return @intCast(c.SentencePieceProcessor_PieceToId(@ptrCast(self), ffi.ZigSlice.from(token)));
|
||||||
|
}
|
||||||
|
};
|
||||||
118
zml/tokenizer/tokenizer.zig
Normal file
118
zml/tokenizer/tokenizer.zig
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const hftokenizers = @import("hftokenizers");
|
||||||
|
const sentencepiece = @import("sentencepiece");
|
||||||
|
const asynk = @import("async");
|
||||||
|
|
||||||
|
const Tokenizers = enum {
|
||||||
|
hftokenizers,
|
||||||
|
sentencepiece,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Tokenizer = union(Tokenizers) {
|
||||||
|
pub const Encoder = union(Tokenizers) {
|
||||||
|
hftokenizers: hftokenizers.Encoder,
|
||||||
|
sentencepiece: sentencepiece.Encoder,
|
||||||
|
|
||||||
|
pub fn deinit(self: *Encoder) void {
|
||||||
|
switch (self.*) {
|
||||||
|
inline else => |*v| v.deinit(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Encoder) void {
|
||||||
|
switch (self.*) {
|
||||||
|
inline else => |*v| v.reset(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
||||||
|
return switch (self.*) {
|
||||||
|
inline else => |*v| v.encode(input),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ids(self: Encoder) []const u32 {
|
||||||
|
return switch (self) {
|
||||||
|
inline else => |v| v.ids(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Decoder = union(Tokenizers) {
|
||||||
|
hftokenizers: hftokenizers.Decoder,
|
||||||
|
sentencepiece: sentencepiece.Decoder,
|
||||||
|
|
||||||
|
pub fn deinit(self: *Decoder) void {
|
||||||
|
switch (self.*) {
|
||||||
|
inline else => |*v| v.deinit(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(self: *Decoder) void {
|
||||||
|
switch (self.*) {
|
||||||
|
inline else => |*v| v.reset(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(self: *Decoder, ids_: []const u32) ![]const u8 {
|
||||||
|
return switch (self.*) {
|
||||||
|
inline else => |*v| v.decode(ids_),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn string(self: Decoder) []const u8 {
|
||||||
|
return switch (self.*) {
|
||||||
|
inline else => |v| v.string(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ids(self: Decoder) []u32 {
|
||||||
|
return switch (self.*) {
|
||||||
|
inline else => |v| v.ids(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self: *Decoder, token_id: u32) !?[]const u8 {
|
||||||
|
return switch (self.*) {
|
||||||
|
inline else => |*v| v.next(token_id),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
hftokenizers: *hftokenizers.HFTokenizer,
|
||||||
|
sentencepiece: *sentencepiece.SentencePieceProcessor,
|
||||||
|
|
||||||
|
pub fn from_file(_: std.mem.Allocator, model: []const u8) !Tokenizer {
|
||||||
|
if (std.mem.endsWith(u8, model, ".pb")) {
|
||||||
|
return .{ .sentencepiece = try asynk.callBlocking(sentencepiece.SentencePieceProcessor.from_file, .{model}) };
|
||||||
|
}
|
||||||
|
if (std.mem.endsWith(u8, model, ".json")) {
|
||||||
|
return .{ .hftokenizers = try asynk.callBlocking(hftokenizers.HFTokenizer.from_file, .{model}) };
|
||||||
|
}
|
||||||
|
return error.InvalidArgument;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Tokenizer) void {
|
||||||
|
switch (self.*) {
|
||||||
|
inline else => |t| t.deinit(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encoder(self: Tokenizer) !Encoder {
|
||||||
|
return switch (self) {
|
||||||
|
inline else => |v, tag| @unionInit(Encoder, @tagName(tag), try v.encoder()),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder(self: Tokenizer) !Decoder {
|
||||||
|
return switch (self) {
|
||||||
|
inline else => |v, tag| @unionInit(Decoder, @tagName(tag), try v.decoder()),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn token_to_id(self: Tokenizer, token: []const u8) ?u32 {
|
||||||
|
return switch (self) {
|
||||||
|
inline else => |v| v.token_to_id(token),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
@ -29,7 +29,8 @@ pub const pjrt = @import("pjrtx.zig");
|
|||||||
pub const testing = @import("testing.zig");
|
pub const testing = @import("testing.zig");
|
||||||
pub const torch = @import("torch.zig");
|
pub const torch = @import("torch.zig");
|
||||||
|
|
||||||
pub const tokenizer = @import("tokenizer.zig");
|
// pub const tokenizer = @import("tokenizer.zig");
|
||||||
|
pub const tokenizer = @import("zml/tokenizer");
|
||||||
|
|
||||||
pub const call = ops.call;
|
pub const call = ops.call;
|
||||||
pub const compile = exe.compile;
|
pub const compile = exe.compile;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user