zml: deprecate buggy Tensor.chunk; introduce chunkExact and chunkAllowTrailing with clarified behavior
This commit is contained in:
parent
7e131a106b
commit
058e1415fa
@ -6,7 +6,7 @@ module(
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
||||
bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
|
||||
167
zml/tensor.zig
167
zml/tensor.zig
@ -1410,9 +1410,18 @@ pub const Tensor = struct {
|
||||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
||||
}
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "slices={any}", .{slices});
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
||||
const slice_op = dialect.stablehlo.slice(self.getContext().mlirCtx(), self.value(), start_indices[0..self.rank()], limit_indices[0..self.rank()], strides[0..self.rank()], result_type, loc);
|
||||
const mlir_ctx = self.getContext().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
|
||||
const slice_op = dialect.stablehlo.slice(
|
||||
mlir_ctx,
|
||||
self.value(),
|
||||
start_indices[0..self.rank()],
|
||||
limit_indices[0..self.rank()],
|
||||
strides[0..self.rank()],
|
||||
result_type,
|
||||
loc,
|
||||
);
|
||||
|
||||
return _result(res_shape, slice_op.result(0));
|
||||
}
|
||||
@ -2437,42 +2446,103 @@ pub const Tensor = struct {
|
||||
return self._shape.axes(axes_);
|
||||
}
|
||||
|
||||
pub fn chunkAlloc(self: Tensor, allocator: std.mem.Allocator, chunks: i64, axis_: i64) ![]Tensor {
|
||||
/// Chunk a given tensor into exactly n parts of equal shape.
|
||||
/// `self.dim(axis_)` must be divisible by n_chunks.
|
||||
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
|
||||
const a = self.axis(axis_);
|
||||
const length = self.dim(a);
|
||||
const chunk_size: i64 = @divFloor(length, chunks);
|
||||
const full_chunks: i64 = @divFloor(length, chunk_size);
|
||||
const tail_chunk_size: i64 = @rem(length, chunk_size);
|
||||
|
||||
const n_chunks = if (tail_chunk_size > 0) full_chunks + 1 else full_chunks;
|
||||
const result = try allocator.alloc(Tensor, @intCast(n_chunks));
|
||||
self.chunk(result, axis_);
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor {
|
||||
const a = self.axis(axis_);
|
||||
const length = self.dim(a);
|
||||
_ = @divExact(length, n_chunks);
|
||||
var res: [n_chunks]Tensor = undefined;
|
||||
self.chunk(&res, a);
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn chunk(self: Tensor, chunks: []Tensor, axis_: i64) void {
|
||||
const a = self.axis(axis_);
|
||||
const length = self.dim(a);
|
||||
const n_chunks: i64 = @intCast(chunks.len);
|
||||
const chunk_size: i64 = @divFloor(length, n_chunks);
|
||||
const full_chunks: usize = @intCast(@divFloor(length, chunk_size));
|
||||
const tail_chunk_size: i64 = @mod(length, chunk_size);
|
||||
for (0..full_chunks) |i| {
|
||||
const d = self.dim(a);
|
||||
const chunk_size = @divExact(d, n_chunks);
|
||||
var chunks: [n_chunks]Tensor = undefined;
|
||||
for (0..n_chunks) |i| {
|
||||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||
chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size });
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
test chunkExact {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 12 }, .a, 3, .{ .a = 4 } },
|
||||
.{ .{ .a = 12, .b = 2 }, .a, 3, .{ .a = 4, .b = 2 } },
|
||||
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const ax, const n_chunks, const res = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const chunks = x.chunkExact(ax, n_chunks);
|
||||
|
||||
const res_shape = Shape.init(res, .f16);
|
||||
for (&chunks) |chk| {
|
||||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Chunk a given tensor into n parts of equal shape, and one part with the remaining items.
|
||||
/// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`.
|
||||
pub fn chunkAllowTrailing(
|
||||
self: Tensor,
|
||||
axis_: i64,
|
||||
n_chunks: comptime_int,
|
||||
) std.BoundedArray(Tensor, n_chunks + 1) {
|
||||
const a = self.axis(axis_);
|
||||
const d = self.dim(a);
|
||||
const chunk_size: i64 = @divFloor(d, n_chunks);
|
||||
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
||||
|
||||
var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{};
|
||||
for (0..n_chunks) |i| {
|
||||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||
chunks.appendAssumeCapacity(
|
||||
self.slice1d(a, .{ .start = start, .end = start + chunk_size }),
|
||||
);
|
||||
}
|
||||
if (tail_chunk_size != 0) {
|
||||
const start: i64 = @as(i64, @intCast(full_chunks)) * chunk_size;
|
||||
chunks[full_chunks] = self.slice1d(a, .{ .start = start });
|
||||
const start: i64 = n_chunks * chunk_size;
|
||||
chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
test chunkAllowTrailing {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 10 }, .a, 3, .{ .a = 3 }, .{ .a = 1 } },
|
||||
.{ .{ .a = 10, .b = 2 }, .a, 3, .{ .a = 3, .b = 2 }, .{ .a = 1, .b = 2 } },
|
||||
.{ .{ 10, 2 }, 0, 3, .{ 3, 2 }, .{ 1, 2 } },
|
||||
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 }, .{} },
|
||||
}) |testcase| {
|
||||
const x_shape, const ax, const n_chunks, const res, const trailing = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
|
||||
|
||||
const res_shape = Shape.init(res, .f16);
|
||||
for (chunks.constSlice()[0..n_chunks]) |chk| {
|
||||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||
}
|
||||
const trailing_shape = Shape.init(trailing, .f16);
|
||||
if (trailing_shape.rank() > 0) {
|
||||
try std.testing.expectEqual(n_chunks + 1, chunks.len);
|
||||
try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape());
|
||||
} else {
|
||||
try std.testing.expectEqual(n_chunks, chunks.len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2487,28 +2557,15 @@ pub const Tensor = struct {
|
||||
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
||||
}
|
||||
|
||||
return switch (split_size_or_sections.len) {
|
||||
1 => {
|
||||
var chunk_count: i64 = @divFloor(length, split_size_or_sections[0]);
|
||||
if (@as(usize, @intCast(length)) % @as(usize, @intCast(split_size_or_sections[0])) != 0) {
|
||||
chunk_count += 1;
|
||||
}
|
||||
const res = try allocator.alloc(Tensor, @intCast(chunk_count));
|
||||
self.chunk(res, a);
|
||||
return res;
|
||||
},
|
||||
else => {
|
||||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
||||
errdefer allocator.dealloc(res);
|
||||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
||||
errdefer allocator.dealloc(res);
|
||||
|
||||
var start: i64 = 0;
|
||||
for (split_size_or_sections, 0..) |n, i| {
|
||||
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
||||
start += n;
|
||||
}
|
||||
return res;
|
||||
},
|
||||
};
|
||||
var start: i64 = 0;
|
||||
for (split_size_or_sections, 0..) |n, i| {
|
||||
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
||||
start += n;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user