diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel new file mode 100644 index 0000000..5dc3d21 --- /dev/null +++ b/tools/BUILD.bazel @@ -0,0 +1,7 @@ +load("@rules_python//python:py_library.bzl", "py_library") + +py_library( + name = "zml_utils", + srcs = ["zml_utils.py"], + visibility = ["//visibility:public"], +) diff --git a/zml/aio.zig b/zml/aio.zig index 27abbd8..c9e5191 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -574,6 +574,58 @@ pub fn unloadBuffers(model: anytype) void { }).cb, {}, model); } +/// Assists in debuggigng `BufferNotFound` error +/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos) +fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void { + const suffixes = [_][]const u8{ "", ".weight", ".bias" }; + var shown_keys = std.StringHashMap(void).init(temp_allocator); + defer shown_keys.deinit(); + + // remove suffix .weight and .bias + var base_key = original_key; + for (suffixes) |suffix| { + if (std.mem.endsWith(u8, original_key, suffix)) { + base_key = original_key[0 .. original_key.len - suffix.len]; + break; + } + } + + // first test: look for exact matches + var matches: usize = 0; + var it = store.buffers.iterator(); + while (it.next()) |entry| { + const key = entry.key_ptr.*; + if (std.mem.startsWith(u8, key, base_key)) { + if (matches == 0) log.warn("Similar buffers found:", .{}); + if (!shown_keys.contains(key)) { + log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); + shown_keys.put(key, {}) catch continue; + matches += 1; + } + } + } + + // second test: progressive partial matches + if (matches == 0) { + var components = std.mem.splitScalar(u8, base_key, '.'); + while (components.next()) |component| { + matches = 0; + it = store.buffers.iterator(); + while (it.next()) |entry| { + const key = entry.key_ptr.*; + if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) { + if (matches == 0) log.warn("Partial matches for '{s}':", .{component}); + log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); + shown_keys.put(key, {}) catch continue; + matches += 1; + if (matches >= 5) break; + } + } + if (matches > 0) break; + } + } +} + /// deinit all buffers in the given struct pub fn awaitAll(buffers: anytype) !void { // TODO: implement once we have async buffers. @@ -600,6 +652,10 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi buf_with_metadata._shape = obj._shape; obj.* = try zml.Buffer.from(platform, buf_with_metadata); } else { + log.err("Buffer not found: {s}", .{prefix}); + + findSimilarBufferKeys(prefix, buffer_store, allocator); + return error.BufferNotFound; }; } else if (T == zml.Shape) return; diff --git a/zml/aio/sentencepiece.zig b/zml/aio/sentencepiece.zig index b459319..298d8fd 100644 --- a/zml/aio/sentencepiece.zig +++ b/zml/aio/sentencepiece.zig @@ -31,6 +31,7 @@ pub fn normalizerFromSpec(spec: sentencepiece_proto.NormalizerSpec) Normalizer { .add_dummy_suffix = false, .lower_case_ascii = false, .split_on_punct_ascii = false, + .use_nfc = false, }, if (spec.escape_whitespaces orelse false) Normalizer.sentencepiece_space else null, ); diff --git a/zml/nn.zig b/zml/nn.zig index 7d10a3a..6aac847 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -69,15 +69,15 @@ pub const Activation = union(enum) { .leakyReLU => |slope| x.leakyReLU(slope), }; } - - pub fn elu(x: Tensor, alpha: f32) Tensor { - return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select( - x, - x.exp().addConstant(-1).scale(alpha), - ); - } }; +pub fn elu(x: Tensor, alpha: f32) Tensor { + return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select( + x, + x.exp().addConstant(-1).scale(alpha), + ); +} + pub fn chainModules(module_list: anytype, input: Tensor) Tensor { const T = @TypeOf(module_list); switch (@typeInfo(T)) { @@ -765,8 +765,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { k = k.mul(head_scaling.convert(k.dtype())); var attn_weights = q.dot(k, .{.hd}); - // log.debug("attn_weights : {}", .{attn_weights}); - // log.debug("attn_mask : {?}", .{attn_mask}); + // log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask }); if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); diff --git a/zml/tensor.zig b/zml/tensor.zig index 15543d7..a385f25 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1262,8 +1262,15 @@ pub const Tensor = struct { /// Returns a Tensor containing the softmax function applied to each element of the input Tensor. pub fn softmax(self: Tensor, axis_: anytype) Tensor { const a = self.axis(axis_); + const max_val = self.max(a); + const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype())); + const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp(); - return exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape)); + const res = exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape)); + + // If a row is full -inf return full 0 instead of full nan, + // this fix attention when mask hides a full row. + return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype())); } /// Returns a Tensor containing the log of the sum of exponential over the given axis.