Main issue with current `scatter` implementation is that it uses
broadcasting dims of `stablehlo.scatter`.
While nice in theory, the optimizer doesn't handle them well and they
often are unrolled into while loop.
Here I convert the batching dim to extra iotas indices.