zml: Relocate tests next to the functions they verify and remove obsolete dynamicSlice1d test.
This commit is contained in:
parent
dfa71018a5
commit
7ef67eea27
413
zml/tensor.zig
413
zml/tensor.zig
@ -308,6 +308,22 @@ pub const Tensor = struct {
|
||||
return self.remainder(Tensor.scalar(divisor, .f32).broadcast(self._shape, &.{}));
|
||||
}
|
||||
|
||||
test fmod {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const inputs: [2][6]f32 = .{ .{ -3.0, -2, -1, 1, 2, 3 }, .{ 1, 2, 3, 4, 5, -5 } };
|
||||
const expectations: [2][6]f32 = .{ .{ -1.0, -0.0, -1.0, 1.0, 0.0, 1.0 }, .{ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000, -0.5000 } };
|
||||
const divisors: [2]f32 = .{ 2, -1.5 };
|
||||
|
||||
inline for (inputs, expectations, divisors) |i, e, d| {
|
||||
const input = try zml.Buffer.fromSlice(platform, .{6}, &i);
|
||||
const output = try zml.testing.compileAndCall(platform, Tensor.fmod, .{ input, d });
|
||||
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromSlice(.{6}, &e), output, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
|
||||
pub fn shiftLeft(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
||||
@ -1426,6 +1442,33 @@ pub const Tensor = struct {
|
||||
return _result(res_shape, slice_op.result(0));
|
||||
}
|
||||
|
||||
test slice {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
|
||||
// Wrap slice1d to hide the anytype in the signature.
|
||||
const Local = struct {
|
||||
pub fn slice1dAxis(input: Tensor, ax: i8, slice_: Tensor.Slice) Tensor {
|
||||
return input.slice1d(ax, slice_);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
|
||||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
|
||||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
|
||||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||
}
|
||||
}
|
||||
|
||||
inline fn wrapIndex(self: Tensor, axis_: usize, idx: i64) i64 {
|
||||
return if (idx < 0) self.dim(axis_) + idx else idx;
|
||||
}
|
||||
@ -2566,6 +2609,40 @@ pub const Tensor = struct {
|
||||
);
|
||||
}
|
||||
|
||||
test argMax {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
const allocator = std.testing.allocator;
|
||||
const ArgMaxTest = struct {
|
||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||
return x.argMax(1, .i32);
|
||||
}
|
||||
};
|
||||
|
||||
const argmax = try zml.compileFn(allocator, ArgMaxTest.forward, .{Shape.init(.{ 1, 5 }, .f32)}, platform);
|
||||
defer argmax.deinit();
|
||||
// Test with tie
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, 4.1, 7.9, 0, 7.9 }});
|
||||
const res = argmax.call(.{x});
|
||||
const max_ = res.values.getValue(f32);
|
||||
const max_idx = res.indices.getValue(i32);
|
||||
try testing.expectEqual(max_, 7.9);
|
||||
// We should always return the first max found.
|
||||
try testing.expectEqual(max_idx, 2);
|
||||
}
|
||||
|
||||
// Test with Nan
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, std.math.nan(f32), 7.9, 0, 7.9 }});
|
||||
const res = argmax.call(.{x});
|
||||
const max_ = try res.values.getValue(f32);
|
||||
const max_idx = try res.indices.getValue(i32);
|
||||
try testing.expect(std.math.isNan(max_));
|
||||
try testing.expectEqual(max_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub const SortRes = ArgMaxRes;
|
||||
|
||||
/// Returns two Tensors. The first contains the sorted values and the second one contains the sorted indices.
|
||||
@ -2591,6 +2668,66 @@ pub const Tensor = struct {
|
||||
return self.sort(axis_, .{ .descending = opts.descending }).indices;
|
||||
}
|
||||
|
||||
test argsort {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
defer arena_state.deinit();
|
||||
const allocator = arena_state.allocator();
|
||||
// 2D Tensor - dim = 1, ascending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{} });
|
||||
const res_cpu = try res.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||||
}
|
||||
// 3D Tensor, dim = 1, descending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 10 }, &[_]f16{
|
||||
-0.2505, 1.2520, -0.7041, 0.1066, 1.2773, -1.7246, 0.8389, 1.1094, 0.0601, 1.0684,
|
||||
0.9619, 1.3916, 1.2246, -0.1406, 0.3674, -1.2480, -1.7051, -0.0934, 0.3435, 0.4373,
|
||||
1.3809, 0.5444, -0.6079, 1.2031, -0.6880, 1.2979, -0.1869, 0.2991, 0.0156, 0.1847,
|
||||
0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008,
|
||||
2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918,
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{ .descending = true } });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
||||
2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
|
||||
1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
|
||||
3, 2, 0, 1, 4, 1, 1, 1, 2, 1,
|
||||
0, 3, 3, 3, 3, 0, 3, 3, 4, 2,
|
||||
}, res.items(i32));
|
||||
}
|
||||
// 4D Tensor, dim = 3, ascending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 4, 2, 1, 4 }, &[_]i32{
|
||||
89, 31, 22, 42,
|
||||
64, 39, 0, 30,
|
||||
64, 71, 46, 31,
|
||||
89, 82, 78, 86,
|
||||
55, 32, 43, 19,
|
||||
93, 24, 45, 72,
|
||||
64, 86, 62, 88,
|
||||
57, 21, 19, 12,
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCallWithTensors(platform, Tensor.argsort, .{ x.shape(), 3, .{} }, .{ x, 0, .{} });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
2, 1, 3, 0,
|
||||
2, 3, 1, 0,
|
||||
3, 2, 0, 1,
|
||||
2, 1, 3, 0,
|
||||
3, 1, 2, 0,
|
||||
1, 2, 3, 0,
|
||||
2, 0, 1, 3,
|
||||
3, 2, 1, 0,
|
||||
}, res.items(i32));
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a Tensor representing the result of Top-K over the given axis.
|
||||
pub fn topK(self: Tensor, k: u32, axis_: anytype, opts: struct { descending: bool = true }) SortRes {
|
||||
const a = self.axis(axis_);
|
||||
@ -2892,6 +3029,29 @@ pub const Tensor = struct {
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
|
||||
test dynamicSlice {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
const T = f32;
|
||||
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z });
|
||||
|
||||
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||||
}
|
||||
|
||||
{
|
||||
// Strided
|
||||
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||||
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
||||
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z });
|
||||
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime.
|
||||
pub fn dynamicUpdateSlice1d(self: Tensor, update: Tensor, axis_: i64, offset: Tensor) Tensor {
|
||||
const placeholder = Tensor.scalar(0, .i32);
|
||||
@ -3413,235 +3573,9 @@ test "Tensor.maxPool2d" {
|
||||
);
|
||||
}
|
||||
|
||||
fn disabledTestTensorMaxPool3d() void {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const MaxPool = struct {
|
||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||
return x.maxPool3d(.{
|
||||
.window_dimensions = &.{ 3, 2, 2 },
|
||||
.window_strides = &.{ 2, 1, 2 },
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
var data: [100]f32 = undefined;
|
||||
for (&data, 0..) |*v, i| v.* = @floatFromInt(i);
|
||||
const x = try zml.Buffer.fromSlice(.{ 1, 2, 5, 5, 2 }, &data, platform);
|
||||
|
||||
const result = try zml.testing.compileAndCall(platform, MaxPool.forward, .{x});
|
||||
try std.testing.expectEqualSlices(i64, &.{ 1, 2, 2, 4, 1 }, result.values.dims());
|
||||
try std.testing.expectEqualSlices(i64, &.{ 1, 2, 2, 4, 1 }, result.indices.dims());
|
||||
var buffer: [1][2][2][4][1]f32 = undefined;
|
||||
_ = result.values.toHost(std.mem.asBytes(&buffer));
|
||||
try std.testing.expectEqualDeep(
|
||||
[1][2][2][4][1]f32{
|
||||
.{
|
||||
.{
|
||||
.{ .{23}, .{25}, .{27}, .{29} },
|
||||
.{ .{43}, .{45}, .{47}, .{49} },
|
||||
},
|
||||
.{
|
||||
.{ .{73}, .{75}, .{77}, .{79} },
|
||||
.{ .{93}, .{95}, .{97}, .{99} },
|
||||
},
|
||||
},
|
||||
},
|
||||
buffer,
|
||||
);
|
||||
}
|
||||
|
||||
test "argMax" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
const allocator = std.testing.allocator;
|
||||
const ArgMaxTest = struct {
|
||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||
return x.argMax(1, .i32);
|
||||
}
|
||||
};
|
||||
|
||||
const argmax = try zml.compileFn(allocator, ArgMaxTest.forward, .{Shape.init(.{ 1, 5 }, .f32)}, platform);
|
||||
defer argmax.deinit();
|
||||
// Test with tie
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, 4.1, 7.9, 0, 7.9 }});
|
||||
const res = argmax.call(.{x});
|
||||
const max = res.values.getValue(f32);
|
||||
const max_idx = res.indices.getValue(i32);
|
||||
try testing.expectEqual(max, 7.9);
|
||||
// We should always return the first max found.
|
||||
try testing.expectEqual(max_idx, 2);
|
||||
}
|
||||
|
||||
// Test with Nan
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [1][5]f32{.{ 5.0, std.math.nan(f32), 7.9, 0, 7.9 }});
|
||||
const res = argmax.call(.{x});
|
||||
const max = try res.values.getValue(f32);
|
||||
const max_idx = try res.indices.getValue(i32);
|
||||
try testing.expect(std.math.isNan(max));
|
||||
try testing.expectEqual(max_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
test "dynamicUpdateSlice1d" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
defer arena_state.deinit();
|
||||
const allocator = arena_state.allocator();
|
||||
const T = f32;
|
||||
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{10}, &[_]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
const y = try zml.Buffer.fromSlice(platform, .{2}, &[_]T{ -1, -1 });
|
||||
const z = try zml.Buffer.fromSlice(platform, .{}, &[_]i32{4});
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicUpdateSlice1d, .{ x, y, 0, z });
|
||||
try testing.expectEqualSlices(T, &.{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, &try res.getValue([10]T));
|
||||
}
|
||||
|
||||
{
|
||||
// Partial update: 3 out of 5 elements on the second row
|
||||
// Note: this seems error prone, but stablehlo allows it. Should we be more restrictive ?
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
const y = try zml.Buffer.fromSlice(platform, .{ 1, 3 }, &[_]T{ 0, 0, 0 });
|
||||
const z = try zml.Buffer.fromSlice(platform, .{}, &[_]i32{1});
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicUpdateSlice1d, .{ x, y, 0, z });
|
||||
|
||||
try testing.expectEqualSlices(T, &.{ 0, 1, 2, 3, 4, 0, 0, 0, 8, 9 }, &try res.getValue([10]T));
|
||||
}
|
||||
|
||||
{
|
||||
// Strided
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
const y = try zml.Buffer.fromSlice(platform, .{ 2, 1 }, &[_]T{ 0, 0 });
|
||||
const z = try zml.Buffer.fromSlice(platform, .{}, &[_]i32{3});
|
||||
const res_dev = try zml.testing.compileAndCall(platform, Tensor.dynamicUpdateSlice1d, .{ x, y, 1, z });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
|
||||
try testing.expectEqualSlices(T, &.{ 0, 1, 2, 0, 4, 5, 6, 7, 0, 9 }, res.items(T));
|
||||
}
|
||||
}
|
||||
|
||||
test "slice" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
|
||||
// Wrap slice1d to hide the anytype in the signature.
|
||||
const Local = struct {
|
||||
pub fn slice1dAxis(input: Tensor, ax: i8, slice: Tensor.Slice) Tensor {
|
||||
return input.slice1d(ax, slice);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
|
||||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
|
||||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local.slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
|
||||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||
}
|
||||
}
|
||||
|
||||
test "Tensor.fmod" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const inputs: [2][6]f32 = .{ .{ -3.0, -2, -1, 1, 2, 3 }, .{ 1, 2, 3, 4, 5, -5 } };
|
||||
const expectations: [2][6]f32 = .{ .{ -1.0, -0.0, -1.0, 1.0, 0.0, 1.0 }, .{ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000, -0.5000 } };
|
||||
const divisors: [2]f32 = .{ 2, -1.5 };
|
||||
|
||||
inline for (inputs, expectations, divisors) |i, e, d| {
|
||||
const input = try zml.Buffer.fromSlice(platform, .{6}, &i);
|
||||
const output = try zml.testing.compileAndCall(platform, Tensor.fmod, .{ input, d });
|
||||
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromSlice(.{6}, &e), output, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
test "Tensor.argsort" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
defer arena_state.deinit();
|
||||
const allocator = arena_state.allocator();
|
||||
// 2D Tensor - dim = 1, ascending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{} });
|
||||
const res_cpu = try res.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||||
}
|
||||
// 3D Tensor, dim = 1, descending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 10 }, &[_]f16{
|
||||
-0.2505, 1.2520, -0.7041, 0.1066, 1.2773, -1.7246, 0.8389, 1.1094, 0.0601, 1.0684,
|
||||
0.9619, 1.3916, 1.2246, -0.1406, 0.3674, -1.2480, -1.7051, -0.0934, 0.3435, 0.4373,
|
||||
1.3809, 0.5444, -0.6079, 1.2031, -0.6880, 1.2979, -0.1869, 0.2991, 0.0156, 0.1847,
|
||||
0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008,
|
||||
2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918,
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{ .descending = true } });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
||||
2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
|
||||
1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
|
||||
3, 2, 0, 1, 4, 1, 1, 1, 2, 1,
|
||||
0, 3, 3, 3, 3, 0, 3, 3, 4, 2,
|
||||
}, res.items(i32));
|
||||
}
|
||||
// 4D Tensor, dim = 3, ascending
|
||||
{
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 4, 2, 1, 4 }, &[_]i32{
|
||||
89, 31, 22, 42,
|
||||
64, 39, 0, 30,
|
||||
64, 71, 46, 31,
|
||||
89, 82, 78, 86,
|
||||
55, 32, 43, 19,
|
||||
93, 24, 45, 72,
|
||||
64, 86, 62, 88,
|
||||
57, 21, 19, 12,
|
||||
});
|
||||
const res_dev = try zml.testing.compileAndCallWithTensors(platform, Tensor.argsort, .{ x.shape(), 3, .{} }, .{ x, 0, .{} });
|
||||
const res = try res_dev.toHostAlloc(allocator);
|
||||
try testing.expectEqualSlices(i32, &.{
|
||||
2, 1, 3, 0,
|
||||
2, 3, 1, 0,
|
||||
3, 2, 0, 1,
|
||||
2, 1, 3, 0,
|
||||
3, 1, 2, 0,
|
||||
1, 2, 3, 0,
|
||||
2, 0, 1, 3,
|
||||
3, 2, 1, 0,
|
||||
}, res.items(i32));
|
||||
}
|
||||
}
|
||||
|
||||
fn parseArrayInfo(T: type) Shape {
|
||||
return switch (@typeInfo(T)) {
|
||||
.Array => |arr| {
|
||||
const s = parseArrayInfo(arr.child);
|
||||
return s.insert(0, .{arr.len});
|
||||
},
|
||||
else => .{ ._dtype = DataType.fromZigType(T) },
|
||||
};
|
||||
}
|
||||
|
||||
pub inline fn toI64(values: anytype) []i64 {
|
||||
var res: [Tensor.MAX_RANK]i64 = undefined;
|
||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||
return res[0..values.len];
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
return meta.MapType(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
/// Return a clone of a type with Tensors replaced by Shapes.
|
||||
@ -3748,11 +3682,6 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
return meta.MapType(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
||||
const AxesT = @TypeOf(axes_);
|
||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
||||
@ -3764,3 +3693,19 @@ fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArr
|
||||
|
||||
return .{ axes_is_scalar, coord_axes };
|
||||
}
|
||||
|
||||
fn parseArrayInfo(T: type) Shape {
|
||||
return switch (@typeInfo(T)) {
|
||||
.Array => |arr| {
|
||||
const s = parseArrayInfo(arr.child);
|
||||
return s.insert(0, .{arr.len});
|
||||
},
|
||||
else => .{ ._dtype = DataType.fromZigType(T) },
|
||||
};
|
||||
}
|
||||
|
||||
inline fn toI64(values: anytype) []i64 {
|
||||
var res: [Tensor.MAX_RANK]i64 = undefined;
|
||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||
return res[0..values.len];
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user