Add async I/O, SentencePiece, NN, and tensor utilities for ModernBERT support and update Bazel build configuration.
This commit is contained in:
parent
17d02621e7
commit
18eb0e5a7b
7
tools/BUILD.bazel
Normal file
7
tools/BUILD.bazel
Normal file
@ -0,0 +1,7 @@
|
||||
load("@rules_python//python:py_library.bzl", "py_library")
|
||||
|
||||
py_library(
|
||||
name = "zml_utils",
|
||||
srcs = ["zml_utils.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
56
zml/aio.zig
56
zml/aio.zig
@ -574,6 +574,58 @@ pub fn unloadBuffers(model: anytype) void {
|
||||
}).cb, {}, model);
|
||||
}
|
||||
|
||||
/// Assists in debuggigng `BufferNotFound` error
|
||||
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
|
||||
fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void {
|
||||
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
|
||||
var shown_keys = std.StringHashMap(void).init(temp_allocator);
|
||||
defer shown_keys.deinit();
|
||||
|
||||
// remove suffix .weight and .bias
|
||||
var base_key = original_key;
|
||||
for (suffixes) |suffix| {
|
||||
if (std.mem.endsWith(u8, original_key, suffix)) {
|
||||
base_key = original_key[0 .. original_key.len - suffix.len];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// first test: look for exact matches
|
||||
var matches: usize = 0;
|
||||
var it = store.buffers.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const key = entry.key_ptr.*;
|
||||
if (std.mem.startsWith(u8, key, base_key)) {
|
||||
if (matches == 0) log.warn("Similar buffers found:", .{});
|
||||
if (!shown_keys.contains(key)) {
|
||||
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
||||
shown_keys.put(key, {}) catch continue;
|
||||
matches += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// second test: progressive partial matches
|
||||
if (matches == 0) {
|
||||
var components = std.mem.splitScalar(u8, base_key, '.');
|
||||
while (components.next()) |component| {
|
||||
matches = 0;
|
||||
it = store.buffers.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const key = entry.key_ptr.*;
|
||||
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
||||
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
||||
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
||||
shown_keys.put(key, {}) catch continue;
|
||||
matches += 1;
|
||||
if (matches >= 5) break;
|
||||
}
|
||||
}
|
||||
if (matches > 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// deinit all buffers in the given struct
|
||||
pub fn awaitAll(buffers: anytype) !void {
|
||||
// TODO: implement once we have async buffers.
|
||||
@ -600,6 +652,10 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
||||
buf_with_metadata._shape = obj._shape;
|
||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
||||
} else {
|
||||
log.err("Buffer not found: {s}", .{prefix});
|
||||
|
||||
findSimilarBufferKeys(prefix, buffer_store, allocator);
|
||||
|
||||
return error.BufferNotFound;
|
||||
};
|
||||
} else if (T == zml.Shape) return;
|
||||
|
||||
@ -31,6 +31,7 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
|
||||
.add_dummy_suffix = false,
|
||||
.lower_case_ascii = false,
|
||||
.split_on_punct_ascii = false,
|
||||
.use_nfc = false,
|
||||
},
|
||||
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
|
||||
);
|
||||
|
||||
17
zml/nn.zig
17
zml/nn.zig
@ -69,15 +69,15 @@ pub const Activation = union(enum) {
|
||||
.leakyReLU => |slope| x.leakyReLU(slope),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn elu(x: Tensor, alpha: f32) Tensor {
|
||||
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
|
||||
x,
|
||||
x.exp().addConstant(-1).scale(alpha),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn elu(x: Tensor, alpha: f32) Tensor {
|
||||
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
|
||||
x,
|
||||
x.exp().addConstant(-1).scale(alpha),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
|
||||
const T = @TypeOf(module_list);
|
||||
switch (@typeInfo(T)) {
|
||||
@ -765,8 +765,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
k = k.mul(head_scaling.convert(k.dtype()));
|
||||
|
||||
var attn_weights = q.dot(k, .{.hd});
|
||||
// log.debug("attn_weights : {}", .{attn_weights});
|
||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||
// log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
|
||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||
|
||||
|
||||
@ -1262,8 +1262,15 @@ pub const Tensor = struct {
|
||||
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
|
||||
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
||||
const a = self.axis(axis_);
|
||||
const max_val = self.max(a);
|
||||
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
|
||||
|
||||
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
|
||||
return exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
|
||||
const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
|
||||
|
||||
// If a row is full -inf return full 0 instead of full nan,
|
||||
// this fix attention when mask hides a full row.
|
||||
return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype()));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user