diff --git a/zml/nn.zig b/zml/nn.zig index 8ba6903..2793bf7 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -52,6 +52,7 @@ pub const Activation = union(enum) { tanh, relu, leakyReLU: f32, + elu: f32, silu, gelu, quick_gelu, @@ -63,10 +64,18 @@ pub const Activation = union(enum) { .relu => x.relu(), .silu => x.silu(), .gelu => x.gelu(), + .elu => |alpha| elu(x, alpha), .quick_gelu => x.quickGelu(), .leakyReLU => |slope| x.leakyReLU(slope), }; } + + pub fn elu(x: Tensor, alpha: f32) Tensor { + return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select( + x, + x.exp().addConstant(-1).scale(alpha), + ); + } }; pub fn chainModules(module_list: anytype, input: Tensor) Tensor { @@ -701,7 +710,6 @@ pub fn causalAttnMask( } if (dtype.isFloat()) { - meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats const zeros = Tensor.constant(mask.shape(), dtype.zero()); const minus_inf = Tensor.constant(mask.shape(), dtype.minValue()); mask = Tensor.select(mask, zeros, minus_inf); diff --git a/zml/shape.zig b/zml/shape.zig index a61f25e..d4a9abc 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -523,11 +523,11 @@ pub const Shape = struct { return res; } - pub fn remove(self: Shape, d: anytype) Shape { + pub fn remove(self: Shape, axis_: anytype) Shape { var res = self; - const d_ = self.axis(d); - _ = res._dims.orderedRemove(d_); - _ = res._tags.orderedRemove(d_); + const a = self.axis(axis_); + _ = res._dims.orderedRemove(a); + _ = res._tags.orderedRemove(a); return res; } diff --git a/zml/tensor.zig b/zml/tensor.zig index a69a3c7..0428a86 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1829,27 +1829,6 @@ pub const Tensor = struct { return left.mul(right); } - /// Creates a 2D Tensor with its diagonal set to the input vector. - pub fn diag(self: Tensor) Tensor { - meta.assert(self.rank() == 1, "diag only supports tensor of rank 1, got {}", .{self.rank()}); // TODO: support 2D tensors - // - const indices = Tensor.arange(.{ .end = self.dim(0) }, .i64); - const sh = Shape.init(.{ self.dim(0), self.dim(0) }, self.dtype()); - const indices_1 = indices.broadcast(sh, &.{1}); - const indices_0 = indices.broadcast(sh, &.{0}); - - const loc = self.getContext().mlirCtx().location(@src()); - // TODO: handle more complex cases (2D, specifying `diagonal` index). - const op = dialect.stablehlo.select( - self.getContext().mlirCtx(), - indices_1.cmp(.EQ, indices_0).value(), - self.broadcast(sh, &.{1}).value(), - Tensor.constant(self.dims(), self.dtype().zero()).value(), - loc, - ); - return _result(sh, op.result(0)); - } - /// Given a tensor and a shape of the same rank, /// will "broadcast" the given axes, so that `self` has the given shape. /// This happens by virtually repeating the data several time along each give axes. @@ -1895,7 +1874,7 @@ pub const Tensor = struct { pub fn broad(self: Tensor, other: Shape) Tensor { // Non ambiguous broadcasting if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) { - return self.broadcast(other, Shape.range(self._shape.rank(), self.dtype()).dims()); + return self.broadcast(other, Shape.range(self._shape.rank(), .bool).dims()); } // check that each axis of self maps to an axis of other @@ -3340,6 +3319,126 @@ pub const Tensor = struct { return _result(self._shape.withDtype(.bool), op.result(0)); } + /// For each vector in the input tensor, + /// creates a diagonal-matrix where diagonal values are set to the vector values. + pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor { + meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); + const a = self.axis(axis_); + const d = self.dim(a); + var res_shape = self._shape; + res_shape._dims.replaceRange(a, 1, &.{ d, d }) catch unreachable; + res_shape._tags.replaceRange(a, 1, &.{ @tagName(new_tags[0]), @tagName(new_tags[1]) }) catch unreachable; + + const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape); + const zeros = Tensor.constant(res_shape, self.dtype().zero()); + + const x = Tensor.iota(res_shape, .i32, a); + const y = Tensor.iota(res_shape, .i32, a + 1); + var res = x.cmp(.EQ, y).select(values, zeros); + res._shape = res_shape; + return res; + } + + test toDiagonal { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + pub fn _toDiag(input: Tensor) Tensor { + const res = input.toDiagonal(-1, .{ .x, .y }); + std.debug.assert(res.dim(.x) == input.dim(-1)); + std.debug.assert(res.dim(.y) == input.dim(-1)); + return res; + } + }; + + const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } }); + { + const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x}); + try testing.expectEqual( + [2][2][2]u8{ .{ + .{ 1, 0 }, + .{ 0, 2 }, + }, .{ + .{ 3, 0 }, + .{ 0, 4 }, + } }, + try res.getValue([2][2][2]u8), + ); + } + } + + /// For each matrix specified by the two axes, returns the lower triangular part of it. + /// The other elements are set to 0. + /// Usage: `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .w, .h}, 0);` + /// + /// * if `num_diagonals` is set to 0, the diagonal is not modified. + /// * if set to -1, the diagonal is set to 0 + /// * if set to n, the n "quasi diagonal" above the diagonal are conserved. + /// + /// To get the upper triangular part, swap the order of axes: + /// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);` + pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor { + meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); + const _axes = self.axes(axes_); + + const x = Tensor.iota(self.shape(), .i32, _axes.get(0)); + const y = Tensor.iota(self.shape(), .i32, _axes.get(1)); + + const zeros = Tensor.constant(self.shape(), self.dtype().zero()); + return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros); + } + + test triangular { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + pub fn _tri(input: Tensor, num_diagonals: i32) Tensor { + return input.triangular(.{ -2, -1 }, num_diagonals); + } + }; + + const x = try zml.Buffer.fromArray(platform, [3][3]u8{ + .{ 1, 1, 1 }, + .{ 1, 1, 1 }, + .{ 1, 1, 1 }, + }); + { + const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 }); + try testing.expectEqual( + [3][3]u8{ + .{ 1, 0, 0 }, + .{ 1, 1, 0 }, + .{ 1, 1, 1 }, + }, + try res.getValue([3][3]u8), + ); + } + { + const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 }); + try testing.expectEqual( + [3][3]u8{ + .{ 1, 1, 0 }, + .{ 1, 1, 1 }, + .{ 1, 1, 1 }, + }, + try res.getValue([3][3]u8), + ); + } + { + const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 }); + try testing.expectEqual( + [3][3]u8{ + .{ 0, 0, 0 }, + .{ 1, 0, 0 }, + .{ 1, 1, 0 }, + }, + try res.getValue([3][3]u8), + ); + } + } + /// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]` /// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]` pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor { @@ -3508,7 +3607,7 @@ pub const Tensor = struct { ) fn (Tensor, Tensor) Tensor { return struct { pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor { - meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self.dtype(), other.dtype() }); + meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); if (self.rank() == 0 and other.rank() != 0) { return binaryOpHelper(self.broad(other._shape), other); diff --git a/zml/testing.zig b/zml/testing.zig index 2afecee..53d0f02 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -244,18 +244,24 @@ pub fn testLayerOut( } var buf: [1024]u8 = undefined; + var failed: bool = false; for (0..mod.inner.result_buffer_count) |i| { const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable; const expected_out = activations.get(full_name) orelse { log.warn("Output buffer not found: {s}", .{full_name}); continue; }; - zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| { - log.err("{s}.{d} doesn't match !", .{ out_name, i }); - return err; + zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| switch (err) { + error.TestUnexpectedResult => { + log.err("{s}.{d} doesn't match !", .{ out_name, i }); + failed = true; + continue; + }, + else => return err, }; } + if (failed) return error.TestUnexpectedResult; log.info("all good for {s} !", .{name}); }