From 05faa5021eed13ebec07d7abed98aa3d43c99a62 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 17 May 2023 09:01:27 +0000 Subject: [PATCH] zml.tensor: add cumulativeSum operator and refactor maxPoolND. Introduce cumulative sum using reduceWindow. Simplify reduceWindow signature by merging padding_shape and padding_value. Update maxPool1D/2D to accept tuple arguments. Revise pad to use tagged or AOS syntax; remove SOA syntax. --- mlir/dialects/stablehlo.zig | 4 +- zml/ops.zig | 10 +- zml/shape.zig | 33 +++++++ zml/tensor.zig | 180 +++++++++++++++++++++++------------- 4 files changed, 158 insertions(+), 69 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 9f3606e..fff3021 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -481,7 +481,7 @@ pub fn round_nearest_even(ctx: mlir.Context, value: mlir.Value, location: mlir.L pub const PadOpts = struct { low: []const i64, high: []const i64, - interior: ?[]const i64, + interior: []const i64, }; pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation { @@ -491,7 +491,7 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts .attributes = &.{ .{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? }, .{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? }, - .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior.?).as(mlir.Attribute).? }, + .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? }, }, .location = location, }); diff --git a/zml/ops.zig b/zml/ops.zig index 35765ce..74c6060 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -211,8 +211,7 @@ pub const ReduceWindowOpts = struct { window_strides: []const i64, base_dilations: []const i64, window_dilations: []const i64, - padding_values: []const i64, - padding_shape: []const i64, + padding: []const [2]i64, }; pub fn reduceWindow( @@ -235,7 +234,10 @@ pub fn reduceWindow( const loc = ctx.mlirCtx().location(@src()); - const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64)).as(mlir.Type).?; + const pad_shape = mlir.RankedTensorType.init( + &.{ @intCast(opts.padding.len), 2 }, + mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64), + ).as(mlir.Type).?; const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{ .variadic_operands = &.{ input_values[0..], init_values[0..] }, .result_type_inference = true, @@ -245,7 +247,7 @@ pub fn reduceWindow( .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? }, .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? }, .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? }, - .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? }, + .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? }, }, .location = loc, }); diff --git a/zml/shape.zig b/zml/shape.zig index 2ed4016..a61f25e 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -943,6 +943,39 @@ pub const Shape = struct { try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice()); } + /// Parses a struct literal into a list of options for each axes. + pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) { + const V = @TypeOf(options); + + var res: std.BoundedArray(T, MAX_RANK) = .{}; + if (comptime meta.isSliceOf(V, T)) { + meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); + for (options) |d| { + res.appendAssumeCapacity(d); + } + } + + if (comptime meta.isStruct(V)) { + for (0..self.rank()) |_| res.appendAssumeCapacity(default); + const fields = std.meta.fields(V); + meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len }); + inline for (fields) |field| { + const a = self.axis(field); + res.buffer[a] = @field(options, field.name); + } + return res; + } + + meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); + } + + test parseAxesOptions { + const shape = Shape.init(.{ .a = 10, .b = 20, .c = 30 }, .u8); + const scaling = shape.parseAxesOptions(f32, .{ .b = 1.2, .a = 0.1 }, 1.0); + + try testing.expectEqualSlices(f32, &.{ 0.1, 1.2, 1.0 }, scaling.constSlice()); + } + test "comptimeShape" { comptime { const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32); diff --git a/zml/tensor.zig b/zml/tensor.zig index f1065ae..a69a3c7 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1268,6 +1268,7 @@ pub const Tensor = struct { } /// Returns a Tensor containing the sum of elements over the given axis. + /// Output shape is the input shape with the axis_ dim set to 1. pub fn sum(self: Tensor, axis_: anytype) Tensor { const a = self.axis(axis_); return ops.reduce( @@ -1283,10 +1284,62 @@ pub const Tensor = struct { } /// Returns a Tensor containing the mean of elements over the given axis. + /// Output shape is the input shape with the axis_ dim set to 1. pub fn mean(self: Tensor, axis_: anytype) Tensor { return self.sum(axis_).divByConst(self.dim(axis_)); } + /// Returns a Tensor containing the cumulative sum of elements over the given axis. + /// Output shape is the same as input shape. + /// [0, 1, 0, 1, 0, 0, 1, 1].cumulativeSum(0) -> [0, 1, 1, 2, 2, 2, 3, 4] + /// The last value contains the sum of all element in the array. + pub fn cumulativeSum(self: Tensor, axis_: anytype) Tensor { + const rk = self.rank(); + const a = self.axis(axis_); + + const ones = [_]i64{1} ** MAX_RANK; + var window_dimensions = ones; + window_dimensions[a] = self.dim(a); + var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK; + padding[a] = .{ self.dim(a) - 1, 0 }; + + var res = ops.reduceWindow( + Tensor.add, + self, + Tensor.scalar(0, self.dtype()), + .{ + .base_dilations = ones[0..rk], + .window_dilations = ones[0..rk], + .window_strides = ones[0..rk], + .window_dimensions = window_dimensions[0..rk], + .padding = padding[0..rk], + }, + ); + res._shape = self._shape; + return res; + } + + test cumulativeSum { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + pub fn _cumsum(input: Tensor) Tensor { + return input.withPartialTags(.{.n}).cumulativeSum(.n); + } + }; + + const x = try zml.Buffer.fromArray( + platform, + [2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } }, + ); + const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x}); + try testing.expectEqual( + [2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } }, + try res.getValue([2][5]f32), + ); + } + /// Returns a transposed Tensor computed using the given axes. pub fn transpose(self: Tensor, axes_: anytype) Tensor { const axes__ = self.axes(axes_).constSlice(); @@ -1868,47 +1921,45 @@ pub const Tensor = struct { return _result(output_shape, reshape_value.result(0)); } - pub const Pad1dOpts = struct { low: i64, high: i64, interior: i64 = 0 }; - - /// Pads the input Tensor with the given value over the given axis. - pub fn pad1d(self: Tensor, axis_: i8, pad_value: anytype, opts: Pad1dOpts) Tensor { - const ZEROS = [_]i64{0} ** MAX_RANK; - var padding_low = ZEROS; - var padding_high = ZEROS; - var padding_interior = ZEROS; - const a = self.axis(axis_); - padding_low[a] = opts.low; - padding_high[a] = opts.high; - padding_interior[a] = opts.interior; - - const rk = self.rank(); - return self.pad( - pad_value, - .{ .low = padding_low[0..rk], .high = padding_high[0..rk], .interior = padding_interior[0..rk] }, - ); - } + pub const Pad = struct { + low: i32 = 0, + high: i32 = 0, + interior: i32 = 0, + }; /// Pads the input Tensor with the given values. - pub fn pad(self: Tensor, pad_value: anytype, opts: dialect.stablehlo.PadOpts) Tensor { - meta.assert(opts.low.len == self.rank(), "pad expects tensor rank and 'opts.low' length to be equal, got {} and {}", .{ self.rank(), opts.low.len }); - meta.assert(opts.high.len == self.rank(), "pad expects tensor rank and 'opts.high' length to be equal, got {} and {}", .{ self.rank(), opts.high.len }); + /// Usage: x.pad(0, .{ .a = .{ .low = 1, .high = 1 }}); + pub fn pad(self: Tensor, padding_value: anytype, paddings: anytype) Tensor { + const _paddings = self.shape().parseAxesOptions(Pad, paddings, .{}); + const ZEROS = [_]i64{0} ** MAX_RANK; - const interior = opts.interior orelse ZEROS[0..self.rank()]; + var low = ZEROS; + var high = ZEROS; + var interior = ZEROS; - var new_shape = self._shape; + var res_shape = self._shape; + for (_paddings.constSlice(), 0..) |padding, i| { + low[i] = padding.low; + high[i] = padding.high; + interior[i] = padding.interior; - for (0..self.rank()) |i| { - const d = self.dim(i) + opts.low[i] + (@max(self.dim(i) - 1, 0) * interior[i]) + opts.high[i]; - new_shape = new_shape.set(i, d); + var d: i64 = self.dim(i); + d += low[i] + (@max(d - 1, 0) * interior[i]) + high[i]; + res_shape._dims.set(i, d); } - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "pad(value={}, opts={})", .{ pad_value, opts }); - var full_opts = opts; - full_opts.interior = opts.interior orelse ZEROS[0..self.rank()]; - const pad_value_tensor = Tensor.scalar(pad_value, self.dtype()); - const pad_op = dialect.stablehlo.pad(self.getContext().mlirCtx(), self.value(), pad_value_tensor.value(), full_opts, loc); + const rk = self.rank(); + const mlir_ctx = self.getContext().mlirCtx(); + const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "pad({},{})", .{ padding_value, _paddings }); + const pad_op = dialect.stablehlo.pad( + mlir_ctx, + self.value(), + Tensor.scalar(padding_value, self.dtype()).value(), + .{ .low = low[0..rk], .high = high[0..rk], .interior = interior[0..rk] }, + loc, + ); - return _result(new_shape, pad_op.result(0)); + return _result(res_shape, pad_op.result(0)); } /// Inserts 1-dim axes at the given position, with the given tags. @@ -2760,7 +2811,7 @@ pub const Tensor = struct { window_strides: ?i64, base_dilations: i64 = 1, window_dilations: i64 = 1, - padding: []const i64 = &.{0}, + padding: [2]i64 = .{ 0, 0 }, }) MaxPoolRes { // TODO migrate to the following syntax. // maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} }, @@ -2771,16 +2822,21 @@ pub const Tensor = struct { // .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2} // }) - meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "maxPool1d expects 'opts.padding' length to be a single integer or to be equal to the double of input tensor rank, got {} (input tensor rank is {})", .{ opts.padding.len, self.rank() }); - - // Note: the problem is initPoolArg assuming last axis // TODO: support maxPool on non last axis const a = self.axis(-1); + const ones = [_]i64{1} ** Tensor.MAX_RANK; + var window_dimensions = ones; + window_dimensions[a] = opts.window_dimensions; + var window_strides = window_dimensions; + if (opts.window_strides) |stride| window_strides[a] = stride; - const window_dimensions = initPoolArg(self.rank(), &.{opts.window_dimensions}); - const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &.{ws}) else window_dimensions; - const base_dilation = initPoolArg(self.rank(), &.{opts.base_dilations}); - const window_dilations = initPoolArg(self.rank(), &.{opts.window_dilations}); + var base_dilations = ones; + base_dilations[a] = opts.base_dilations; + var window_dilations = ones; + window_dilations[a] = opts.window_dilations; + + var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK; + padding[a] = opts.padding; return ops.reduceWindow( MaxPoolRes.cmp, @@ -2789,38 +2845,36 @@ pub const Tensor = struct { .{ .window_dimensions = window_dimensions[0..self.rank()], .window_strides = window_strides[0..self.rank()], - .base_dilations = base_dilation[0..self.rank()], + .base_dilations = base_dilations[0..self.rank()], .window_dilations = window_dilations[0..self.rank()], - .padding_values = opts.padding, - .padding_shape = &.{ @intCast(self.rank()), 2 }, + .padding = padding[0..self.rank()], }, ); } /// Computes the 2d maxPool operation on the input Tensor. pub fn maxPool2d(self: Tensor, opts: struct { - window_dimensions: []const i64, - window_strides: ?[]const i64 = null, - base_dilations: []const i64 = &.{ 1, 1 }, - window_dilations: []const i64 = &.{ 1, 1 }, - padding: []const i64 = &.{0}, + window_dimensions: [2]i64, + window_strides: ?[2]i64 = null, + base_dilations: [2]i64 = .{ 1, 1 }, + window_dilations: [2]i64 = .{ 1, 1 }, + padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } }, }) MaxPoolRes { - // TODO: rewrite using modern ZML, add ops.reduceWindow + // TODO: rewrite using modern ZML meta.guard(self.rank() == 3 or self.rank() == 4, @src()); - meta.guard(opts.window_dimensions.len == 2, @src()); - meta.guard(opts.window_strides == null or opts.window_strides.?.len == 2, @src()); - meta.guard(opts.base_dilations.len == 2, @src()); - meta.guard(opts.window_dilations.len == 2, @src()); - meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "Padding needs to either be a single integer, or to be 2x time the number of input rank. In maxPool({}, .padding={d})", .{ self, opts.padding }); // TODO: support maxPool on non last axis // Note: the problem is initPoolArg assuming last axis const a = self.axis(-1); - const window_dimensions = initPoolArg(self.rank(), opts.window_dimensions); - const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), ws) else window_dimensions; - const base_dilation = initPoolArg(self.rank(), opts.base_dilations); - const window_dilations = initPoolArg(self.rank(), opts.window_dilations); + const window_dimensions = initPoolArg(self.rank(), &opts.window_dimensions); + const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &ws) else window_dimensions; + const base_dilation = initPoolArg(self.rank(), &opts.base_dilations); + const window_dilations = initPoolArg(self.rank(), &opts.window_dilations); + + var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK; + padding[a - 1] = opts.padding[0]; + padding[a] = opts.padding[1]; return ops.reduceWindow( MaxPoolRes.cmp, @@ -2831,8 +2885,7 @@ pub const Tensor = struct { .window_strides = window_strides[0..self.rank()], .base_dilations = base_dilation[0..self.rank()], .window_dilations = window_dilations[0..self.rank()], - .padding_values = opts.padding, - .padding_shape = &.{ @intCast(self.rank()), 2 }, + .padding = padding[0..self.rank()], }, ); } @@ -3086,6 +3139,7 @@ pub const Tensor = struct { /// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 })); /// ``` pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor { + // TODO: add updateSlice for when the offset isn't dynamic meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() }); const offset, const offset_tags = Shape.parseStruct(Tensor, offset_); @@ -3559,8 +3613,8 @@ test "Tensor.maxPool2d" { const MaxPool = struct { pub fn forward(x: Tensor) Tensor.ArgMaxRes { return x.maxPool2d(.{ - .window_dimensions = &.{ 3, 2 }, - .window_strides = &.{ 2, 1 }, + .window_dimensions = .{ 3, 2 }, + .window_strides = .{ 2, 1 }, }); } };