Replace silu implementation with stablehlo.logistic for higher precision, move logistic logic into sigmoid and alias logistic to sigmoid (breaking change).
This commit is contained in:
parent
021111d07d
commit
fefd84b1bb
@ -293,13 +293,6 @@ pub const Tensor = struct {
|
|||||||
return _result(self._shape.withDtype(.bool), op.result(0));
|
return _result(self._shape.withDtype(.bool), op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise logistic operation on the input Tensor.
|
|
||||||
pub fn logistic(self: Tensor) Tensor {
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
|
||||||
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
|
|
||||||
return _result(self._shape, op.result(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
||||||
pub fn popcnt(self: Tensor) Tensor {
|
pub fn popcnt(self: Tensor) Tensor {
|
||||||
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
||||||
@ -1190,11 +1183,13 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
|
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
|
||||||
pub fn sigmoid(self: Tensor) Tensor {
|
pub fn sigmoid(self: Tensor) Tensor {
|
||||||
// until the metal plugin supports `stablehlo.logistics`, implement in the way JAX does it
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
const one = Tensor.constant(&.{}, self.dtype().one()).broadcast(self._shape, &.{});
|
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
|
||||||
return one.div(one.add(self.negate().exp()));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const logistic = sigmoid;
|
||||||
|
|
||||||
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
|
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
|
||||||
pub fn relu(self: Tensor) Tensor {
|
pub fn relu(self: Tensor) Tensor {
|
||||||
return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));
|
return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user