zml: deprecate buggy Tensor.chunk; introduce chunkExact and chunkAllowTrailing with clarified behavior

This commit is contained in:
Tarry Singh 2023-02-07 12:42:34 +00:00
parent 7e131a106b
commit 058e1415fa
2 changed files with 113 additions and 56 deletions

View File

@ -6,7 +6,7 @@ module(
bazel_dep(name = "bazel_skylib", version = "1.7.1") bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_cc", version = "0.0.9") bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a") bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure( llvm.configure(

View File

@ -1410,9 +1410,18 @@ pub const Tensor = struct {
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable); res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
} }
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "slices={any}", .{slices}); const mlir_ctx = self.getContext().mlirCtx();
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?; const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
const slice_op = dialect.stablehlo.slice(self.getContext().mlirCtx(), self.value(), start_indices[0..self.rank()], limit_indices[0..self.rank()], strides[0..self.rank()], result_type, loc); const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
const slice_op = dialect.stablehlo.slice(
mlir_ctx,
self.value(),
start_indices[0..self.rank()],
limit_indices[0..self.rank()],
strides[0..self.rank()],
result_type,
loc,
);
return _result(res_shape, slice_op.result(0)); return _result(res_shape, slice_op.result(0));
} }
@ -2437,42 +2446,103 @@ pub const Tensor = struct {
return self._shape.axes(axes_); return self._shape.axes(axes_);
} }
pub fn chunkAlloc(self: Tensor, allocator: std.mem.Allocator, chunks: i64, axis_: i64) ![]Tensor { /// Chunk a given tensor into exactly n parts of equal shape.
/// `self.dim(axis_)` must be divisible by n_chunks.
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
const a = self.axis(axis_); const a = self.axis(axis_);
const length = self.dim(a); const d = self.dim(a);
const chunk_size: i64 = @divFloor(length, chunks); const chunk_size = @divExact(d, n_chunks);
const full_chunks: i64 = @divFloor(length, chunk_size); var chunks: [n_chunks]Tensor = undefined;
const tail_chunk_size: i64 = @rem(length, chunk_size); for (0..n_chunks) |i| {
const n_chunks = if (tail_chunk_size > 0) full_chunks + 1 else full_chunks;
const result = try allocator.alloc(Tensor, @intCast(n_chunks));
self.chunk(result, axis_);
return result;
}
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor {
const a = self.axis(axis_);
const length = self.dim(a);
_ = @divExact(length, n_chunks);
var res: [n_chunks]Tensor = undefined;
self.chunk(&res, a);
return res;
}
pub fn chunk(self: Tensor, chunks: []Tensor, axis_: i64) void {
const a = self.axis(axis_);
const length = self.dim(a);
const n_chunks: i64 = @intCast(chunks.len);
const chunk_size: i64 = @divFloor(length, n_chunks);
const full_chunks: usize = @intCast(@divFloor(length, chunk_size));
const tail_chunk_size: i64 = @mod(length, chunk_size);
for (0..full_chunks) |i| {
const start: i64 = @as(i64, @intCast(i)) * chunk_size; const start: i64 = @as(i64, @intCast(i)) * chunk_size;
chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size }); chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size });
} }
return chunks;
}
test chunkExact {
const zml = @import("zml.zig");
const platform = zml.testing.env();
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 12 }, .a, 3, .{ .a = 4 } },
.{ .{ .a = 12, .b = 2 }, .a, 3, .{ .a = 4, .b = 2 } },
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 } },
}) |testcase| {
const x_shape, const ax, const n_chunks, const res = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const chunks = x.chunkExact(ax, n_chunks);
const res_shape = Shape.init(res, .f16);
for (&chunks) |chk| {
try zml.testing.expectEqualShapes(res_shape, chk.shape());
}
}
}
/// Chunk a given tensor into n parts of equal shape, and one part with the remaining items.
/// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`.
pub fn chunkAllowTrailing(
self: Tensor,
axis_: i64,
n_chunks: comptime_int,
) std.BoundedArray(Tensor, n_chunks + 1) {
const a = self.axis(axis_);
const d = self.dim(a);
const chunk_size: i64 = @divFloor(d, n_chunks);
const tail_chunk_size: i64 = @rem(d, chunk_size);
var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{};
for (0..n_chunks) |i| {
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
chunks.appendAssumeCapacity(
self.slice1d(a, .{ .start = start, .end = start + chunk_size }),
);
}
if (tail_chunk_size != 0) { if (tail_chunk_size != 0) {
const start: i64 = @as(i64, @intCast(full_chunks)) * chunk_size; const start: i64 = n_chunks * chunk_size;
chunks[full_chunks] = self.slice1d(a, .{ .start = start }); chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
}
return chunks;
}
test chunkAllowTrailing {
const zml = @import("zml.zig");
const platform = zml.testing.env();
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10 }, .a, 3, .{ .a = 3 }, .{ .a = 1 } },
.{ .{ .a = 10, .b = 2 }, .a, 3, .{ .a = 3, .b = 2 }, .{ .a = 1, .b = 2 } },
.{ .{ 10, 2 }, 0, 3, .{ 3, 2 }, .{ 1, 2 } },
.{ .{ 12, 2 }, 0, 3, .{ 4, 2 }, .{} },
}) |testcase| {
const x_shape, const ax, const n_chunks, const res, const trailing = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
const res_shape = Shape.init(res, .f16);
for (chunks.constSlice()[0..n_chunks]) |chk| {
try zml.testing.expectEqualShapes(res_shape, chk.shape());
}
const trailing_shape = Shape.init(trailing, .f16);
if (trailing_shape.rank() > 0) {
try std.testing.expectEqual(n_chunks + 1, chunks.len);
try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape());
} else {
try std.testing.expectEqual(n_chunks, chunks.len);
}
} }
} }
@ -2487,28 +2557,15 @@ pub const Tensor = struct {
meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
} }
return switch (split_size_or_sections.len) { const res = try allocator.alloc(Tensor, split_size_or_sections.len);
1 => { errdefer allocator.dealloc(res);
var chunk_count: i64 = @divFloor(length, split_size_or_sections[0]);
if (@as(usize, @intCast(length)) % @as(usize, @intCast(split_size_or_sections[0])) != 0) {
chunk_count += 1;
}
const res = try allocator.alloc(Tensor, @intCast(chunk_count));
self.chunk(res, a);
return res;
},
else => {
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
errdefer allocator.dealloc(res);
var start: i64 = 0; var start: i64 = 0;
for (split_size_or_sections, 0..) |n, i| { for (split_size_or_sections, 0..) |n, i| {
res[i] = self.slice1d(a, .{ .start = start, .end = start + n }); res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
start += n; start += n;
} }
return res; return res;
},
};
} }
/// Slices the input Tensor along a specific axis, with a start offset known at runtime. /// Slices the input Tensor along a specific axis, with a start offset known at runtime.