zml.tensor: add triangular operator to zero out the upper‑right matrix region with configurable offset, and toDiagonal (diag_embed) to embed a vector as a diagonal matrix, correcting previous diag naming. Also add ELU activation under zml.nn.Activation.

This commit is contained in:
Tarry Singh 2023-05-18 16:39:21 +00:00
parent 05faa5021e
commit 2f54e2a5f3
4 changed files with 144 additions and 31 deletions

View File

@ -52,6 +52,7 @@ pub const Activation = union(enum) {
tanh,
relu,
leakyReLU: f32,
elu: f32,
silu,
gelu,
quick_gelu,
@ -63,10 +64,18 @@ pub const Activation = union(enum) {
.relu => x.relu(),
.silu => x.silu(),
.gelu => x.gelu(),
.elu => |alpha| elu(x, alpha),
.quick_gelu => x.quickGelu(),
.leakyReLU => |slope| x.leakyReLU(slope),
};
}
pub fn elu(x: Tensor, alpha: f32) Tensor {
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
x,
x.exp().addConstant(-1).scale(alpha),
);
}
};
pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
@ -701,7 +710,6 @@ pub fn causalAttnMask(
}
if (dtype.isFloat()) {
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
const zeros = Tensor.constant(mask.shape(), dtype.zero());
const minus_inf = Tensor.constant(mask.shape(), dtype.minValue());
mask = Tensor.select(mask, zeros, minus_inf);

View File

@ -523,11 +523,11 @@ pub const Shape = struct {
return res;
}
pub fn remove(self: Shape, d: anytype) Shape {
pub fn remove(self: Shape, axis_: anytype) Shape {
var res = self;
const d_ = self.axis(d);
_ = res._dims.orderedRemove(d_);
_ = res._tags.orderedRemove(d_);
const a = self.axis(axis_);
_ = res._dims.orderedRemove(a);
_ = res._tags.orderedRemove(a);
return res;
}

View File

@ -1829,27 +1829,6 @@ pub const Tensor = struct {
return left.mul(right);
}
/// Creates a 2D Tensor with its diagonal set to the input vector.
pub fn diag(self: Tensor) Tensor {
meta.assert(self.rank() == 1, "diag only supports tensor of rank 1, got {}", .{self.rank()}); // TODO: support 2D tensors
//
const indices = Tensor.arange(.{ .end = self.dim(0) }, .i64);
const sh = Shape.init(.{ self.dim(0), self.dim(0) }, self.dtype());
const indices_1 = indices.broadcast(sh, &.{1});
const indices_0 = indices.broadcast(sh, &.{0});
const loc = self.getContext().mlirCtx().location(@src());
// TODO: handle more complex cases (2D, specifying `diagonal` index).
const op = dialect.stablehlo.select(
self.getContext().mlirCtx(),
indices_1.cmp(.EQ, indices_0).value(),
self.broadcast(sh, &.{1}).value(),
Tensor.constant(self.dims(), self.dtype().zero()).value(),
loc,
);
return _result(sh, op.result(0));
}
/// Given a tensor and a shape of the same rank,
/// will "broadcast" the given axes, so that `self` has the given shape.
/// This happens by virtually repeating the data several time along each give axes.
@ -1895,7 +1874,7 @@ pub const Tensor = struct {
pub fn broad(self: Tensor, other: Shape) Tensor {
// Non ambiguous broadcasting
if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) {
return self.broadcast(other, Shape.range(self._shape.rank(), self.dtype()).dims());
return self.broadcast(other, Shape.range(self._shape.rank(), .bool).dims());
}
// check that each axis of self maps to an axis of other
@ -3340,6 +3319,126 @@ pub const Tensor = struct {
return _result(self._shape.withDtype(.bool), op.result(0));
}
/// For each vector in the input tensor,
/// creates a diagonal-matrix where diagonal values are set to the vector values.
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
const a = self.axis(axis_);
const d = self.dim(a);
var res_shape = self._shape;
res_shape._dims.replaceRange(a, 1, &.{ d, d }) catch unreachable;
res_shape._tags.replaceRange(a, 1, &.{ @tagName(new_tags[0]), @tagName(new_tags[1]) }) catch unreachable;
const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape);
const zeros = Tensor.constant(res_shape, self.dtype().zero());
const x = Tensor.iota(res_shape, .i32, a);
const y = Tensor.iota(res_shape, .i32, a + 1);
var res = x.cmp(.EQ, y).select(values, zeros);
res._shape = res_shape;
return res;
}
test toDiagonal {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn _toDiag(input: Tensor) Tensor {
const res = input.toDiagonal(-1, .{ .x, .y });
std.debug.assert(res.dim(.x) == input.dim(-1));
std.debug.assert(res.dim(.y) == input.dim(-1));
return res;
}
};
const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } });
{
const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x});
try testing.expectEqual(
[2][2][2]u8{ .{
.{ 1, 0 },
.{ 0, 2 },
}, .{
.{ 3, 0 },
.{ 0, 4 },
} },
try res.getValue([2][2][2]u8),
);
}
}
/// For each matrix specified by the two axes, returns the lower triangular part of it.
/// The other elements are set to 0.
/// Usage: `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .w, .h}, 0);`
///
/// * if `num_diagonals` is set to 0, the diagonal is not modified.
/// * if set to -1, the diagonal is set to 0
/// * if set to n, the n "quasi diagonal" above the diagonal are conserved.
///
/// To get the upper triangular part, swap the order of axes:
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor {
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
const _axes = self.axes(axes_);
const x = Tensor.iota(self.shape(), .i32, _axes.get(0));
const y = Tensor.iota(self.shape(), .i32, _axes.get(1));
const zeros = Tensor.constant(self.shape(), self.dtype().zero());
return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros);
}
test triangular {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn _tri(input: Tensor, num_diagonals: i32) Tensor {
return input.triangular(.{ -2, -1 }, num_diagonals);
}
};
const x = try zml.Buffer.fromArray(platform, [3][3]u8{
.{ 1, 1, 1 },
.{ 1, 1, 1 },
.{ 1, 1, 1 },
});
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 });
try testing.expectEqual(
[3][3]u8{
.{ 1, 0, 0 },
.{ 1, 1, 0 },
.{ 1, 1, 1 },
},
try res.getValue([3][3]u8),
);
}
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 });
try testing.expectEqual(
[3][3]u8{
.{ 1, 1, 0 },
.{ 1, 1, 1 },
.{ 1, 1, 1 },
},
try res.getValue([3][3]u8),
);
}
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 });
try testing.expectEqual(
[3][3]u8{
.{ 0, 0, 0 },
.{ 1, 0, 0 },
.{ 1, 1, 0 },
},
try res.getValue([3][3]u8),
);
}
}
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
@ -3508,7 +3607,7 @@ pub const Tensor = struct {
) fn (Tensor, Tensor) Tensor {
return struct {
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self.dtype(), other.dtype() });
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
if (self.rank() == 0 and other.rank() != 0) {
return binaryOpHelper(self.broad(other._shape), other);

View File

@ -244,18 +244,24 @@ pub fn testLayerOut(
}
var buf: [1024]u8 = undefined;
var failed: bool = false;
for (0..mod.inner.result_buffer_count) |i| {
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
const expected_out = activations.get(full_name) orelse {
log.warn("Output buffer not found: {s}", .{full_name});
continue;
};
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| {
log.err("{s}.{d} doesn't match !", .{ out_name, i });
return err;
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| switch (err) {
error.TestUnexpectedResult => {
log.err("{s}.{d} doesn't match !", .{ out_name, i });
failed = true;
continue;
},
else => return err,
};
}
if (failed) return error.TestUnexpectedResult;
log.info("all good for {s} !", .{name});
}