Add async I/O, SentencePiece, NN, and tensor utilities for ModernBERT support and update Bazel build configuration.

This commit is contained in:
Tarry Singh 2024-06-14 15:27:06 +00:00
parent 17d02621e7
commit 18eb0e5a7b
5 changed files with 80 additions and 10 deletions

7
tools/BUILD.bazel Normal file
View File

@ -0,0 +1,7 @@
load("@rules_python//python:py_library.bzl", "py_library")
py_library(
name = "zml_utils",
srcs = ["zml_utils.py"],
visibility = ["//visibility:public"],
)

View File

@ -574,6 +574,58 @@ pub fn unloadBuffers(model: anytype) void {
}).cb, {}, model); }).cb, {}, model);
} }
/// Assists in debuggigng `BufferNotFound` error
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void {
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
var shown_keys = std.StringHashMap(void).init(temp_allocator);
defer shown_keys.deinit();
// remove suffix .weight and .bias
var base_key = original_key;
for (suffixes) |suffix| {
if (std.mem.endsWith(u8, original_key, suffix)) {
base_key = original_key[0 .. original_key.len - suffix.len];
break;
}
}
// first test: look for exact matches
var matches: usize = 0;
var it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.startsWith(u8, key, base_key)) {
if (matches == 0) log.warn("Similar buffers found:", .{});
if (!shown_keys.contains(key)) {
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
}
}
}
// second test: progressive partial matches
if (matches == 0) {
var components = std.mem.splitScalar(u8, base_key, '.');
while (components.next()) |component| {
matches = 0;
it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
if (matches >= 5) break;
}
}
if (matches > 0) break;
}
}
}
/// deinit all buffers in the given struct /// deinit all buffers in the given struct
pub fn awaitAll(buffers: anytype) !void { pub fn awaitAll(buffers: anytype) !void {
// TODO: implement once we have async buffers. // TODO: implement once we have async buffers.
@ -600,6 +652,10 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
buf_with_metadata._shape = obj._shape; buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata); obj.* = try zml.Buffer.from(platform, buf_with_metadata);
} else { } else {
log.err("Buffer not found: {s}", .{prefix});
findSimilarBufferKeys(prefix, buffer_store, allocator);
return error.BufferNotFound; return error.BufferNotFound;
}; };
} else if (T == zml.Shape) return; } else if (T == zml.Shape) return;

View File

@ -31,6 +31,7 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
.add_dummy_suffix = false, .add_dummy_suffix = false,
.lower_case_ascii = false, .lower_case_ascii = false,
.split_on_punct_ascii = false, .split_on_punct_ascii = false,
.use_nfc = false,
}, },
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null, if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
); );

View File

@ -69,15 +69,15 @@ pub const Activation = union(enum) {
.leakyReLU => |slope| x.leakyReLU(slope), .leakyReLU => |slope| x.leakyReLU(slope),
}; };
} }
pub fn elu(x: Tensor, alpha: f32) Tensor {
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
x,
x.exp().addConstant(-1).scale(alpha),
);
}
}; };
pub fn elu(x: Tensor, alpha: f32) Tensor {
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
x,
x.exp().addConstant(-1).scale(alpha),
);
}
pub fn chainModules(module_list: anytype, input: Tensor) Tensor { pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
const T = @TypeOf(module_list); const T = @TypeOf(module_list);
switch (@typeInfo(T)) { switch (@typeInfo(T)) {
@ -765,8 +765,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
k = k.mul(head_scaling.convert(k.dtype())); k = k.mul(head_scaling.convert(k.dtype()));
var attn_weights = q.dot(k, .{.hd}); var attn_weights = q.dot(k, .{.hd});
// log.debug("attn_weights : {}", .{attn_weights}); // log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
// log.debug("attn_mask : {?}", .{attn_mask});
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());

View File

@ -1262,8 +1262,15 @@ pub const Tensor = struct {
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor. /// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
pub fn softmax(self: Tensor, axis_: anytype) Tensor { pub fn softmax(self: Tensor, axis_: anytype) Tensor {
const a = self.axis(axis_); const a = self.axis(axis_);
const max_val = self.max(a);
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp(); const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
return exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape)); const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
// If a row is full -inf return full 0 instead of full nan,
// this fix attention when mask hides a full row.
return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype()));
} }
/// Returns a Tensor containing the log of the sum of exponential over the given axis. /// Returns a Tensor containing the log of the sum of exponential over the given axis.