diff --git a/zml/tensor.zig b/zml/tensor.zig index a8d71e8..f1065ae 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -293,13 +293,6 @@ pub const Tensor = struct { return _result(self._shape.withDtype(.bool), op.result(0)); } - /// Returns a Tensor containing the element-wise logistic operation on the input Tensor. - pub fn logistic(self: Tensor) Tensor { - const loc = self.getContext().mlirCtx().location(@src()); - const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc); - return _result(self._shape, op.result(0)); - } - /// Returns a Tensor containing the element-wise number of bits set in the input Tensor. pub fn popcnt(self: Tensor) Tensor { meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()}); @@ -1190,11 +1183,13 @@ pub const Tensor = struct { /// Returns a Tensor containing the sigmoid function applied to each element of the input Tensor. pub fn sigmoid(self: Tensor) Tensor { - // until the metal plugin supports `stablehlo.logistics`, implement in the way JAX does it - const one = Tensor.constant(&.{}, self.dtype().one()).broadcast(self._shape, &.{}); - return one.div(one.add(self.negate().exp())); + const loc = self.getContext().mlirCtx().location(@src()); + const op = dialect.stablehlo.logistic(self.getContext().mlirCtx(), self.value(), loc); + return _result(self._shape, op.result(0)); } + pub const logistic = sigmoid; + /// Returns a Tensor containing the ReLU activation function applied to each element of the input Tensor. pub fn relu(self: Tensor) Tensor { return self.maximum(Tensor.constant(self.dims(), self.dtype().zero()));