From 01db09c24b8b6822fa10375adbe8925b5b57c8fe Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Wed, 23 Jul 2025 12:53:46 +0000 Subject: [PATCH] Update MNIST example (BUILD.bazel and mnist.zig) to remove torch loader references. --- examples/mnist/BUILD.bazel | 4 ++-- examples/mnist/mnist.zig | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mnist/BUILD.bazel b/examples/mnist/BUILD.bazel index dfc635b..11ac0c6 100644 --- a/examples/mnist/BUILD.bazel +++ b/examples/mnist/BUILD.bazel @@ -4,11 +4,11 @@ load("@rules_zig//zig:defs.bzl", "zig_binary") zig_binary( name = "mnist", args = [ - "$(location @mnist//:mnist.pt)", + "$(location @mnist//:mnist.safetensors)", "$(location @mnist//:t10k-images.idx3-ubyte)", ], data = [ - "@mnist//:mnist.pt", + "@mnist//:mnist.safetensors", "@mnist//:t10k-images.idx3-ubyte", ], main = "mnist.zig", diff --git a/examples/mnist/mnist.zig b/examples/mnist/mnist.zig index d4cfa32..5c85f5b 100644 --- a/examples/mnist/mnist.zig +++ b/examples/mnist/mnist.zig @@ -30,7 +30,7 @@ const Mnist = struct { var x = input.flattenAll().convert(.f32); const layers: []const Layer = &.{ self.fc1, self.fc2 }; for (layers) |layer| { - x = zml.call(layer, .forward, .{x}); + x = layer.forward(x); } return x.argMax(0).indices.convert(.u8); } @@ -66,7 +66,7 @@ pub fn asyncMain() !void { // Read model shapes. // Note this works because Mnist struct uses the same layer names as the pytorch model - var buffer_store = try zml.aio.torch.open(allocator, pt_model); + var buffer_store = try zml.aio.detectFormatAndOpen(allocator, pt_model); defer buffer_store.deinit(); const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);