diff --git a/examples/benchmark/main.zig b/examples/benchmark/main.zig index 8d7604e..cbbe330 100644 --- a/examples/benchmark/main.zig +++ b/examples/benchmark/main.zig @@ -5,11 +5,16 @@ const flags = @import("tigerbeetle/flags"); const async_ = asynk.async_; +// set log level to debug to print the generated IR +pub const std_options = .{ + .log_level = .debug, +}; + /// Model definition const Benchmark = struct { pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor { _ = self; - return a.dot(b, .{.k}); + return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m}); } }; @@ -68,7 +73,7 @@ pub fn asyncMain() !void { deviceKind, }); // we only list 1 CPU device - if (target == .cpu) break; + if (target == .cpu and platform.sharding().num_partitions == 1) break; } } }