Update benchmark example to use new user sharding hints and drop deprecated module options.
This commit is contained in:
parent
8746a5ce78
commit
e30e35deeb
@ -5,11 +5,16 @@ const flags = @import("tigerbeetle/flags");
|
|||||||
|
|
||||||
const async_ = asynk.async_;
|
const async_ = asynk.async_;
|
||||||
|
|
||||||
|
// set log level to debug to print the generated IR
|
||||||
|
pub const std_options = .{
|
||||||
|
.log_level = .debug,
|
||||||
|
};
|
||||||
|
|
||||||
/// Model definition
|
/// Model definition
|
||||||
const Benchmark = struct {
|
const Benchmark = struct {
|
||||||
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
||||||
_ = self;
|
_ = self;
|
||||||
return a.dot(b, .{.k});
|
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -68,7 +73,7 @@ pub fn asyncMain() !void {
|
|||||||
deviceKind,
|
deviceKind,
|
||||||
});
|
});
|
||||||
// we only list 1 CPU device
|
// we only list 1 CPU device
|
||||||
if (target == .cpu) break;
|
if (target == .cpu and platform.sharding().num_partitions == 1) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user