From 058e1415fabce9f852e897f24a3d3ffdd2643c0c Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 7 Feb 2023 12:42:34 +0000 Subject: [PATCH] zml: deprecate buggy Tensor.chunk; introduce chunkExact and chunkAllowTrailing with clarified behavior --- .../20240829.0-54aa1a5/overlay/MODULE.bazel | 2 +- zml/tensor.zig | 167 ++++++++++++------ 2 files changed, 113 insertions(+), 56 deletions(-) diff --git a/third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel b/third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel index 14386c0..29e45e0 100644 --- a/third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel +++ b/third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel @@ -6,7 +6,7 @@ module( bazel_dep(name = "bazel_skylib", version = "1.7.1") bazel_dep(name = "rules_cc", version = "0.0.9") -bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a") +bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm.configure( diff --git a/zml/tensor.zig b/zml/tensor.zig index a1ad219..1a0bf91 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1410,9 +1410,18 @@ pub const Tensor = struct { res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable); } - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "slices={any}", .{slices}); - const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?; - const slice_op = dialect.stablehlo.slice(self.getContext().mlirCtx(), self.value(), start_indices[0..self.rank()], limit_indices[0..self.rank()], strides[0..self.rank()], result_type, loc); + const mlir_ctx = self.getContext().mlirCtx(); + const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices}); + const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?; + const slice_op = dialect.stablehlo.slice( + mlir_ctx, + self.value(), + start_indices[0..self.rank()], + limit_indices[0..self.rank()], + strides[0..self.rank()], + result_type, + loc, + ); return _result(res_shape, slice_op.result(0)); } @@ -2437,42 +2446,103 @@ pub const Tensor = struct { return self._shape.axes(axes_); } - pub fn chunkAlloc(self: Tensor, allocator: std.mem.Allocator, chunks: i64, axis_: i64) ![]Tensor { + /// Chunk a given tensor into exactly n parts of equal shape. + /// `self.dim(axis_)` must be divisible by n_chunks. + pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor { const a = self.axis(axis_); - const length = self.dim(a); - const chunk_size: i64 = @divFloor(length, chunks); - const full_chunks: i64 = @divFloor(length, chunk_size); - const tail_chunk_size: i64 = @rem(length, chunk_size); - - const n_chunks = if (tail_chunk_size > 0) full_chunks + 1 else full_chunks; - const result = try allocator.alloc(Tensor, @intCast(n_chunks)); - self.chunk(result, axis_); - return result; - } - - pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor { - const a = self.axis(axis_); - const length = self.dim(a); - _ = @divExact(length, n_chunks); - var res: [n_chunks]Tensor = undefined; - self.chunk(&res, a); - return res; - } - - pub fn chunk(self: Tensor, chunks: []Tensor, axis_: i64) void { - const a = self.axis(axis_); - const length = self.dim(a); - const n_chunks: i64 = @intCast(chunks.len); - const chunk_size: i64 = @divFloor(length, n_chunks); - const full_chunks: usize = @intCast(@divFloor(length, chunk_size)); - const tail_chunk_size: i64 = @mod(length, chunk_size); - for (0..full_chunks) |i| { + const d = self.dim(a); + const chunk_size = @divExact(d, n_chunks); + var chunks: [n_chunks]Tensor = undefined; + for (0..n_chunks) |i| { const start: i64 = @as(i64, @intCast(i)) * chunk_size; chunks[i] = self.slice1d(a, .{ .start = start, .end = start + chunk_size }); } + return chunks; + } + + test chunkExact { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 12 }, .a, 3, .{ .a = 4 } }, + .{ .{ .a = 12, .b = 2 }, .a, 3, .{ .a = 4, .b = 2 } }, + .{ .{ 12, 2 }, 0, 3, .{ 4, 2 } }, + }) |testcase| { + const x_shape, const ax, const n_chunks, const res = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const chunks = x.chunkExact(ax, n_chunks); + + const res_shape = Shape.init(res, .f16); + for (&chunks) |chk| { + try zml.testing.expectEqualShapes(res_shape, chk.shape()); + } + } + } + + /// Chunk a given tensor into n parts of equal shape, and one part with the remaining items. + /// When `self.dim(axis_)` is divisible by `n_chunks` it will return exactly `n_chunks`. + pub fn chunkAllowTrailing( + self: Tensor, + axis_: i64, + n_chunks: comptime_int, + ) std.BoundedArray(Tensor, n_chunks + 1) { + const a = self.axis(axis_); + const d = self.dim(a); + const chunk_size: i64 = @divFloor(d, n_chunks); + const tail_chunk_size: i64 = @rem(d, chunk_size); + + var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{}; + for (0..n_chunks) |i| { + const start: i64 = @as(i64, @intCast(i)) * chunk_size; + chunks.appendAssumeCapacity( + self.slice1d(a, .{ .start = start, .end = start + chunk_size }), + ); + } if (tail_chunk_size != 0) { - const start: i64 = @as(i64, @intCast(full_chunks)) * chunk_size; - chunks[full_chunks] = self.slice1d(a, .{ .start = start }); + const start: i64 = n_chunks * chunk_size; + chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start })); + } + return chunks; + } + + test chunkAllowTrailing { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 10 }, .a, 3, .{ .a = 3 }, .{ .a = 1 } }, + .{ .{ .a = 10, .b = 2 }, .a, 3, .{ .a = 3, .b = 2 }, .{ .a = 1, .b = 2 } }, + .{ .{ 10, 2 }, 0, 3, .{ 3, 2 }, .{ 1, 2 } }, + .{ .{ 12, 2 }, 0, 3, .{ 4, 2 }, .{} }, + }) |testcase| { + const x_shape, const ax, const n_chunks, const res, const trailing = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks); + + const res_shape = Shape.init(res, .f16); + for (chunks.constSlice()[0..n_chunks]) |chk| { + try zml.testing.expectEqualShapes(res_shape, chk.shape()); + } + const trailing_shape = Shape.init(trailing, .f16); + if (trailing_shape.rank() > 0) { + try std.testing.expectEqual(n_chunks + 1, chunks.len); + try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape()); + } else { + try std.testing.expectEqual(n_chunks, chunks.len); + } } } @@ -2487,28 +2557,15 @@ pub const Tensor = struct { meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); } - return switch (split_size_or_sections.len) { - 1 => { - var chunk_count: i64 = @divFloor(length, split_size_or_sections[0]); - if (@as(usize, @intCast(length)) % @as(usize, @intCast(split_size_or_sections[0])) != 0) { - chunk_count += 1; - } - const res = try allocator.alloc(Tensor, @intCast(chunk_count)); - self.chunk(res, a); - return res; - }, - else => { - const res = try allocator.alloc(Tensor, split_size_or_sections.len); - errdefer allocator.dealloc(res); + const res = try allocator.alloc(Tensor, split_size_or_sections.len); + errdefer allocator.dealloc(res); - var start: i64 = 0; - for (split_size_or_sections, 0..) |n, i| { - res[i] = self.slice1d(a, .{ .start = start, .end = start + n }); - start += n; - } - return res; - }, - }; + var start: i64 = 0; + for (split_size_or_sections, 0..) |n, i| { + res[i] = self.slice1d(a, .{ .start = start, .end = start + n }); + start += n; + } + return res; } /// Slices the input Tensor along a specific axis, with a start offset known at runtime.