Update MNIST example (BUILD.bazel and mnist.zig) to remove torch loader references.

This commit is contained in:
Foke Singh 2025-07-23 12:53:46 +00:00
parent 1cf26756a1
commit 01db09c24b
2 changed files with 4 additions and 4 deletions

View File

@ -4,11 +4,11 @@ load("@rules_zig//zig:defs.bzl", "zig_binary")
zig_binary( zig_binary(
name = "mnist", name = "mnist",
args = [ args = [
"$(location @mnist//:mnist.pt)", "$(location @mnist//:mnist.safetensors)",
"$(location @mnist//:t10k-images.idx3-ubyte)", "$(location @mnist//:t10k-images.idx3-ubyte)",
], ],
data = [ data = [
"@mnist//:mnist.pt", "@mnist//:mnist.safetensors",
"@mnist//:t10k-images.idx3-ubyte", "@mnist//:t10k-images.idx3-ubyte",
], ],
main = "mnist.zig", main = "mnist.zig",

View File

@ -30,7 +30,7 @@ const Mnist = struct {
var x = input.flattenAll().convert(.f32); var x = input.flattenAll().convert(.f32);
const layers: []const Layer = &.{ self.fc1, self.fc2 }; const layers: []const Layer = &.{ self.fc1, self.fc2 };
for (layers) |layer| { for (layers) |layer| {
x = zml.call(layer, .forward, .{x}); x = layer.forward(x);
} }
return x.argMax(0).indices.convert(.u8); return x.argMax(0).indices.convert(.u8);
} }
@ -66,7 +66,7 @@ pub fn asyncMain() !void {
// Read model shapes. // Read model shapes.
// Note this works because Mnist struct uses the same layer names as the pytorch model // Note this works because Mnist struct uses the same layer names as the pytorch model
var buffer_store = try zml.aio.torch.open(allocator, pt_model); var buffer_store = try zml.aio.detectFormatAndOpen(allocator, pt_model);
defer buffer_store.deinit(); defer buffer_store.deinit();
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store); const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);