Update MNIST and ModernBERT example scripts to use the new gather and topk APIs.
This commit is contained in:
parent
e641d05dd2
commit
7264fff493
@ -27,7 +27,7 @@ const Mnist = struct {
|
|||||||
/// just two linear layers + relu activation
|
/// just two linear layers + relu activation
|
||||||
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
|
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
|
||||||
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
|
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
|
||||||
var x = input.flattenAll().convert(.f32);
|
var x = input.flatten().convert(.f32);
|
||||||
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
||||||
for (layers) |layer| {
|
for (layers) |layer| {
|
||||||
x = layer.forward(x);
|
x = layer.forward(x);
|
||||||
|
|||||||
@ -43,7 +43,7 @@ pub const ModernBertForMaskedLM = struct {
|
|||||||
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
|
const biased_logits = logits.add(self.decoder.bias.withTags(.{.voc}).broad(logits.shape()));
|
||||||
|
|
||||||
const probabilities = biased_logits.softmax(.voc);
|
const probabilities = biased_logits.softmax(.voc);
|
||||||
return probabilities.topK(5, .voc, .{ .descending = true });
|
return probabilities.topK(.{ .best_words = .voc }, 5, .{ .descending = true });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user