diff --git a/mlir/mlir.zig b/mlir/mlir.zig index de5087a..1bb1727 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -1509,6 +1509,10 @@ pub const RankedTensorType = struct { pub fn getDimension(self: RankedTensorType, dim: usize) i64 { return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)); } + + pub fn asType(self: RankedTensorType) Type { + return self.as(Type).?; + } }; pub const Dialect = struct { diff --git a/zml/tensor.zig b/zml/tensor.zig index 554aa4b..ba79165 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1746,13 +1746,18 @@ pub const Tensor = struct { } /// Returns a Tensor containing values in increasing order starting from 0 along the given axis. - pub fn iota(sh: Shape, dt: DataType, axis_: anytype) Tensor { - const ctx = CompilationContext.current(); - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "iota({}, {})", .{ sh, axis_ }); - + /// + /// The output dtype will be `.i32`, unless the given axis has a too big dimension, in that case we use `.i64`. + /// In most program this shouldn't matter, because typically this will be used in a comparison, + /// or explicitly converted by the user to do floating point arithmetic. + pub fn iota(sh: Shape, axis_: anytype) Tensor { const a = sh.axis(axis_); + const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const res_shape = sh.withDtype(dt); - var op = dialect.stablehlo.iota(ctx.mlirCtx(), @intCast(a), mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?, loc); + const mlir_ctx = CompilationContext.current().mlirCtx(); + const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ }); + + var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); return _result(res_shape, op.result(0)); } @@ -2694,7 +2699,7 @@ pub const Tensor = struct { pub const SortRes = ArgMaxRes; /// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices. - pub fn sort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) SortRes { + pub fn sort(self: Tensor, axis_: anytype, opts: struct { descending: bool = false }) SortRes { const a = self.axis(axis_); const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a}); const res = ops.sort( @@ -2823,7 +2828,7 @@ pub const Tensor = struct { return ops.reduceWindow( MaxPoolRes.cmp, - .{ .values = self, .indices = iota(self._shape, .i32, a) }, + .{ .values = self, .indices = iota(self._shape, a) }, .{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) }, .{ .window_dimensions = window_dimensions[0..self.rank()], @@ -2861,7 +2866,7 @@ pub const Tensor = struct { return ops.reduceWindow( MaxPoolRes.cmp, - .{ .values = self, .indices = iota(self._shape, .i32, a) }, + .{ .values = self, .indices = iota(self._shape, a) }, .{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) }, .{ .window_dimensions = window_dimensions[0..self.rank()], @@ -3336,8 +3341,8 @@ pub const Tensor = struct { const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape); const zeros = Tensor.constant(res_shape, self.dtype().zero()); - const x = Tensor.iota(res_shape, .i32, a); - const y = Tensor.iota(res_shape, .i32, a + 1); + const x = Tensor.iota(res_shape, a); + const y = Tensor.iota(res_shape, a + 1); var res = x.cmp(.EQ, y).select(values, zeros); res._shape = res_shape; return res; @@ -3386,8 +3391,8 @@ pub const Tensor = struct { meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); const _axes = self.axes(axes_); - const x = Tensor.iota(self.shape(), .i32, _axes.get(0)); - const y = Tensor.iota(self.shape(), .i32, _axes.get(1)); + const x = Tensor.iota(self.shape(), _axes.get(0)); + const y = Tensor.iota(self.shape(), _axes.get(1)); const zeros = Tensor.constant(self.shape(), self.dtype().zero()); return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros); @@ -3448,6 +3453,14 @@ pub const Tensor = struct { pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor { meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()}); meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() }); + + if (bool_tensor.rank() != 0 and on_true.rank() == 0) { + return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false); + } + if (bool_tensor.rank() != 0 and on_false.rank() == 0) { + return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape())); + } + meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });