zml.tensor: add cumulativeSum operator and refactor maxPoolND. Introduce cumulative sum using reduceWindow. Simplify reduceWindow signature by merging padding_shape and padding_value. Update maxPool1D/2D to accept tuple arguments. Revise pad to use tagged or AOS syntax; remove SOA syntax.
This commit is contained in:
parent
54e7eb30b4
commit
05faa5021e
@ -481,7 +481,7 @@ pub fn round_nearest_even(ctx: mlir.Context, value: mlir.Value, location: mlir.L
|
|||||||
pub const PadOpts = struct {
|
pub const PadOpts = struct {
|
||||||
low: []const i64,
|
low: []const i64,
|
||||||
high: []const i64,
|
high: []const i64,
|
||||||
interior: ?[]const i64,
|
interior: []const i64,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation {
|
pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation {
|
||||||
@ -491,7 +491,7 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
|
|||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? },
|
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? },
|
||||||
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? },
|
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? },
|
||||||
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior.?).as(mlir.Attribute).? },
|
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
|
|||||||
10
zml/ops.zig
10
zml/ops.zig
@ -211,8 +211,7 @@ pub const ReduceWindowOpts = struct {
|
|||||||
window_strides: []const i64,
|
window_strides: []const i64,
|
||||||
base_dilations: []const i64,
|
base_dilations: []const i64,
|
||||||
window_dilations: []const i64,
|
window_dilations: []const i64,
|
||||||
padding_values: []const i64,
|
padding: []const [2]i64,
|
||||||
padding_shape: []const i64,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn reduceWindow(
|
pub fn reduceWindow(
|
||||||
@ -235,7 +234,10 @@ pub fn reduceWindow(
|
|||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src());
|
const loc = ctx.mlirCtx().location(@src());
|
||||||
|
|
||||||
const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64)).as(mlir.Type).?;
|
const pad_shape = mlir.RankedTensorType.init(
|
||||||
|
&.{ @intCast(opts.padding.len), 2 },
|
||||||
|
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
|
||||||
|
).as(mlir.Type).?;
|
||||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
|
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
|
||||||
.variadic_operands = &.{ input_values[0..], init_values[0..] },
|
.variadic_operands = &.{ input_values[0..], init_values[0..] },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
@ -245,7 +247,7 @@ pub fn reduceWindow(
|
|||||||
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? },
|
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? },
|
||||||
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? },
|
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? },
|
||||||
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? },
|
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? },
|
||||||
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? },
|
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? },
|
||||||
},
|
},
|
||||||
.location = loc,
|
.location = loc,
|
||||||
});
|
});
|
||||||
|
|||||||
@ -943,6 +943,39 @@ pub const Shape = struct {
|
|||||||
try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice());
|
try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parses a struct literal into a list of options for each axes.
|
||||||
|
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) {
|
||||||
|
const V = @TypeOf(options);
|
||||||
|
|
||||||
|
var res: std.BoundedArray(T, MAX_RANK) = .{};
|
||||||
|
if (comptime meta.isSliceOf(V, T)) {
|
||||||
|
meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
|
||||||
|
for (options) |d| {
|
||||||
|
res.appendAssumeCapacity(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (comptime meta.isStruct(V)) {
|
||||||
|
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
|
||||||
|
const fields = std.meta.fields(V);
|
||||||
|
meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
|
||||||
|
inline for (fields) |field| {
|
||||||
|
const a = self.axis(field);
|
||||||
|
res.buffer[a] = @field(options, field.name);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
|
||||||
|
}
|
||||||
|
|
||||||
|
test parseAxesOptions {
|
||||||
|
const shape = Shape.init(.{ .a = 10, .b = 20, .c = 30 }, .u8);
|
||||||
|
const scaling = shape.parseAxesOptions(f32, .{ .b = 1.2, .a = 0.1 }, 1.0);
|
||||||
|
|
||||||
|
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2, 1.0 }, scaling.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
test "comptimeShape" {
|
test "comptimeShape" {
|
||||||
comptime {
|
comptime {
|
||||||
const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
|
const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
|
||||||
|
|||||||
180
zml/tensor.zig
180
zml/tensor.zig
@ -1268,6 +1268,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the sum of elements over the given axis.
|
/// Returns a Tensor containing the sum of elements over the given axis.
|
||||||
|
/// Output shape is the input shape with the axis_ dim set to 1.
|
||||||
pub fn sum(self: Tensor, axis_: anytype) Tensor {
|
pub fn sum(self: Tensor, axis_: anytype) Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
return ops.reduce(
|
return ops.reduce(
|
||||||
@ -1283,10 +1284,62 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the mean of elements over the given axis.
|
/// Returns a Tensor containing the mean of elements over the given axis.
|
||||||
|
/// Output shape is the input shape with the axis_ dim set to 1.
|
||||||
pub fn mean(self: Tensor, axis_: anytype) Tensor {
|
pub fn mean(self: Tensor, axis_: anytype) Tensor {
|
||||||
return self.sum(axis_).divByConst(self.dim(axis_));
|
return self.sum(axis_).divByConst(self.dim(axis_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a Tensor containing the cumulative sum of elements over the given axis.
|
||||||
|
/// Output shape is the same as input shape.
|
||||||
|
/// [0, 1, 0, 1, 0, 0, 1, 1].cumulativeSum(0) -> [0, 1, 1, 2, 2, 2, 3, 4]
|
||||||
|
/// The last value contains the sum of all element in the array.
|
||||||
|
pub fn cumulativeSum(self: Tensor, axis_: anytype) Tensor {
|
||||||
|
const rk = self.rank();
|
||||||
|
const a = self.axis(axis_);
|
||||||
|
|
||||||
|
const ones = [_]i64{1} ** MAX_RANK;
|
||||||
|
var window_dimensions = ones;
|
||||||
|
window_dimensions[a] = self.dim(a);
|
||||||
|
var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK;
|
||||||
|
padding[a] = .{ self.dim(a) - 1, 0 };
|
||||||
|
|
||||||
|
var res = ops.reduceWindow(
|
||||||
|
Tensor.add,
|
||||||
|
self,
|
||||||
|
Tensor.scalar(0, self.dtype()),
|
||||||
|
.{
|
||||||
|
.base_dilations = ones[0..rk],
|
||||||
|
.window_dilations = ones[0..rk],
|
||||||
|
.window_strides = ones[0..rk],
|
||||||
|
.window_dimensions = window_dimensions[0..rk],
|
||||||
|
.padding = padding[0..rk],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
res._shape = self._shape;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
test cumulativeSum {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const Local = struct {
|
||||||
|
pub fn _cumsum(input: Tensor) Tensor {
|
||||||
|
return input.withPartialTags(.{.n}).cumulativeSum(.n);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const x = try zml.Buffer.fromArray(
|
||||||
|
platform,
|
||||||
|
[2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } },
|
||||||
|
);
|
||||||
|
const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x});
|
||||||
|
try testing.expectEqual(
|
||||||
|
[2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } },
|
||||||
|
try res.getValue([2][5]f32),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a transposed Tensor computed using the given axes.
|
/// Returns a transposed Tensor computed using the given axes.
|
||||||
pub fn transpose(self: Tensor, axes_: anytype) Tensor {
|
pub fn transpose(self: Tensor, axes_: anytype) Tensor {
|
||||||
const axes__ = self.axes(axes_).constSlice();
|
const axes__ = self.axes(axes_).constSlice();
|
||||||
@ -1868,47 +1921,45 @@ pub const Tensor = struct {
|
|||||||
return _result(output_shape, reshape_value.result(0));
|
return _result(output_shape, reshape_value.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const Pad1dOpts = struct { low: i64, high: i64, interior: i64 = 0 };
|
pub const Pad = struct {
|
||||||
|
low: i32 = 0,
|
||||||
/// Pads the input Tensor with the given value over the given axis.
|
high: i32 = 0,
|
||||||
pub fn pad1d(self: Tensor, axis_: i8, pad_value: anytype, opts: Pad1dOpts) Tensor {
|
interior: i32 = 0,
|
||||||
const ZEROS = [_]i64{0} ** MAX_RANK;
|
};
|
||||||
var padding_low = ZEROS;
|
|
||||||
var padding_high = ZEROS;
|
|
||||||
var padding_interior = ZEROS;
|
|
||||||
const a = self.axis(axis_);
|
|
||||||
padding_low[a] = opts.low;
|
|
||||||
padding_high[a] = opts.high;
|
|
||||||
padding_interior[a] = opts.interior;
|
|
||||||
|
|
||||||
const rk = self.rank();
|
|
||||||
return self.pad(
|
|
||||||
pad_value,
|
|
||||||
.{ .low = padding_low[0..rk], .high = padding_high[0..rk], .interior = padding_interior[0..rk] },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Pads the input Tensor with the given values.
|
/// Pads the input Tensor with the given values.
|
||||||
pub fn pad(self: Tensor, pad_value: anytype, opts: dialect.stablehlo.PadOpts) Tensor {
|
/// Usage: x.pad(0, .{ .a = .{ .low = 1, .high = 1 }});
|
||||||
meta.assert(opts.low.len == self.rank(), "pad expects tensor rank and 'opts.low' length to be equal, got {} and {}", .{ self.rank(), opts.low.len });
|
pub fn pad(self: Tensor, padding_value: anytype, paddings: anytype) Tensor {
|
||||||
meta.assert(opts.high.len == self.rank(), "pad expects tensor rank and 'opts.high' length to be equal, got {} and {}", .{ self.rank(), opts.high.len });
|
const _paddings = self.shape().parseAxesOptions(Pad, paddings, .{});
|
||||||
|
|
||||||
const ZEROS = [_]i64{0} ** MAX_RANK;
|
const ZEROS = [_]i64{0} ** MAX_RANK;
|
||||||
const interior = opts.interior orelse ZEROS[0..self.rank()];
|
var low = ZEROS;
|
||||||
|
var high = ZEROS;
|
||||||
|
var interior = ZEROS;
|
||||||
|
|
||||||
var new_shape = self._shape;
|
var res_shape = self._shape;
|
||||||
|
for (_paddings.constSlice(), 0..) |padding, i| {
|
||||||
|
low[i] = padding.low;
|
||||||
|
high[i] = padding.high;
|
||||||
|
interior[i] = padding.interior;
|
||||||
|
|
||||||
for (0..self.rank()) |i| {
|
var d: i64 = self.dim(i);
|
||||||
const d = self.dim(i) + opts.low[i] + (@max(self.dim(i) - 1, 0) * interior[i]) + opts.high[i];
|
d += low[i] + (@max(d - 1, 0) * interior[i]) + high[i];
|
||||||
new_shape = new_shape.set(i, d);
|
res_shape._dims.set(i, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "pad(value={}, opts={})", .{ pad_value, opts });
|
const rk = self.rank();
|
||||||
var full_opts = opts;
|
const mlir_ctx = self.getContext().mlirCtx();
|
||||||
full_opts.interior = opts.interior orelse ZEROS[0..self.rank()];
|
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "pad({},{})", .{ padding_value, _paddings });
|
||||||
const pad_value_tensor = Tensor.scalar(pad_value, self.dtype());
|
const pad_op = dialect.stablehlo.pad(
|
||||||
const pad_op = dialect.stablehlo.pad(self.getContext().mlirCtx(), self.value(), pad_value_tensor.value(), full_opts, loc);
|
mlir_ctx,
|
||||||
|
self.value(),
|
||||||
|
Tensor.scalar(padding_value, self.dtype()).value(),
|
||||||
|
.{ .low = low[0..rk], .high = high[0..rk], .interior = interior[0..rk] },
|
||||||
|
loc,
|
||||||
|
);
|
||||||
|
|
||||||
return _result(new_shape, pad_op.result(0));
|
return _result(res_shape, pad_op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inserts 1-dim axes at the given position, with the given tags.
|
/// Inserts 1-dim axes at the given position, with the given tags.
|
||||||
@ -2760,7 +2811,7 @@ pub const Tensor = struct {
|
|||||||
window_strides: ?i64,
|
window_strides: ?i64,
|
||||||
base_dilations: i64 = 1,
|
base_dilations: i64 = 1,
|
||||||
window_dilations: i64 = 1,
|
window_dilations: i64 = 1,
|
||||||
padding: []const i64 = &.{0},
|
padding: [2]i64 = .{ 0, 0 },
|
||||||
}) MaxPoolRes {
|
}) MaxPoolRes {
|
||||||
// TODO migrate to the following syntax.
|
// TODO migrate to the following syntax.
|
||||||
// maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} },
|
// maxPool(.{.a = .{ .stride = 5, .dilation = 2, .padding = .{0, 1} },
|
||||||
@ -2771,16 +2822,21 @@ pub const Tensor = struct {
|
|||||||
// .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2}
|
// .padding = .{ .a = .{ 0, 2 }, .b = .{0, 2}
|
||||||
// })
|
// })
|
||||||
|
|
||||||
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "maxPool1d expects 'opts.padding' length to be a single integer or to be equal to the double of input tensor rank, got {} (input tensor rank is {})", .{ opts.padding.len, self.rank() });
|
|
||||||
|
|
||||||
// Note: the problem is initPoolArg assuming last axis
|
|
||||||
// TODO: support maxPool on non last axis
|
// TODO: support maxPool on non last axis
|
||||||
const a = self.axis(-1);
|
const a = self.axis(-1);
|
||||||
|
const ones = [_]i64{1} ** Tensor.MAX_RANK;
|
||||||
|
var window_dimensions = ones;
|
||||||
|
window_dimensions[a] = opts.window_dimensions;
|
||||||
|
var window_strides = window_dimensions;
|
||||||
|
if (opts.window_strides) |stride| window_strides[a] = stride;
|
||||||
|
|
||||||
const window_dimensions = initPoolArg(self.rank(), &.{opts.window_dimensions});
|
var base_dilations = ones;
|
||||||
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &.{ws}) else window_dimensions;
|
base_dilations[a] = opts.base_dilations;
|
||||||
const base_dilation = initPoolArg(self.rank(), &.{opts.base_dilations});
|
var window_dilations = ones;
|
||||||
const window_dilations = initPoolArg(self.rank(), &.{opts.window_dilations});
|
window_dilations[a] = opts.window_dilations;
|
||||||
|
|
||||||
|
var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK;
|
||||||
|
padding[a] = opts.padding;
|
||||||
|
|
||||||
return ops.reduceWindow(
|
return ops.reduceWindow(
|
||||||
MaxPoolRes.cmp,
|
MaxPoolRes.cmp,
|
||||||
@ -2789,38 +2845,36 @@ pub const Tensor = struct {
|
|||||||
.{
|
.{
|
||||||
.window_dimensions = window_dimensions[0..self.rank()],
|
.window_dimensions = window_dimensions[0..self.rank()],
|
||||||
.window_strides = window_strides[0..self.rank()],
|
.window_strides = window_strides[0..self.rank()],
|
||||||
.base_dilations = base_dilation[0..self.rank()],
|
.base_dilations = base_dilations[0..self.rank()],
|
||||||
.window_dilations = window_dilations[0..self.rank()],
|
.window_dilations = window_dilations[0..self.rank()],
|
||||||
.padding_values = opts.padding,
|
.padding = padding[0..self.rank()],
|
||||||
.padding_shape = &.{ @intCast(self.rank()), 2 },
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes the 2d maxPool operation on the input Tensor.
|
/// Computes the 2d maxPool operation on the input Tensor.
|
||||||
pub fn maxPool2d(self: Tensor, opts: struct {
|
pub fn maxPool2d(self: Tensor, opts: struct {
|
||||||
window_dimensions: []const i64,
|
window_dimensions: [2]i64,
|
||||||
window_strides: ?[]const i64 = null,
|
window_strides: ?[2]i64 = null,
|
||||||
base_dilations: []const i64 = &.{ 1, 1 },
|
base_dilations: [2]i64 = .{ 1, 1 },
|
||||||
window_dilations: []const i64 = &.{ 1, 1 },
|
window_dilations: [2]i64 = .{ 1, 1 },
|
||||||
padding: []const i64 = &.{0},
|
padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } },
|
||||||
}) MaxPoolRes {
|
}) MaxPoolRes {
|
||||||
// TODO: rewrite using modern ZML, add ops.reduceWindow
|
// TODO: rewrite using modern ZML
|
||||||
meta.guard(self.rank() == 3 or self.rank() == 4, @src());
|
meta.guard(self.rank() == 3 or self.rank() == 4, @src());
|
||||||
meta.guard(opts.window_dimensions.len == 2, @src());
|
|
||||||
meta.guard(opts.window_strides == null or opts.window_strides.?.len == 2, @src());
|
|
||||||
meta.guard(opts.base_dilations.len == 2, @src());
|
|
||||||
meta.guard(opts.window_dilations.len == 2, @src());
|
|
||||||
meta.assert(opts.padding.len == 1 or opts.padding.len == 2 * self.rank(), "Padding needs to either be a single integer, or to be 2x time the number of input rank. In maxPool({}, .padding={d})", .{ self, opts.padding });
|
|
||||||
|
|
||||||
// TODO: support maxPool on non last axis
|
// TODO: support maxPool on non last axis
|
||||||
// Note: the problem is initPoolArg assuming last axis
|
// Note: the problem is initPoolArg assuming last axis
|
||||||
const a = self.axis(-1);
|
const a = self.axis(-1);
|
||||||
|
|
||||||
const window_dimensions = initPoolArg(self.rank(), opts.window_dimensions);
|
const window_dimensions = initPoolArg(self.rank(), &opts.window_dimensions);
|
||||||
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), ws) else window_dimensions;
|
const window_strides = if (opts.window_strides) |ws| initPoolArg(self.rank(), &ws) else window_dimensions;
|
||||||
const base_dilation = initPoolArg(self.rank(), opts.base_dilations);
|
const base_dilation = initPoolArg(self.rank(), &opts.base_dilations);
|
||||||
const window_dilations = initPoolArg(self.rank(), opts.window_dilations);
|
const window_dilations = initPoolArg(self.rank(), &opts.window_dilations);
|
||||||
|
|
||||||
|
var padding = [_][2]i64{.{ 0, 0 }} ** Tensor.MAX_RANK;
|
||||||
|
padding[a - 1] = opts.padding[0];
|
||||||
|
padding[a] = opts.padding[1];
|
||||||
|
|
||||||
return ops.reduceWindow(
|
return ops.reduceWindow(
|
||||||
MaxPoolRes.cmp,
|
MaxPoolRes.cmp,
|
||||||
@ -2831,8 +2885,7 @@ pub const Tensor = struct {
|
|||||||
.window_strides = window_strides[0..self.rank()],
|
.window_strides = window_strides[0..self.rank()],
|
||||||
.base_dilations = base_dilation[0..self.rank()],
|
.base_dilations = base_dilation[0..self.rank()],
|
||||||
.window_dilations = window_dilations[0..self.rank()],
|
.window_dilations = window_dilations[0..self.rank()],
|
||||||
.padding_values = opts.padding,
|
.padding = padding[0..self.rank()],
|
||||||
.padding_shape = &.{ @intCast(self.rank()), 2 },
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -3086,6 +3139,7 @@ pub const Tensor = struct {
|
|||||||
/// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 }));
|
/// Tensor(.{ .a = 2, .b = 5 }).dynamicUpdateSlice(.{ .a = scalar(1, .i32) }, Tensor(.{ .b = 5 }));
|
||||||
/// ```
|
/// ```
|
||||||
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
|
pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor {
|
||||||
|
// TODO: add updateSlice for when the offset isn't dynamic
|
||||||
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
|
meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() });
|
||||||
|
|
||||||
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
|
const offset, const offset_tags = Shape.parseStruct(Tensor, offset_);
|
||||||
@ -3559,8 +3613,8 @@ test "Tensor.maxPool2d" {
|
|||||||
const MaxPool = struct {
|
const MaxPool = struct {
|
||||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||||
return x.maxPool2d(.{
|
return x.maxPool2d(.{
|
||||||
.window_dimensions = &.{ 3, 2 },
|
.window_dimensions = .{ 3, 2 },
|
||||||
.window_strides = &.{ 2, 1 },
|
.window_strides = .{ 2, 1 },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user