zml.tensor: add triangular operator to zero out the upper‑right matrix region with configurable offset, and toDiagonal (diag_embed) to embed a vector as a diagonal matrix, correcting previous diag naming. Also add ELU activation under zml.nn.Activation.
This commit is contained in:
parent
05faa5021e
commit
2f54e2a5f3
10
zml/nn.zig
10
zml/nn.zig
@ -52,6 +52,7 @@ pub const Activation = union(enum) {
|
||||
tanh,
|
||||
relu,
|
||||
leakyReLU: f32,
|
||||
elu: f32,
|
||||
silu,
|
||||
gelu,
|
||||
quick_gelu,
|
||||
@ -63,10 +64,18 @@ pub const Activation = union(enum) {
|
||||
.relu => x.relu(),
|
||||
.silu => x.silu(),
|
||||
.gelu => x.gelu(),
|
||||
.elu => |alpha| elu(x, alpha),
|
||||
.quick_gelu => x.quickGelu(),
|
||||
.leakyReLU => |slope| x.leakyReLU(slope),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn elu(x: Tensor, alpha: f32) Tensor {
|
||||
return x.cmp(.GE, Tensor.scalar(0, x.dtype())).select(
|
||||
x,
|
||||
x.exp().addConstant(-1).scale(alpha),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn chainModules(module_list: anytype, input: Tensor) Tensor {
|
||||
@ -701,7 +710,6 @@ pub fn causalAttnMask(
|
||||
}
|
||||
|
||||
if (dtype.isFloat()) {
|
||||
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
|
||||
const zeros = Tensor.constant(mask.shape(), dtype.zero());
|
||||
const minus_inf = Tensor.constant(mask.shape(), dtype.minValue());
|
||||
mask = Tensor.select(mask, zeros, minus_inf);
|
||||
|
||||
@ -523,11 +523,11 @@ pub const Shape = struct {
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn remove(self: Shape, d: anytype) Shape {
|
||||
pub fn remove(self: Shape, axis_: anytype) Shape {
|
||||
var res = self;
|
||||
const d_ = self.axis(d);
|
||||
_ = res._dims.orderedRemove(d_);
|
||||
_ = res._tags.orderedRemove(d_);
|
||||
const a = self.axis(axis_);
|
||||
_ = res._dims.orderedRemove(a);
|
||||
_ = res._tags.orderedRemove(a);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
145
zml/tensor.zig
145
zml/tensor.zig
@ -1829,27 +1829,6 @@ pub const Tensor = struct {
|
||||
return left.mul(right);
|
||||
}
|
||||
|
||||
/// Creates a 2D Tensor with its diagonal set to the input vector.
|
||||
pub fn diag(self: Tensor) Tensor {
|
||||
meta.assert(self.rank() == 1, "diag only supports tensor of rank 1, got {}", .{self.rank()}); // TODO: support 2D tensors
|
||||
//
|
||||
const indices = Tensor.arange(.{ .end = self.dim(0) }, .i64);
|
||||
const sh = Shape.init(.{ self.dim(0), self.dim(0) }, self.dtype());
|
||||
const indices_1 = indices.broadcast(sh, &.{1});
|
||||
const indices_0 = indices.broadcast(sh, &.{0});
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
// TODO: handle more complex cases (2D, specifying `diagonal` index).
|
||||
const op = dialect.stablehlo.select(
|
||||
self.getContext().mlirCtx(),
|
||||
indices_1.cmp(.EQ, indices_0).value(),
|
||||
self.broadcast(sh, &.{1}).value(),
|
||||
Tensor.constant(self.dims(), self.dtype().zero()).value(),
|
||||
loc,
|
||||
);
|
||||
return _result(sh, op.result(0));
|
||||
}
|
||||
|
||||
/// Given a tensor and a shape of the same rank,
|
||||
/// will "broadcast" the given axes, so that `self` has the given shape.
|
||||
/// This happens by virtually repeating the data several time along each give axes.
|
||||
@ -1895,7 +1874,7 @@ pub const Tensor = struct {
|
||||
pub fn broad(self: Tensor, other: Shape) Tensor {
|
||||
// Non ambiguous broadcasting
|
||||
if (self._shape.rank() == 0 or self._shape.rank() == other.rank()) {
|
||||
return self.broadcast(other, Shape.range(self._shape.rank(), self.dtype()).dims());
|
||||
return self.broadcast(other, Shape.range(self._shape.rank(), .bool).dims());
|
||||
}
|
||||
|
||||
// check that each axis of self maps to an axis of other
|
||||
@ -3340,6 +3319,126 @@ pub const Tensor = struct {
|
||||
return _result(self._shape.withDtype(.bool), op.result(0));
|
||||
}
|
||||
|
||||
/// For each vector in the input tensor,
|
||||
/// creates a diagonal-matrix where diagonal values are set to the vector values.
|
||||
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
|
||||
meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
|
||||
const a = self.axis(axis_);
|
||||
const d = self.dim(a);
|
||||
var res_shape = self._shape;
|
||||
res_shape._dims.replaceRange(a, 1, &.{ d, d }) catch unreachable;
|
||||
res_shape._tags.replaceRange(a, 1, &.{ @tagName(new_tags[0]), @tagName(new_tags[1]) }) catch unreachable;
|
||||
|
||||
const values = self.insertAxes(a + 1, .{new_tags[1]}).broad(res_shape);
|
||||
const zeros = Tensor.constant(res_shape, self.dtype().zero());
|
||||
|
||||
const x = Tensor.iota(res_shape, .i32, a);
|
||||
const y = Tensor.iota(res_shape, .i32, a + 1);
|
||||
var res = x.cmp(.EQ, y).select(values, zeros);
|
||||
res._shape = res_shape;
|
||||
return res;
|
||||
}
|
||||
|
||||
test toDiagonal {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Local = struct {
|
||||
pub fn _toDiag(input: Tensor) Tensor {
|
||||
const res = input.toDiagonal(-1, .{ .x, .y });
|
||||
std.debug.assert(res.dim(.x) == input.dim(-1));
|
||||
std.debug.assert(res.dim(.y) == input.dim(-1));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } });
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x});
|
||||
try testing.expectEqual(
|
||||
[2][2][2]u8{ .{
|
||||
.{ 1, 0 },
|
||||
.{ 0, 2 },
|
||||
}, .{
|
||||
.{ 3, 0 },
|
||||
.{ 0, 4 },
|
||||
} },
|
||||
try res.getValue([2][2][2]u8),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// For each matrix specified by the two axes, returns the lower triangular part of it.
|
||||
/// The other elements are set to 0.
|
||||
/// Usage: `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .w, .h}, 0);`
|
||||
///
|
||||
/// * if `num_diagonals` is set to 0, the diagonal is not modified.
|
||||
/// * if set to -1, the diagonal is set to 0
|
||||
/// * if set to n, the n "quasi diagonal" above the diagonal are conserved.
|
||||
///
|
||||
/// To get the upper triangular part, swap the order of axes:
|
||||
/// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);`
|
||||
pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor {
|
||||
meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{});
|
||||
const _axes = self.axes(axes_);
|
||||
|
||||
const x = Tensor.iota(self.shape(), .i32, _axes.get(0));
|
||||
const y = Tensor.iota(self.shape(), .i32, _axes.get(1));
|
||||
|
||||
const zeros = Tensor.constant(self.shape(), self.dtype().zero());
|
||||
return x.addConstant(num_diagonals).cmp(.GE, y).select(self, zeros);
|
||||
}
|
||||
|
||||
test triangular {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Local = struct {
|
||||
pub fn _tri(input: Tensor, num_diagonals: i32) Tensor {
|
||||
return input.triangular(.{ -2, -1 }, num_diagonals);
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromArray(platform, [3][3]u8{
|
||||
.{ 1, 1, 1 },
|
||||
.{ 1, 1, 1 },
|
||||
.{ 1, 1, 1 },
|
||||
});
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 });
|
||||
try testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 1, 0, 0 },
|
||||
.{ 1, 1, 0 },
|
||||
.{ 1, 1, 1 },
|
||||
},
|
||||
try res.getValue([3][3]u8),
|
||||
);
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 });
|
||||
try testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 1, 1, 0 },
|
||||
.{ 1, 1, 1 },
|
||||
.{ 1, 1, 1 },
|
||||
},
|
||||
try res.getValue([3][3]u8),
|
||||
);
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 });
|
||||
try testing.expectEqual(
|
||||
[3][3]u8{
|
||||
.{ 0, 0, 0 },
|
||||
.{ 1, 0, 0 },
|
||||
.{ 1, 1, 0 },
|
||||
},
|
||||
try res.getValue([3][3]u8),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]`
|
||||
/// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]`
|
||||
pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor {
|
||||
@ -3508,7 +3607,7 @@ pub const Tensor = struct {
|
||||
) fn (Tensor, Tensor) Tensor {
|
||||
return struct {
|
||||
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
||||
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self.dtype(), other.dtype() });
|
||||
meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
|
||||
|
||||
if (self.rank() == 0 and other.rank() != 0) {
|
||||
return binaryOpHelper(self.broad(other._shape), other);
|
||||
|
||||
@ -244,18 +244,24 @@ pub fn testLayerOut(
|
||||
}
|
||||
|
||||
var buf: [1024]u8 = undefined;
|
||||
var failed: bool = false;
|
||||
for (0..mod.inner.result_buffer_count) |i| {
|
||||
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
|
||||
const expected_out = activations.get(full_name) orelse {
|
||||
log.warn("Output buffer not found: {s}", .{full_name});
|
||||
continue;
|
||||
};
|
||||
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| {
|
||||
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| switch (err) {
|
||||
error.TestUnexpectedResult => {
|
||||
log.err("{s}.{d} doesn't match !", .{ out_name, i });
|
||||
return err;
|
||||
failed = true;
|
||||
continue;
|
||||
},
|
||||
else => return err,
|
||||
};
|
||||
}
|
||||
|
||||
if (failed) return error.TestUnexpectedResult;
|
||||
log.info("all good for {s} !", .{name});
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user