zml: deprecate buggy Tensor.chunk; introduce chunkExact and chunkAllowTrailing with clarified behavior
This commit is contained in:
parent
7e131a106b
commit
058e1415fa
@ -6,7 +6,7 @@ module(
|
|||||||
|
|
||||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a")
|
||||||
|
|
||||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
llvm.configure(
|
llvm.configure(
|
||||||
|
|||||||
167
zml/tensor.zig
167
zml/tensor.zig
@ -1410,9 +1410,18 @@ pub const Tensor = struct {
|
|||||||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "slices={any}", .{slices});
|
const mlir_ctx = self.getContext().mlirCtx();
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
||||||
const slice_op = dialect.stablehlo.slice(self.getContext().mlirCtx(), self.value(), start_indices[0..self.rank()], limit_indices[0..self.rank()], strides[0..self.rank()], result_type, loc);
|
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
|
||||||
|
const slice_op = dialect.stablehlo.slice(
|
||||||
|
mlir_ctx,
|
||||||
|
self.value(),
|
||||||
|
start_indices[0..self.rank()],
|
||||||
|
limit_indices[0..self.rank()],
|
||||||
|
strides[0..self.rank()],
|
||||||
|
result_type,
|
||||||
|
loc,
|
||||||
|
);
|
||||||
|
|
||||||
return _result(res_shape, slice_op.result(0));
|
return _result(res_shape, slice_op.result(0));
|
||||||
}
|
}
|
||||||
@ -2437,42 +2446,103 @@ pub const Tensor = struct {
|
|||||||
return self._shape.axes(axes_);
|
return self._shape.axes(axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunkAlloc(self: Tensor, allocator: std.mem.Allocator, chunks: i64, axis_: i64) ![]Tensor {
|
/// Chunk a given tensor into exactly n parts of equal shape.
|
||||||
|
/// `self.dim(axis_)` must be divisible by n_chunks.
|
||||||
|
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const length = self.dim(a);
|
const d = self.dim(a);
|
||||||
const chunk_size: i64 = @divFloor(length, chunks);
|
const chunk_size = @divExact(d, n_chunks);
|
||||||
const full_chunks: i64 = @divFloor(length, chunk_size);
|
var chunks: [n_chunks]Tensor = undefined;
|
||||||
const tail_chunk_size: i64 = @rem(length, chunk_size);
|
for (0..n_chunks) |i| {
|
||||||
|
|
||||||
const n_chunks = if (tail_chunk_size > 0) full_chunks + 1 else full_chunks;
|
|
||||||
const result = try allocator.alloc(Tensor, @intCast(n_chunks));
|
|
||||||
self.chunk(result, axis_);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor {
|
|
||||||
const a = self.axis(axis_);
|
|
||||||
const length = self.dim(a);
|
|
||||||
_ = @divExact(length, n_chunks);
|
|
||||||
var res: [n_chunks]Tensor = undefined;
|
|
||||||
self.chunk(&res, a);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn chunk(self: Tensor, chunks: []Tensor, axis_: i64) void {
|
|
||||||
const a = self.axis(axis_);
|
|
||||||
const length = self.dim(a);
|
|
||||||
const n_chunks: i64 = @intCast(chunks.len);
|
|
||||||
const chunk_size: i64 = @divFloor(length, n_chunks);
|
|
||||||
const full_chunks: usize = @intCast(@divFloor(length, chunk_size));
|
|
||||||
const tail_chunk_size: i64 = @mod(length, chunk_size);
|
|
||||||
for (0..full_chunks) |i| {
|
|
||||||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||||
chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size });
|
chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size });
|
||||||
}
|
}
|
||||||
|
return chunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
test chunkExact {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
// Only test shapes
|
||||||
|
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||||
|
defer comp.deinit();
|
||||||
|
comp.activate();
|
||||||
|
defer comp.deactivate();
|
||||||
|
|
||||||
|
inline for (.{
|
||||||
|
.{ .{ .a = 12 }, .a, 3, .{ .a = 4 } },
|
||||||
|
.{ .{ .a = 12, .b = 2 }, .a, 3, .{ .a = 4, .b = 2 } },
|
||||||
|
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 } },
|
||||||
|
}) |testcase| {
|
||||||
|
const x_shape, const ax, const n_chunks, const res = testcase;
|
||||||
|
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||||
|
const chunks = x.chunkExact(ax, n_chunks);
|
||||||
|
|
||||||
|
const res_shape = Shape.init(res, .f16);
|
||||||
|
for (&chunks) |chk| {
|
||||||
|
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chunk a given tensor into n parts of equal shape, and one part with the remaining items.
|
||||||
|
/// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`.
|
||||||
|
pub fn chunkAllowTrailing(
|
||||||
|
self: Tensor,
|
||||||
|
axis_: i64,
|
||||||
|
n_chunks: comptime_int,
|
||||||
|
) std.BoundedArray(Tensor, n_chunks + 1) {
|
||||||
|
const a = self.axis(axis_);
|
||||||
|
const d = self.dim(a);
|
||||||
|
const chunk_size: i64 = @divFloor(d, n_chunks);
|
||||||
|
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
||||||
|
|
||||||
|
var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{};
|
||||||
|
for (0..n_chunks) |i| {
|
||||||
|
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||||
|
chunks.appendAssumeCapacity(
|
||||||
|
self.slice1d(a, .{ .start = start, .end = start + chunk_size }),
|
||||||
|
);
|
||||||
|
}
|
||||||
if (tail_chunk_size != 0) {
|
if (tail_chunk_size != 0) {
|
||||||
const start: i64 = @as(i64, @intCast(full_chunks)) * chunk_size;
|
const start: i64 = n_chunks * chunk_size;
|
||||||
chunks[full_chunks] = self.slice1d(a, .{ .start = start });
|
chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
|
||||||
|
}
|
||||||
|
return chunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
test chunkAllowTrailing {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
// Only test shapes
|
||||||
|
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||||
|
defer comp.deinit();
|
||||||
|
comp.activate();
|
||||||
|
defer comp.deactivate();
|
||||||
|
|
||||||
|
inline for (.{
|
||||||
|
.{ .{ .a = 10 }, .a, 3, .{ .a = 3 }, .{ .a = 1 } },
|
||||||
|
.{ .{ .a = 10, .b = 2 }, .a, 3, .{ .a = 3, .b = 2 }, .{ .a = 1, .b = 2 } },
|
||||||
|
.{ .{ 10, 2 }, 0, 3, .{ 3, 2 }, .{ 1, 2 } },
|
||||||
|
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 }, .{} },
|
||||||
|
}) |testcase| {
|
||||||
|
const x_shape, const ax, const n_chunks, const res, const trailing = testcase;
|
||||||
|
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||||
|
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
|
||||||
|
|
||||||
|
const res_shape = Shape.init(res, .f16);
|
||||||
|
for (chunks.constSlice()[0..n_chunks]) |chk| {
|
||||||
|
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||||
|
}
|
||||||
|
const trailing_shape = Shape.init(trailing, .f16);
|
||||||
|
if (trailing_shape.rank() > 0) {
|
||||||
|
try std.testing.expectEqual(n_chunks + 1, chunks.len);
|
||||||
|
try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape());
|
||||||
|
} else {
|
||||||
|
try std.testing.expectEqual(n_chunks, chunks.len);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2487,28 +2557,15 @@ pub const Tensor = struct {
|
|||||||
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
||||||
}
|
}
|
||||||
|
|
||||||
return switch (split_size_or_sections.len) {
|
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
||||||
1 => {
|
errdefer allocator.dealloc(res);
|
||||||
var chunk_count: i64 = @divFloor(length, split_size_or_sections[0]);
|
|
||||||
if (@as(usize, @intCast(length)) % @as(usize, @intCast(split_size_or_sections[0])) != 0) {
|
|
||||||
chunk_count += 1;
|
|
||||||
}
|
|
||||||
const res = try allocator.alloc(Tensor, @intCast(chunk_count));
|
|
||||||
self.chunk(res, a);
|
|
||||||
return res;
|
|
||||||
},
|
|
||||||
else => {
|
|
||||||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
|
||||||
errdefer allocator.dealloc(res);
|
|
||||||
|
|
||||||
var start: i64 = 0;
|
var start: i64 = 0;
|
||||||
for (split_size_or_sections, 0..) |n, i| {
|
for (split_size_or_sections, 0..) |n, i| {
|
||||||
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
||||||
start += n;
|
start += n;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user