Add async I/O, SentencePiece, NN, and tensor utilities for ModernBERT support and update Bazel build configuration.
This commit is contained in:
parent
17d02621e7
commit
18eb0e5a7b
7
tools/BUILD.bazel
Normal file
7
tools/BUILD.bazel
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
load("@rules_python//python:py_library.bzl", "py_library")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "zml_utils",
|
||||||
|
srcs = ["zml_utils.py"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
56
zml/aio.zig
56
zml/aio.zig
@ -574,6 +574,58 @@ pub fn unloadBuffers(model: anytype) void {
|
|||||||
}).cb, {}, model);
|
}).cb, {}, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Assists in debuggigng `BufferNotFound` error
|
||||||
|
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
|
||||||
|
fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void {
|
||||||
|
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
|
||||||
|
var shown_keys = std.StringHashMap(void).init(temp_allocator);
|
||||||
|
defer shown_keys.deinit();
|
||||||
|
|
||||||
|
// remove suffix .weight and .bias
|
||||||
|
var base_key = original_key;
|
||||||
|
for (suffixes) |suffix| {
|
||||||
|
if (std.mem.endsWith(u8, original_key, suffix)) {
|
||||||
|
base_key = original_key[0 .. original_key.len - suffix.len];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// first test: look for exact matches
|
||||||
|
var matches: usize = 0;
|
||||||
|
var it = store.buffers.iterator();
|
||||||
|
while (it.next()) |entry| {
|
||||||
|
const key = entry.key_ptr.*;
|
||||||
|
if (std.mem.startsWith(u8, key, base_key)) {
|
||||||
|
if (matches == 0) log.warn("Similar buffers found:", .{});
|
||||||
|
if (!shown_keys.contains(key)) {
|
||||||
|
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
||||||
|
shown_keys.put(key, {}) catch continue;
|
||||||
|
matches += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// second test: progressive partial matches
|
||||||
|
if (matches == 0) {
|
||||||
|
var components = std.mem.splitScalar(u8, base_key, '.');
|
||||||
|
while (components.next()) |component| {
|
||||||
|
matches = 0;
|
||||||
|
it = store.buffers.iterator();
|
||||||
|
while (it.next()) |entry| {
|
||||||
|
const key = entry.key_ptr.*;
|
||||||
|
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
||||||
|
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
||||||
|
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
||||||
|
shown_keys.put(key, {}) catch continue;
|
||||||
|
matches += 1;
|
||||||
|
if (matches >= 5) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (matches > 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// deinit all buffers in the given struct
|
/// deinit all buffers in the given struct
|
||||||
pub fn awaitAll(buffers: anytype) !void {
|
pub fn awaitAll(buffers: anytype) !void {
|
||||||
// TODO: implement once we have async buffers.
|
// TODO: implement once we have async buffers.
|
||||||
@ -600,6 +652,10 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
buf_with_metadata._shape = obj._shape;
|
buf_with_metadata._shape = obj._shape;
|
||||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
obj.* = try zml.Buffer.from(platform, buf_with_metadata);
|
||||||
} else {
|
} else {
|
||||||
|
log.err("Buffer not found: {s}", .{prefix});
|
||||||
|
|
||||||
|
findSimilarBufferKeys(prefix, buffer_store, allocator);
|
||||||
|
|
||||||
return error.BufferNotFound;
|
return error.BufferNotFound;
|
||||||
};
|
};
|
||||||
} else if (T == zml.Shape) return;
|
} else if (T == zml.Shape) return;
|
||||||
|
|||||||
@ -31,6 +31,7 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer {
|
|||||||
.add_dummy_suffix = false,
|
.add_dummy_suffix = false,
|
||||||
.lower_case_ascii = false,
|
.lower_case_ascii = false,
|
||||||
.split_on_punct_ascii = false,
|
.split_on_punct_ascii = false,
|
||||||
|
.use_nfc = false,
|
||||||
},
|
},
|
||||||
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
|
if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null,
|
||||||
);
|
);
|
||||||
|
|||||||
@ -69,14 +69,14 @@ pub const Activation = union(enum) {
|
|||||||
.leakyReLU => |slope| x.leakyReLU(slope),
|
.leakyReLU => |slope| x.leakyReLU(slope),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub fn elu(x: Tensor, alpha: f32) Tensor {
|
pub fn elu(x: Tensor, alpha: f32) Tensor {
|
||||||
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
|
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
|
||||||
x,
|
x,
|
||||||
x.exp().addConstant(-1).scale(alpha),
|
x.exp().addConstant(-1).scale(alpha),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
|
pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
|
||||||
const T = @TypeOf(module_list);
|
const T = @TypeOf(module_list);
|
||||||
@ -765,8 +765,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
k = k.mul(head_scaling.convert(k.dtype()));
|
k = k.mul(head_scaling.convert(k.dtype()));
|
||||||
|
|
||||||
var attn_weights = q.dot(k, .{.hd});
|
var attn_weights = q.dot(k, .{.hd});
|
||||||
// log.debug("attn_weights : {}", .{attn_weights});
|
// log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
|
||||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||||
|
|
||||||
|
|||||||
@ -1262,8 +1262,15 @@ pub const Tensor = struct {
|
|||||||
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
|
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
|
||||||
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
|
const max_val = self.max(a);
|
||||||
|
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
|
||||||
|
|
||||||
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
|
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
|
||||||
return exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
|
const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
|
||||||
|
|
||||||
|
// If a row is full -inf return full 0 instead of full nan,
|
||||||
|
// this fix attention when mask hides a full row.
|
||||||
|
return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user