zml: set iota default dtype to .i32, with fallback to .i64 for axes with many elements, simplifying usage.

This commit is contained in:
Tarry Singh 2023-06-15 12:45:52 +00:00
parent 344e07fb6e
commit b244a18621
2 changed files with 29 additions and 12 deletions

View File

@ -1509,6 +1509,10 @@ pub const RankedTensorType = struct {
pub fn getDimension(self: RankedTensorType, dim: usize) i64 { pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)); return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
} }
pub fn asType(self: RankedTensorType) Type {
return self.as(Type).?;
}
}; };
pub const Dialect = struct { pub const Dialect = struct {

View File

@ -1746,13 +1746,18 @@ pub const Tensor = struct {
} }
/// Returns a Tensor containing values in increasing order starting from 0 along the given axis. /// Returns a Tensor containing values in increasing order starting from 0 along the given axis.
pub fn iota(sh: Shape, dt: DataType, axis_: anytype) Tensor { ///
const ctx = CompilationContext.current(); /// The output dtype will be `.i32`, unless the given axis has a too big dimension, in that case we use `.i64`.
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "iota({}, {})", .{ sh, axis_ }); /// In most program this shouldn't matter, because typically this will be used in a comparison,
/// or explicitly converted by the user to do floating point arithmetic.
pub fn iota(sh: Shape, axis_: anytype) Tensor {
const a = sh.axis(axis_); const a = sh.axis(axis_);
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
const res_shape = sh.withDtype(dt); const res_shape = sh.withDtype(dt);
var op = dialect.stablehlo.iota(ctx.mlirCtx(), @intCast(a), mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?, loc); const mlir_ctx = CompilationContext.current().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ });
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));
} }
@ -2694,7 +2699,7 @@ pub const Tensor = struct {
pub const SortRes = ArgMaxRes; pub const SortRes = ArgMaxRes;
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices. /// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
pub fn sort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) SortRes { pub fn sort(self: Tensor, axis_: anytype, opts: struct { descending: bool = false }) SortRes {
const a = self.axis(axis_); const a = self.axis(axis_);
const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a}); const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a});
const res = ops.sort( const res = ops.sort(
@ -2823,7 +2828,7 @@ pub const Tensor = struct {
return ops.reduceWindow( return ops.reduceWindow(
MaxPoolRes.cmp, MaxPoolRes.cmp,
.{ .values = self, .indices = iota(self._shape, .i32, a) }, .{ .values = self, .indices = iota(self._shape, a) },
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) }, .{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
.{ .{
.window_dimensions = window_dimensions[0..self.rank()], .window_dimensions = window_dimensions[0..self.rank()],
@ -2861,7 +2866,7 @@ pub const Tensor = struct {
return ops.reduceWindow( return ops.reduceWindow(
MaxPoolRes.cmp, MaxPoolRes.cmp,
.{ .values = self, .indices = iota(self._shape, .i32, a) }, .{ .values = self, .indices = iota(self._shape, a) },
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) }, .{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
.{ .{
.window_dimensions = window_dimensions[0..self.rank()], .window_dimensions = window_dimensions[0..self.rank()],
@ -3336,8 +3341,8 @@ pub const Tensor = struct {
const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape); const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape);
const zeros = Tensor.constant(res_shape, self.dtype().zero()); const zeros = Tensor.constant(res_shape, self.dtype().zero());
const x = Tensor.iota(res_shape, .i32, a); const x = Tensor.iota(res_shape, a);
const y = Tensor.iota(res_shape, .i32, a + 1); const y = Tensor.iota(res_shape, a + 1);
var res = x.cmp(.EQ, y).select(values, zeros); var res = x.cmp(.EQ, y).select(values, zeros);
res._shape = res_shape; res._shape = res_shape;
return res; return res;
@ -3386,8 +3391,8 @@ pub const Tensor = struct {
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
const _axes = self.axes(axes_); const _axes = self.axes(axes_);
const x = Tensor.iota(self.shape(), .i32, _axes.get(0)); const x = Tensor.iota(self.shape(), _axes.get(0));
const y = Tensor.iota(self.shape(), .i32, _axes.get(1)); const y = Tensor.iota(self.shape(), _axes.get(1));
const zeros = Tensor.constant(self.shape(), self.dtype().zero()); const zeros = Tensor.constant(self.shape(), self.dtype().zero());
return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros); return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros);
@ -3448,6 +3453,14 @@ pub const Tensor = struct {
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor { pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()}); meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() }); meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
if (bool_tensor.rank() != 0 and on_true.rank() == 0) {
return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false);
}
if (bool_tensor.rank() != 0 and on_false.rank() == 0) {
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
}
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });