zml.tensor: add cumulativeSum operator and refactor maxPoolND. Introduce cumulative sum using reduceWindow. Simplify reduceWindow signature by merging padding_shape and padding_value. Update maxPool1D/2D to accept tuple arguments. Revise pad to use tagged or AOS syntax; remove SOA syntax.

This commit is contained in:
Tarry Singh 2023-05-17 09:01:27 +00:00
parent 54e7eb30b4
commit 05faa5021e
4 changed files with 158 additions and 69 deletions

View File

@ -481,7 +481,7 @@ pub fn round_nearest_even(ctx: mlir.Context, value: mlir.Value, location: mlir.L
pub const PadOpts = struct {
low: []const i64,
high: []const i64,
interior: ?[]const i64,
interior: []const i64,
};
pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation {
@ -491,7 +491,7 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
.attributes = &.{
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? },
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior.?).as(mlir.Attribute).? },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? },
},
.location = location,
});

View File

@ -211,8 +211,7 @@ pub const ReduceWindowOpts = struct {
window_strides: []const i64,
base_dilations: []const i64,
window_dilations: []const i64,
padding_values: []const i64,
padding_shape: []const i64,
padding: []const [2]i64,
};
pub fn reduceWindow(
@ -235,7 +234,10 @@ pub fn reduceWindow(
const loc = ctx.mlirCtx().location(@src());
const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64)).as(mlir.Type).?;
const pad_shape = mlir.RankedTensorType.init(
&.{ @intCast(opts.padding.len), 2 },
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
).as(mlir.Type).?;
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
.variadic_operands = &.{ input_values[0..], init_values[0..] },
.result_type_inference = true,
@ -245,7 +247,7 @@ pub fn reduceWindow(
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? },
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? },
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? },
},
.location = loc,
});

View File

