3733 lines
174 KiB
Zig
3733 lines
174 KiB
Zig
const builtin = @import("builtin");
|
||
const std = @import("std");
|
||
const assert = std.debug.assert;
|
||
const testing = std.testing;
|
||
|
||
const meta = @import("meta.zig");
|
||
const mlir = @import("mlir.zig");
|
||
const ops = @import("ops.zig");
|
||
const module = @import("module.zig");
|
||
|
||
const Location = mlir.Location;
|
||
const CompilationContext = module.CompilationContext;
|
||
const Shape = @import("shape.zig").Shape;
|
||
const Buffer = @import("buffer.zig").Buffer;
|
||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||
const Data = @import("dtype.zig").Data;
|
||
const DataType = @import("dtype.zig").DataType;
|
||
const Platform = @import("platform.zig").Platform;
|
||
const EnumLiteral = @TypeOf(.enum_literal);
|
||
|
||
const dialect = struct {
|
||
const stablehlo = @import("mlir/dialects").stablehlo;
|
||
};
|
||
|
||
const scoped_log = std.log.scoped(.zml_tensor);
|
||
|
||
test {
|
||
std.testing.refAllDecls(Tensor);
|
||
}
|
||
|
||
/// Represents an abstract Tensor object, which can be the input,
|
||
/// output, weight or activations of a neural network.
|
||
/// Tensor are abstract in the sense they only represent a computation,
|
||
/// but not a specific memory buffer.
|
||
/// Tensor namespace contains most of linear algebra needed to
|
||
/// represent mathematical operations.
|
||
/// More operations are available in `zml.nn` and `zml.torch` namespaces.
|
||
pub const Tensor = struct {
|
||
_shape: Shape,
|
||
_id: _Id,
|
||
_donation: _Donation = .no_buffer,
|
||
|
||
pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 };
|
||
pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 };
|
||
|
||
pub const MAX_RANK = Shape.MAX_RANK;
|
||
|
||
/// Returns the current compilation context.
|
||
pub fn getContext(self: Tensor) *CompilationContext {
|
||
_ = self;
|
||
return CompilationContext.current();
|
||
}
|
||
|
||
pub fn format(
|
||
self: Tensor,
|
||
comptime fmt: []const u8,
|
||
options: std.fmt.FormatOptions,
|
||
writer: anytype,
|
||
) !void {
|
||
_ = fmt;
|
||
_ = options;
|
||
try writer.print("Tensor({_})", .{self._shape});
|
||
}
|
||
|
||
/// Returns the shape of a Tensor.
|
||
pub fn shape(self: Tensor) Shape {
|
||
return self._shape;
|
||
}
|
||
|
||
/// Returns the datatype of a Tensor.
|
||
pub fn dtype(self: Tensor) DataType {
|
||
return self._shape.dtype();
|
||
}
|
||
|
||
/// Returns the rank of a Tensor.
|
||
pub inline fn rank(self: Tensor) u4 {
|
||
return self._shape.rank();
|
||
}
|
||
|
||
/// Returns the number of element of a Tensor.
|
||
pub fn count(self: Tensor) usize {
|
||
return self._shape.count();
|
||
}
|
||
|
||
/// Returns the size in bytes of a Tensor.
|
||
pub fn byteSize(self: Tensor) usize {
|
||
return self._shape.byteSize();
|
||
}
|
||
|
||
/// Internal use
|
||
///
|
||
/// Creates a tensor from a Shape and an mlir.Value.
|
||
pub fn _result(sh: Shape, val: mlir.Value) Tensor {
|
||
const res: Tensor = .{
|
||
._shape = sh,
|
||
._id = .{ .mlir = val },
|
||
};
|
||
|
||
if (builtin.mode == .Debug) {
|
||
// Check that the MLIR value actually have the same shape.
|
||
const other = fromMlirValue(val);
|
||
meta.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape });
|
||
}
|
||
|
||
return res;
|
||
}
|
||
|
||
/// Creates a Tensor from a mlir.Value
|
||
///
|
||
/// The shape is derived from the type of the mlir.Value.
|
||
pub fn fromMlirValue(val: mlir.Value) Tensor {
|
||
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
|
||
const n = ranked_tensor.getRank();
|
||
|
||
meta.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
||
|
||
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
|
||
for (0..n) |i| {
|
||
sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i));
|
||
}
|
||
sh._tags.resize(n) catch unreachable;
|
||
|
||
return .{ ._shape = sh, ._id = .{ .mlir = val } };
|
||
}
|
||
|
||
/// Returns the dimension of axis 'axis_'.
|
||
///
|
||
/// 'axis_' can be a signed integer or a tag.
|
||
pub fn dim(self: Tensor, axis_: anytype) i64 {
|
||
return self._shape.dim(axis_);
|
||
}
|
||
|
||
/// Returns the dimensions of a Tensor as a slice.
|
||
pub fn dims(self: *const Tensor) []const i64 {
|
||
return self._shape.dims();
|
||
}
|
||
|
||
/// Returns the index of axis 'axis_'.
|
||
///
|
||
/// 'axis_' can be a signed integer or a tag.
|
||
pub fn axis(self: Tensor, axis_: anytype) u3 {
|
||
return self._shape.axis(axis_);
|
||
}
|
||
|
||
/// Returns a Tensor tagged with the tags in 'tagz'.
|
||
pub fn withTags(self: Tensor, tagz: anytype) Tensor {
|
||
var res = self;
|
||
res._shape = self._shape.withTags(tagz);
|
||
return res;
|
||
}
|
||
|
||
/// Returns a Tensor tagged partially with the tags in 'tagz'.
|
||
///
|
||
/// If 'tagz' is of length n, the n last dimensions of the Tensor will be tagged.
|
||
pub fn withPartialTags(self: Tensor, tagz: anytype) Tensor {
|
||
var res = self;
|
||
res._shape = self._shape.withPartialTags(tagz);
|
||
return res;
|
||
}
|
||
|
||
pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
|
||
return switch (self._id) {
|
||
.arg_id, .mlir => {
|
||
const ctx = self.getContext();
|
||
var res = self;
|
||
res._shape = self._shape.withSharding(axes_);
|
||
|
||
const sharding = ctx.getShardingAttr(res._shape);
|
||
|
||
const op = dialect.stablehlo.sharding(
|
||
ctx.mlirCtx(),
|
||
&.{self.value()},
|
||
sharding,
|
||
&.{self.value().getType()},
|
||
ctx.mlirCtx().location(@src()),
|
||
);
|
||
|
||
return _result(res._shape, op.result(0));
|
||
},
|
||
.buffer_id => {
|
||
var res = self;
|
||
res._shape = self._shape.withSharding(axes_);
|
||
return res;
|
||
},
|
||
};
|
||
}
|
||
|
||
/// Returns a Tensor with new tag names.
|
||
pub fn rename(self: Tensor, renames: anytype) Tensor {
|
||
var res = self;
|
||
res._shape = self._shape.rename(renames);
|
||
return res;
|
||
}
|
||
|
||
pub fn renameAxis(self: Tensor, ax: i8, name: EnumLiteral) Tensor {
|
||
var res = self;
|
||
res._shape._tags.set(self.axis(ax), @tagName(name).ptr);
|
||
return res;
|
||
}
|
||
|
||
/// Returns the mlir.Value associated with the Tensor.
|
||
///
|
||
/// This will fail if used outside of a compilation context.
|
||
pub fn value(self: Tensor) mlir.Value {
|
||
return self.getContext().getValueAndDonation(self)[0];
|
||
}
|
||
/// Tell PJRT compiler that memory should be reuse between the two tensors.
|
||
/// The compiler is already aggressively reusing tensors for intermediate results,
|
||
/// but this API allows to reuse buffer between input and output arguments
|
||
/// of a given function.
|
||
/// Note this is visible from the outside. The caller of a function with donations
|
||
/// is not allowed to reuse the donated input buffer after the call.
|
||
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
|
||
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
|
||
// Note: check donation docs, this may be too permissive.
|
||
meta.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin });
|
||
|
||
// TODO: should we store all donations inside the context ?
|
||
var res = self;
|
||
res._donation = self.getContext().getValueAndDonation(origin)[1];
|
||
|
||
return res;
|
||
}
|
||
|
||
var _global_tensor_counter: u64 = 0;
|
||
|
||
/// Internal use
|
||
pub fn reserveIdRange(len: u32) u64 {
|
||
return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst);
|
||
}
|
||
|
||
/// Internal use
|
||
pub fn setUniqueId(self: *Tensor) void {
|
||
self._id = .{ .buffer_id = reserveIdRange(1) };
|
||
}
|
||
|
||
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
|
||
pub fn abs(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.abs(self.getContext().mlirCtx(), self.value(), loc);
|
||
const dt = switch (self.dtype()) {
|
||
.c64 => .f32,
|
||
.c128 => .f64,
|
||
else => self.dtype(),
|
||
};
|
||
|
||
return _result(self._shape.withDtype(dt), op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor whose elements have been bitcast to a target datatype.
|
||
///
|
||
/// The Tensor shape needs to be compatible with the target datatype.
|
||
pub fn bitCast(self: Tensor, dt: DataType) Tensor {
|
||
const src_bit_size = self.dtype().bitSizeOf();
|
||
const tgt_bit_size = dt.bitSizeOf();
|
||
var res_shape = if (src_bit_size == tgt_bit_size)
|
||
self._shape
|
||
else if (src_bit_size > tgt_bit_size) gt: {
|
||
const new_dim = std.math.divExact(u16, src_bit_size, tgt_bit_size) catch std.debug.panic("bitcast expects target datatype to be a multiple of source datatype when upcasting, got {} (bitsize of {}) and {} (bitsize of {})", .{ self.dtype(), src_bit_size, dt, tgt_bit_size });
|
||
var res = self._shape;
|
||
res = res.append(.{ .bitcast = new_dim });
|
||
break :gt res;
|
||
} else lt: {
|
||
// several contiguous elements of self maps to one element of the result
|
||
meta.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size });
|
||
break :lt self._shape.remove(-1);
|
||
};
|
||
|
||
res_shape = res_shape.withDtype(dt);
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "bitCast({})", .{dt});
|
||
const op = dialect.stablehlo.bitcast_convert(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?,
|
||
loc,
|
||
);
|
||
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
|
||
pub fn countLeadingZeros(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.count_leading_zeros(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing booleans indicating if each element of the input Tensor is finite.
|
||
pub fn isFinite(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.is_finite(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape.withDtype(.bool), op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise logistic operation on the input Tensor.
|
||
pub fn logistic(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
||
pub fn popcnt(self: Tensor) Tensor {
|
||
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the sign of the input Tensor element-wise.
|
||
pub fn sign(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.sign(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||
pub fn remainder(self: Tensor, other: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.remainder(self.getContext().mlirCtx(), self.value(), other.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||
///
|
||
/// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details.
|
||
pub fn fmod(self: Tensor, divisor: f32) Tensor {
|
||
return self.remainder(Tensor.scalar(divisor, .f32).broadcast(self._shape, &.{}));
|
||
}
|
||
|
||
test fmod {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const inputs: [2][6]f32 = .{ .{ -3.0, -2, -1, 1, 2, 3 }, .{ 1, 2, 3, 4, 5, -5 } };
|
||
const expectations: [2][6]f32 = .{ .{ -1.0, -0.0, -1.0, 1.0, 0.0, 1.0 }, .{ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000, -0.5000 } };
|
||
const divisors: [2]f32 = .{ 2, -1.5 };
|
||
|
||
inline for (inputs, expectations, divisors) |i, e, d| {
|
||
const input = try zml.Buffer.fromSlice(platform, .{6}, &i);
|
||
const output = try zml.testing.compileAndCall(platform, Tensor.fmod, .{ input, d });
|
||
|
||
try zml.testing.expectClose(zml.HostBuffer.fromSlice(.{6}, &e), output, 1e-4);
|
||
}
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
|
||
pub fn shiftLeft(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'.
|
||
pub fn shiftRightArithmetic(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'.
|
||
pub fn shiftRightLogical(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other);
|
||
}
|
||
|
||
/// Returns the Cholesky decomposition of the input Tensor.
|
||
///
|
||
/// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true
|
||
/// and upper-triangular otherwise.
|
||
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
||
meta.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
|
||
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Solves the system of linear equations formed by the input tensors.
|
||
pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor {
|
||
meta.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||
meta.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
||
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor.
|
||
pub fn roundNearestAfz(self: Tensor) Tensor {
|
||
meta.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()});
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor.
|
||
pub fn roundNearestEven(self: Tensor) Tensor {
|
||
meta.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()});
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
|
||
pub fn complex(re: Tensor, im: Tensor) Tensor {
|
||
meta.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape });
|
||
meta.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
|
||
|
||
const loc = re.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc);
|
||
const dt: DataType = if (re.dtype() == .f32) .c64 else .c128;
|
||
return _result(re._shape.withDtype(dt), op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise real part of the input Tensor.
|
||
///
|
||
/// Tensor type can float or complex.
|
||
pub fn real(self: Tensor) Tensor {
|
||
meta.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
||
|
||
if (self.dtype().isFloat()) {
|
||
return self;
|
||
}
|
||
|
||
const dt: DataType = switch (self.dtype()) {
|
||
.c64 => .f32,
|
||
.c128 => .f64,
|
||
else => unreachable,
|
||
};
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.real(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape.withDtype(dt), op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise imaginary part of the input Tensor.
|
||
///
|
||
/// Tensor type can float or complex.
|
||
pub fn imag(self: Tensor) Tensor {
|
||
meta.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()});
|
||
|
||
// Real tensors don't have imaginary part.
|
||
if (self.dtype().isFloat()) {
|
||
return Tensor.constant(self._shape, self.dtype().zero());
|
||
}
|
||
|
||
const dt: DataType = switch (self.dtype()) {
|
||
.bf16, .f16, .f32, .f64 => self.dtype(),
|
||
.c64 => .f32,
|
||
.c128 => .f64,
|
||
else => unreachable,
|
||
};
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.imag(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape.withDtype(dt), op.result(0));
|
||
}
|
||
|
||
/// Returns the Fast Fourier Transform (FFT) of the input Tensor.
|
||
pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor {
|
||
// TODO: support tagged API.
|
||
|
||
meta.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len});
|
||
meta.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() });
|
||
|
||
const sh = switch (opts.kind) {
|
||
.FFT, .IFFT => blk: {
|
||
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
||
|
||
break :blk self._shape;
|
||
},
|
||
.RFFT => blk: {
|
||
meta.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() });
|
||
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
||
|
||
const dt: DataType = switch (self.dtype()) {
|
||
.f32 => .c64,
|
||
else => .c128,
|
||
};
|
||
const shape_ = self._shape.setDim(-1, @divExact(self.dim(-1), 2) + 1);
|
||
break :blk shape_.withDtype(dt);
|
||
},
|
||
.IRFFT => blk: {
|
||
meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() });
|
||
meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len });
|
||
|
||
const dt: DataType = switch (self.dtype()) {
|
||
.c64 => .f32,
|
||
else => .f64,
|
||
};
|
||
const shape_ = self._shape.setDim(-1, @divExact(self.dim(-1) - 1, 2));
|
||
break :blk shape_.withDtype(dt);
|
||
},
|
||
};
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
||
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
||
return _result(sh, op.result(0));
|
||
}
|
||
|
||
pub const Rng = struct {
|
||
_state: Tensor,
|
||
algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT,
|
||
|
||
pub fn shape() ShapeOf(Rng) {
|
||
return .{
|
||
._state = Shape.init(.{2}, .u64),
|
||
};
|
||
}
|
||
|
||
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
|
||
const bits: [2]u64 = @bitCast(seed);
|
||
return .{
|
||
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
|
||
.algorithm = undefined,
|
||
};
|
||
}
|
||
|
||
/// Returns a Tensor of the given shape, filled with uniform random bits, and a new Rng state.
|
||
///
|
||
/// The given Rng state should not be used anymore (or you'll get the same numbers again).
|
||
/// The output is guaranteed to be deterministic function of `self` Rng state,
|
||
/// but it is not guaranteed to be deterministic between implementations.
|
||
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
||
const ctx = CompilationContext.current();
|
||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "rand.bitGen({})", .{sh});
|
||
const op = dialect.stablehlo.rng_bit_generator(
|
||
ctx.mlirCtx(),
|
||
self.algorithm,
|
||
self._state.value(),
|
||
mlir.ext.mlirType(ctx.mlirCtx(), self._state._shape),
|
||
mlir.ext.mlirType(ctx.mlirCtx(), sh),
|
||
loc,
|
||
);
|
||
return .{ self.update(op.result(0)), _result(sh, op.result(1)) };
|
||
}
|
||
|
||
fn update(self: Rng, new_state: mlir.Value) Rng {
|
||
return .{
|
||
._state = _result(self._state._shape, new_state).reuseBuffer(self._state),
|
||
.algorithm = self.algorithm,
|
||
};
|
||
}
|
||
|
||
/// Returns a Tensor of the given shape, filled with uniformly sampled floating point numbers from an interval,
|
||
/// and a new Rng state.
|
||
///
|
||
/// https://en.wikipedia.org/wiki/Continuous_uniform_distribution
|
||
pub fn uniform(
|
||
self: Rng,
|
||
shape_: Shape,
|
||
opts: struct { min: f64 = 0, max: f64 = 1 },
|
||
) struct { Rng, Tensor } {
|
||
const dt = if (shape_.dtype().isFloat()) shape_.dtype() else .f32;
|
||
|
||
const mantissa_bit_count = @import("dtype.zig").mantissaSize(dt);
|
||
const bit_count: usize = dt.bitSizeOf();
|
||
const rng_bit_count = if (mantissa_bit_count < 8) 8 else bit_count;
|
||
const uint_dtype: DataType = switch (bit_count) {
|
||
8 => .u8,
|
||
16 => .u16,
|
||
32 => .u32,
|
||
64 => .u64,
|
||
else => meta.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}),
|
||
};
|
||
|
||
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
|
||
|
||
// Erase bits outside of mantissa.
|
||
var float_bits = bits.shiftRightLogical(scalar(rng_bit_count - mantissa_bit_count, uint_dtype));
|
||
|
||
// Set exponent bits to represent e^0 (eg 127 for f32).
|
||
float_bits = float_bits.logical(.OR, scalar(1, dt).bitCast(uint_dtype));
|
||
|
||
// float_bits now uniformly represents number in [1, 2[ range.
|
||
// Let's convert to floats, and subtract one to go to [0, 1[ range.
|
||
var floats = float_bits.bitCast(dt).sub(scalar(1, dt));
|
||
floats = floats.mul(scalar(opts.max - opts.min, dt)).addConstant(opts.min);
|
||
|
||
// Convert back to integer if needed.
|
||
return .{ rng, floats.convert(shape_.dtype()) };
|
||
}
|
||
|
||
test uniform {
|
||
const zml = @import("zml.zig");
|
||
const Stats = struct {
|
||
const Stats = @This();
|
||
|
||
mean: Tensor,
|
||
variance: Tensor,
|
||
min: Tensor,
|
||
max: Tensor,
|
||
|
||
pub fn uniformStats(
|
||
rand: Rng,
|
||
shape_: Shape,
|
||
opts: struct { min: f64, max: f64 },
|
||
) struct { Rng, Stats } {
|
||
const rng, const data = rand.uniform(shape_, .{ .min = opts.min, .max = opts.max });
|
||
const mean_ = data.mean(0);
|
||
const variance = data.sub(mean_.broad(data.shape())).pow(Tensor.scalar(2, .f32)).mean(0);
|
||
return .{ rng, .{
|
||
.mean = mean_,
|
||
.variance = variance,
|
||
.min = data.min(0),
|
||
.max = data.max(0),
|
||
} };
|
||
}
|
||
};
|
||
|
||
const platform = zml.testing.env();
|
||
// Compute stats over a uniform distribution on [-2, 10].
|
||
const rand, const stats = try zml.testing.compileAndCall(
|
||
platform,
|
||
Stats.uniformStats,
|
||
.{
|
||
try Rng.init(platform, 1234),
|
||
Shape.init(.{1024}, .f32),
|
||
.{ .min = -2, .max = 10 },
|
||
},
|
||
);
|
||
// Check the Rng state has been modified.
|
||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
||
|
||
// Check the mean and variance are close to theoritical values.
|
||
const mean_ = try stats.mean.getValue(f32);
|
||
try std.testing.expectApproxEqAbs(4, mean_, 0.03);
|
||
|
||
const variance = try stats.variance.getValue(f32);
|
||
try std.testing.expectApproxEqAbs(12.0 * 12.0 / 12.0, variance, 0.01);
|
||
|
||
// Check that no value is outside of the interval
|
||
// and we have samples close to the edges.
|
||
const min_ = try stats.min.getValue(f32);
|
||
try std.testing.expect(min_ >= -2);
|
||
try std.testing.expectApproxEqAbs(-2, min_, 0.05);
|
||
|
||
const max_ = try stats.max.getValue(f32);
|
||
try std.testing.expect(max_ < 10);
|
||
try std.testing.expectApproxEqAbs(10, max_, 0.05);
|
||
}
|
||
|
||
/// Returns a Tensor of the given shape, filled with floating point numbers sampled from a normal distribution.
|
||
///
|
||
/// Note: this uses stablehlo.rng which is deprecated.
|
||
/// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md
|
||
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
||
meta.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
||
|
||
const ctx = CompilationContext.current().mlirCtx();
|
||
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
|
||
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
||
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
||
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
||
const op = dialect.stablehlo.rng(ctx, a.value(), b.value(), res_shape.value(), .NORMAL, loc);
|
||
return _result(sh, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor of the given shape, filled with floating point numbers sampled from a Gumbel distribution, and a new Rng state.
|
||
///
|
||
/// Often used in ML because of the reparametrization tricks.
|
||
/// Sampling from a gumbel distribution is equivalent to sample
|
||
/// from a softmax distribution, but doesn't require to compute the sum of exponentials.
|
||
/// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
|
||
/// See `sampleTokens` for a practical use case.
|
||
/// Note: we only implement the μ=0, β=1 version.
|
||
pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } {
|
||
const rand, const u = self.uniform(
|
||
shape_,
|
||
// We don't want 0 to be sampled otherwise `log` will return -inf.
|
||
.{ .min = std.math.floatEps(f64), .max = 1 },
|
||
);
|
||
return .{ rand, u.log().scale(-1).log().scale(-1) };
|
||
}
|
||
|
||
test gumbel {
|
||
const zml = @import("zml.zig");
|
||
const Stats = struct {
|
||
const Stats = @This();
|
||
|
||
mean: Tensor,
|
||
variance: Tensor,
|
||
actual_dist: Tensor,
|
||
|
||
pub fn gumbelStats(rand: Rng, target_dist: Tensor) struct { Rng, Stats } {
|
||
const s = Shape.init(.{ .n = 1024, .d = 4 }, .f32);
|
||
const rng, const data = rand.gumbel(s);
|
||
const flat = data.flattenAll();
|
||
const mean_ = flat.mean(0);
|
||
const variance = flat.sub(mean_.broad(flat.shape())).pow(Tensor.scalar(2, .f32)).mean(0);
|
||
|
||
// Test out the gumbel reparametrization trick
|
||
var x = target_dist.log().withTags(.{.d}).broad(s);
|
||
x = x.add(data);
|
||
const samples = x.argMax(.d, .i32).indices.squeeze(.d);
|
||
|
||
// count 0, 1, 2 and 3 in samples:
|
||
// - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58
|
||
// - sum in u64
|
||
// - split to [4]u16
|
||
const powers = blk: {
|
||
var powers: [4]u64 = undefined;
|
||
for (&powers, 0..) |*p, i| p.* = std.math.pow(u64, 2, i * 16);
|
||
break :blk powers;
|
||
};
|
||
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
|
||
const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16);
|
||
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
|
||
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
|
||
}
|
||
};
|
||
|
||
const platform = zml.testing.env();
|
||
const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 };
|
||
const rand, const stats = try zml.testing.compileAndCall(platform, Stats.gumbelStats, .{
|
||
try Rng.init(platform, 1234), try HostBuffer.fromArray(&tgt_dist).toDevice(platform),
|
||
});
|
||
// Check the Rng state has been modified.
|
||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
||
|
||
// Check the mean and variance are close to theoritical values.
|
||
const mean_ = try stats.mean.getValue(f32);
|
||
try std.testing.expectApproxEqAbs(0.5772, mean_, 0.02);
|
||
|
||
const variance = try stats.variance.getValue(f32);
|
||
const pi = std.math.pi;
|
||
try std.testing.expectApproxEqAbs(pi * pi / 6.0, variance, 0.03);
|
||
|
||
// Check the distribution obtained with the gumbel trick matches the target distribution.
|
||
const actual_dist = try stats.actual_dist.getValue([4]f32);
|
||
scoped_log.debug("tgt_dist: {d}, actual_dist: {d}", .{ tgt_dist, actual_dist });
|
||
for (tgt_dist, actual_dist) |tgt, actual| {
|
||
// We normalize tgt_dist to make it a well formed distribution.
|
||
// We didn't do it before calling gumbel, because the gumbel trick
|
||
// doesn't require normalized distributions as input.
|
||
try std.testing.expectApproxEqAbs(tgt / 10.0, actual, 0.05);
|
||
}
|
||
}
|
||
};
|
||
|
||
/// Returns a Tensor containing the element-wise conversion to another floating point type.
|
||
pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor {
|
||
meta.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()});
|
||
meta.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
||
meta.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
||
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
|
||
meta.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() });
|
||
const N = self.rank();
|
||
meta.guard(opts.window_strides.len == N - 2, @src());
|
||
for (opts.window_strides) |s| meta.guard(0 < s, @src());
|
||
meta.guard(opts.lhs_dilation.len == N - 2, @src());
|
||
for (opts.lhs_dilation) |d| meta.guard(0 < d, @src());
|
||
meta.guard(opts.rhs_dilation.len == N - 2, @src());
|
||
for (opts.rhs_dilation) |d| meta.guard(0 < d, @src());
|
||
meta.guard(opts.window_reversal.len == N - 2, @src());
|
||
meta.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src());
|
||
meta.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src());
|
||
meta.guard(opts.input_spatial_dimensions.len == N - 2, @src());
|
||
meta.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src());
|
||
meta.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src());
|
||
meta.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src());
|
||
for (opts.input_spatial_dimensions, 0..) |d, i| {
|
||
meta.guard(d != opts.input_batch_dimension, @src());
|
||
meta.guard(d != opts.input_feature_dimension, @src());
|
||
meta.guard(0 <= d and d < N, @src());
|
||
if (i < opts.input_spatial_dimensions.len - 1) continue;
|
||
meta.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||
}
|
||
meta.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src());
|
||
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src());
|
||
meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src());
|
||
meta.guard(opts.kernel_spatial_dimensions.len == N - 2, @src());
|
||
meta.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src());
|
||
meta.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src());
|
||
meta.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src());
|
||
for (opts.kernel_spatial_dimensions, 0..) |d, i| {
|
||
meta.guard(d != opts.kernel_input_feature_dimension, @src());
|
||
meta.guard(d != opts.kernel_output_feature_dimension, @src());
|
||
meta.guard(0 <= d and d < N, @src());
|
||
if (i < opts.kernel_spatial_dimensions.len - 1) continue;
|
||
meta.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||
}
|
||
meta.guard(opts.output_spatial_dimensions.len == N - 2, @src());
|
||
meta.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src());
|
||
meta.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src());
|
||
meta.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src());
|
||
for (opts.output_spatial_dimensions, 0..) |d, i| {
|
||
meta.guard(d != opts.output_batch_dimension, @src());
|
||
meta.guard(d != opts.output_feature_dimension, @src());
|
||
meta.guard(0 <= d and d < N, @src());
|
||
if (i < opts.output_spatial_dimensions.len - 1) continue;
|
||
meta.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src());
|
||
}
|
||
meta.guard(0 < opts.feature_group_count, @src());
|
||
meta.guard(0 < opts.batch_group_count, @src());
|
||
meta.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src());
|
||
var used_opts = opts;
|
||
used_opts.pad_shape = &.{ @intCast(N - 2), 2 };
|
||
used_opts.precision_config = &.{ .DEFAULT, .DEFAULT };
|
||
|
||
var new_shape = self._shape;
|
||
var res_dim: i64 = undefined;
|
||
|
||
for (0..N) |i| {
|
||
if (i == @as(usize, @intCast(opts.output_batch_dimension))) {
|
||
res_dim = @divTrunc(self.dim(opts.input_batch_dimension), opts.batch_group_count);
|
||
} else if (i == @as(usize, @intCast(opts.output_feature_dimension))) {
|
||
res_dim = other.dim(opts.kernel_output_feature_dimension);
|
||
} else {
|
||
// calculate spatial dimension value
|
||
const spatial_dim: usize = std.mem.indexOfScalar(i64, opts.output_spatial_dimensions, @as(i64, @intCast(i))).?;
|
||
const lhs_dim = opts.input_spatial_dimensions[spatial_dim];
|
||
const rhs_dim = opts.kernel_spatial_dimensions[spatial_dim];
|
||
const dilated_input_shape_lhs_dim: i64 = if (self.dim(lhs_dim) == 0) 0 else (self.dim(lhs_dim) - 1) * opts.lhs_dilation[spatial_dim] + 1;
|
||
const left_pad_value, const right_pad_value = if (opts.pad_value.len == 1)
|
||
.{ opts.pad_value[0], opts.pad_value[0] }
|
||
else
|
||
.{ opts.pad_value[2 * spatial_dim], opts.pad_value[2 * spatial_dim + 1] };
|
||
const padded_input_shape_lhs_dim = left_pad_value + dilated_input_shape_lhs_dim + right_pad_value;
|
||
const dilated_window_shape_lhs_dim: i64 = if (other.dim(rhs_dim) == 0) 0 else (other.dim(rhs_dim) - 1) * opts.rhs_dilation[spatial_dim] + 1;
|
||
const is_empty_window_lhs_dim = padded_input_shape_lhs_dim == 0 or dilated_window_shape_lhs_dim > padded_input_shape_lhs_dim;
|
||
res_dim = if (is_empty_window_lhs_dim) 0 else @divTrunc(padded_input_shape_lhs_dim - dilated_window_shape_lhs_dim, opts.window_strides[spatial_dim]) + 1;
|
||
}
|
||
|
||
new_shape = new_shape.set(i, res_dim);
|
||
}
|
||
|
||
// inferred shape '[1, 256, 1, 12008]' is incompatible with return type of operation '[1, 256, 1, 11978]'
|
||
const op = dialect.stablehlo.convolution(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
other.value(),
|
||
used_opts,
|
||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type).?,
|
||
loc,
|
||
);
|
||
|
||
return _result(new_shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the result of the 1D convolution of 'input' by 'kernel'.
|
||
pub fn conv1d(
|
||
input: Tensor,
|
||
kernel: Tensor,
|
||
opts: struct {
|
||
window_strides: i64 = 1,
|
||
padding: []const i64 = &.{ 0, 0 },
|
||
lhs_dilation: i64 = 1,
|
||
rhs_dilation: i64 = 1,
|
||
window_reversal: bool = false,
|
||
input_batch_dimension: i64 = 0,
|
||
input_feature_dimension: i64 = 1,
|
||
input_spatial_dimensions: i64 = 2,
|
||
kernel_output_feature_dimension: i64 = 0,
|
||
kernel_input_feature_dimension: i64 = 1,
|
||
kernel_spatial_dimensions: i64 = 2,
|
||
output_batch_dimension: i64 = 0,
|
||
output_feature_dimension: i64 = 1,
|
||
output_spatial_dimensions: i64 = 2,
|
||
feature_group_count: i64 = 1,
|
||
batch_group_count: i64 = 1,
|
||
},
|
||
) Tensor {
|
||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
||
return input.convolution(kernel, .{
|
||
.window_strides = &.{opts.window_strides},
|
||
.pad_value = opts.padding,
|
||
.lhs_dilation = &.{opts.lhs_dilation},
|
||
.rhs_dilation = &.{opts.rhs_dilation},
|
||
.window_reversal = &.{opts.window_reversal},
|
||
.input_batch_dimension = opts.input_batch_dimension,
|
||
.input_feature_dimension = opts.input_feature_dimension,
|
||
.input_spatial_dimensions = &.{opts.input_spatial_dimensions},
|
||
.kernel_input_feature_dimension = opts.kernel_input_feature_dimension,
|
||
.kernel_output_feature_dimension = opts.kernel_output_feature_dimension,
|
||
.kernel_spatial_dimensions = &.{opts.kernel_spatial_dimensions},
|
||
.output_batch_dimension = opts.output_batch_dimension,
|
||
.output_feature_dimension = opts.output_feature_dimension,
|
||
.output_spatial_dimensions = &.{opts.output_spatial_dimensions},
|
||
.feature_group_count = opts.feature_group_count,
|
||
.batch_group_count = opts.batch_group_count,
|
||
}, loc);
|
||
}
|
||
|
||
/// Returns a Tensor containing the result of the 2D convolution of 'input' by 'kernel'.
|
||
/// Defaults values correspond to a (B, C_in, W, H) image, (C_out, C_in, W, H) kernel weights and (B, C_out, W, H) output.
|
||
pub fn conv2d(
|
||
input: Tensor,
|
||
kernel: Tensor,
|
||
opts: struct {
|
||
window_strides: []const i64 = &.{ 1, 1 },
|
||
padding: []const i64 = &.{ 0, 0, 0, 0 },
|
||
lhs_dilation: []const i64 = &.{ 1, 1 },
|
||
rhs_dilation: []const i64 = &.{ 1, 1 },
|
||
window_reversal: []const bool = &.{ false, false },
|
||
input_batch_dimension: i64 = 0,
|
||
input_feature_dimension: i64 = 1,
|
||
input_spatial_dimensions: []const i64 = &.{ 2, 3 },
|
||
kernel_input_feature_dimension: i64 = 1,
|
||
kernel_output_feature_dimension: i64 = 0,
|
||
kernel_spatial_dimensions: []const i64 = &.{ 2, 3 },
|
||
output_batch_dimension: i64 = 0,
|
||
output_feature_dimension: i64 = 1,
|
||
output_spatial_dimensions: []const i64 = &.{ 2, 3 },
|
||
feature_group_count: i64 = 1,
|
||
batch_group_count: i64 = 1,
|
||
},
|
||
) Tensor {
|
||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
||
return input.convolution(kernel, .{
|
||
.window_strides = opts.window_strides,
|
||
.pad_value = opts.padding,
|
||
.lhs_dilation = opts.lhs_dilation,
|
||
.rhs_dilation = opts.rhs_dilation,
|
||
.window_reversal = opts.window_reversal,
|
||
.input_batch_dimension = opts.input_batch_dimension,
|
||
.input_feature_dimension = opts.input_feature_dimension,
|
||
.input_spatial_dimensions = opts.input_spatial_dimensions,
|
||
.kernel_input_feature_dimension = opts.kernel_input_feature_dimension,
|
||
.kernel_output_feature_dimension = opts.kernel_output_feature_dimension,
|
||
.kernel_spatial_dimensions = opts.kernel_spatial_dimensions,
|
||
.output_batch_dimension = opts.output_batch_dimension,
|
||
.output_feature_dimension = opts.output_feature_dimension,
|
||
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
||
.feature_group_count = opts.feature_group_count,
|
||
.batch_group_count = opts.batch_group_count,
|
||
}, loc);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise addition of the input Tensors.
|
||
pub fn add(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("add", dialect.stablehlo.add)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise subtraction of the input Tensors.
|
||
pub fn sub(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("subtract", dialect.stablehlo.subtract)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise multiplication of the input Tensors.
|
||
pub fn mul(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("mul", dialect.stablehlo.multiply)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise division of the input Tensors.
|
||
pub fn div(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("div", dialect.stablehlo.divide)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise exponentiation of the input Tensors.
|
||
pub fn pow(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("pow", dialect.stablehlo.power)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise maximum operation of the input Tensors.
|
||
pub fn maximum(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("maximum", dialect.stablehlo.maximum)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise minimum operation of the input Tensors.
|
||
pub fn minimum(self: Tensor, other: Tensor) Tensor {
|
||
return binaryOp("minimum", dialect.stablehlo.minimum)(self, other);
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
|
||
pub fn addConstant(self: Tensor, b: anytype) Tensor {
|
||
return self.add(Tensor.scalar(b, self.dtype()));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise division of the input Tensor by a constant.
|
||
pub fn divByConst(self: Tensor, b: anytype) Tensor {
|
||
return self.div(Tensor.scalar(b, self.dtype()));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise multiplication of the input Tensor by a constant.
|
||
pub inline fn scale(self: Tensor, val: anytype) Tensor {
|
||
return self.mul(Tensor.scalar(val, self.dtype()));
|
||
}
|
||
|
||
pub const LogicalOp = enum { OR, XOR, AND };
|
||
|
||
/// Returns a Tensor containing the element-wise logical operation of the input Tensors.
|
||
pub fn logical(self: Tensor, comptime logical_op: LogicalOp, other: Tensor) Tensor {
|
||
return switch (logical_op) {
|
||
.OR => binaryOp("or", dialect.stablehlo.or_)(self, other),
|
||
.XOR => binaryOp("xor", dialect.stablehlo.xor)(self, other),
|
||
.AND => binaryOp("and", dialect.stablehlo.and_)(self, other),
|
||
};
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise floor operation of the input Tensor.
|
||
pub fn floor(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
return _result(self._shape, dialect.stablehlo.floor(self.getContext().mlirCtx(), self.value(), loc).result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise ceil operation of the input Tensor.
|
||
pub fn ceil(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
return _result(self._shape, dialect.stablehlo.ceil(self.getContext().mlirCtx(), self.value(), loc).result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise conversion to another type.
|
||
pub fn convert(self: Tensor, dt: DataType) Tensor {
|
||
if (dt == self.dtype()) {
|
||
return self;
|
||
}
|
||
|
||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), dt)).as(mlir.Type).?;
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dtype={}", .{dt});
|
||
|
||
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
||
return _result(self._shape.withDtype(dt), op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
|
||
pub fn round(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const sine_op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, sine_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise clamping operation of the input Tensor.
|
||
pub fn clamp(self: Tensor, min_: Tensor, max_: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.clamp(self.getContext().mlirCtx(), min_.value(), self.value(), max_.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// See torch.matmul
|
||
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
||
return @import("torch.zig").matmul(lhs, rhs);
|
||
}
|
||
|
||
/// Matrix multiplication, where contracting axes are specified using their tags.
|
||
/// eg dot(.{ .a, .b, .c }, .{ .a, .c, .d }, .{ .c }) -> .{ .a, .c, .d }
|
||
/// Axes with the same tag on both sides, and which aren't contracting,
|
||
/// are considered "batching axes".
|
||
pub fn dot(lhs: Tensor, rhs: Tensor, comptime contracting: anytype) Tensor {
|
||
var contracting_axes: [contracting.len][2]i8 = undefined;
|
||
inline for (contracting, 0..) |c, i| {
|
||
contracting_axes[i] = .{ lhs.axis(c), rhs.axis(c) };
|
||
}
|
||
|
||
var batching_axes: [MAX_RANK][2]i8 = undefined;
|
||
var n_batching: u8 = 0;
|
||
for (lhs._shape.tags(), 0..) |l, li| {
|
||
meta.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
|
||
|
||
for (rhs._shape.tags(), 0..) |r, ri| {
|
||
meta.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
|
||
|
||
if (l == r) {
|
||
for (contracting_axes) |ct| {
|
||
if (l == lhs._shape.tag(ct[0])) {
|
||
break;
|
||
}
|
||
} else {
|
||
// tag is both in lhs and rhs but not in contracting -> it's a batching dim.
|
||
batching_axes[n_batching] = .{ @intCast(li), @intCast(ri) };
|
||
n_batching += 1;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return dotGeneral(lhs, rhs, contracting_axes[0..], batching_axes[0..n_batching]);
|
||
}
|
||
|
||
test dot {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .c = 20 }, .{ .c = 20 }, .{.c}, .{} },
|
||
.{
|
||
.{ .a = 20, .b = 21, .c = 22 },
|
||
.{ .a = 20, .d = 23, .c = 22 },
|
||
.{.c},
|
||
.{ .a = 20, .b = 21, .d = 23 },
|
||
},
|
||
.{
|
||
.{ .a = 20, .b = 21, .c = 22 },
|
||
.{ .c = 22, .d = 23, .e = 24 },
|
||
.{.c},
|
||
.{ .a = 20, .b = 21, .d = 23, .e = 24 },
|
||
},
|
||
.{
|
||
.{ .a = 20, .b = 21, .c = 22 },
|
||
.{ .c = 22, .d = 23, .a = 20 },
|
||
.{ .c, .a },
|
||
.{ .b = 21, .d = 23 },
|
||
},
|
||
}) |testcase| {
|
||
const x_shape, const y_shape, const ctr, const z_shape = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f32 = 0.0 });
|
||
const y = Tensor.constant(y_shape, .{ .f32 = 0.0 });
|
||
const z = x.dot(y, ctr);
|
||
|
||
try zml.testing.expectEqualShapes(Shape.init(z_shape, .f32), z.shape());
|
||
}
|
||
}
|
||
|
||
/// Generalized matrix multiplication of two tensors along the specified axes.
|
||
/// In this version batching dimensions need to be explicitly specified.
|
||
/// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes.
|
||
/// Where "result axes" are non-contracting, non-batching axes of each input tensor.
|
||
pub fn dotGeneral(
|
||
lhs: Tensor,
|
||
rhs: Tensor,
|
||
contracting_axes: []const [2]i8,
|
||
batching_axes: []const [2]i8,
|
||
) Tensor {
|
||
meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
||
|
||
const Axes = std.BoundedArray(i64, MAX_RANK);
|
||
|
||
var res_shape: Shape = .{ ._dtype = lhs.dtype() };
|
||
// Validate batching axes
|
||
var lhs_batching_axes: Axes = .{};
|
||
var rhs_batching_axes: Axes = .{};
|
||
for (batching_axes) |b_axes| {
|
||
const l, const r = b_axes;
|
||
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
||
var t = lhs._shape.tag(l);
|
||
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
||
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
||
lhs_batching_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
||
rhs_batching_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
||
}
|
||
|
||
// Validate contracting axes
|
||
var lhs_contracting_axes: Axes = .{};
|
||
var rhs_contracting_axes: Axes = .{};
|
||
for (contracting_axes) |c_axes| {
|
||
const l, const r = c_axes;
|
||
meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
||
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
||
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
||
}
|
||
|
||
// Result shape is obtained by concatenating batching dimensions, (already done)
|
||
// then dimensions from lhs axes that aren't contracting nor batching,
|
||
// then dimensions from rhs axes that aren't contracting nor batching.
|
||
for (0..lhs.rank()) |l| {
|
||
if (std.mem.indexOfScalar(i64, lhs_contracting_axes.constSlice(), @intCast(l))) |_| {
|
||
continue;
|
||
}
|
||
if (std.mem.indexOfScalar(i64, lhs_batching_axes.constSlice(), @intCast(l))) |_| {
|
||
continue;
|
||
}
|
||
res_shape = res_shape.appendDim(lhs._shape.dim(l), lhs._shape.tag(l));
|
||
}
|
||
for (0..rhs.rank()) |r| {
|
||
if (std.mem.indexOfScalar(i64, rhs_contracting_axes.constSlice(), @intCast(r))) |_| {
|
||
continue;
|
||
}
|
||
if (std.mem.indexOfScalar(i64, rhs_batching_axes.constSlice(), @intCast(r))) |_| {
|
||
continue;
|
||
}
|
||
res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r));
|
||
}
|
||
|
||
const loc = lhs.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.dot_general(
|
||
lhs.getContext().mlirCtx(),
|
||
lhs.value(),
|
||
rhs.value(),
|
||
mlir.ext.mlirType(lhs.getContext().mlirCtx(), res_shape),
|
||
loc,
|
||
.{
|
||
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
|
||
.rhs_batching_dimensions = rhs_batching_axes.constSlice(),
|
||
.lhs_contracting_dimensions = lhs_contracting_axes.constSlice(),
|
||
.rhs_contracting_dimensions = rhs_contracting_axes.constSlice(),
|
||
.precision = &.{ .DEFAULT, .DEFAULT },
|
||
},
|
||
);
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
|
||
pub fn sigmoid(self: Tensor) Tensor {
|
||
// until the metal plugin supports `stablehlo.logistics`, implement in the way JAX does it
|
||
const one = Tensor.constant(&.{}, self.dtype().one()).broadcast(self._shape, &.{});
|
||
return one.div(one.add(self.negate().exp()));
|
||
}
|
||
|
||
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
|
||
pub fn relu(self: Tensor) Tensor {
|
||
return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));
|
||
}
|
||
|
||
/// Returns a Tensor containing the leaky-ReLU activation function applied to each element of the input Tensor.
|
||
///
|
||
/// LeakyReLU(x) = max(0,x) + negative_slope * min(0,x)
|
||
/// ref: https://paperswithcode.com/method/leaky-relu
|
||
pub fn leakyReLU(self: Tensor, negative_slope: f32) Tensor {
|
||
const below_zero = self.scale(negative_slope).minimum(Tensor.scalar(0, self.dtype()));
|
||
return self.relu().add(below_zero);
|
||
}
|
||
|
||
test leakyReLU {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const input = try zml.Buffer.fromSlice(platform, .{2}, &[_]f32{ -0.6884, 1.6795 });
|
||
const res = try zml.testing.compileAndCall(platform, leakyReLU, .{ input, 0.1 });
|
||
|
||
const expectation = zml.HostBuffer.fromArray(&[2]f32{ -0.0688, 1.6795 });
|
||
try zml.testing.expectClose(expectation, res, 1e-4);
|
||
}
|
||
|
||
/// Returns a Tensor containing the SwiGLU activation function applied to the input Tensor.
|
||
pub fn swiglu(self: Tensor, beta: f32, w: Tensor, b: Tensor) Tensor {
|
||
const sigmoid_tensor = self.mul(Tensor.constant(self._shape, Data.init(self.dtype(), beta))).sigmoid();
|
||
const one_minus_sigmoid_tensor = Tensor.constant(self._shape, Data.init(self.dtype(), 1)).sub(sigmoid_tensor);
|
||
|
||
return self.mul(sigmoid_tensor).add(one_minus_sigmoid_tensor.mul(w.matmul(self).add(b)));
|
||
}
|
||
|
||
/// Returns a Tensor containing the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
|
||
///
|
||
/// We use an approximation of the erf function using tanh:
|
||
/// gelu(x) ≃ 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
|
||
/// see: https://paperswithcode.com/method/gelu
|
||
pub fn gelu(x: Tensor) Tensor {
|
||
const scaled_x_cube = x.mul(x).mul(x).scale(0.044715);
|
||
const one = Tensor.constant(x._shape, x.dtype().one());
|
||
const one_plus_tanh = Tensor.add(x, scaled_x_cube).scale(std.math.sqrt(2.0 / std.math.pi)).tanh().add(one);
|
||
return one_plus_tanh.mul(x).scale(0.5);
|
||
}
|
||
|
||
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
|
||
///
|
||
/// It's an even more crude approximation than gelu.
|
||
pub fn quickGelu(x: Tensor) Tensor {
|
||
return x.scale(1.702).sigmoid().mul(x);
|
||
}
|
||
|
||
/// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor.
|
||
///
|
||
/// silu(x) = x σ(x)
|
||
/// https://paperswithcode.com/method/silu
|
||
pub fn silu(x: Tensor) Tensor {
|
||
return x.mul(x.sigmoid());
|
||
}
|
||
|
||
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
|
||
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
|
||
return exp_diff_max.div(exp_diff_max.sum(a).broad(self._shape));
|
||
}
|
||
|
||
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
||
pub fn logSumExp(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
// stabilization: shift `self` by it's max value before passing to exponent.
|
||
const M = self.max(a);
|
||
const log_sum_exp = log(sum(exp(self.sub(M.broad(self._shape))), a));
|
||
// restore the shift again
|
||
return M.add(log_sum_exp);
|
||
}
|
||
|
||
/// Returns a Tensor containing the sum of elements over the given axis.
|
||
pub fn sum(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
return ops.reduce(
|
||
struct {
|
||
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||
return res.add(x.convert(res.dtype()));
|
||
}
|
||
}.acc,
|
||
self,
|
||
Tensor.scalar(0, self.dtype()),
|
||
&.{a},
|
||
);
|
||
}
|
||
|
||
/// Returns a Tensor containing the mean of elements over the given axis.
|
||
pub fn mean(self: Tensor, axis_: anytype) Tensor {
|
||
return self.sum(axis_).divByConst(self.dim(axis_));
|
||
}
|
||
|
||
/// Returns a transposed Tensor computed using the given axes.
|
||
pub fn transpose(self: Tensor, axes_: anytype) Tensor {
|
||
const axes__ = self.axes(axes_).constSlice();
|
||
const default_perm = [MAX_RANK]i64{ 7, 6, 5, 4, 3, 2, 1, 0 };
|
||
const no_op = [MAX_RANK]i64{ 0, 1, 2, 3, 4, 5, 6, 7 };
|
||
|
||
const permutation: []const i64 = if (axes__.len == 0)
|
||
default_perm[MAX_RANK - self.rank() ..]
|
||
else
|
||
toI64(axes__);
|
||
|
||
meta.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
|
||
|
||
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
||
return self;
|
||
}
|
||
|
||
const res_shape = self._shape.transpose(permutation);
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "tr({any})", .{axes_});
|
||
const op = dialect.stablehlo.transpose(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
mlir.ext.mlirType(self.getContext().mlirCtx(), res_shape),
|
||
loc,
|
||
.{ .permutation = toI64(permutation) },
|
||
);
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
pub fn swapAxes(self: Tensor, a: anytype, b: anytype) Tensor {
|
||
if (self.axis(a) == self.axis(b)) return self;
|
||
var perm: Shape.AxesArray = .{};
|
||
for (0..self.rank()) |i| {
|
||
perm.appendAssumeCapacity(@intCast(i));
|
||
}
|
||
perm.set(self.axis(a), self.axis(b));
|
||
perm.set(self.axis(b), self.axis(a));
|
||
return self.transpose(perm.constSlice());
|
||
}
|
||
|
||
/// Returns a Tensor with the given axis unflattened.
|
||
///
|
||
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
||
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
|
||
meta.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
|
||
|
||
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
|
||
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
||
const new_shape = self._shape.set(a, n).insert(a + 1, .{ ._ = new_dim });
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, n={}", .{ axis_, n });
|
||
const reshaped_val = dialect.stablehlo.reshape(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||
loc,
|
||
);
|
||
return _result(new_shape, reshaped_val.result(0));
|
||
}
|
||
|
||
/// Splits the given axis in several axes.
|
||
/// eg: `Tensor.init(.{ .a = 10, .b = 3 }).split(.a, .{.a1 = 5, .a2 = 2});`
|
||
/// The number of elements in the split shape must match the number of element
|
||
/// in the target axis.
|
||
pub fn splitAxis(self: Tensor, ax: anytype, split_shape: anytype) Tensor {
|
||
const new_shape = self._shape.splitAxis(ax, split_shape);
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "splitAxis({}, {any})", .{ ax, split_shape });
|
||
const reshaped_val = dialect.stablehlo.reshape(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||
loc,
|
||
);
|
||
return _result(new_shape, reshaped_val.result(0));
|
||
}
|
||
|
||
/// Merges two or more contiguous axes into one axis.
|
||
pub fn merge(self: Tensor, merges_: anytype) Tensor {
|
||
return self.reshape(self._shape.mergeAxes(merges_));
|
||
}
|
||
|
||
/// Merges two or more non-contiguous axes into one axis.
|
||
/// Will make a transpose if needed.
|
||
/// .{ .a, .b, .c }.mergeTranspose(.{ .a, .c }, .ac) -> .{ .b, .ac }
|
||
pub fn mergeTranspose(self: Tensor, axes_: anytype, merged: EnumLiteral) Tensor {
|
||
const cont = self.contiguous(axes_);
|
||
return cont.reshape(cont._shape.mergeAxis(axes_, merged));
|
||
}
|
||
|
||
/// Transposes the input Tensor, such has the given axes end up in contiguous position.
|
||
/// .{ .a, .b, .c, .d }.contiguous(.{ .c, .a }) -> .{ .b, .d, .c, .a }
|
||
pub fn contiguous(self: Tensor, axes_: anytype) Tensor {
|
||
const perm = self._shape.contiguousPerm(axes_);
|
||
return self.transpose(perm.constSlice());
|
||
}
|
||
|
||
/// Flattens the given axis and the next one, into one new axis.
|
||
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
|
||
const old_shape = self._shape;
|
||
const a = self.axis(axis_);
|
||
// meta.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
|
||
|
||
const reshaped_val = dialect.stablehlo.reshape(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
|
||
loc,
|
||
);
|
||
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
|
||
return _result(new_shape, reshaped_val.result(0));
|
||
}
|
||
|
||
pub inline fn flattenAll(self: Tensor) Tensor {
|
||
return self.reshape(.{self.count()});
|
||
}
|
||
|
||
pub const Slice = struct {
|
||
start: i64 = 0,
|
||
end: ?i64 = null,
|
||
step: i64 = 1,
|
||
};
|
||
|
||
/// Slices the input Tensor over the given axis using the given parameters.
|
||
pub fn slice1d(self: Tensor, axis_: anytype, s: Slice) Tensor {
|
||
var slices = [_]Slice{.{}} ** MAX_RANK;
|
||
slices[self.axis(axis_)] = s;
|
||
return self.slice(slices[0..self.rank()]);
|
||
}
|
||
|
||
/// Slices the input Tensor using the given parameters.
|
||
pub fn slice(self: Tensor, slices: []const Slice) Tensor {
|
||
var start_indices: [MAX_RANK]i64 = undefined;
|
||
var strides: [MAX_RANK]i64 = undefined;
|
||
var limit_indices: [MAX_RANK]i64 = undefined;
|
||
var res_shape: Shape = self._shape;
|
||
|
||
for (slices, 0..) |s, a| {
|
||
meta.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
|
||
|
||
const args: Slice = .{
|
||
.start = self.wrapIndex(a, s.start),
|
||
.end = if (s.end) |end| self.wrapIndex(a, end) else self.dim(a),
|
||
.step = s.step,
|
||
};
|
||
start_indices[a] = args.start;
|
||
limit_indices[a] = args.end.?;
|
||
strides[a] = args.step;
|
||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
||
}
|
||
|
||
const mlir_ctx = self.getContext().mlirCtx();
|
||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
||
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
|
||
const slice_op = dialect.stablehlo.slice(
|
||
mlir_ctx,
|
||
self.value(),
|
||
start_indices[0..self.rank()],
|
||
limit_indices[0..self.rank()],
|
||
strides[0..self.rank()],
|
||
result_type,
|
||
loc,
|
||
);
|
||
|
||
return _result(res_shape, slice_op.result(0));
|
||
}
|
||
|
||
test slice {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||
|
||
// Wrap slice1d to hide the anytype in the signature.
|
||
const Local = struct {
|
||
pub fn slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor {
|
||
return input.slice1d(ax, slice_);
|
||
}
|
||
};
|
||
|
||
{
|
||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
|
||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||
}
|
||
{
|
||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
|
||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||
}
|
||
{
|
||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
|
||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||
}
|
||
}
|
||
|
||
inline fn wrapIndex(self: Tensor, axis_: usize, idx: i64) i64 {
|
||
return if (idx < 0) self.dim(axis_) + idx else idx;
|
||
}
|
||
|
||
pub fn choose1d(self: Tensor, axis_: i64, i: i64) Tensor {
|
||
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_);
|
||
}
|
||
|
||
/// Concatenates the input Tensors along the given axis.
|
||
pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor {
|
||
meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len});
|
||
var buffer: [32]mlir.Value = undefined;
|
||
std.debug.assert(tensors.len <= buffer.len);
|
||
std.debug.assert(tensors.len > 0);
|
||
const a = tensors[0].axis(axis_);
|
||
// TODO(Corendos): Check that tensor axes match.
|
||
|
||
var concatenated_dim: i64 = 0;
|
||
for (tensors, 0..) |t, i| {
|
||
buffer[i] = t.value();
|
||
concatenated_dim += t.dim(a);
|
||
}
|
||
|
||
const res_shape = tensors[0]._shape.set(a, concatenated_dim);
|
||
const loc = tensors[0].getContext().mlirCtx().location(@src()).namedFmt(tensors[0].getContext().mlirCtx(), "axis={}", .{axis_});
|
||
const op = dialect.stablehlo.concatenate(tensors[0].getContext().mlirCtx(), buffer[0..tensors.len], a, loc);
|
||
// log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape });
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
/// Concatenates the input Tensors along a new axis. The Tensors must have the same shape.
|
||
/// For x, y, z of shape .{ .a = 10, .b = 11, .c = 12 }:
|
||
/// - Tensor.stack(&.{x, y, z}, .b, .layers) -> .{ .a, .layers, .b, .c }
|
||
/// - Tensor.stack(&.{x, y, z}, 1, .layers) -> .{ .a, .layers, .b, .c }
|
||
/// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers }
|
||
pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor {
|
||
// Note: we could ask the compilation context for some memory instead of stack allocating
|
||
meta.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len});
|
||
|
||
const shape0 = tensors[0]._shape;
|
||
const res_shape = shape0.insertTag(axis_, 1, tag);
|
||
|
||
for (tensors[1..]) |tensor| {
|
||
meta.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
|
||
}
|
||
|
||
var reshaped: [32]Tensor = undefined;
|
||
for (tensors, 0..) |tensor, i| {
|
||
reshaped[i] = tensor.reshape(res_shape);
|
||
}
|
||
|
||
// Be careful here: we need to resolve ax before calling concatenate,
|
||
// because we added an axis, so all
|
||
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
||
shape0.rank()
|
||
else
|
||
shape0.axis(axis_);
|
||
|
||
return Tensor.concatenate(reshaped[0..tensors.len], ax);
|
||
}
|
||
|
||
/// Repeats a Tensor several times along the given axis.
|
||
///
|
||
/// * repeat1d(x, concat(&.{x, x, x, x}, axis);
|
||
/// * repeat1d([0, 1, 2, 3], 0, 2) = [0, 1, 2, 3, 0, 1, 2, 3]
|
||
pub fn repeat1d(self: Tensor, axis_: anytype, n_rep: u63) Tensor {
|
||
if (n_rep == 1) {
|
||
return self;
|
||
}
|
||
|
||
const a = self.axis(axis_);
|
||
const broadshape = self._shape.insert(a + 1, .{n_rep});
|
||
const repeat_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1);
|
||
|
||
var res = self.broadcast(broadshape, repeat_dims.dims()).flatten(a);
|
||
// Restor the tag that has been lost by flatten.
|
||
res._shape._tags.set(a, self._shape.tag(a));
|
||
|
||
return res;
|
||
}
|
||
|
||
/// Repeats a Tensor several times along the given axes.
|
||
pub fn repeat(self: Tensor, n_reps: []const u63) Tensor {
|
||
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
|
||
meta.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
||
|
||
var res = self;
|
||
for (n_reps, 0..) |n_rep, a| {
|
||
if (n_rep == 1) continue;
|
||
|
||
res = res.repeat1d(a, n_rep);
|
||
}
|
||
return res;
|
||
}
|
||
|
||
/// Repeats in line each value along the given axis.
|
||
///
|
||
/// * stutter1d([0, 1, 2, 3], 0, 2) = [0, 0, 1, 1, 2, 2, 3, 3]
|
||
pub fn stutter1d(self: Tensor, axis_: i64, n_rep: u63) Tensor {
|
||
const a = self.axis(axis_);
|
||
const broadshape = self._shape.insert(a + 1, .{n_rep});
|
||
const stutter_dims = Shape.range(self.rank() + 1, self.dtype()).remove(a + 1);
|
||
|
||
return self.broadcast(broadshape, stutter_dims.dims()).flatten(a);
|
||
}
|
||
|
||
/// Repeats in line each value along the given axes.
|
||
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
|
||
meta.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
||
|
||
var res = self;
|
||
for (n_reps, 0..) |n_rep, a| {
|
||
if (n_rep == 1) continue;
|
||
res = res.stutter1d(@intCast(a), n_rep);
|
||
}
|
||
return res;
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise negation of the input Tensor.
|
||
pub fn negate(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const negate_op = dialect.stablehlo.negate(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, negate_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise cosine of the input Tensor.
|
||
pub fn cos(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const cosine_op = dialect.stablehlo.cosine(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, cosine_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise sine of the input Tensor.
|
||
pub fn sin(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const sine_op = dialect.stablehlo.sine(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, sine_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise exponential operation of the input Tensor.
|
||
pub fn exp(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.exponential(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise logarithm operation of the input Tensor.
|
||
pub fn log(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.log(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise square-root of the input Tensor.
|
||
pub fn sqrt(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const sqrt_op = dialect.stablehlo.sqrt(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, sqrt_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise reverse square-root of the input Tensor.
|
||
pub fn rsqrt(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const rsqrt_op = dialect.stablehlo.rsqrt(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, rsqrt_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise hyperbolic tangent of the input Tensor.
|
||
pub fn tanh(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const tanh_op = dialect.stablehlo.tanh(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, tanh_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise exponential minus one operation of the input Tensor.
|
||
pub fn exponentialMinusOne(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const expm1_op = dialect.stablehlo.exponential_minus_one(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, expm1_op.result(0));
|
||
}
|
||
|
||
pub const ArangeArgs = HostBuffer.ArangeArgs;
|
||
|
||
/// Returns a Tensor containing evenly spaced values within a given interval.
|
||
pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
|
||
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||
|
||
const ctx = CompilationContext.current();
|
||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
|
||
|
||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||
const sh = Shape.init(.{n_steps}, dt);
|
||
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||
var res = _result(sh, op.result(0));
|
||
|
||
if (args.step != 1) {
|
||
res = res.scale(args.step);
|
||
}
|
||
|
||
if (args.start != 0) {
|
||
res = res.addConstant(args.start);
|
||
}
|
||
|
||
return res;
|
||
}
|
||
|
||
/// Returns a Tensor containing values in increasing order starting from 0 along the given axis.
|
||
pub fn iota(sh: Shape, dt: DataType, axis_: anytype) Tensor {
|
||
const ctx = CompilationContext.current();
|
||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "iota({}, {})", .{ sh, axis_ });
|
||
|
||
const a = sh.axis(axis_);
|
||
const res_shape = sh.withDtype(dt);
|
||
var op = dialect.stablehlo.iota(ctx.mlirCtx(), @intCast(a), mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?, loc);
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
pub const LinspaceArgs = struct {
|
||
start: f64,
|
||
end: f64,
|
||
steps: i64,
|
||
};
|
||
|
||
/// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive.
|
||
pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor {
|
||
meta.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||
meta.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps});
|
||
meta.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
||
|
||
const ctx = CompilationContext.current();
|
||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
|
||
|
||
const sh = Shape.init(.{args.steps}, dt);
|
||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||
var res = _result(sh, iota_op.result(0));
|
||
|
||
if (args.steps != 1) {
|
||
res = res.scale(args.steps);
|
||
}
|
||
|
||
if (args.start != 0) {
|
||
res = res.addConstant(args.start);
|
||
}
|
||
|
||
return res;
|
||
}
|
||
|
||
/// Returns a 0d Tensor with the given value.
|
||
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
||
return Tensor.constant(.{}, Data.init(dt, val));
|
||
}
|
||
|
||
/// Returns a constant Tensor with the given value.
|
||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||
const sh = Shape.init(dimz, val.dtype());
|
||
const singleton_sh = Shape.init(.{}, val.dtype());
|
||
const ctx = CompilationContext.current().mlirCtx();
|
||
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||
const elem_type = mlir.ext.denseElementAttrType(val.dtype());
|
||
var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc);
|
||
if (sh.rank() > 0) {
|
||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
|
||
}
|
||
return _result(sh, constant_op.result(0));
|
||
}
|
||
|
||
/// Embeds a buffer with concrete values into an Mlir program.
|
||
pub fn constantTensor(val: HostBuffer) Tensor {
|
||
const ctx = CompilationContext.current().mlirCtx();
|
||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
||
const loc = ctx.location(@src());
|
||
const elem_type = mlir.ext.denseElementAttrType(val.dtype());
|
||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc);
|
||
return _result(val.shape(), constant_op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the result of the outer product between the input Tensors.
|
||
pub fn outer(self: Tensor, other: Tensor) Tensor {
|
||
meta.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() });
|
||
|
||
if (self.rank() + other.rank() == 1) {
|
||
return self.mul(other);
|
||
}
|
||
|
||
const dimz = .{ self.dim(0), other.dim(0) };
|
||
const left = self.broadcast(Shape.init(dimz, self.dtype()), &.{0});
|
||
const right = other.broadcast(Shape.init(dimz, other.dtype()), &.{1});
|
||
return left.mul(right);
|
||
}
|
||
|
||
/// Creates a 2D Tensor with its diagonal set to the input vector.
|
||
pub fn diag(self: Tensor) Tensor {
|
||
meta.assert(self.rank() == 1, "diag only supports tensor of rank 1, got {}", .{self.rank()}); // TODO: support 2D tensors
|
||
//
|
||
const indices = Tensor.arange(.{ .end = self.dim(0) }, .i64);
|
||
const sh = Shape.init(.{ self.dim(0), self.dim(0) }, self.dtype());
|
||
const indices_1 = indices.broadcast(sh, &.{1});
|
||
const indices_0 = indices.broadcast(sh, &.{0});
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
// TODO: handle more complex cases (2D, specifying `diagonal` index).
|
||
const op = dialect.stablehlo.select(
|
||
self.getContext().mlirCtx(),
|
||
indices_1.cmp(.EQ, indices_0).value(),
|
||
self.broadcast(sh, &.{1}).value(),
|
||
Tensor.constant(self.dims(), self.dtype().zero()).value(),
|
||
loc,
|
||
);
|
||
return _result(sh, op.result(0));
|
||
}
|
||
|
||
/// Given a tensor and a shape of the same rank,
|
||
/// will "broadcast" the given axes, so that `self` has the given shape.
|
||
/// This happens by virtually repeating the data several time along each give axes.
|
||
/// Note: most of the time the optimizer will make it so that the broadcast doesn't trigger a copy.
|
||
/// Note: the tags of the return tensor will be from the `output_shape`.
|
||
/// This means if you use and un-tagged broadcast on a tagged tensor,
|
||
/// you will lose the tags.
|
||
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
|
||
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
|
||
const res_shape = output_shape.withDtype(self.dtype());
|
||
|
||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({any}, axes={d})", .{ res_shape, axes_ });
|
||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc);
|
||
|
||
return _result(res_shape, broadcast_op.result(0));
|
||
}
|
||
|
||
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
|
||
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
|
||
meta.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
||
|
||
const a = output_shape.rank() - self.rank();
|
||
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
|
||
return self;
|
||
}
|
||
|
||
return self.broadcast(output_shape, Shape.range(output_shape.rank(), output_shape.dtype()).dims()[a..]);
|
||
}
|
||
|
||
/// Broadcasts a Tensor to the given shape, adding axes at the end.
|
||
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
|
||
meta.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
||
|
||
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
|
||
return self;
|
||
}
|
||
|
||
return self.broadcast(output_shape, Shape.range(self.rank(), output_shape.dtype()).dims());
|
||
}
|
||
|
||
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
|
||
pub fn broad(self: Tensor, other: Shape) Tensor {
|
||
// Non ambiguous broadcasting
|
||
if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) {
|
||
return self.broadcast(other, Shape.range(self._shape.rank(), self.dtype()).dims());
|
||
}
|
||
|
||
// check that each axis of self maps to an axis of other
|
||
var axes_: std.BoundedArray(i64, MAX_RANK) = .{};
|
||
for (self._shape.tags()) |t| {
|
||
if (t != Shape.TagUnknown) {
|
||
if (other.hasTag(t)) |ax| {
|
||
axes_.appendAssumeCapacity(@intCast(other.axis(ax)));
|
||
} else {
|
||
std.debug.panic("Can't broadcast {} to {}", .{ self, other });
|
||
}
|
||
}
|
||
}
|
||
return self.broadcast(other, axes_.constSlice());
|
||
}
|
||
|
||
/// Reshapes the input Tensor with the given shape.
|
||
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
||
const output_shape = self._shape.reshape(output_shape_);
|
||
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reshape({any})", .{output_shape});
|
||
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
||
return _result(output_shape, reshape_value.result(0));
|
||
}
|
||
|
||
pub const Pad1dOpts = struct { low: i64, high: i64, interior: i64 = 0 };
|
||
|
||
/// Pads the input Tensor with the given value over the given axis.
|
||
pub fn pad1d(self: Tensor, axis_: i8, pad_value: anytype, opts: Pad1dOpts) Tensor {
|
||
const ZEROS = [_]i64{0} ** MAX_RANK;
|
||
var padding_low = ZEROS;
|
||
var padding_high = ZEROS;
|
||
var padding_interior = ZEROS;
|
||
const a = self.axis(axis_);
|
||
padding_low[a] = opts.low;
|
||
padding_high[a] = opts.high;
|
||
padding_interior[a] = opts.interior;
|
||
|
||
const rk = self.rank();
|
||
return self.pad(
|
||
pad_value,
|
||
.{ .low = padding_low[0..rk], .high = padding_high[0..rk], .interior = padding_interior[0..rk] },
|
||
);
|
||
}
|
||
|
||
/// Pads the input Tensor with the given values.
|
||
pub fn pad(self: Tensor, pad_value: anytype, opts: dialect.stablehlo.PadOpts) Tensor {
|
||
meta.assert(opts.low.len == self.rank(), "pad expects tensor rank and 'opts.low' length to be equal, got {} and {}", .{ self.rank(), opts.low.len });
|
||
meta.assert(opts.high.len == self.rank(), "pad expects tensor rank and 'opts.high' length to be equal, got {} and {}", .{ self.rank(), opts.high.len });
|
||
const ZEROS = [_]i64{0} ** MAX_RANK;
|
||
const interior = opts.interior orelse ZEROS[0..self.rank()];
|
||
|
||
var new_shape = self._shape;
|
||
|
||
for (0..self.rank()) |i| {
|
||
const d = self.dim(i) + opts.low[i] + (@max(self.dim(i) - 1, 0) * interior[i]) + opts.high[i];
|
||
new_shape = new_shape.set(i, d);
|
||
}
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "pad(value={}, opts={})", .{ pad_value, opts });
|
||
var full_opts = opts;
|
||
full_opts.interior = opts.interior orelse ZEROS[0..self.rank()];
|
||
const pad_value_tensor = Tensor.scalar(pad_value, self.dtype());
|
||
const pad_op = dialect.stablehlo.pad(self.getContext().mlirCtx(), self.value(), pad_value_tensor.value(), full_opts, loc);
|
||
|
||
return _result(new_shape, pad_op.result(0));
|
||
}
|
||
|
||
/// Inserts 1-dim axes at the given position, with the given tags.
|
||
/// `.{.a = 5, .b = 4}.insert(.b, .{ .c, .d }) -> .{ .a = 5, .c = 1, .d = 1, .b = 4 }`
|
||
pub fn insertAxes(self: Tensor, axis_: anytype, tags: anytype) Tensor {
|
||
const tags_ = Shape.parseTags(tags);
|
||
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
||
self.rank()
|
||
else
|
||
self.axis(axis_);
|
||
|
||
var res_shape = self._shape;
|
||
const ones = [_]i64{1} ** MAX_RANK;
|
||
res_shape._dims.insertSlice(ax, ones[0..tags_.len]) catch unreachable;
|
||
res_shape._tags.insertSlice(ax, tags_.constSlice()) catch unreachable;
|
||
|
||
return self.reshape(res_shape);
|
||
}
|
||
|
||
/// Appends a 1-dim axis, with the given tag.
|
||
pub fn appendAxes(self: Tensor, t: anytype) Tensor {
|
||
meta.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK });
|
||
|
||
return self.insertAxes(.last, t);
|
||
}
|
||
|
||
/// Drops a 1-dim axis at the given index
|
||
pub fn squeeze(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
meta.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)});
|
||
|
||
const new_shape = self._shape.remove(a);
|
||
// log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape });
|
||
|
||
return _result(new_shape, self.reshape(new_shape).value());
|
||
}
|
||
|
||
/// Returns a Tensor with the given axes reversed.
|
||
pub fn reverse(self: Tensor, axes_: anytype) Tensor {
|
||
const actual_axes = self._shape.axes(axes_);
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({any})", .{axes_});
|
||
const reverse_op = dialect.stablehlo.reverse(self.getContext().mlirCtx(), self.value(), toI64(actual_axes.constSlice()), loc);
|
||
return _result(self._shape, reverse_op.result(0));
|
||
}
|
||
|
||
pub const GatherOpts = struct { indices_are_sorted: bool = false };
|
||
|
||
/// For each coordinate in `indices`,
|
||
/// `gatherValues` extracts a single value of the given tensor.
|
||
///
|
||
/// * axes_ is a single axis, or a tuple of axis: .b, or .{ .b, .c }
|
||
/// * indices is an integer tensor
|
||
/// * result is a tensor whose shape is similar to the input shape
|
||
/// where the gathered axes have been replaced by axes from 'indices'.
|
||
///
|
||
/// Some example input for the base case where we work on one axis:
|
||
/// - gatherValues(f:[a]->float, .a, ind:[n]->int)[n] == f[ind[n]]
|
||
/// - gatherValues(f:[a, b], .a, ind:[n])[n, b] == f[ind[n], b]
|
||
/// - gatherValues(f: [a,b,c], .{.b}, ind: [n,m])[a, n, m, c] == f[a, ind[n, m], c]
|
||
///
|
||
/// If an axis in common between `self` and `indices`,
|
||
/// it is treated as a "batching" axis, meaning that semantically
|
||
/// the operator is doing a gatherValues one time per dimension of this axis:
|
||
/// - gatherValues(f: [a,b,c], .{.b}, ind: [a,n])[a, n] == f[a, ind[a, n]]
|
||
///
|
||
/// It is an error to have an axis present in `self`, `axes_` and `indices`.
|
||
///
|
||
/// If several axes are passed, then the last axis of indices is treated as coordinates:
|
||
/// - gatherValues(f: [a,b,c], .{.b, .c}, ind: [n,2])[a, n] == f[a, ind[n][0], ind[n][1]]
|
||
/// - gatherValues(f: [a,b,c,d], .{.b, .c}, ind: [a, n,2])[a, n, d] == f[a, ind[a, n][0], ind[a, n][1], d]
|
||
///
|
||
/// It is possible to use gatherValues without tags, but batching won't be available.
|
||
pub fn gatherValues(self: Tensor, coord_axes: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||
// scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices });
|
||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||
|
||
meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||
for (coord_axes_.constSlice(), 0..) |a, i| {
|
||
if (i > 0) {
|
||
meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
|
||
}
|
||
}
|
||
|
||
const AxisKind = enum { batching, offset, collapsed, indices };
|
||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||
var indices_batch_axes: Shape.DimsArray = .{};
|
||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
|
||
if (indices._shape.hasTag(t)) |id_ax| {
|
||
// tag is both in self and indices -> it's a batching dim
|
||
// Note: tags are required for batching.
|
||
self_kind.appendAssumeCapacity(.batching);
|
||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||
meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
||
} else if (maybe_coord_ax) |_| {
|
||
// for gatherValues we collapsed all gathered axes
|
||
// (contrary to gatherSlices where we collapse none)
|
||
self_kind.appendAssumeCapacity(.collapsed);
|
||
} else {
|
||
self_kind.appendAssumeCapacity(.offset);
|
||
}
|
||
}
|
||
|
||
// When we receive several coord_axes we need an extra dimension to store
|
||
// one index per axis, which makes the coordinates of one value.
|
||
// Otherwi se stablehlo uses the "indices.rank()" default value.
|
||
const index_coord_axis = if (single_coord)
|
||
indices.rank()
|
||
else blk: {
|
||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||
meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
|
||
break :blk ax;
|
||
};
|
||
|
||
// compute res shape
|
||
var res_shape = Shape.init(.{}, self.dtype());
|
||
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
|
||
const ax: u3 = @intCast(ax_usize);
|
||
if (ax == coord_axes_.get(0)) {
|
||
// The first val_ax is special cause this is the place where we insert indices axes.
|
||
for (indices._shape.tags(), 0..indices.rank()) |t, id_ax| {
|
||
if (id_ax == index_coord_axis) continue;
|
||
if (std.mem.indexOfScalar(i64, indices_batch_axes.constSlice(), @intCast(id_ax))) |_| {
|
||
// batching dim are already in res
|
||
continue;
|
||
}
|
||
|
||
res_shape = res_shape.appendDim(indices.dim(id_ax), t);
|
||
res_kind.appendAssumeCapacity(.indices);
|
||
}
|
||
}
|
||
switch (kind) {
|
||
.collapsed => continue,
|
||
else => {
|
||
res_shape = res_shape.appendDim(self.dim(ax), self._shape.tag(ax));
|
||
res_kind.appendAssumeCapacity(kind);
|
||
},
|
||
}
|
||
}
|
||
|
||
// This is not a gather, but a dynamicSlice.
|
||
// Sometimes the backend recognize this pattern, but not always.
|
||
// So let us handle that.
|
||
if (indices.count() == 1) {
|
||
return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
||
}
|
||
|
||
var slice_dims: Shape.DimsArray = .{};
|
||
for (self_kind.constSlice(), self.dims()) |k, d| {
|
||
slice_dims.appendAssumeCapacity(switch (k) {
|
||
.batching, .collapsed => 1,
|
||
.offset => d,
|
||
.indices => unreachable,
|
||
});
|
||
}
|
||
|
||
// scoped_log.debug("gatherValues --> {} {any}", .{ res_shape, res_kind.constSlice() });
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const gather_op = dialect.stablehlo.gather(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
indices.value(),
|
||
slice_dims.constSlice(),
|
||
loc,
|
||
.{
|
||
.offset_dims = _collectAxes(AxisKind, res_kind, .offset).constSlice(),
|
||
.collapsed_slice_dims = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
|
||
.operand_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
|
||
.start_indices_batching_dims = indices_batch_axes.constSlice(),
|
||
.start_index_map = _collectAxes(AxisKind, self_kind, .collapsed).constSlice(),
|
||
.index_vector_dim = index_coord_axis,
|
||
.indices_are_sorted = opts.indices_are_sorted,
|
||
},
|
||
);
|
||
|
||
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
||
meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
|
||
return _result(res_shape, gather_op.result(0));
|
||
}
|
||
|
||
test gatherValues {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
{
|
||
// Only test shapes
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .a = 10 }, .a, .{}, .{} },
|
||
.{ .{ .a = 10 }, .a, .{ .n = 8 }, .{ .n = 8 } },
|
||
.{ .{ .a = 10, .b = 20 }, .a, .{}, .{ .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, 0, .{ .n = 8 }, .{ .n = 8, .b = 20 } },
|
||
// Favor val shape, instead of indices shape.
|
||
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
|
||
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
|
||
// batching axes are implicits.
|
||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
||
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
||
// stablehlo.gather is biased toward indices shape (like gatherSlice).
|
||
// This make it awkward to use when you have both batching dimension and new indices dimensions.
|
||
// For now we reject those, and let user explicitly transpose self or indices if needed.
|
||
// .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8, .a = 10 }, .{ .a = 10, .n = 8 } },
|
||
// Also handle tuples
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .n = 8 } },
|
||
.{ .{ 10, 20 }, .{ -2, -1 }, .{ 8, 2 }, .{8} },
|
||
// and 1-tuple
|
||
.{ .{ .a = 10, .b = 20 }, .{.b}, .{ .n = 8, ._ = 1 }, .{ .a = 10, .n = 8 } },
|
||
}) |testcase| {
|
||
const x_shape, const tag, const idx_shape, const res_shape = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
|
||
const y = gatherValues(x, tag, idx, .{});
|
||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||
try std.testing.expect(y.value().owner().verify());
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Gathers slices along the given axes with runtime indices.
|
||
/// * slice_shape represents the shape of the slices to extract,
|
||
/// it must be smaller than original shape.
|
||
/// It must use a subset of self axes.
|
||
/// If slice_shape is **not** tagged, then it must have the same rank than self.
|
||
/// * `indices` represents a set of coordinates.
|
||
/// The coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
|
||
/// The coordinate axis must have `slice_shape.rank()` dims.
|
||
/// The coordinates represent the "top-left" corner of the slice to extract.
|
||
/// * the output tensor starts with axes from `indices`.
|
||
/// * if the input tensor has tagged axes, matching `indices` axes,
|
||
/// they will be considered "batching" axes.
|
||
///
|
||
/// Sample input/output shapes:
|
||
/// * gatherSlices([A, B, C, D], .{.b=B', .c=C'}, [N, 2]) -> [N, A, B', C', D]
|
||
/// * gatherSlices(x(a,b,c,d), .{.b=B', .c=C'}, g(n,m)) = z(n, a, b', c', d) = x(a, g(n, 0) + b', g(n, 1) + c', d)
|
||
///
|
||
/// Note: the axis order of the result is different from gatherValues.
|
||
/// This is because gatherSlices, favorizes contiguous copy of the extracted slices,
|
||
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
|
||
/// In our example the contiguous dimension .d is not sliced
|
||
/// and gatherSlices can copy data by group of C'*D elements.
|
||
pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
||
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
|
||
|
||
const tagged_api = slice_shape.isFullyTagged();
|
||
if (tagged_api) {
|
||
for (slice_shape.tags()) |t| {
|
||
meta.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self });
|
||
}
|
||
} else {
|
||
// For untagged api, we require all slices to be specified.
|
||
// Note: we could relax this and right align the slice.
|
||
meta.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
|
||
}
|
||
|
||
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||
meta.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices });
|
||
|
||
// Compute result shape
|
||
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
|
||
var slice_dims = self._shape._dims;
|
||
var self_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
||
var indices_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
||
var start_index_map: std.BoundedArray(i64, MAX_RANK) = .{};
|
||
var self_offset_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||
const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax);
|
||
|
||
if (tagged_api and indices._shape.hasTag(t) != null) {
|
||
// tag is both in self and indices -> it's a batching dim
|
||
// Note: tags are required for batching.
|
||
self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
|
||
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
|
||
slice_dims.set(self_ax, 1);
|
||
meta.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices });
|
||
} else if (maybe_slice_ax) |slice_ax| {
|
||
// Specified axes contains the start offset of the slices,
|
||
// and are collected in `start_index_map`.
|
||
const slice_dim = slice_shape.dim(slice_ax);
|
||
meta.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape });
|
||
slice_dims.set(self_ax, slice_dim);
|
||
res_shape = res_shape.appendDim(slice_dim, t);
|
||
start_index_map.appendAssumeCapacity(@intCast(self_ax));
|
||
self_offset_axes.appendAssumeCapacity(res_shape.rank() - 1);
|
||
} else {
|
||
// non-batching, non-indexed axes
|
||
res_shape = res_shape.appendDim(self.dim(self_ax), t);
|
||
self_offset_axes.appendAssumeCapacity(res_shape.rank() - 1);
|
||
}
|
||
}
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const gather_op = dialect.stablehlo.gather(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
indices.value(),
|
||
slice_dims.constSlice(),
|
||
loc,
|
||
.{
|
||
.offset_dims = self_offset_axes.constSlice(),
|
||
.collapsed_slice_dims = &.{},
|
||
.operand_batching_dims = self_batch_axes.constSlice(),
|
||
.start_indices_batching_dims = indices_batch_axes.constSlice(),
|
||
.start_index_map = start_index_map.constSlice(),
|
||
.index_vector_dim = index_coord_axis,
|
||
.indices_are_sorted = opts.indices_are_sorted,
|
||
},
|
||
);
|
||
return _result(res_shape, gather_op.result(0));
|
||
}
|
||
|
||
test gatherSlices {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
{
|
||
// Only test shapes
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .a = 10 }, .{}, .{ ._ = 0 }, .{ .a = 10 } },
|
||
.{ .{ .a = 10 }, .{ .a = 7 }, .{ ._ = 1 }, .{ .a = 7 } },
|
||
.{ .{ .a = 10 }, .{ .a = 7 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 7 } },
|
||
.{ .{ .a = 10 }, .{ .a = 7 }, .{ .coord = 1, .n = 8 }, .{ .n = 8, .a = 7 } },
|
||
// tags aren't required.
|
||
.{ .{10}, .{7}, .{ .n = 8, ._ = 1 }, .{ .n = 8, ._ = 7 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a = 7 }, .{ ._ = 1 }, .{ .a = 7, .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a = 7 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 7, .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a = 7 }, .{ .n = 8, .coord = 1, .m = 9 }, .{ .n = 8, .m = 9, .a = 7, .b = 20 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 10, .b = 17 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a = 7, .b = 17 }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 7, .b = 17 } },
|
||
// Note: currently the order of the axes in the slice is not used.
|
||
.{ .{ .a = 10, .b = 20 }, .{ .b = 17, .a = 7 }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 7, .b = 17 } },
|
||
.{ .{ .a = 10, .b = 20, .c = 20 }, .{ .b = 17 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 10, .b = 17, .c = 20 } },
|
||
// batching dims
|
||
.{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .a = 10, ._ = 1 }, .{ .a = 10, .b = 17 } },
|
||
.{ .{ .b = 200, .a = 100, .c = 300 }, .{ .c = 300 }, .{ .a = 100, .b = 200, ._ = 1 }, .{ .a = 100, .b = 200, .c = 300 } },
|
||
}) |testcase| {
|
||
const x_shape, const slice_dims, const idx_shape, const res_shape = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||
const slice_shape = Shape.init(slice_dims, .u16);
|
||
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
|
||
const y = gatherSlices(x, slice_shape, idx, .{});
|
||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||
try std.testing.expect(y.value().owner().verify());
|
||
|
||
const mod = try zml.compileFn(
|
||
std.testing.allocator,
|
||
gatherSlices,
|
||
.{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } },
|
||
platform,
|
||
);
|
||
defer mod.deinit();
|
||
}
|
||
}
|
||
|
||
// Test with actual values.
|
||
const range = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 2 * 4 * 6 }, .u16);
|
||
defer range.deinit(std.testing.allocator);
|
||
const operand = try range.reshape(.{ .a = 2, .b = 4, .c = 6 }).toDevice(platform);
|
||
defer operand.deinit();
|
||
const start_indices = (try zml.Buffer.fromArray(platform, [2][2]i32{ .{ 2, 1 }, .{ 0, 3 } })).withTags(.{ .n, ._ });
|
||
defer start_indices.deinit();
|
||
|
||
const result = try zml.testing.compileAndCall(platform, gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
||
|
||
const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{
|
||
.{
|
||
.{ .{ 13, 14, 15 }, .{ 19, 20, 21 } },
|
||
.{ .{ 37, 38, 39 }, .{ 43, 44, 45 } },
|
||
},
|
||
.{
|
||
.{ .{ 3, 4, 5 }, .{ 9, 10, 11 } },
|
||
.{ .{ 27, 28, 29 }, .{ 33, 34, 35 } },
|
||
},
|
||
});
|
||
try zml.testing.expectClose(expected, result, 0);
|
||
}
|
||
|
||
pub const ScatterOpts = struct {
|
||
/// Promise scatter that all coordinates in `indices` are sorted, wrt to the final in memory offset.
|
||
/// Result is undefined if the promise is violated.
|
||
indices_are_sorted: bool = false,
|
||
|
||
/// Promise scatter that slices don't overlap.
|
||
/// Result is undefined if the promise is violated.
|
||
indices_are_unique: bool = false,
|
||
|
||
/// Function used to update previous value in `self` with values from `updates`.
|
||
/// If `update_fn` is not associative (ie the order of execution matters),
|
||
/// then you should make sure the slices don't overlap,
|
||
/// otherwise the result will depend on the runtime scheduling
|
||
/// of the operator which is backend specific.
|
||
update_fn: *const fn (*const anyopaque, Tensor, Tensor) Tensor = increment,
|
||
|
||
/// Extra data that may be needed for a custom update function.
|
||
/// `override` and `increment` don't need it, leaving it to undefined works.
|
||
update_fn_ctx: *const anyopaque = undefined,
|
||
|
||
pub fn increment(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
|
||
return old_value.add(new_value);
|
||
}
|
||
|
||
pub fn override(_: *const anyopaque, old_value: Tensor, new_value: Tensor) Tensor {
|
||
_ = old_value;
|
||
return new_value;
|
||
}
|
||
};
|
||
|
||
/// Update the given tensors, by copying `values` into self slices.
|
||
/// The slices are chosen at runtime by interpreting indices as coordinates into `self`.
|
||
/// * `indices` represents a set of coordinates into `self`.
|
||
/// For the sake of simplifying the creation of `indices` tensor,
|
||
/// it's allowed to not mention a specific axis if the coordinate for this axis is always `0`.
|
||
/// Similarly to `gatherValues`, the coordinates are read from the `.coord` axis, or last axis if `.coord` is not found.
|
||
/// The coordinates represent the "top-left" corner of the slice to extract.
|
||
/// `indices.dim(.coord)` must match `coord_axes.len`.
|
||
/// Other axes identify one "slice" and they must be found inside `updates`.
|
||
///
|
||
/// * the output tensor starts with axes from `indices`.
|
||
/// * if the input tensor has tagged axes, matching `indices` axes,
|
||
/// they will be considered "batching" axes.
|
||
///
|
||
/// Sample input/output shapes:
|
||
/// * scatterSlices([A, B, C, D], .{b, c}, [N, 2], [N, B', C']) -> [A, B, C, D]
|
||
/// * scatterSlices(x(a,b,c,d), g(n,m), y[n,b,c]) [A,B,C,D] {
|
||
/// var z = x;
|
||
/// for (0..N) |n| { z[a,g[n,0]+b',g[n,1]+c',d] = y[n,a,b',c',d]; }
|
||
/// }
|
||
///
|
||
/// **Warning**: if `opts.update_fn` is not associative not all calls to `scatterSlices` are sound.
|
||
/// In particular if you scatter overlapping slices, with `zml.Tensor.ScatterOpts.override`,
|
||
/// then the result will depend on the execution order that you don't control.
|
||
pub fn scatterSlices(self: Tensor, coord_axes: anytype, indices: Tensor, updates: Tensor, opts: ScatterOpts) Tensor {
|
||
const loc = @src();
|
||
// scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates });
|
||
|
||
meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() });
|
||
|
||
const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes);
|
||
const AxisKind = enum { batching, update_window, inserted_window, window_id };
|
||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||
var indices_batch_axes: Shape.DimsArray = .{};
|
||
for (self._shape.tags()) |t| {
|
||
if (updates._shape.hasTag(t)) |_| {
|
||
if (indices._shape.hasTag(t)) |id_ax| {
|
||
// tag is in self, indices and updates -> it's a batching dim
|
||
self_kind.appendAssumeCapacity(.batching);
|
||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||
} else {
|
||
self_kind.appendAssumeCapacity(.update_window);
|
||
}
|
||
} else {
|
||
self_kind.appendAssumeCapacity(.inserted_window);
|
||
}
|
||
}
|
||
// scoped_log.warn(" self_kind -> {any}", .{self_kind.constSlice()});
|
||
|
||
const index_coord_axis = if (single_coord)
|
||
indices.rank()
|
||
else blk: {
|
||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||
meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices });
|
||
|
||
break :blk ax;
|
||
};
|
||
if (indices.count() == 1) {
|
||
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
|
||
}
|
||
|
||
var up_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||
// Note: we assume the scatter_dims appear in the same order inside indices and inside self.
|
||
for (updates._shape.tags(), 0..) |t, up_ax| {
|
||
if (self._shape.hasTag(t)) |self_ax| {
|
||
if (self_kind.get(self_ax) == .batching) {
|
||
up_kind.appendAssumeCapacity(.batching);
|
||
} else {
|
||
meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates });
|
||
up_kind.appendAssumeCapacity(.update_window);
|
||
}
|
||
} else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) {
|
||
up_kind.appendAssumeCapacity(.window_id);
|
||
} else {
|
||
std.debug.panic("scatterSlices expects 'updates' to be made of axes from 'self={}' and from 'indices={}', got unknown tag {s} in {}", .{ self, indices, t, updates });
|
||
}
|
||
}
|
||
const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
||
if (single_coord) {
|
||
meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||
} else {
|
||
meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates });
|
||
}
|
||
|
||
const ctx = self.getContext();
|
||
const mlir_ctx = ctx.mlirCtx();
|
||
|
||
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
|
||
const UpdateS = ops.BlockSign(ScatterOpts.increment);
|
||
const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||
|
||
const op = dialect.stablehlo.scatter(
|
||
mlir_ctx,
|
||
&.{self.value()},
|
||
&.{indices.value()},
|
||
&.{updates.value()},
|
||
update_block,
|
||
.{
|
||
.update_window_dims = _collectAxes(AxisKind, up_kind, .update_window).constSlice(),
|
||
.inserted_window_dims = _collectAxes(AxisKind, self_kind, .inserted_window).constSlice(),
|
||
.input_batching_dims = _collectAxes(AxisKind, self_kind, .batching).constSlice(),
|
||
.scatter_indices_batching_dims = indices_batch_axes.constSlice(),
|
||
.scatter_dims_to_operand_dims = toI64(coord_axes_.constSlice()),
|
||
.index_vector_dim = index_coord_axis,
|
||
.indices_are_sorted = opts.indices_are_sorted,
|
||
.unique_indices = opts.indices_are_unique,
|
||
},
|
||
mlir_ctx.location(loc),
|
||
);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
test scatterSlices {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const Local = struct {
|
||
pub fn scatter(self: Tensor, coord_axes: Shape.AxesArray, indices: Tensor, updates: Tensor) Tensor {
|
||
return self.scatterSlices(
|
||
coord_axes.constSlice(),
|
||
indices,
|
||
updates,
|
||
.{ .update_fn = ScatterOpts.increment },
|
||
);
|
||
}
|
||
};
|
||
|
||
{
|
||
// Only test shapes
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .a = 10 }, .a, .{}, .{ .a = 3 } },
|
||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8, .b = 2 } },
|
||
// I'm not sure I like this variant, cause `b` is not mentionned in updates.
|
||
// So 'stablehlo.scatter' is implicitly broadcasting the updates along `b` axis.
|
||
// OTOH asking the user to do the broadcasting isn't trivial cause they will need to do shape wrangling and that's annoying.
|
||
.{ .{ .a = 10, .b = 20 }, .a, .{ .n = 8 }, .{ .n = 8, .a = 2 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .b, .a }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 3, .b = 2 } },
|
||
.{ .{ .a = 10, .b = 20 }, .{ .a, .b }, .{ .n = 8, ._ = 2 }, .{ .a = 3, .n = 8, .b = 2 } },
|
||
}) |testcase| {
|
||
const x_shape, const axes_, const idx_shape, const updates_shapes = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||
const idx = Tensor.constant(idx_shape, .{ .i32 = 0 });
|
||
const updates = Tensor.constant(updates_shapes, .{ .f16 = 0 });
|
||
|
||
const y = scatterSlices(x, axes_, idx, updates, .{});
|
||
// Shape doesn't change with scatterSlices
|
||
try zml.testing.expectEqualShapes(x.shape(), y.shape());
|
||
try std.testing.expect(y.value().owner().verify());
|
||
}
|
||
}
|
||
// Test with actual values, no batching.
|
||
{
|
||
const a_host = try zml.HostBuffer.arange(std.testing.allocator, .{ .end = 9 }, .i32);
|
||
const a = (try zml.Buffer.from(platform, a_host.reshape(.{ 3, 3 }))).withTags(.{ .a, .b });
|
||
defer a.deinit();
|
||
a_host.deinit(std.testing.allocator);
|
||
|
||
const scatter_indices = try zml.Buffer.fromArray(platform, [2][1]i32{ .{0}, .{2} });
|
||
const updates = try zml.Buffer.fromArray(platform, [2][3]i32{ .{ 10, 20, 30 }, .{ 70, 80, 90 } });
|
||
|
||
const expected = [3][3]i32{ .{ 10, 21, 32 }, .{ 3, 4, 5 }, .{ 76, 87, 98 } };
|
||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{
|
||
a,
|
||
a.shape().axes(.{.a}),
|
||
scatter_indices.withTags(.{ .n, .coord }),
|
||
updates.withTags(.{ .n, .b }),
|
||
});
|
||
try std.testing.expect(a.shape().eql(result.shape()));
|
||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||
}
|
||
{
|
||
// Test with actual values and batching along axis .a
|
||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
|
||
defer operand.deinit();
|
||
const start_indices = (try zml.Buffer.fromArray(
|
||
platform,
|
||
[2][2][3][2]i32{
|
||
.{
|
||
.{ .{ 0, 0 }, .{ 1, 0 }, .{ 2, 1 } },
|
||
.{ .{ 0, 1 }, .{ 1, 1 }, .{ 0, 9 } },
|
||
},
|
||
.{
|
||
.{ .{ 0, 0 }, .{ 2, 1 }, .{ 2, 2 } },
|
||
.{ .{ 1, 2 }, .{ 0, 1 }, .{ 1, 0 } },
|
||
},
|
||
},
|
||
)).withTags(.{ .n, .a, .m, .coord });
|
||
defer start_indices.deinit();
|
||
|
||
const values = try zml.Buffer.constant(
|
||
platform,
|
||
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
|
||
1,
|
||
);
|
||
defer values.deinit();
|
||
|
||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values });
|
||
|
||
const expected = [2][3][4][2]u16{
|
||
.{
|
||
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
|
||
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 2, 2 }, .{ 2, 2 } },
|
||
.{ .{ 0, 0 }, .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 } },
|
||
},
|
||
.{
|
||
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
|
||
.{ .{ 2, 2 }, .{ 3, 3 }, .{ 1, 1 }, .{ 0, 0 } },
|
||
.{ .{ 0, 0 }, .{ 1, 1 }, .{ 1, 1 }, .{ 0, 0 } },
|
||
},
|
||
};
|
||
try std.testing.expect(operand.shape().eql(result.shape()));
|
||
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||
}
|
||
}
|
||
|
||
/// Returns a Tensor containing the maximum over a given axis.
|
||
pub fn max(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
return ops.reduce(
|
||
struct {
|
||
pub fn cmp(x: Tensor, res: Tensor) Tensor {
|
||
return res.maximum(x.convert(res.dtype()));
|
||
}
|
||
}.cmp,
|
||
self,
|
||
Tensor.constant(&.{}, self.dtype().minValue()),
|
||
&.{a},
|
||
);
|
||
}
|
||
|
||
/// Returns a Tensor containing the minimum over a given axis.
|
||
pub fn min(self: Tensor, axis_: anytype) Tensor {
|
||
const a = self.axis(axis_);
|
||
return ops.reduce(
|
||
struct {
|
||
pub fn cmp(x: Tensor, res: Tensor) Tensor {
|
||
return res.minimum(x.convert(res.dtype()));
|
||
}
|
||
}.cmp,
|
||
self,
|
||
Tensor.constant(&.{}, self.dtype().maxValue()),
|
||
&.{a},
|
||
);
|
||
}
|
||
|
||
pub const ArgMaxRes = struct {
|
||
values: Tensor,
|
||
indices: Tensor,
|
||
|
||
fn cmp(left: ArgMaxRes, right: ArgMaxRes) ArgMaxRes {
|
||
const left_val = left.values;
|
||
const right_val = right.values;
|
||
const left_idx = left.indices;
|
||
const right_idx = right.indices;
|
||
|
||
const left_gt_right = left_val.cmp(.GT, right_val);
|
||
const is_nan = left_val.cmp(.NE, left_val);
|
||
const left_gt_or_nan = left_gt_right.logical(.OR, is_nan);
|
||
// we are bubbling up Nan.
|
||
const max_val = left_gt_or_nan.select(left_val, right_val);
|
||
|
||
// If left_val == right_val: keep the smallest idx.
|
||
const is_same = left_val.cmp(.EQ, right_val);
|
||
const is_first = left_idx.cmp(.LT, right_idx);
|
||
const is_same_but_first = is_same.logical(.AND, is_first);
|
||
const keep_left_idx = left_gt_or_nan.logical(.OR, is_same_but_first);
|
||
const max_idx = keep_left_idx.select(left_idx, right_idx);
|
||
|
||
return .{ .values = max_val, .indices = max_idx };
|
||
}
|
||
};
|
||
|
||
/// Returns two Tensors containing the maximum and the index of this maximum over a given axis.
|
||
///
|
||
/// Stable argmax:
|
||
/// * bubbles up Nan
|
||
/// * in case of equality the smallest index matching the maximum
|
||
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
|
||
meta.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
|
||
|
||
const a = x.axis(axis_);
|
||
|
||
return ops.reduce(
|
||
ArgMaxRes.cmp,
|
||
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x._shape.withDtype(index_dtype), &.{a}) },
|
||
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) },
|
||
&.{a},
|
||
);
|
||
}
|
||
|
||
test argMax {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
const allocator = std.testing.allocator;
|
||
const ArgMaxTest = struct {
|
||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||
return x.argMax(1, .i32);
|
||
}
|
||
};
|
||
|
||
const argmax = try zml.compileFn(allocator, ArgMaxTest.forward, .{Shape.init(.{ 1, 5 }, .f32)}, platform);
|
||
defer argmax.deinit();
|
||
// Test with tie
|
||
{
|
||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, 4.1, 7.9, 0, 7.9 }});
|
||
const res = argmax.call(.{x});
|
||
const max_ = res.values.getValue(f32);
|
||
const max_idx = res.indices.getValue(i32);
|
||
try testing.expectEqual(max_, 7.9);
|
||
// We should always return the first max found.
|
||
try testing.expectEqual(max_idx, 2);
|
||
}
|
||
|
||
// Test with Nan
|
||
{
|
||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, std.math.nan(f32), 7.9, 0, 7.9 }});
|
||
const res = argmax.call(.{x});
|
||
const max_ = try res.values.getValue(f32);
|
||
const max_idx = try res.indices.getValue(i32);
|
||
try testing.expect(std.math.isNan(max_));
|
||
try testing.expectEqual(max_idx, 1);
|
||
}
|
||
}
|
||
|
||
pub const SortRes = ArgMaxRes;
|
||
|
||
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
|
||
pub fn sort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) SortRes {
|
||
const a = self.axis(axis_);
|
||
const indices = Tensor.arange(.{ .end = self.dim(a) }, .i32).broadcast(self._shape, &.{a});
|
||
const res = ops.sort(
|
||
struct {
|
||
fn call(direction: dialect.stablehlo.ComparisonDirection.Direction, lhs: Tensor, rhs: Tensor, _: Tensor, _: Tensor) Tensor {
|
||
return lhs.cmp(direction, rhs);
|
||
}
|
||
}.call,
|
||
if (opts.descending) .GT else .LT,
|
||
.{ self, indices },
|
||
self.axis(axis_),
|
||
true,
|
||
);
|
||
return .{ .values = res[0], .indices = res[1] };
|
||
}
|
||
|
||
/// Returns a Tensor containing the indices corresponding to the sorted values over the given axis.
|
||
pub fn argsort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) Tensor {
|
||
return self.sort(axis_, .{ .descending = opts.descending }).indices;
|
||
}
|
||
|
||
test argsort {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||
defer arena_state.deinit();
|
||
const allocator = arena_state.allocator();
|
||
// 2D Tensor - dim = 1, ascending
|
||
{
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
||
const res = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{} });
|
||
const res_cpu = try res.toHostAlloc(allocator);
|
||
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||
}
|
||
// 3D Tensor, dim = 1, descending
|
||
{
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 10 }, &[_]f16{
|
||
-0.2505, 1.2520, -0.7041, 0.1066, 1.2773, -1.7246, 0.8389, 1.1094, 0.0601, 1.0684,
|
||
0.9619, 1.3916, 1.2246, -0.1406, 0.3674, -1.2480, -1.7051, -0.0934, 0.3435, 0.4373,
|
||
1.3809, 0.5444, -0.6079, 1.2031, -0.6880, 1.2979, -0.1869, 0.2991, 0.0156, 0.1847,
|
||
0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008,
|
||
2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918,
|
||
});
|
||
const res_dev = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{ .descending = true } });
|
||
const res = try res_dev.toHostAlloc(allocator);
|
||
try testing.expectEqualSlices(i32, &.{
|
||
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
||
2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
|
||
1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
|
||
3, 2, 0, 1, 4, 1, 1, 1, 2, 1,
|
||
0, 3, 3, 3, 3, 0, 3, 3, 4, 2,
|
||
}, res.items(i32));
|
||
}
|
||
// 4D Tensor, dim = 3, ascending
|
||
{
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 4, 2, 1, 4 }, &[_]i32{
|
||
89, 31, 22, 42,
|
||
64, 39, 0, 30,
|
||
64, 71, 46, 31,
|
||
89, 82, 78, 86,
|
||
55, 32, 43, 19,
|
||
93, 24, 45, 72,
|
||
64, 86, 62, 88,
|
||
57, 21, 19, 12,
|
||
});
|
||
const res_dev = try zml.testing.compileAndCallWithTensors(platform, Tensor.argsort, .{ x.shape(), 3, .{} }, .{ x, 0, .{} });
|
||
const res = try res_dev.toHostAlloc(allocator);
|
||
try testing.expectEqualSlices(i32, &.{
|
||
2, 1, 3, 0,
|
||
2, 3, 1, 0,
|
||
3, 2, 0, 1,
|
||
2, 1, 3, 0,
|
||
3, 1, 2, 0,
|
||
1, 2, 3, 0,
|
||
2, 0, 1, 3,
|
||
3, 2, 1, 0,
|
||
}, res.items(i32));
|
||
}
|
||
}
|
||
|
||
/// Returns a Tensor representing the result of Top-K over the given axis.
|
||
pub fn topK(self: Tensor, k: u32, axis_: anytype, opts: struct { descending: bool = true }) SortRes {
|
||
const a = self.axis(axis_);
|
||
const result = self.sort(a, .{ .descending = opts.descending });
|
||
return .{
|
||
.values = result.values.slice1d(a, .{ .end = k }),
|
||
.indices = result.indices.slice1d(a, .{ .end = k }),
|
||
};
|
||
}
|
||
|
||
pub const MaxPoolRes = ArgMaxRes;
|
||
|
||
/// Computes the 1d maxPool operation on the input Tensor.
|
||
pub fn maxPool1d(self: Tensor, opts: struct {
|
||
window_dimensions: i64,
|
||
window_strides: ?i64,
|
||
base_dilations: i64 = 1,
|
||
window_dilations: i64 = 1,
|
||
padding: []const i64 = &.{0},
|
||
}) MaxPoolRes {
|
||
// TODO migrate to the following syntax.
|
||
// maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} },
|
||
// .b = .{ .stride = 8, .dilation = 2, .padding = .{0, 1} }),
|
||
// maxPool(.{
|
||
// .stride = .{ .a = 5, .b = 8 },
|
||
// .dilation = .{ .a = 2, .b = 2 },
|
||
// .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2}
|
||
// })
|
||
|
||
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "maxPool1d expects 'opts.padding' length to be a single integer or to be equal to the double of input tensor rank, got {} (input tensor rank is {})", .{ opts.padding.len, self.rank() });
|
||
|
||
// Note: the problem is initPoolArg assuming last axis
|
||
// TODO: support maxPool on non last axis
|
||
const a = self.axis(-1);
|
||
|
||
const window_dimensions = initPoolArg(self.rank(), &.{opts.window_dimensions});
|
||
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &.{ws}) else window_dimensions;
|
||
const base_dilation = initPoolArg(self.rank(), &.{opts.base_dilations});
|
||
const window_dilations = initPoolArg(self.rank(), &.{opts.window_dilations});
|
||
|
||
return ops.reduceWindow(
|
||
MaxPoolRes.cmp,
|
||
.{ .values = self, .indices = iota(self._shape, .i32, a) },
|
||
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
||
.{
|
||
.window_dimensions = window_dimensions[0..self.rank()],
|
||
.window_strides = window_strides[0..self.rank()],
|
||
.base_dilations = base_dilation[0..self.rank()],
|
||
.window_dilations = window_dilations[0..self.rank()],
|
||
.padding_values = opts.padding,
|
||
.padding_shape = &.{ @intCast(self.rank()), 2 },
|
||
},
|
||
);
|
||
}
|
||
|
||
/// Computes the 2d maxPool operation on the input Tensor.
|
||
pub fn maxPool2d(self: Tensor, opts: struct {
|
||
window_dimensions: []const i64,
|
||
window_strides: ?[]const i64 = null,
|
||
base_dilations: []const i64 = &.{ 1, 1 },
|
||
window_dilations: []const i64 = &.{ 1, 1 },
|
||
padding: []const i64 = &.{0},
|
||
}) MaxPoolRes {
|
||
// TODO: rewrite using modern ZML, add ops.reduceWindow
|
||
meta.guard(self.rank() == 3 or self.rank() == 4, @src());
|
||
meta.guard(opts.window_dimensions.len == 2, @src());
|
||
meta.guard(opts.window_strides == null or opts.window_strides.?.len == 2, @src());
|
||
meta.guard(opts.base_dilations.len == 2, @src());
|
||
meta.guard(opts.window_dilations.len == 2, @src());
|
||
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "Padding needs to either be a single integer, or to be 2x time the number of input rank. In maxPool({}, .padding={d})", .{ self, opts.padding });
|
||
|
||
// TODO: support maxPool on non last axis
|
||
// Note: the problem is initPoolArg assuming last axis
|
||
const a = self.axis(-1);
|
||
|
||
const window_dimensions = initPoolArg(self.rank(), opts.window_dimensions);
|
||
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), ws) else window_dimensions;
|
||
const base_dilation = initPoolArg(self.rank(), opts.base_dilations);
|
||
const window_dilations = initPoolArg(self.rank(), opts.window_dilations);
|
||
|
||
return ops.reduceWindow(
|
||
MaxPoolRes.cmp,
|
||
.{ .values = self, .indices = iota(self._shape, .i32, a) },
|
||
.{ .values = Tensor.constant(.{}, self.dtype().minValue()), .indices = Tensor.scalar(0, .i32) },
|
||
.{
|
||
.window_dimensions = window_dimensions[0..self.rank()],
|
||
.window_strides = window_strides[0..self.rank()],
|
||
.base_dilations = base_dilation[0..self.rank()],
|
||
.window_dilations = window_dilations[0..self.rank()],
|
||
.padding_values = opts.padding,
|
||
.padding_shape = &.{ @intCast(self.rank()), 2 },
|
||
},
|
||
);
|
||
}
|
||
|
||
pub inline fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) {
|
||
return self._shape.axes(axes_);
|
||
}
|
||
|
||
/// Chunk a given tensor into exactly n parts of equal shape.
|
||
/// `self.dim(axis_)` must be divisible by n_chunks.
|
||
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
|
||
const a = self.axis(axis_);
|
||
const d = self.dim(a);
|
||
const chunk_size = @divExact(d, n_chunks);
|
||
var chunks: [n_chunks]Tensor = undefined;
|
||
for (0..n_chunks) |i| {
|
||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||
chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size });
|
||
}
|
||
return chunks;
|
||
}
|
||
|
||
test chunkExact {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
// Only test shapes
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .a = 12 }, .a, 3, .{ .a = 4 } },
|
||
.{ .{ .a = 12, .b = 2 }, .a, 3, .{ .a = 4, .b = 2 } },
|
||
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 } },
|
||
}) |testcase| {
|
||
const x_shape, const ax, const n_chunks, const res = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||
const chunks = x.chunkExact(ax, n_chunks);
|
||
|
||
const res_shape = Shape.init(res, .f16);
|
||
for (&chunks) |chk| {
|
||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Chunk a given tensor into n parts of equal shape, and one part with the remaining items.
|
||
/// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`.
|
||
pub fn chunkAllowTrailing(
|
||
self: Tensor,
|
||
axis_: i64,
|
||
n_chunks: comptime_int,
|
||
) std.BoundedArray(Tensor, n_chunks + 1) {
|
||
const a = self.axis(axis_);
|
||
const d = self.dim(a);
|
||
const chunk_size: i64 = @divFloor(d, n_chunks);
|
||
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
||
|
||
var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{};
|
||
for (0..n_chunks) |i| {
|
||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||
chunks.appendAssumeCapacity(
|
||
self.slice1d(a, .{ .start = start, .end = start + chunk_size }),
|
||
);
|
||
}
|
||
if (tail_chunk_size != 0) {
|
||
const start: i64 = n_chunks * chunk_size;
|
||
chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
|
||
}
|
||
return chunks;
|
||
}
|
||
|
||
test chunkAllowTrailing {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
// Only test shapes
|
||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||
defer comp.deinit();
|
||
comp.activate();
|
||
defer comp.deactivate();
|
||
|
||
inline for (.{
|
||
.{ .{ .a = 10 }, .a, 3, .{ .a = 3 }, .{ .a = 1 } },
|
||
.{ .{ .a = 10, .b = 2 }, .a, 3, .{ .a = 3, .b = 2 }, .{ .a = 1, .b = 2 } },
|
||
.{ .{ 10, 2 }, 0, 3, .{ 3, 2 }, .{ 1, 2 } },
|
||
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 }, .{} },
|
||
}) |testcase| {
|
||
const x_shape, const ax, const n_chunks, const res, const trailing = testcase;
|
||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
|
||
|
||
const res_shape = Shape.init(res, .f16);
|
||
for (chunks.constSlice()[0..n_chunks]) |chk| {
|
||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||
}
|
||
const trailing_shape = Shape.init(trailing, .f16);
|
||
if (trailing_shape.rank() > 0) {
|
||
try std.testing.expectEqual(n_chunks + 1, chunks.len);
|
||
try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape());
|
||
} else {
|
||
try std.testing.expectEqual(n_chunks, chunks.len);
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor {
|
||
meta.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len});
|
||
|
||
const a = self.axis(axis_);
|
||
const length = self.dim(a);
|
||
if (split_size_or_sections.len != 1) {
|
||
var split_sum: i64 = 0;
|
||
for (split_size_or_sections) |n| split_sum += n;
|
||
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
||
}
|
||
|
||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
||
errdefer allocator.dealloc(res);
|
||
|
||
var start: i64 = 0;
|
||
for (split_size_or_sections, 0..) |n, i| {
|
||
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
||
start += n;
|
||
}
|
||
return res;
|
||
}
|
||
|
||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||
/// Note: this doesn't support tagging, if you have tags,
|
||
/// you should use `dynamicSlice` directly.
|
||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
||
meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
||
|
||
const a = self.axis(axis_);
|
||
const new_shape = self._shape.set(a, len);
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, len={}", .{ axis_, len });
|
||
var indices: [Tensor.MAX_RANK]mlir.Value = undefined;
|
||
for (0..self.rank()) |i| {
|
||
indices[i] = if (i == a)
|
||
start_indices.value()
|
||
else
|
||
constant(.{}, start_indices.dtype().zero()).value();
|
||
}
|
||
|
||
const op = dialect.stablehlo.dynamicSlice(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
new_shape.dims(),
|
||
indices[0..self.rank()],
|
||
loc,
|
||
);
|
||
|
||
return _result(new_shape, op.result(0));
|
||
}
|
||
|
||
pub const DynSlice = struct { start: Tensor, len: i64 };
|
||
|
||
/// Slices a Tensor across many axes, with runtime known offsets.
|
||
///
|
||
/// Due to the nature of stablehlo, the length of the slices need to be known when compiling the IR.
|
||
/// When using the tagged API it is allowed to not specify some axes.
|
||
/// But with the non-tagged API all slices need to be specified.
|
||
/// Examples:
|
||
/// ```
|
||
/// Tensor(.{.a=20,.b=30,.c=40 }).dynamicSlice(.{ .a = .{ .start = a_off, .len = 11});
|
||
/// Tensor(.{.a=20,.b=30,.c=40 }).dynamicSlice(.{
|
||
/// .a = .{ .start = a_off, .len = 11 },
|
||
/// .b = .{ .start = b_off, .len = 12 },
|
||
/// });
|
||
/// Tensor(.{ 20,30,40}).dynamicSlice(.{.{ .start = scalar(0, .i32), .len = 20 }, .{ .start = b_off, .len = 12 }, .{ .start = scalar(0, .i32), .len = 40 }});
|
||
/// ```
|
||
pub fn dynamicSlice(self: Tensor, slices_: anytype) Tensor {
|
||
// TODO: the untagged api is a bit verbose. Should I allow: `Tensor(.{ 20,30,40}).dynamicSlice(.{.{}, .{ .start = b_off, .len = 12 }, .{}});` ??
|
||
//
|
||
const slices, const slices_tags = Shape.parseStruct(DynSlice, slices_);
|
||
|
||
// TODO use slices and slices_tags for the format.
|
||
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
|
||
// which is kinda ugly.
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({any})", .{slices_});
|
||
|
||
const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32;
|
||
const zero = Tensor.scalar(0, idx_dtype).value();
|
||
var offset_values = [_]mlir.Value{zero} ** MAX_RANK;
|
||
var res_shape = self._shape;
|
||
for (slices.constSlice(), 0..) |slice_, i| {
|
||
const offset = slice_.start;
|
||
const len = slice_.len;
|
||
if (slices_tags.len == 0) {
|
||
meta.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len });
|
||
|
||
offset_values[i] = offset.value();
|
||
res_shape._dims.set(i, len);
|
||
|
||
meta.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i });
|
||
} else {
|
||
const t = slices_tags.get(i);
|
||
const a = res_shape.hasTag(t) orelse meta.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
||
|
||
meta.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t });
|
||
|
||
offset_values[a] = offset.value();
|
||
res_shape._dims.set(a, len);
|
||
}
|
||
}
|
||
const op = dialect.stablehlo.dynamicSlice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc);
|
||
return _result(res_shape, op.result(0));
|
||
}
|
||
|
||
test dynamicSlice {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
const T = f32;
|
||
|
||
{
|
||
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z });
|
||
|
||
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||
}
|
||
|
||
{
|
||
// Strided
|
||
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
||
|
||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z });
|
||
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||
}
|
||
}
|
||
|
||
/// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime.
|
||
pub fn dynamicUpdateSlice1d(self: Tensor, update: Tensor, axis_: i64, offset: Tensor) Tensor {
|
||
const placeholder = Tensor.scalar(0, .i32);
|
||
var start_indices = [_]Tensor{placeholder} ** MAX_RANK;
|
||
start_indices[self.axis(axis_)] = offset;
|
||
return self.dynamicUpdateSlice(start_indices[0..self.rank()], update);
|
||
}
|
||
|
||
/// Updates a part of the input Tensor using the given 'update' Tensor, with runtime known offsets.
|
||
///
|
||
/// The offsets are specified similarly to the dynamicSlice api.
|
||
/// It's semantically equivalent to:
|
||
/// self.dynamicSlice(offsets_) := update
|
||
/// Examples:
|
||
/// ```
|
||
/// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 }));
|
||
/// ```
|
||
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
|
||
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
|
||
|
||
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
|
||
// log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags });
|
||
for (offset.constSlice(), 0..) |start_idx, i| {
|
||
meta.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i });
|
||
}
|
||
|
||
const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0;
|
||
// When using tags, we can safely insert axis with a 1-dim.
|
||
// the offset into the inserted axis will need to be specified through indices.
|
||
var update = update_;
|
||
if (tagged_api) {
|
||
// Check that all update tags are known.
|
||
for (update._shape._tags.constSlice()) |t| {
|
||
meta.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape });
|
||
}
|
||
|
||
var update_shape = self._shape;
|
||
var prev_ax: i8 = -1;
|
||
for (self._shape.tags(), 0..) |t, self_ax| {
|
||
if (update._shape.hasTag(t)) |up_ax| {
|
||
meta.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape });
|
||
|
||
update_shape._dims.set(self_ax, update.dim(up_ax));
|
||
prev_ax = up_ax;
|
||
} else {
|
||
update_shape._dims.set(self_ax, 1);
|
||
}
|
||
}
|
||
update = update.reshape(update_shape);
|
||
}
|
||
|
||
meta.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() });
|
||
|
||
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
|
||
const t = self._shape.debugTag(ax);
|
||
meta.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t });
|
||
|
||
if (tagged_api and up_d < self_d) {
|
||
const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null;
|
||
|
||
meta.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t });
|
||
}
|
||
}
|
||
|
||
const idx_dtype = if (offset.len > 0) offset.get(0).dtype() else .i32;
|
||
const zero = Tensor.scalar(0, idx_dtype).value();
|
||
var offset_values: [MAX_RANK]mlir.Value = undefined;
|
||
if (offset_tags.len == 0) {
|
||
// Without offset tags we need the same number of offset than rank.
|
||
meta.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len });
|
||
|
||
for (offset.constSlice(), 0..) |idx, i| {
|
||
offset_values[i] = idx.value();
|
||
}
|
||
} else {
|
||
// If an axis isn't specified, update the full slice.
|
||
// This is only allowed when using tagged sliced.
|
||
offset_values = .{zero} ** MAX_RANK;
|
||
for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
|
||
const a = self._shape.hasTag(t) orelse meta.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
||
offset_values[a] = start.value();
|
||
}
|
||
}
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.dynamic_update_slice(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
update.value(),
|
||
offset_values[0..self.rank()],
|
||
loc,
|
||
);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
test dynamicUpdateSlice {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
{
|
||
const x = try zml.Buffer.fromArray(platform, [10]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||
const y = try zml.Buffer.fromArray(platform, [2]f32{ -1, -1 });
|
||
const idx = try zml.Buffer.scalar(platform, 4, .i32);
|
||
const res = try zml.testing.compileAndCall(
|
||
platform,
|
||
struct {
|
||
pub fn forward(x_: Tensor, idx_: struct { a: Tensor }, y_: Tensor) Tensor {
|
||
return x_.dynamicUpdateSlice(idx_, y_);
|
||
}
|
||
}.forward,
|
||
.{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) },
|
||
);
|
||
try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
|
||
}
|
||
|
||
{
|
||
// Updates 2D, tagged api
|
||
const x = try zml.Buffer.fromArray(platform, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||
const y = try zml.Buffer.fromArray(platform, [2]f32{ -1, -1 });
|
||
const idx = try zml.Buffer.scalar(platform, 3, .i32);
|
||
|
||
const res = try zml.testing.compileAndCall(
|
||
platform,
|
||
struct {
|
||
pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor {
|
||
return x_.dynamicUpdateSlice(.{ .b = idx_ }, y_);
|
||
}
|
||
}.forward,
|
||
.{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) },
|
||
);
|
||
try testing.expectEqualDeep(
|
||
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||
try res.getValue([2][5]f32),
|
||
);
|
||
}
|
||
|
||
{
|
||
// Updates 2D slice, un-tagged api. Note that `y` needs to have a 1 dimension axis.
|
||
const x = try zml.Buffer.fromArray(platform, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||
const y = try zml.Buffer.fromArray(platform, [2][1]f32{ .{-1}, .{-1} });
|
||
const idx = try zml.Buffer.scalar(platform, 3, .i32);
|
||
const res = try zml.testing.compileAndCall(
|
||
platform,
|
||
struct {
|
||
pub fn forward(x_: Tensor, idx_: Tensor, y_: Tensor) Tensor {
|
||
return x_.dynamicUpdateSlice(.{ zml.Tensor.scalar(0, .i32), idx_ }, y_);
|
||
}
|
||
}.forward,
|
||
.{ x, idx, y },
|
||
);
|
||
try testing.expectEqualDeep(
|
||
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||
res.getValue([2][5]f32),
|
||
);
|
||
}
|
||
|
||
{
|
||
// Updates 2D, partial update
|
||
const x = try zml.Buffer.fromArray(platform, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||
const y = try zml.Buffer.fromArray(platform, [1]f32{-1});
|
||
const idx_a = try zml.Buffer.scalar(platform, 1, .i32);
|
||
const idx_b = try zml.Buffer.scalar(platform, 3, .i32);
|
||
const res = try zml.testing.compileAndCall(
|
||
platform,
|
||
struct {
|
||
pub fn forward(x_: Tensor, idx_: struct { a: Tensor, b: Tensor }, y_: Tensor) Tensor {
|
||
return x_.dynamicUpdateSlice(idx_, y_);
|
||
}
|
||
}.forward,
|
||
.{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) },
|
||
);
|
||
try testing.expectEqualDeep(
|
||
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||
res.getValue([2][5]f32),
|
||
);
|
||
}
|
||
|
||
{
|
||
// Updates 2D, partial update, un-tagged api.
|
||
const x = try zml.Buffer.fromArray(platform, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||
const y = try zml.Buffer.fromArray(platform, [1][1]f32{.{-1}});
|
||
const idx_a = try zml.Buffer.scalar(platform, 1, .i32);
|
||
const idx_b = try zml.Buffer.scalar(platform, 3, .i32);
|
||
const A = struct {
|
||
pub fn forward(x_: Tensor, idx_: [2]Tensor, y_: Tensor) Tensor {
|
||
return x_.dynamicUpdateSlice(&idx_, y_);
|
||
}
|
||
};
|
||
const res = try zml.testing.compileAndCall(platform, A.forward, .{ x, .{ idx_a, idx_b }, y });
|
||
try testing.expectEqualDeep(
|
||
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
|
||
res.getValue([2][5]f32),
|
||
);
|
||
}
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
|
||
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
|
||
meta.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||
|
||
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
|
||
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
|
||
|
||
meta.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
||
|
||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
|
||
const op = dialect.stablehlo.compare(
|
||
self.getContext().mlirCtx(),
|
||
self.value(),
|
||
other.value(),
|
||
dialect.stablehlo.ComparisonDirection.init(self.getContext().mlirCtx(), direction),
|
||
getComparisonType(self.getContext().mlirCtx(), self.dtype()),
|
||
loc,
|
||
);
|
||
|
||
return _result(self._shape.withDtype(.bool), op.result(0));
|
||
}
|
||
|
||
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
|
||
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
|
||
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
||
meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()});
|
||
meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() });
|
||
meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
||
meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
||
|
||
const loc = bool_tensor.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.select(
|
||
bool_tensor.getContext().mlirCtx(),
|
||
bool_tensor.value(),
|
||
on_true.value(),
|
||
on_false.value(),
|
||
loc,
|
||
);
|
||
|
||
return _result(on_true._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing the element-wise not logical operation of the input Tensor.
|
||
pub fn not(self: Tensor) Tensor {
|
||
const loc = self.getContext().mlirCtx().location(@src());
|
||
const op = dialect.stablehlo.not(self.getContext().mlirCtx(), self.value(), loc);
|
||
return _result(self._shape, op.result(0));
|
||
}
|
||
|
||
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||
pub fn any(self: Tensor, axis_: i64) Tensor {
|
||
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
||
const red = ops.reduce(
|
||
struct {
|
||
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||
return res.logical(.OR, x);
|
||
}
|
||
}.acc,
|
||
pred,
|
||
Tensor.scalar(0, pred.dtype()),
|
||
&.{self.axis(axis_)},
|
||
);
|
||
return red;
|
||
}
|
||
|
||
/// Given a set of N vectors of lengths A, B, C, D,
|
||
/// returns N tensors of rank N, and shape (A, B, C, D).
|
||
/// For any coordinate (a, b, c, d),
|
||
/// we have:
|
||
/// - res[0][a, b, c, d] == A[a]
|
||
/// - res[1][a, b, c, d] == B[b]
|
||
/// - res[2][a, b, c, d] == C[c]
|
||
/// - res[3][a, b, c, d] == D[d]
|
||
/// This is implemented with broadcasting, so typically it won't copy.
|
||
/// In Pytorch/Numpy this is know as `meshgrid` with "ij" mode.
|
||
/// See torch.meshgrid for the "xy" mode.
|
||
pub fn cartesianProduct(comptime N: u3, vectors: [N]Tensor) [N]Tensor {
|
||
var out: @TypeOf(vectors) = undefined;
|
||
_cartesianProduct(&vectors, &out);
|
||
return out;
|
||
}
|
||
|
||
fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void {
|
||
meta.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{});
|
||
meta.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len });
|
||
for (vectors) |x| {
|
||
meta.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors});
|
||
meta.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors});
|
||
}
|
||
|
||
var res_shape = Shape.init(.{}, vectors[0].dtype());
|
||
for (vectors) |x| {
|
||
if (x.rank() == 0) {
|
||
res_shape = res_shape.appendDim(1, null);
|
||
} else {
|
||
res_shape = res_shape.appendDim(x.dim(0), x.shape().tag(0));
|
||
}
|
||
}
|
||
|
||
for (out, vectors, 0..) |*o, x, i| {
|
||
o.* = x.broadcast(res_shape, &[1]i64{@intCast(i)});
|
||
}
|
||
}
|
||
|
||
test cartesianProduct {
|
||
const zml = @import("zml.zig");
|
||
const client = zml.testing.env();
|
||
|
||
const x = try zml.Buffer.fromSlice(client, .{6}, &[_]i32{ 0, 1, 2, 3, 4, 5 });
|
||
const y = try zml.Buffer.fromSlice(client, .{4}, &[_]i32{ 0, 1, 2, 3 });
|
||
|
||
const Local = struct {
|
||
pub fn cartesianProduct2(a: Tensor, b: Tensor) [2]Tensor {
|
||
return cartesianProduct(2, .{ a, b });
|
||
}
|
||
};
|
||
|
||
{
|
||
const xs, const ys = try zml.testing.compileAndCall(client, Local.cartesianProduct2, .{ x, y });
|
||
try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, xs.shape().dims());
|
||
try std.testing.expectEqualSlices(i64, &.{ 6, 4 }, ys.shape().dims());
|
||
try std.testing.expectEqualDeep(
|
||
[6][4]i32{
|
||
.{ 0, 0, 0, 0 },
|
||
.{ 1, 1, 1, 1 },
|
||
.{ 2, 2, 2, 2 },
|
||
.{ 3, 3, 3, 3 },
|
||
.{ 4, 4, 4, 4 },
|
||
.{ 5, 5, 5, 5 },
|
||
},
|
||
try xs.getValue([6][4]i32),
|
||
);
|
||
try std.testing.expectEqualDeep(
|
||
[6][4]i32{
|
||
.{ 0, 1, 2, 3 },
|
||
.{ 0, 1, 2, 3 },
|
||
.{ 0, 1, 2, 3 },
|
||
.{ 0, 1, 2, 3 },
|
||
.{ 0, 1, 2, 3 },
|
||
.{ 0, 1, 2, 3 },
|
||
},
|
||
try ys.getValue([6][4]i32),
|
||
);
|
||
}
|
||
}
|
||
|
||
/// Given a set of N vectors of lengths A, B, C, D,
|
||
/// returns 1 tensors of rank N+1, and shape (A, B, C, D, N).
|
||
/// For any coordinate (a, b, c, d),
|
||
/// we have:
|
||
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
|
||
pub fn cartesianProductStacked(vectors: []const Tensor) Tensor {
|
||
var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable;
|
||
_cartesianProduct(vectors, out.slice());
|
||
|
||
return Tensor.stack(out.constSlice(), .last, .coord);
|
||
}
|
||
|
||
test cartesianProductStacked {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
const x = try zml.Buffer.fromSlice(platform, .{6}, &[_]i32{ 0, 1, 2, 3, 4, 5 });
|
||
const y = try zml.Buffer.fromSlice(platform, .{4}, &[_]i32{ 0, 1, 2, 3 });
|
||
|
||
const Local = struct {
|
||
pub fn cartesianProduct2(a: Tensor, b: Tensor) Tensor {
|
||
return cartesianProductStacked(&.{ a, b });
|
||
}
|
||
};
|
||
|
||
const z = try zml.testing.compileAndCall(platform, Local.cartesianProduct2, .{ x, y });
|
||
try std.testing.expectEqualDeep(
|
||
[6][4][2]i32{
|
||
.{ .{ 0, 0 }, .{ 0, 1 }, .{ 0, 2 }, .{ 0, 3 } },
|
||
.{ .{ 1, 0 }, .{ 1, 1 }, .{ 1, 2 }, .{ 1, 3 } },
|
||
.{ .{ 2, 0 }, .{ 2, 1 }, .{ 2, 2 }, .{ 2, 3 } },
|
||
.{ .{ 3, 0 }, .{ 3, 1 }, .{ 3, 2 }, .{ 3, 3 } },
|
||
.{ .{ 4, 0 }, .{ 4, 1 }, .{ 4, 2 }, .{ 4, 3 } },
|
||
.{ .{ 5, 0 }, .{ 5, 1 }, .{ 5, 2 }, .{ 5, 3 } },
|
||
},
|
||
try z.getValue([6][4][2]i32),
|
||
);
|
||
}
|
||
|
||
fn binaryOp(
|
||
op_name: []const u8,
|
||
op_fn: fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation,
|
||
) fn (Tensor, Tensor) Tensor {
|
||
return struct {
|
||
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
||
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self.dtype(), other.dtype() });
|
||
|
||
if (self.rank() == 0 and other.rank() != 0) {
|
||
return binaryOpHelper(self.broad(other._shape), other);
|
||
}
|
||
|
||
if (self.rank() != 0 and other.rank() == 0) {
|
||
return binaryOpHelper(self, other.broad(self._shape));
|
||
}
|
||
|
||
meta.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
||
|
||
const mlirCtx = self.getContext().mlirCtx();
|
||
const location = mlirCtx.location(@src());
|
||
const ret = @call(.auto, op_fn, .{ mlirCtx, self.value(), other.value(), location });
|
||
return _result(self._shape, ret.result(0));
|
||
}
|
||
}.binaryOpHelper;
|
||
}
|
||
};
|
||
|
||
fn initPoolArg(rank: usize, data: []const i64) [Tensor.MAX_RANK]i64 {
|
||
// TODO use shape
|
||
var result = [_]i64{1} ** Tensor.MAX_RANK;
|
||
const start = rank - data.len;
|
||
@memcpy(result[start .. start + data.len], data);
|
||
return result;
|
||
}
|
||
|
||
fn getPoolResDims(dt: DataType, in_dims: []const i64, base_dilations: @Vector(Tensor.MAX_RANK, i64), padding: []const i64, window_dimensions: @Vector(Tensor.MAX_RANK, i64), window_dilations: @Vector(Tensor.MAX_RANK, i64), window_strides: @Vector(Tensor.MAX_RANK, i64)) Shape {
|
||
// TODO use shape
|
||
var input_dims = [_]i64{1} ** Tensor.MAX_RANK;
|
||
@memcpy(input_dims[0..in_dims.len], in_dims);
|
||
|
||
const input_dims_: @Vector(Tensor.MAX_RANK, i64) = input_dims;
|
||
const splat_one: @Vector(Tensor.MAX_RANK, i64) = @splat(1);
|
||
const dilated_input_shape: @Vector(Tensor.MAX_RANK, i64) = (input_dims_ - splat_one) * base_dilations + splat_one;
|
||
var pad_slice0: @Vector(Tensor.MAX_RANK, i64) = @splat(padding[0]);
|
||
var pad_slice1: @Vector(Tensor.MAX_RANK, i64) = @splat(padding[0]);
|
||
if (padding.len > 1) {
|
||
var idx: usize = 0;
|
||
while (idx < in_dims.len * 2) : (idx += 2) {
|
||
pad_slice0[idx / 2] = padding[idx];
|
||
pad_slice1[idx / 2] = padding[idx + 1];
|
||
}
|
||
}
|
||
const padded_input_shape: @Vector(Tensor.MAX_RANK, i64) = pad_slice0 + dilated_input_shape + pad_slice1;
|
||
const dilated_window_shape = (window_dimensions - splat_one) * window_dilations + splat_one;
|
||
const dims = @divFloor(padded_input_shape - dilated_window_shape, window_strides) + splat_one;
|
||
const dims_arr: [Tensor.MAX_RANK]i64 = @bitCast(dims);
|
||
return Shape.init(dims_arr[0..in_dims.len], dt);
|
||
}
|
||
|
||
fn getComparisonType(ctx: mlir.Context, dtype: DataType) dialect.stablehlo.CompareType {
|
||
return dialect.stablehlo.CompareType.init(ctx, switch (dtype) {
|
||
.i4, .i8, .i16, .i32, .i64 => .SIGNED,
|
||
.bool, .u4, .u8, .u16, .u32, .u64 => .UNSIGNED,
|
||
.f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16, .f16, .f32, .f64 => .FLOAT,
|
||
.c64, .c128 => @panic("Can't compare complex numbers"),
|
||
});
|
||
}
|
||
|
||
test "Tensor.maxPool1d" {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const MaxPool = struct {
|
||
pub fn forward(x: zml.Tensor) Tensor.ArgMaxRes {
|
||
return x.maxPool1d(.{
|
||
.window_dimensions = 3,
|
||
.window_strides = 2,
|
||
});
|
||
}
|
||
};
|
||
|
||
var data: [20]f32 = undefined;
|
||
for (&data, 0..) |*v, i| v.* = @floatFromInt(i);
|
||
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5 }, &data);
|
||
const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x});
|
||
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .f32), result.values.shape());
|
||
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2 }, .i32), result.indices.shape());
|
||
const buffer = result.values.getValue([2][2][2]f32);
|
||
try std.testing.expectEqualDeep(
|
||
[2][2][2]f32{
|
||
[2][2]f32{
|
||
[2]f32{ 2, 4 },
|
||
[2]f32{ 7, 9 },
|
||
},
|
||
[2][2]f32{
|
||
[2]f32{ 12, 14 },
|
||
[2]f32{ 17, 19 },
|
||
},
|
||
},
|
||
buffer,
|
||
);
|
||
}
|
||
|
||
test "Tensor.maxPool2d" {
|
||
const zml = @import("zml.zig");
|
||
const platform = zml.testing.env();
|
||
|
||
const MaxPool = struct {
|
||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||
return x.maxPool2d(.{
|
||
.window_dimensions = &.{ 3, 2 },
|
||
.window_strides = &.{ 2, 1 },
|
||
});
|
||
}
|
||
};
|
||
|
||
var data: [100]f32 = undefined;
|
||
for (&data, 0..) |*v, i| v.* = @floatFromInt(i);
|
||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 2, 5, 5 }, &data);
|
||
|
||
const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x});
|
||
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .f32), result.values.shape());
|
||
try zml.testing.expectEqualShapes(Shape.init(.{ 2, 2, 2, 4 }, .i32), result.indices.shape());
|
||
var buffer: [2][2][2][4]f32 = undefined;
|
||
_ = try result.values.toHost(std.mem.asBytes(&buffer));
|
||
try std.testing.expectEqualDeep(
|
||
[2][2][2][4]f32{
|
||
.{
|
||
.{ .{ 11, 12, 13, 14 }, .{ 21, 22, 23, 24 } },
|
||
.{ .{ 36, 37, 38, 39 }, .{ 46, 47, 48, 49 } },
|
||
},
|
||
.{
|
||
.{ .{ 61, 62, 63, 64 }, .{ 71, 72, 73, 74 } },
|
||
.{ .{ 86, 87, 88, 89 }, .{ 96, 97, 98, 99 } },
|
||
},
|
||
},
|
||
buffer,
|
||
);
|
||
}
|
||
|
||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||
pub fn Bufferized(comptime T: type) type {
|
||
return meta.MapType(Tensor, Buffer).map(T);
|
||
}
|
||
|
||
/// Return a clone of a type with Tensors replaced by Shapes.
|
||
/// Recursively descends into the type.
|
||
/// See also: shapesOf() and its tests, and meta.MapType().
|
||
pub fn ShapeOf(comptime T: type) type {
|
||
const M = meta.MapType(Tensor, Shape);
|
||
return M.map(T);
|
||
}
|
||
|
||
/// Return a clone of the argument where each instance of a Tensor is replaced
|
||
/// by its Shape. This is similar to ShapeOf(), but with runtime values.
|
||
/// See also: meta.mapAlloc().
|
||
pub fn shapesOf(model: anytype, allocator: std.mem.Allocator) !ShapeOf(@TypeOf(model)) {
|
||
var shapes: ShapeOf(@TypeOf(model)) = undefined;
|
||
try meta.mapAlloc(struct {
|
||
fn shapeFromTensorCallback(_: void, tensor: Tensor) Shape {
|
||
return tensor.shape();
|
||
}
|
||
}.shapeFromTensorCallback, allocator, {}, model, &shapes);
|
||
return shapes;
|
||
}
|
||
|
||
test shapesOf {
|
||
const alloc = std.testing.allocator;
|
||
|
||
// Tensor in struct
|
||
{
|
||
const S = struct {
|
||
a: Tensor,
|
||
};
|
||
const shape = Shape.init(.{ 28, 28 }, .f32);
|
||
const s: S = .{
|
||
.a = Tensor{ ._shape = shape, ._id = undefined },
|
||
};
|
||
const shapes = try shapesOf(s, alloc);
|
||
try std.testing.expectEqual(shape, shapes.a);
|
||
}
|
||
|
||
// single Tensor
|
||
{
|
||
const shape = Shape.init(.{ 28, 28 }, .f32);
|
||
const tensor = Tensor{ ._shape = shape, ._id = undefined };
|
||
const shapes = try shapesOf(tensor, alloc);
|
||
try std.testing.expectEqual(shape, shapes);
|
||
}
|
||
|
||
// nn linear layer, no bias
|
||
{
|
||
const nn = @import("nn.zig");
|
||
const shape = Shape.init(.{ 28, 28 }, .f32);
|
||
const layer: nn.Linear = .{
|
||
.weight = Tensor{ ._shape = shape, ._id = undefined },
|
||
.bias = null,
|
||
};
|
||
|
||
const shapes = try shapesOf(layer, alloc);
|
||
try std.testing.expectEqual(shape, shapes.weight);
|
||
try std.testing.expectEqual(null, shapes.bias);
|
||
}
|
||
|
||
// model
|
||
{
|
||
const Mnist = struct {
|
||
fc1: Layer,
|
||
fc2: Layer,
|
||
|
||
const Layer = struct {
|
||
weight: Tensor,
|
||
bias: Tensor,
|
||
};
|
||
};
|
||
|
||
const fc1_weight_shape = Shape.init(.{ 500, 784 }, .f32);
|
||
const fc1_bias_shape = Shape.init(.{500}, .f32);
|
||
const fc2_weight_shape = Shape.init(.{ 10, 500 }, .f32);
|
||
const fc2_bias_shape = Shape.init(.{10}, .f32);
|
||
const mnist: Mnist = .{
|
||
.fc1 = .{
|
||
.weight = Tensor{ ._shape = fc1_weight_shape, ._id = undefined },
|
||
.bias = Tensor{ ._shape = fc1_bias_shape, ._id = undefined },
|
||
},
|
||
.fc2 = .{
|
||
.weight = Tensor{ ._shape = fc2_weight_shape, ._id = undefined },
|
||
.bias = Tensor{ ._shape = fc2_bias_shape, ._id = undefined },
|
||
},
|
||
};
|
||
|
||
const shapes = try shapesOf(mnist, alloc);
|
||
try std.testing.expectEqual(fc1_weight_shape, shapes.fc1.weight);
|
||
try std.testing.expectEqual(fc1_bias_shape, shapes.fc1.bias);
|
||
try std.testing.expectEqual(fc2_weight_shape, shapes.fc2.weight);
|
||
try std.testing.expectEqual(fc2_bias_shape, shapes.fc2.bias);
|
||
}
|
||
}
|
||
|
||
fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) {
|
||
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
||
for (bounded_array.constSlice(), 0..) |v, ax| {
|
||
if (v == value) {
|
||
res.appendAssumeCapacity(@intCast(ax));
|
||
}
|
||
}
|
||
return res;
|
||
}
|
||
|
||
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
||
const AxesT = @TypeOf(axes_);
|
||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
||
|
||
const coord_axes = if (axes_is_scalar)
|
||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||
else
|
||
self.axes(axes_);
|
||
|
||
return .{ axes_is_scalar, coord_axes };
|
||
}
|
||
|
||
fn parseArrayInfo(T: type) Shape {
|
||
return switch (@typeInfo(T)) {
|
||
.Array => |arr| {
|
||
const s = parseArrayInfo(arr.child);
|
||
return s.insert(0, .{arr.len});
|
||
},
|
||
else => .{ ._dtype = DataType.fromZigType(T) },
|
||
};
|
||
}
|
||
|
||
inline fn toI64(values: anytype) []i64 {
|
||
var res: [Tensor.MAX_RANK]i64 = undefined;
|
||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||
return res[0..values.len];
|
||
}
|