Main issue with current `scatter` implementation is that it uses broadcasting dims of `stablehlo.scatter`. While nice in theory, the optimizer doesn't handle them well and they often are unrolled into while loop. Here I convert the batching dim to extra iotas indices. |
||
|---|---|---|
| async | ||
| bazel | ||
| docs | ||
| examples | ||
| mlir | ||
| pjrt | ||
| platforms | ||
| runtimes | ||
| stdx | ||
| third_party | ||
| tools | ||
| zml | ||
| BUILD.bazel | ||
| build.zig | ||
| MODULE.bazel | ||
| MODULE.bazel.lock | ||
| platform_mappings | ||