@ -943,6 +943,39 @@ pub const Shape = struct {
try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice());
}
/// Parses a struct literal into a list of options for each axes.
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) {
const V = @TypeOf(options);
var res: std.BoundedArray(T, MAX_RANK) = .{};
if (comptime meta.isSliceOf(V, T)) {
meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
for (options) |d| {
res.appendAssumeCapacity(d);
}
}
if (comptime meta.isStruct(V)) {
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
const fields = std.meta.fields(V);
meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
inline for (fields) |field| {
const a = self.axis(field);
res.buffer[a] = @field(options, field.name);
}
return res;
}
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
}
test parseAxesOptions {
const shape = Shape.init(.{ .a = 10, .b = 20, .c = 30 }, .u8);
const scaling = shape.parseAxesOptions(f32, .{ .b = 1.2, .a = 0.1 }, 1.0);
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2, 1.0 }, scaling.constSlice());
}
test "comptimeShape" {
comptime {
const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);

View File

@ -1268,6 +1268,7 @@ pub const Tensor = struct {
}
/// Returns a Tensor containing the sum of elements over the given axis.
/// Output shape is the input shape with the axis_ dim set to 1.
pub fn sum(self: Tensor, axis_: anytype) Tensor {
const a = self.axis(axis_);
return ops.reduce(
@ -1283,10 +1284,62 @@ pub const Tensor = struct {
}
/// Returns a Tensor containing the mean of elements over the given axis.
/// Output shape is the input shape with the axis_ dim set to 1.
pub fn mean(self: Tensor, axis_: anytype) Tensor {
return self.sum(axis_).divByConst(self.dim(axis_));
}
/// Returns a Tensor containing the cumulative sum of elements over the given axis.
/// Output shape is the same as input shape.
/// [0, 1, 0, 1, 0, 0, 1, 1].cumulativeSum(0) -> [0, 1, 1, 2, 2, 2, 3, 4]
/// The last value contains the sum of all element in the array.
pub fn cumulativeSum(self: Tensor, axis_: anytype) Tensor {
const rk = self.rank();
const a = self.axis(axis_);
const ones = [_]i64{1} ** MAX_RANK;
var window_dimensions = ones;
window_dimensions[a] = self.dim(a);
var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK;
padding[a] = .{ self.dim(a) - 1, 0 };
var res = ops.reduceWindow(
Tensor.add,
self,
Tensor.scalar(0, self.dtype()),
.{
.base_dilations = ones[0..rk],
.window_dilations = ones[0..rk],
.window_strides = ones[0..rk],
.window_dimensions = window_dimensions[0..rk],
.padding = padding[0..rk],
},
);
res._shape = self._shape;
return res;
}
test cumulativeSum {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn _cumsum(input: Tensor) Tensor {
return input.withPartialTags(.{.n}).cumulativeSum(.n);
}
};
const x = try zml.Buffer.fromArray(
platform,
[2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } },
);
const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x});
try testing.expectEqual(
[2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } },
try res.getValue([2][5]f32),
);
}
/// Returns a transposed Tensor computed using the given axes.
pub fn transpose(self: Tensor, axes_: anytype) Tensor {
const axes__ = self.axes(axes_).constSlice();
@ -1868,47 +1921,45 @@ pub const Tensor = struct {
return _result(output_shape, reshape_value.result(0));
}
pub const Pad1dOpts = struct { low: i64, high: i64, interior: i64 = 0 };
/// Pads the input Tensor with the given value over the given axis.
pub fn pad1d(self: Tensor, axis_: i8, pad_value: anytype, opts: Pad1dOpts) Tensor {
const ZEROS = [_]i64{0} ** MAX_RANK;
var padding_low = ZEROS;
var padding_high = ZEROS;
var padding_interior = ZEROS;
const a = self.axis(axis_);
padding_low[a] = opts.low;
padding_high[a] = opts.high;
padding_interior[a] = opts.interior;
const rk = self.rank();
return self.pad(
pad_value,
.{ .low = padding_low[0..rk], .high = padding_high[0..rk], .interior = padding_interior[0..rk] },
);
}
pub const Pad = struct {
low: i32 = 0,
high: i32 = 0,
interior: i32 = 0,
};
/// Pads the input Tensor with the given values.
pub fn pad(self: Tensor, pad_value: anytype, opts: dialect.stablehlo.PadOpts) Tensor {
meta.assert(opts.low.len == self.rank(), "pad expects tensor rank and 'opts.low' length to be equal, got {} and {}", .{ self.rank(), opts.low.len });
meta.assert(opts.high.len == self.rank(), "pad expects tensor rank and 'opts.high' length to be equal, got {} and {}", .{ self.rank(), opts.high.len });
/// Usage: x.pad(0, .{ .a = .{ .low = 1, .high = 1 }});
pub fn pad(self: Tensor, padding_value: anytype, paddings: anytype) Tensor {
const _paddings = self.shape().parseAxesOptions(Pad, paddings, .{});
const ZEROS = [_]i64{0} ** MAX_RANK;
const interior = opts.interior orelse ZEROS[0..self.rank()];
var low = ZEROS;
var high = ZEROS;
var interior = ZEROS;
var new_shape = self._shape;
var res_shape = self._shape;
for (_paddings.constSlice(), 0..) |padding, i| {
low[i] = padding.low;
high[i] = padding.high;
interior[i] = padding.interior;
for (0..self.rank()) |i| {
const d = self.dim(i) + opts.low[i] + (@max(self.dim(i) - 1, 0) * interior[i]) + opts.high[i];
new_shape = new_shape.set(i, d);
var d: i64 = self.dim(i);
d += low[i] + (@max(d - 1, 0) * interior[i]) + high[i];
res_shape._dims.set(i, d);
}
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "pad(value={}, opts={})", .{ pad_value, opts });
var full_opts = opts;
full_opts.interior = opts.interior orelse ZEROS[0..self.rank()];
const pad_value_tensor = Tensor.scalar(pad_value, self.dtype());
const pad_op = dialect.stablehlo.pad(self.getContext().mlirCtx(), self.value(), pad_value_tensor.value(), full_opts, loc);
const rk = self.rank();
const mlir_ctx = self.getContext().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "pad({},{})", .{ padding_value, _paddings });
const pad_op = dialect.stablehlo.pad(
mlir_ctx,
self.value(),
Tensor.scalar(padding_value, self.dtype()).value(),
.{ .low = low[0..rk], .high = high[0..rk], .interior = interior[0..rk] },
loc,
);
return _result(new_shape, pad_op.result(0));
return _result(res_shape, pad_op.result(0));
}
/// Inserts 1-dim axes at the given position, with the given tags.
@ -2760,7 +2811,7 @@ pub const Tensor = struct {
window_strides: ?i64,
base_dilations: i64 = 1,
window_dilations: i64 = 1,
padding: []const i64 = &.{0},
padding: [2]i64 = .{ 0, 0 },
}) MaxPoolRes {
// TODO migrate to the following syntax.
// maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} },
@ -2771,16 +2822,21 @@ pub const Tensor = struct {
// .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2}
// })
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "maxPool1d expects 'opts.padding' length to be a single integer or to be equal to the double of input tensor rank, got {} (input tensor rank is {})", .{ opts.padding.len, self.rank() });
// Note: the problem is initPoolArg assuming last axis
// TODO: support maxPool on non last axis
const a = self.axis(-1);
const ones = [_]i64{1} ** Tensor.MAX_RANK;
var window_dimensions = ones;
window_dimensions[a] = opts.window_dimensions;
var window_strides = window_dimensions;
if (opts.window_strides) |stride| window_strides[a] = stride;
const window_dimensions = initPoolArg(self.rank(), &.{opts.window_dimensions});
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &.{ws}) else window_dimensions;
const base_dilation = initPoolArg(self.rank(), &.{opts.base_dilations});
const window_dilations = initPoolArg(self.rank(), &.{opts.window_dilations});
var base_dilations = ones;
base_dilations[a] = opts.base_dilations;
var window_dilations = ones;
window_dilations[a] = opts.window_dilations;
var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK;
padding[a] = opts.padding;
return ops.reduceWindow(
MaxPoolRes.cmp,
@ -2789,38 +2845,36 @@ pub const Tensor = struct {
.{
.window_dimensions = window_dimensions[0..self.rank()],
.window_strides = window_strides[0..self.rank()],
.base_dilations = base_dilation[0..self.rank()],
.base_dilations = base_dilations[0..self.rank()],
.window_dilations = window_dilations[0..self.rank()],
.padding_values = opts.padding,
.padding_shape = &.{ @intCast(self.rank()), 2 },
.padding = padding[0..self.rank()],
},
);
}
/// Computes the 2d maxPool operation on the input Tensor.
pub fn maxPool2d(self: Tensor, opts: struct {
window_dimensions: []const i64,
window_strides: ?[]const i64 = null,
base_dilations: []const i64 = &.{ 1, 1 },
window_dilations: []const i64 = &.{ 1, 1 },
padding: []const i64 = &.{0},
window_dimensions: [2]i64,
window_strides: ?[2]i64 = null,
base_dilations: [2]i64 = .{ 1, 1 },
window_dilations: [2]i64 = .{ 1, 1 },
padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } },
}) MaxPoolRes {
// TODO: rewrite using modern ZML, add ops.reduceWindow
// TODO: rewrite using modern ZML
meta.guard(self.rank() == 3 or self.rank() == 4, @src());
meta.guard(opts.window_dimensions.len == 2, @src());
meta.guard(opts.window_strides == null or opts.window_strides.?.len == 2, @src());
meta.guard(opts.base_dilations.len == 2, @src());
meta.guard(opts.window_dilations.len == 2, @src());
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "Padding needs to either be a single integer, or to be 2x time the number of input rank. In maxPool({}, .padding={d})", .{ self, opts.padding });
// TODO: support maxPool on non last axis
// Note: the problem is initPoolArg assuming last axis
const a = self.axis(-1);
const window_dimensions = initPoolArg(self.rank(), opts.window_dimensions);
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), ws) else window_dimensions;
const base_dilation = initPoolArg(self.rank(), opts.base_dilations);
const window_dilations = initPoolArg(self.rank(), opts.window_dilations);
const window_dimensions = initPoolArg(self.rank(), &opts.window_dimensions);
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &ws) else window_dimensions;
const base_dilation = initPoolArg(self.rank(), &opts.base_dilations);
const window_dilations = initPoolArg(self.rank(), &opts.window_dilations);
var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK;
padding[a - 1] = opts.padding[0];
padding[a] = opts.padding[1];
return ops.reduceWindow(
MaxPoolRes.cmp,
@ -2831,8 +2885,7 @@ pub const Tensor = struct {
.window_strides = window_strides[0..self.rank()],
.base_dilations = base_dilation[0..self.rank()],
.window_dilations = window_dilations[0..self.rank()],
.padding_values = opts.padding,
.padding_shape = &.{ @intCast(self.rank()), 2 },
.padding = padding[0..self.rank()],
},
);
}
@ -3086,6 +3139,7 @@ pub const Tensor = struct {
/// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 }));
/// ```
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
// TODO: add updateSlice for when the offset isn't dynamic
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
@ -3559,8 +3613,8 @@ test "Tensor.maxPool2d" {
const MaxPool = struct {
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
return x.maxPool2d(.{
.window_dimensions = &.{ 3, 2 },
.window_strides = &.{ 2, 1 },
.window_dimensions = .{ 3, 2 },
.window_strides = .{ 2, 1 },
});
}
};