Radix/zml/shape.zig

996 lines
35 KiB
Zig
Raw Normal View History

const builtin = @import("builtin");
const std = @import("std");
const testing = std.testing;
const meta = @import("meta.zig");
const DataType = @import("dtype.zig").DataType;
const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.shape);
test {
std.testing.refAllDecls(Shape);
}
/// Represent the shape of a tensor.
pub const Shape = struct {
pub const MAX_RANK: u8 = 8;
pub const Tag = [*:0]const u8;
pub const TagUnknown = "_".ptr;
const TagLast = "last".ptr;
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
pub const ShardingInfo = @Vector(MAX_RANK, bool);
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
_dtype: DataType,
_dims: DimsArray = .{},
_tags: TagsArray = UnknownTags,
_sharding_info: ShardingInfo = @splat(false),
pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } {
const T = @TypeOf(v);
if (T == Shape) {
return .{ v._dims, v._tags };
}
if (comptime meta.isSliceOfAny(T, meta.isInteger)) {
var dims_ = DimsArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable;
for (v) |d| {
dims_.appendAssumeCapacity(@intCast(d));
tags_.appendAssumeCapacity(TagUnknown);
}
return .{ dims_, tags_ };
}
if (comptime meta.isStruct(T)) {
var dims_: DimsArray = .{};
var tags_: TagsArray = .{};
inline for (std.meta.fields(T)) |field| {
const fv = @field(v, field.name);
if (comptime meta.isInteger(field.type)) {
dims_.appendAssumeCapacity(@intCast(fv));
} else if (comptime isAutoDim(fv)) {
dims_.appendAssumeCapacity(-1);
} else {
meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
}
if (comptime meta.isTuple(T)) {
tags_.appendAssumeCapacity(TagUnknown);
} else {
tags_.appendAssumeCapacity(toTag(field));
}
}
return .{ dims_, tags_ };
}
meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
}
test parseDimensions {
const ref = Shape.init(.{ .a = 0, .b = 1 }, .f32);
const dims_, const tags_ = parseDimensions(.{ .a = 0, .b = 1 });
try testing.expectEqualSlices(i64, ref.dims(), dims_.constSlice());
try testing.expectEqualSlices(Tag, ref.tags(), tags_.constSlice());
}
pub fn parseAxes(self: Shape, v: anytype) struct { AxesArray, TagsArray } {
const T = @TypeOf(v);
if (T == Shape) {
return self.parseAxes(self.tags());
}
var axes_ = AxesArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable;
if (comptime meta.isSliceOfAny(T, isAxisConvertible)) {
for (v) |d| {
axes_.appendAssumeCapacity(self.axis(d));
tags_.appendAssumeCapacity(self.tag(d));
}
return .{ axes_, tags_ };
}
if (comptime meta.isTupleOfAny(T, isAxisConvertible)) {
inline for (std.meta.fields(T)) |field| {
axes_.appendAssumeCapacity(self.axis(@field(v, field.name)));
tags_.appendAssumeCapacity(self.tag(@field(v, field.name)));
}
return .{ axes_, tags_ };
}
meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T});
}
pub fn parseTags(v: anytype) TagsArray {
const T = @TypeOf(v);
meta.assertComptime(meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T});
var tags_ = TagsArray.init(0) catch unreachable;
inline for (v) |field| {
tags_.appendAssumeCapacity(toTag(field));
}
return tags_;
}
/// Create a shape from a struct literal, eg:
/// Shape.init(.{ .h = 1024, .w = 512, .c = 3 });
/// Shape.init(.{ 1024, 512, 3 });
pub fn init(dimz: anytype, dt: DataType) Shape {
var res: Shape = .{ ._dtype = dt };
res._dims, res._tags = parseDimensions(dimz);
return res;
}
/// Creates a Shape with dims set to `.{0, 1, 2, ..., rank-1}`.
pub fn range(rank_: usize, dt: DataType) Shape {
var res: Shape = .{ ._dtype = dt };
for (0..rank_) |i| {
res._dims.append(@intCast(i)) catch {
meta.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ });
};
res._tags.append(TagUnknown) catch unreachable;
}
return res;
}
pub fn dtype(self: Shape) DataType {
return self._dtype;
}
pub fn rank(self: Shape) u4 {
self.ensureDimsAndTagsAreSync();
return self._dims.len;
}
pub fn dim(self: Shape, ax: anytype) i64 {
self.ensureDimsAndTagsAreSync();
return self._dims.get(self.axis(ax));
}
pub fn dims(self: *const Shape) []const i64 {
self.ensureDimsAndTagsAreSync();
return self._dims.constSlice();
}
fn isAxisConvertible(comptime T: type) bool {
return meta.isInteger(T) or isTagConvertible(T);
}
fn isTagConvertible(comptime T: type) bool {
return switch (T) {
EnumLiteral => true,
std.builtin.Type.StructField => true,
Tag => true,
else => false,
};
}
fn toTag(v: anytype) Tag {
const T = @TypeOf(v);
return switch (T) {
EnumLiteral => @tagName(v).ptr,
std.builtin.Type.StructField => v.name.ptr,
Tag => v,
else => meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
};
}
inline fn ensureDimsAndTagsAreSync(self: Shape) void {
meta.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len });
}
pub fn tag(self: Shape, ax: anytype) Tag {
self.ensureDimsAndTagsAreSync();
return self._tags.get(self.axis(ax));
}
/// Returns a printable name for a given axis.
/// Either the tag itself, or a digit if it's not tagged.
pub fn debugTag(self: Shape, ax: usize) []const u8 {
const t = self.tag(ax);
if (t != TagUnknown) return std.mem.span(t);
return "01234567"[ax .. ax + 1];
}
pub fn setTag(self: Shape, ax: anytype, tag_: anytype) Shape {
var res = self;
res._tags.set(self.axis(ax), toTag(tag_));
return res;
}
pub fn tags(self: *const Shape) []const Tag {
self.ensureDimsAndTagsAreSync();
return self._tags.constSlice();
}
pub fn hasTag(self: Shape, tag_: anytype) ?u3 {
return self.axisFromTagMaybe(toTag(tag_));
}
pub fn hasTags(self: Shape, tagz: anytype) bool {
const T = @TypeOf(tagz);
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
for (tagz) |t| {
if (self.hasTag(t) == null) {
return false;
}
}
return true;
}
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
inline for (tagz) |t| {
if (self.hasTag(t) == null) {
return false;
}
}
return true;
}
meta.compileError("Expected tuple of tags, got {any}", .{T});
}
pub fn isFullyTagged(self: Shape) bool {
for (self._tags.constSlice()) |t| {
if (t == TagUnknown) return false;
}
return true;
}
pub fn axis(self: Shape, axis_: anytype) u3 {
self.ensureDimsAndTagsAreSync();
const T = @TypeOf(axis_);
if (comptime meta.isInteger(T)) {
return self.axisFromInt(@intCast(axis_));
}
if (comptime isTagConvertible(T)) {
return self.axisFromTag(toTag(axis_));
}
meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T});
}
pub fn axes(self: Shape, axes_: anytype) AxesArray {
self.ensureDimsAndTagsAreSync();
const T = @TypeOf(axes_);
if (T == Shape) {
return self.axes(axes_.tags());
}
var res = AxesArray.init(0) catch unreachable;
if (comptime meta.isSliceOfAny(T, meta.isInteger) or meta.isSliceOf(T, Tag)) {
for (axes_) |ax| {
res.appendAssumeCapacity(self.axis(ax));
}
return res;
}
if (comptime meta.isStruct(T)) {
inline for (std.meta.fields(T)) |field| {
res.appendAssumeCapacity(self.axis(@field(axes_, field.name)));
}
return res;
}
meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
}
fn axisFromInt(self: Shape, d: isize) u3 {
2023-02-14 13:52:49 +00:00
const rk: i8 = self.rank();
if (d < -rk or d > rk) {
meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
}
2023-02-14 13:52:49 +00:00
return if (d < 0)
@intCast(d + rk)
else
@intCast(d);
}
fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 {
if (d == TagUnknown) {
return null;
}
if (@inComptime()) {
for (0.., self.tags()) |tagIndex, t| {
const a: []const u8 = std.mem.span(t);
const b: []const u8 = std.mem.span(d);
if (std.mem.eql(u8, a, b)) {
return @intCast(tagIndex);
}
}
return null;
}
if (std.mem.indexOfScalar(Tag, self.tags(), d)) |d_| {
return @intCast(d_);
}
return null;
}
fn axisFromTag(self: Shape, d: Tag) u3 {
meta.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self });
return self.axisFromTagMaybe(d) orelse {
meta.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d });
};
}
test axis {
try testing.expectEqual(1, Shape.init(.{ 5, 2 }, .f32).axis(1));
try testing.expectEqual(1, Shape.init(.{ 5, 2 }, .f32).axis(-1));
try testing.expectEqual(1, Shape.init(.{ .a = 5, .b = 2 }, .f32).axis(.b));
}
/// The number of element inside a Tensor described by this shape.
pub fn count(self: Shape) usize {
var res: i64 = 1;
for (self.dims()) |d| {
meta.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self});
res *= d;
}
return @intCast(res);
}
/// Total size in bytes needed to represent this shape.
pub fn byteSize(self: Shape) usize {
return self.dtype().sizeOf() * self.count();
}
/// Compares the two shapes described, ignoring tagging.
pub fn eql(self: Shape, other: Shape) bool {
return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype();
}
/// Compares the two shapes described, ignoring tagging and dtype.
pub fn eqlDims(self: Shape, other: Shape) bool {
return std.mem.eql(i64, self.dims(), other.dims());
}
/// Compares the two shapes described including tags.
pub fn eqlWithTags(self: Shape, other: Shape) bool {
return self.eql(other) and std.mem.eql(Tag, self.tags(), other.tags()) and self.dtype() == other.dtype();
}
/// Format the shape.
/// Default format: "Shape({.a=10, .b=20}, dtype=.f32)"
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
pub fn format(
self: Shape,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = options;
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
for (self.dims(), 0..) |d, i| {
const prefix = if (i == 0) "" else ",";
const t = self.tag(i);
if (t != TagUnknown) {
try writer.print("{s}.{s}={d}", .{ prefix, t, d });
} else {
try writer.print("{s}{d}", .{ prefix, d });
}
if (self._sharding_info[i]) {
try writer.writeByte('!');
}
}
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
if (!bare_fmt) _ = try writer.write(")");
}
pub fn reshape(self: Shape, new_shape_: anytype) Shape {
var new_shape: Shape = .{ ._dtype = self.dtype() };
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
new_shape.inferMissingAxis(self.count());
meta.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() });
return new_shape;
}
fn inferMissingAxis(self: *Shape, n_: usize) void {
meta.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*});
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
// We can't use `self.count()` yet cause we have negative dims.
var tmp_count: i64 = 1;
for (self.dims()) |d| {
if (d > 0) {
tmp_count *= d;
}
}
const n: i64 = @intCast(n_);
// Abort, `reshape` will panic with more context.
if (@mod(n, tmp_count) != 0) {
return;
}
self._dims.set(inferred_ax, @divExact(n, tmp_count));
}
test reshape {
const x = Shape.init(.{ 2, 5, 3 }, .f32);
{
const res = x.reshape(.{ .auto, 3 });
try testing.expectEqualSlices(i64, &.{ 10, 3 }, res.dims());
}
{
const res = x.reshape(.{ 10, -1 });
try testing.expectEqualSlices(i64, &.{ 10, 3 }, res.dims());
}
{
const res = x.reshape(.{-1});
try testing.expectEqualSlices(i64, &.{30}, res.dims());
}
}
pub fn setDim(self: Shape, ax: anytype, d: i64) Shape {
var res = self;
res._dims.set(self.axis(ax), d);
return res;
}
pub const set = setDim;
fn isAutoDim(v: anytype) bool {
return toTag(v) == toTag(.auto);
}
fn isDynDim(v: anytype) bool {
return toTag(v) == toTag(.dyn);
}
/// Inserts one ore more axes with the given dimensions, before the given axis.
/// Negative axis is interpreted wrt the current shape.
/// `.last` axis can be used to insert at the end (ie to append).
/// ```
/// .{10, 11, 12}.insert(1, 2) -> .{10, 2, 11, 12}
/// .{10, 11, 12}.insert(-1, 2) -> .{10, 11, 2, 12}
/// .{10, 11, 12}.insert(.last, 2) -> .{10, 11, 12, 2}
/// ```
pub fn insert(self: Shape, axis_: anytype, dimz: anytype) Shape {
const dims_, const tags_ = parseDimensions(dimz);
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank()
else
self.axis(axis_);
var res = self;
res._dims.insertSlice(ax, dims_.constSlice()) catch unreachable;
res._tags.insertSlice(ax, tags_.constSlice()) catch unreachable;
return res;
}
test insert {
try testing.expectEqualSlices(i64, &.{ 10, 1, 11, 12 }, Shape.init(.{ 10, 11, 12 }, .f32).insert(1, .{1}).dims());
try testing.expectEqualSlices(i64, &.{ 10, 11, 12, 1, 13 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).insert(-1, .{1}).dims());
try testing.expectEqualSlices(i64, &.{ 10, 11, 12, 13, 1 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).insert(.last, .{1}).dims());
}
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
meta.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self});
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank()
else
self.axis(axis_);
var res = self;
res._dims.insert(ax, d) catch unreachable;
res._tags.insert(ax, toTag(tag_)) catch unreachable;
return res;
}
pub fn append(self: Shape, v: anytype) Shape {
var res = self;
const dims_, const tags_ = parseDimensions(v);
res._dims.appendSliceAssumeCapacity(dims_.constSlice());
res._tags.appendSliceAssumeCapacity(tags_.constSlice());
return res;
}
test append {
try testing.expectEqualSlices(
i64,
&.{ 10, 11, 12, 1 },
Shape.init(.{ 10, 11, 12 }, .f32).append(.{1}).dims(),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 11, .c = 12 }, .f32).eqlWithTags(
Shape.init(.{ .a = 10, .b = 11 }, .f32).append(.{ .c = 12 }),
),
);
}
pub fn appendDim(self: Shape, d: i64, tag_: ?Tag) Shape {
var res = self;
res._dims.appendAssumeCapacity(d);
res._tags.appendAssumeCapacity(if (tag_) |t| t else TagUnknown);
return res;
}
pub fn remove(self: Shape, axis_: anytype) Shape {
var res = self;
const a = self.axis(axis_);
_ = res._dims.orderedRemove(a);
_ = res._tags.orderedRemove(a);
return res;
}
pub const drop = remove;
test remove {
try std.testing.expectEqualSlices(i64, &.{ 10, 12 }, Shape.init(.{ 10, 11, 12 }, .f32).remove(1).dims());
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
}
pub fn transpose(self: Shape, permutations: anytype) Shape {
std.debug.assert(self.rank() == permutations.len);
const permutations_ = self.axes(permutations);
var res = self;
for (permutations_.constSlice(), 0..) |permutation, i| {
res._dims.set(i, self.dim(permutation));
res._tags.set(i, self.tag(permutation));
}
return res;
}
test transpose {
try testing.expect(
Shape.init(.{ 12, 11, 10 }, .f32).eql(
Shape.init(.{ 10, 11, 12 }, .f32).transpose(.{ 2, 1, 0 }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .c = 12, .b = 11, .d = 13 }, .f32).eqlWithTags(
Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13 }, .f32).transpose(.{ 0, 2, 1, 3 }),
),
);
}
/// Tag each ax of this shape with tags from a tuple.
pub fn withTags(self: Shape, tagz: anytype) Shape {
const T = @TypeOf(tagz);
if (T == Shape) {
return self.withTags(tagz.tags());
}
var res = self;
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz });
for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz });
inline for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
}
test withTags {
{
const tagged = Shape.init(.{ 0, 1 }, .f32).withTags(.{ .a, .b });
try testing.expectEqual(0, tagged.axis(.a));
try testing.expectEqual(1, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2 }, .f32).withTags(.{ ._, .a, .b });
try testing.expectEqual(1, tagged.axis(.a));
try testing.expectEqual(2, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3 }, .f32).withTags(.{ ._, ._, .a, .b });
try testing.expectEqual(2, tagged.axis(.a));
try testing.expectEqual(3, tagged.axis(.b));
}
}
/// Tag the last axes of this shape with tags from a tuple.
pub fn withPartialTags(self: Shape, tagz: anytype) Shape {
const T = @TypeOf(tagz);
if (T == Shape) {
return self.withPartialTags(tagz.tags());
}
var res = self;
if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) {
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz });
for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) {
meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz });
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
}
test withPartialTags {
{
const tagged = Shape.init(.{ 0, 1 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(0, tagged.axis(.a));
try testing.expectEqual(1, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(1, tagged.axis(.a));
try testing.expectEqual(2, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3, 4 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(3, tagged.axis(.a));
try testing.expectEqual(4, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3, 4, 5, 6 }, .f32).withPartialTags(.{ .a, .b, .c });
try testing.expectEqual(4, tagged.axis(.a));
try testing.expectEqual(5, tagged.axis(.b));
try testing.expectEqual(6, tagged.axis(.c));
}
}
pub fn withDtype(self: Shape, dt: DataType) Shape {
var res = self;
res._dtype = dt;
return res;
}
pub fn withSharding(self: Shape, axes_: anytype) Shape {
var res = self;
// Reset sharding.
res._sharding_info = @splat(false);
for (self.axes(axes_).constSlice()) |ax| {
res._sharding_info[ax] = true;
}
return res;
}
/// Renames some of the tags in this shape.
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
pub fn rename(self: Shape, renames: anytype) Shape {
const T = @TypeOf(renames);
meta.assertComptime(meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
}
return res;
}
test rename {
{
const tagged = Shape.init(.{ .a = 0, .b = 1 }, .f32).rename(.{ .a = .x, .b = .y });
try testing.expectEqual(0, tagged.dim(.x));
try testing.expectEqual(1, tagged.dim(.y));
}
{
const tagged = Shape.init(.{ .a = 0, .b = 1, .c = 2 }, .f32).rename(.{ .a = .x, .c = .z });
try testing.expectEqual(0, tagged.dim(.x));
try testing.expectEqual(1, tagged.dim(.b));
try testing.expectEqual(2, tagged.dim(.z));
}
}
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
const base_stride = self.dtype().sizeOf();
const rk = self.rank();
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
if (rk == 0) return strides;
strides.buffer[rk - 1] = base_stride;
for (1..rk) |i| {
const j = @as(usize, rk) - 1 - i;
strides.buffer[j] = self._dims.get(j + 1) * strides.buffer[j + 1];
}
return strides;
}
/// Returns the permutation needed to transpose this shape
/// so that the given axes are contiguous.
pub fn contiguousPerm(self: Shape, axes_: anytype) AxesArray {
const axes__, _ = self.parseAxes(axes_);
var perms = AxesArray.init(0) catch unreachable;
for (0..self.rank()) |i| {
if (std.mem.indexOfScalar(u3, axes__.constSlice(), @intCast(i))) |_| {
continue;
}
perms.appendAssumeCapacity(@intCast(i));
}
perms.appendSliceAssumeCapacity(axes__.constSlice());
return perms;
}
test contiguousPerm {
const abc = Shape.init(.{ .a = 10, .b = 11, .c = 12 }, .f32);
try testing.expect(
Shape.init(.{ .b = 11, .c = 12, .a = 10 }, .f32).eqlWithTags(
abc.transpose(abc.contiguousPerm(.{.a}).constSlice()),
),
);
try testing.expect(
Shape.init(.{ .c = 12, .b = 11, .a = 10 }, .f32).eqlWithTags(
abc.transpose(abc.contiguousPerm(.{ .b, .a }).constSlice()),
),
);
const abcd = Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13 }, .f32);
try testing.expect(
Shape.init(.{ .a = 10, .c = 12, .b = 11, .d = 13 }, .f32).eqlWithTags(
abcd.transpose(abcd.contiguousPerm(.{ .b, .d }).constSlice()),
),
);
const abcde = Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13, .e = 14 }, .f32);
try testing.expect(
Shape.init(.{ .a = 10, .b = 11, .d = 13, .c = 12, .e = 14 }, .f32).eqlWithTags(
abcde.transpose(abcde.contiguousPerm(.{ .b, .d, .c, .e }).constSlice()),
),
);
}
/// Splits the given axis in several axes.
/// eg: `Shape.init(.{ .a = 10, .b = 3 }).split(.a, .{.a1 = 5, .a2 = 2}); -> .{.a1 = 5, .a2 = 2, .b = 3}`
/// The number of elements in the split shape must match the number of element
/// in the target axis.
pub fn splitAxis(self: Shape, axis_: anytype, split_shape_: anytype) Shape {
const ax = self.axis(axis_);
const dims_, const tags_ = parseDimensions(split_shape_);
var new_shape = self;
new_shape._dims.replaceRange(ax, 1, dims_.constSlice()) catch unreachable;
new_shape._tags.replaceRange(ax, 1, tags_.constSlice()) catch unreachable;
new_shape.inferMissingAxis(self.count());
return new_shape;
}
test splitAxis {
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxis(.a, .{ .a1 = 5, .a2 = 2 }),
),
);
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxis(.a, .{ .a1 = .auto, .a2 = 2 }),
),
);
}
pub fn splitAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_);
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
res = res.splitAxis(field, @field(axes_, field.name));
}
return res;
}
test splitAxes {
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxes(.{ .a = .{ .a1 = 5, .a2 = .auto } }),
),
);
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxes(.{ .a = .{ .a1 = 5, .a2 = .auto } }),
),
);
}
/// Merge the given axes into one axis.
/// eg: `Shape.init(.{.a1 = 5, .a2 = 2, .b = 3}).merge(.{ .a = .{ .a1, .a2 }); -> .{ .a = 10, .b = 3 }`
pub fn mergeAxis(self: Shape, axis_: anytype, axes_: anytype) Shape {
const axes__ = self.axes(axes_);
const first_axis = axes__.get(0);
const last_axis = axes__.get(axes__.len - 1);
var new_dim: i64 = 1;
for (axes__.constSlice(), first_axis..) |ax, counter| {
new_dim *= self.dim(ax);
meta.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ });
}
var new_shape = self;
new_shape._dims.set(first_axis, new_dim);
new_shape._dims.replaceRange(first_axis + 1, self.dims()[first_axis + 1 ..].len, self.dims()[last_axis + 1 ..]) catch unreachable;
new_shape._tags.set(first_axis, if (comptime isTagConvertible(@TypeOf(axis_))) toTag(axis_) else TagUnknown);
new_shape._tags.replaceRange(first_axis + 1, self.dims()[first_axis + 1 ..].len, self.tags()[last_axis + 1 ..]) catch unreachable;
return new_shape;
}
test mergeAxis {
try testing.expect(
Shape.init(.{ .a = 10, .b = 3, .c = 4 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3, .c = 4 }, .f32).mergeAxis(.a, .{ .a1, .a2 }),
),
);
try testing.expect(
Shape.init(.{ .a = 5, .c = 6 }, .f32).eqlWithTags(
Shape.init(.{ .a = 5, .b1 = 2, .b2 = 3 }, .f32).mergeAxis(.c, .{ .b1, .b2 }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(.a, .{ toTag(.a1), toTag(.a2) }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(toTag(.a), @as([]const Tag, &.{ toTag(.a1), toTag(.a2) })),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(.a, @as([]const usize, &.{ 0, 1 })),
),
);
}
pub fn mergeAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_);
meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
meta.assertComptime(meta.isTupleOfAny(field.type, isAxisConvertible) or meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type});
res = res.mergeAxis(field, @field(axes_, field.name));
}
return res;
}
test mergeAxes {
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ .a1, .a2 } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ toTag(.a1), toTag(.a2) } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ 0, 1 } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = @as([]const usize, &.{ 0, 1 }) }),
),
);
}
fn intersectTags(a: []const Tag, b: []const Tag) TagsArray {
var res = TagsArray.init(0) catch unreachable;
for (a) |tag_| {
if (std.mem.indexOfScalar(Tag, b, tag_)) {
res.appendAssumeCapacity(tag_);
}
}
return res;
}
pub fn parseStruct(T: type, v: anytype) struct { std.BoundedArray(T, MAX_RANK), TagsArray } {
const V = @TypeOf(v);
var vals_: std.BoundedArray(T, MAX_RANK) = .{};
var tags_: TagsArray = .{};
if (comptime meta.isSliceOf(V, T)) {
for (v) |d| {
vals_.appendAssumeCapacity(d);
}
return .{ vals_, tags_ };
}
if (comptime meta.isStruct(V)) {
const fields = std.meta.fields(V);
meta.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK });
inline for (fields) |field| {
const fv = @field(v, field.name);
vals_.appendAssumeCapacity(fv);
if (!comptime meta.isTuple(V)) {
tags_.appendAssumeCapacity(toTag(field));
}
}
return .{ vals_, tags_ };
}
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
}
test parseStruct {
const vals_, const tags_ = parseStruct(f32, .{ .a = 0.1, .b = 1.2 });
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2 }, vals_.constSlice());
try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice());
}
/// Parses a struct literal into a list of options for each axes.
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) {
const V = @TypeOf(options);
var res: std.BoundedArray(T, MAX_RANK) = .{};
if (comptime meta.isSliceOf(V, T)) {
meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
for (options) |d| {
res.appendAssumeCapacity(d);
}
}
if (comptime meta.isStruct(V)) {
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
const fields = std.meta.fields(V);
meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
inline for (fields) |field| {
const a = self.axis(field);
res.buffer[a] = @field(options, field.name);
}
return res;
}
meta.compileError("parseStruct expects struct or tuple, got {}", .{V});
}
test parseAxesOptions {
const shape = Shape.init(.{ .a = 10, .b = 20, .c = 30 }, .u8);
const scaling = shape.parseAxesOptions(f32, .{ .b = 1.2, .a = 0.1 }, 1.0);
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2, 1.0 }, scaling.constSlice());
}
test "comptimeShape" {
comptime {
const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
try std.testing.expectEqual(3, s.rank());
try std.testing.expectEqual(4 * 5 * 6 * 7, s.byteSize());
try std.testing.expectEqual(1, s.axis(.b));
}
// comptime only the shape
{
const s = comptime Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
try std.testing.expectEqual(3, s.rank());
try std.testing.expectEqual(4 * 5 * 6 * 7, s.byteSize());
try std.testing.expectEqual(1, s.axis(.b));
}
}
};