Update MNIST and ModernBERT example scripts to use the new gather and topk APIs.

This commit is contained in:
Foke Singh 2025-09-24 15:42:09 +00:00
parent e641d05dd2
commit 7264fff493
2 changed files with 2 additions and 2 deletions

View File

@ -27,7 +27,7 @@ const Mnist = struct {
/// just two linear layers + relu activation /// just two linear layers + relu activation
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor { pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())}); // std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
var x = input.flattenAll().convert(.f32); var x = input.flatten().convert(.f32);
const layers: []const Layer = &.{ self.fc1, self.fc2 }; const layers: []const Layer = &.{ self.fc1, self.fc2 };
for (layers) |layer| { for (layers) |layer| {
x = layer.forward(x); x = layer.forward(x);

View File

@ -43,7 +43,7 @@ pub const ModernBertForMaskedLM = struct {
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape())); const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
const probabilities = biased_logits.softmax(.voc); const probabilities = biased_logits.softmax(.voc);
return probabilities.topK(5, .voc, .{ .descending = true }); return probabilities.topK(.{ .best_words = .voc }, 5, .{ .descending = true });
} }
}; };