Update MNIST example (BUILD.bazel and mnist.zig) to remove torch loader references.
This commit is contained in:
parent
1cf26756a1
commit
01db09c24b
@ -4,11 +4,11 @@ load("@rules_zig//zig:defs.bzl", "zig_binary")
|
||||
zig_binary(
|
||||
name = "mnist",
|
||||
args = [
|
||||
"$(location @mnist//:mnist.pt)",
|
||||
"$(location @mnist//:mnist.safetensors)",
|
||||
"$(location @mnist//:t10k-images.idx3-ubyte)",
|
||||
],
|
||||
data = [
|
||||
"@mnist//:mnist.pt",
|
||||
"@mnist//:mnist.safetensors",
|
||||
"@mnist//:t10k-images.idx3-ubyte",
|
||||
],
|
||||
main = "mnist.zig",
|
||||
|
||||
@ -30,7 +30,7 @@ const Mnist = struct {
|
||||
var x = input.flattenAll().convert(.f32);
|
||||
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
||||
for (layers) |layer| {
|
||||
x = zml.call(layer, .forward, .{x});
|
||||
x = layer.forward(x);
|
||||
}
|
||||
return x.argMax(0).indices.convert(.u8);
|
||||
}
|
||||
@ -66,7 +66,7 @@ pub fn asyncMain() !void {
|
||||
|
||||
// Read model shapes.
|
||||
// Note this works because Mnist struct uses the same layer names as the pytorch model
|
||||
var buffer_store = try zml.aio.torch.open(allocator, pt_model);
|
||||
var buffer_store = try zml.aio.detectFormatAndOpen(allocator, pt_model);
|
||||
defer buffer_store.deinit();
|
||||
|
||||
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user