zml: set iota default dtype to .i32, with fallback to .i64 for axes with many elements, simplifying usage.
This commit is contained in:
parent
344e07fb6e
commit
b244a18621
@ -1509,6 +1509,10 @@ pub const RankedTensorType = struct {
|
|||||||
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
||||||
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
|
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asType(self: RankedTensorType) Type {
|
||||||
|
return self.as(Type).?;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Dialect = struct {
|
pub const Dialect = struct {
|
||||||
|
|||||||
@ -1746,13 +1746,18 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing values in increasing order starting from 0 along the given axis.
|
/// Returns a Tensor containing values in increasing order starting from 0 along the given axis.
|
||||||
pub fn iota(sh: Shape, dt: DataType, axis_: anytype) Tensor {
|
///
|
||||||
const ctx = CompilationContext.current();
|
/// The output dtype will be `.i32`, unless the given axis has a too big dimension, in that case we use `.i64`.
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "iota({}, {})", .{ sh, axis_ });
|
/// In most program this shouldn't matter, because typically this will be used in a comparison,
|
||||||
|
/// or explicitly converted by the user to do floating point arithmetic.
|
||||||
|
pub fn iota(sh: Shape, axis_: anytype) Tensor {
|
||||||
const a = sh.axis(axis_);
|
const a = sh.axis(axis_);
|
||||||
|
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||||
const res_shape = sh.withDtype(dt);
|
const res_shape = sh.withDtype(dt);
|
||||||
var op = dialect.stablehlo.iota(ctx.mlirCtx(), @intCast(a), mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?, loc);
|
const mlir_ctx = CompilationContext.current().mlirCtx();
|
||||||
|
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ });
|
||||||
|
|
||||||
|
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2694,7 +2699,7 @@ pub const Tensor = struct {
|
|||||||
pub const SortRes = ArgMaxRes;
|
pub const SortRes = ArgMaxRes;
|
||||||
|
|
||||||
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
|
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
|
||||||
pub fn sort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) SortRes {
|
pub fn sort(self: Tensor, axis_: anytype, opts: struct { descending: bool = false }) SortRes {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a});
|
const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a});
|
||||||
const res = ops.sort(
|
const res = ops.sort(
|
||||||
@ -2823,7 +2828,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
return ops.reduceWindow(
|
return ops.reduceWindow(
|
||||||
MaxPoolRes.cmp,
|
MaxPoolRes.cmp,
|
||||||
.{ .values = self, .indices = iota(self._shape, .i32, a) },
|
.{ .values = self, .indices = iota(self._shape, a) },
|
||||||
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
||||||
.{
|
.{
|
||||||
.window_dimensions = window_dimensions[0..self.rank()],
|
.window_dimensions = window_dimensions[0..self.rank()],
|
||||||
@ -2861,7 +2866,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
return ops.reduceWindow(
|
return ops.reduceWindow(
|
||||||
MaxPoolRes.cmp,
|
MaxPoolRes.cmp,
|
||||||
.{ .values = self, .indices = iota(self._shape, .i32, a) },
|
.{ .values = self, .indices = iota(self._shape, a) },
|
||||||
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
||||||
.{
|
.{
|
||||||
.window_dimensions = window_dimensions[0..self.rank()],
|
.window_dimensions = window_dimensions[0..self.rank()],
|
||||||
@ -3336,8 +3341,8 @@ pub const Tensor = struct {
|
|||||||
const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape);
|
const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape);
|
||||||
const zeros = Tensor.constant(res_shape, self.dtype().zero());
|
const zeros = Tensor.constant(res_shape, self.dtype().zero());
|
||||||
|
|
||||||
const x = Tensor.iota(res_shape, .i32, a);
|
const x = Tensor.iota(res_shape, a);
|
||||||
const y = Tensor.iota(res_shape, .i32, a + 1);
|
const y = Tensor.iota(res_shape, a + 1);
|
||||||
var res = x.cmp(.EQ, y).select(values, zeros);
|
var res = x.cmp(.EQ, y).select(values, zeros);
|
||||||
res._shape = res_shape;
|
res._shape = res_shape;
|
||||||
return res;
|
return res;
|
||||||
@ -3386,8 +3391,8 @@ pub const Tensor = struct {
|
|||||||
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
|
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
|
||||||
const _axes = self.axes(axes_);
|
const _axes = self.axes(axes_);
|
||||||
|
|
||||||
const x = Tensor.iota(self.shape(), .i32, _axes.get(0));
|
const x = Tensor.iota(self.shape(), _axes.get(0));
|
||||||
const y = Tensor.iota(self.shape(), .i32, _axes.get(1));
|
const y = Tensor.iota(self.shape(), _axes.get(1));
|
||||||
|
|
||||||
const zeros = Tensor.constant(self.shape(), self.dtype().zero());
|
const zeros = Tensor.constant(self.shape(), self.dtype().zero());
|
||||||
return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros);
|
return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros);
|
||||||
@ -3448,6 +3453,14 @@ pub const Tensor = struct {
|
|||||||
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
||||||
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
|
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
|
||||||
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
|
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
|
||||||
|
|
||||||
|
if (bool_tensor.rank() != 0 and on_true.rank() == 0) {
|
||||||
|
return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false);
|
||||||
|
}
|
||||||
|
if (bool_tensor.rank() != 0 and on_false.rank() == 0) {
|
||||||
|
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
|
||||||
|
}
|
||||||
|
|
||||||
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
||||||
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user