Replace silu implementation with stablehlo.logistic for higher precision, move logistic logic into sigmoid and alias logistic to sigmoid (breaking change).

This commit is contained in:
Tarry Singh 2023-05-01 10:40:50 +00:00
parent 021111d07d
commit fefd84b1bb

View File

@ -293,13 +293,6 @@ pub const Tensor = struct {
return _result(self._shape.withDtype(.bool), op.result(0)); return _result(self._shape.withDtype(.bool), op.result(0));
} }
/// Returns a Tensor containing the element-wise logistic operation on the input Tensor.
pub fn logistic(self: Tensor) Tensor {
const loc = self.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
return _result(self._shape, op.result(0));
}
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor. /// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
pub fn popcnt(self: Tensor) Tensor { pub fn popcnt(self: Tensor) Tensor {
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()}); meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
@ -1190,11 +1183,13 @@ pub const Tensor = struct {
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor. /// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
pub fn sigmoid(self: Tensor) Tensor { pub fn sigmoid(self: Tensor) Tensor {
// until the metal plugin supports `stablehlo.logistics`, implement in the way JAX does it const loc = self.getContext().mlirCtx().location(@src());
const one = Tensor.constant(&.{}, self.dtype().one()).broadcast(self._shape, &.{}); const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
return one.div(one.add(self.negate().exp())); return _result(self._shape, op.result(0));
} }
pub const logistic = sigmoid;
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor. /// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
pub fn relu(self: Tensor) Tensor { pub fn relu(self: Tensor) Tensor {
return self.maximum(Tensor.constant(self.dims(), self.dtype().zero())); return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));