Replace silu implementation with stablehlo.logistic for higher precision, move logistic logic into sigmoid and alias logistic to sigmoid (breaking change).
This commit is contained in:
parent
021111d07d
commit
fefd84b1bb
@ -293,13 +293,6 @@ pub const Tensor = struct {
|
||||
return _result(self._shape.withDtype(.bool), op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise logistic operation on the input Tensor.
|
||||
pub fn logistic(self: Tensor) Tensor {
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise number of bits set in the input Tensor.
|
||||
pub fn popcnt(self: Tensor) Tensor {
|
||||
meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()});
|
||||
@ -1190,11 +1183,13 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor.
|
||||
pub fn sigmoid(self: Tensor) Tensor {
|
||||
// until the metal plugin supports `stablehlo.logistics`, implement in the way JAX does it
|
||||
const one = Tensor.constant(&.{}, self.dtype().one()).broadcast(self._shape, &.{});
|
||||
return one.div(one.add(self.negate().exp()));
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
pub const logistic = sigmoid;
|
||||
|
||||
/// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor.
|
||||
pub fn relu(self: Tensor) Tensor {
|
||||
return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user