Update MNIST example (BUILD.bazel and mnist.zig) to remove torch loader references.
This commit is contained in:
parent
1cf26756a1
commit
01db09c24b
@ -4,11 +4,11 @@ load("@rules_zig//zig:defs.bzl", "zig_binary")
|
|||||||
zig_binary(
|
zig_binary(
|
||||||
name = "mnist",
|
name = "mnist",
|
||||||
args = [
|
args = [
|
||||||
"$(location @mnist//:mnist.pt)",
|
"$(location @mnist//:mnist.safetensors)",
|
||||||
"$(location @mnist//:t10k-images.idx3-ubyte)",
|
"$(location @mnist//:t10k-images.idx3-ubyte)",
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"@mnist//:mnist.pt",
|
"@mnist//:mnist.safetensors",
|
||||||
"@mnist//:t10k-images.idx3-ubyte",
|
"@mnist//:t10k-images.idx3-ubyte",
|
||||||
],
|
],
|
||||||
main = "mnist.zig",
|
main = "mnist.zig",
|
||||||
|
|||||||
@ -30,7 +30,7 @@ const Mnist = struct {
|
|||||||
var x = input.flattenAll().convert(.f32);
|
var x = input.flattenAll().convert(.f32);
|
||||||
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
||||||
for (layers) |layer| {
|
for (layers) |layer| {
|
||||||
x = zml.call(layer, .forward, .{x});
|
x = layer.forward(x);
|
||||||
}
|
}
|
||||||
return x.argMax(0).indices.convert(.u8);
|
return x.argMax(0).indices.convert(.u8);
|
||||||
}
|
}
|
||||||
@ -66,7 +66,7 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
// Read model shapes.
|
// Read model shapes.
|
||||||
// Note this works because Mnist struct uses the same layer names as the pytorch model
|
// Note this works because Mnist struct uses the same layer names as the pytorch model
|
||||||
var buffer_store = try zml.aio.torch.open(allocator, pt_model);
|
var buffer_store = try zml.aio.detectFormatAndOpen(allocator, pt_model);
|
||||||
defer buffer_store.deinit();
|
defer buffer_store.deinit();
|
||||||
|
|
||||||
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user