Update MNIST and ModernBERT example scripts to use the new gather and topk APIs.
This commit is contained in:
parent
e641d05dd2
commit
7264fff493
@ -27,7 +27,7 @@ const Mnist = struct {
|
||||
/// just two linear layers + relu activation
|
||||
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
|
||||
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
|
||||
var x = input.flattenAll().convert(.f32);
|
||||
var x = input.flatten().convert(.f32);
|
||||
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
||||
for (layers) |layer| {
|
||||
x = layer.forward(x);
|
||||
|
||||
@ -43,7 +43,7 @@ pub const ModernBertForMaskedLM = struct {
|
||||
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
|
||||
|
||||
const probabilities = biased_logits.softmax(.voc);
|
||||
return probabilities.topK(5, .voc, .{ .descending = true });
|
||||
return probabilities.topK(.{ .best_words = .voc }, 5, .{ .descending = true });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user