Radix/zml/shape.zig

1098 lines
40 KiB
Zig
Raw Permalink Normal View History

const std = @import("std");
const testing = std.testing;
const builtin = @import("builtin");
const stdx = @import("stdx");
const DataType = @import("dtype.zig").DataType;
const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.shape);
test {
std.testing.refAllDecls(Shape);
}
/// Represent the shape of a tensor.
pub const Shape = struct {
pub const MAX_RANK: u8 = 8;
pub const Tag = [*:0]const u8;
pub const TagUnknown = "_".ptr;
const TagLast = "last".ptr;
pub const DimsArray = stdx.BoundedArray(i64, MAX_RANK);
pub const TagsArray = stdx.BoundedArray(Tag, MAX_RANK);
pub const AxesArray = stdx.BoundedArray(u3, MAX_RANK);
pub const ShardingInfo = @Vector(MAX_RANK, bool);
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
_dtype: DataType,
_dims: DimsArray = .{},
_tags: TagsArray = UnknownTags,
_sharding_info: ShardingInfo = @splat(false),
pub fn parseDimensions(v: anytype) struct { DimsArray, TagsArray } {
const T = @TypeOf(v);
if (T == Shape) {
return .{ v._dims, v._tags };
}
if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger)) {
var dims_ = DimsArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable;
for (v) |d| {
dims_.appendAssumeCapacity(@intCast(d));
tags_.appendAssumeCapacity(TagUnknown);
}
return .{ dims_, tags_ };
}
if (comptime stdx.meta.isStruct(T)) {
var dims_: DimsArray = .{};
var tags_: TagsArray = .{};
inline for (std.meta.fields(T)) |field| {
const fv = @field(v, field.name);
if (comptime stdx.meta.isInteger(field.type)) {
dims_.appendAssumeCapacity(@intCast(fv));
} else if (@TypeOf(fv) == EnumLiteral and comptime isAutoDim(fv)) {
dims_.appendAssumeCapacity(-1);
} else {
stdx.debug.compileError("Field {s} should be an integer or an auto dimension, got {}", .{ field.name, field.type });
}
if (comptime stdx.meta.isTuple(T)) {
tags_.appendAssumeCapacity(TagUnknown);
} else {
tags_.appendAssumeCapacity(toTag(field));
}
}
return .{ dims_, tags_ };
}
stdx.debug.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T});
}
test parseDimensions {
const ref = Shape.init(.{ .a = 0, .b = 1 }, .f32);
const dims_, const tags_ = parseDimensions(.{ .a = 0, .b = 1 });
try testing.expectEqualSlices(i64, ref.dims(), dims_.constSlice());
try testing.expectEqualSlices(Tag, ref.tags(), tags_.constSlice());
}
pub fn parseAxes(self: Shape, v: anytype) struct { AxesArray, TagsArray } {
const T = @TypeOf(v);
if (T == Shape) {
return self.parseAxes(self.tags());
}
var axes_ = AxesArray.init(0) catch unreachable;
var tags_ = TagsArray.init(0) catch unreachable;
if (comptime stdx.meta.isSliceOfAny(T, isAxisConvertible)) {
for (v) |d| {
axes_.appendAssumeCapacity(self.axis(d));
tags_.appendAssumeCapacity(self.tag(d));
}
return .{ axes_, tags_ };
}
if (comptime stdx.meta.isTupleOfAny(T, isAxisConvertible)) {
inline for (std.meta.fields(T)) |field| {
axes_.appendAssumeCapacity(self.axis(@field(v, field.name)));
tags_.appendAssumeCapacity(self.tag(@field(v, field.name)));
}
return .{ axes_, tags_ };
}
stdx.debug.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T});
}
pub fn parseTags(v: anytype) TagsArray {
const T = @TypeOf(v);
stdx.debug.assertComptime(stdx.meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T});
var tags_ = TagsArray.init(0) catch unreachable;
inline for (v) |field| {
tags_.appendAssumeCapacity(toTag(field));
}
return tags_;
}
/// Create a shape from a struct literal, eg:
/// Shape.init(.{ .h = 1024, .w = 512, .c = 3 });
/// Shape.init(.{ 1024, 512, 3 });
pub fn init(dimz: anytype, dt: DataType) Shape {
var res: Shape = .{ ._dtype = dt };
res._dims, res._tags = parseDimensions(dimz);
return res;
}
pub fn scalar(dt: DataType) Shape {
return .{ ._dtype = dt };
}
/// Creates a Shape with dims set to `.{0, 1, 2, ..., rank-1}`.
pub fn range(rank_: usize, dt: DataType) Shape {
var res: Shape = .{ ._dtype = dt };
for (0..rank_) |i| {
res._dims.append(@intCast(i)) catch {
stdx.debug.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ });
};
res._tags.append(TagUnknown) catch unreachable;
}
return res;
}
pub fn dtype(self: Shape) DataType {
return self._dtype;
}
pub fn rank(self: Shape) u4 {
self.ensureDimsAndTagsAreSync();
return @intCast(self._dims.len);
}
pub fn dim(self: Shape, ax: anytype) i64 {
self.ensureDimsAndTagsAreSync();
return self._dims.get(self.axis(ax));
}
pub fn dims(self: *const Shape) []const i64 {
self.ensureDimsAndTagsAreSync();
return self._dims.constSlice();
}
fn isAxisConvertible(comptime T: type) bool {
return stdx.meta.isInteger(T) or isTagConvertible(T);
}
fn isTagConvertible(comptime T: type) bool {
return switch (T) {
EnumLiteral => true,
std.builtin.Type.StructField => true,
Tag => true,
else => false,
};
}
fn toTag(v: anytype) Tag {
const T = @TypeOf(v);
return switch (T) {
EnumLiteral => @tagName(v).ptr,
std.builtin.Type.StructField => v.name.ptr,
Tag => v,
else => stdx.debug.compileError("Shape tag should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
};
}
inline fn ensureDimsAndTagsAreSync(self: Shape) void {
stdx.debug.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len });
}
pub fn tag(self: Shape, ax: anytype) Tag {
self.ensureDimsAndTagsAreSync();
return self._tags.get(self.axis(ax));
}
/// Returns a printable name for a given axis.
/// Either the tag itself, or a digit if it's not tagged.
pub fn debugTag(self: Shape, ax: usize) []const u8 {
const t = self.tag(ax);
if (t != TagUnknown) return std.mem.span(t);
return "01234567"[ax .. ax + 1];
}
pub fn setTag(self: Shape, ax: anytype, tag_: anytype) Shape {
var res = self;
res._tags.set(self.axis(ax), toTag(tag_));
return res;
}
pub fn tags(self: *const Shape) []const Tag {
self.ensureDimsAndTagsAreSync();
return self._tags.constSlice();
}
pub fn hasTag(self: Shape, tag_: anytype) ?u3 {
return self.axisFromTagMaybe(toTag(tag_));
}
pub fn hasTags(self: Shape, tagz: anytype) bool {
const T = @TypeOf(tagz);
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
for (tagz) |t| {
if (self.hasTag(t) == null) {
return false;
}
}
return true;
}
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
inline for (tagz) |t| {
if (self.hasTag(t) == null) {
return false;
}
}
return true;
}
stdx.debug.compileError("Expected tuple of tags, got {any}", .{T});
}
pub fn isFullyTagged(self: Shape) bool {
for (self._tags.constSlice()) |t| {
if (t == TagUnknown) return false;
}
return true;
}
pub fn axis(self: Shape, axis_: anytype) u3 {
self.ensureDimsAndTagsAreSync();
const T = @TypeOf(axis_);
if (comptime stdx.meta.isInteger(T)) {
return self.axisFromInt(@intCast(axis_));
}
if (comptime isTagConvertible(T)) {
return self.axisFromTag(toTag(axis_));
}
stdx.debug.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T});
}
pub fn axes(self: Shape, axes_: anytype) AxesArray {
self.ensureDimsAndTagsAreSync();
const T = @TypeOf(axes_);
if (T == Shape) {
return self.axes(axes_.tags());
}
var res = AxesArray.init(0) catch unreachable;
if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger) or stdx.meta.isSliceOf(T, Tag)) {
for (axes_) |ax| {
res.appendAssumeCapacity(self.axis(ax));
}
return res;
}
if (comptime stdx.meta.isStruct(T)) {
inline for (std.meta.fields(T)) |field| {
res.appendAssumeCapacity(self.axis(@field(axes_, field.name)));
}
return res;
}
stdx.debug.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
}
fn axisFromInt(self: Shape, a: isize) u3 {
2023-02-14 13:52:49 +00:00
const rk: i8 = self.rank();
if (a < -rk or a > rk) {
stdx.debug.panic("Tensor {f} doesn't have dimension: {d}", .{ self, a });
}
return if (a < 0)
@intCast(a + rk)
2023-02-14 13:52:49 +00:00
else
@intCast(a);
}
fn axisFromTagMaybe(self: Shape, t: Tag) ?u3 {
if (t == TagUnknown) return null;
if (axisFromLiteralInt(t)) |ax| return ax;
if (@inComptime()) {
// At comptime two duplicated strings may have two different representations
const t_bytes: []const u8 = std.mem.span(t);
for (self.tags(), 0..) |self_tag, ax| {
if (std.mem.eql(u8, t_bytes, std.mem.span(self_tag))) {
return @truncate(ax);
}
}
return null;
}
// But at runtime the comptime strings have been deduplicated and ptr match is enough.
if (std.mem.indexOfScalar(Tag, self.tags(), t)) |ax| {
return @truncate(ax);
}
return null;
}
/// Handle .{ ._0 = x } syntax.
fn axisFromLiteralInt(t: Tag) ?u3 {
// match .{ '_', '0-9', null }
if (t[0] == '_' and t[1] >= '0' and t[1] < '8' and t[2] == 0) {
return @intCast(t[1] - '0');
}
return null;
}
fn axisFromTag(self: Shape, d: Tag) u3 {
stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {f}", .{ d, self });
return self.axisFromTagMaybe(d) orelse {
stdx.debug.panic("Tensor {f} doesn't have dimension with tag: {s}", .{ self, d });
};
}
test axis {
try testing.expectEqual(1, Shape.init(.{ 5, 2 }, .f32).axis(1));
try testing.expectEqual(1, Shape.init(.{ 5, 2 }, .f32).axis(-1));
try testing.expectEqual(1, Shape.init(.{ .a = 5, .b = 2 }, .f32).axis(.b));
}
/// The number of element inside a Tensor described by this shape.
pub fn count(self: Shape) usize {
var res: i64 = 1;
for (self.dims()) |d| {
stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {f}", .{self});
res *= d;
}
return @intCast(res);
}
/// Total size in bytes needed to represent this shape.
pub fn byteSize(self: Shape) usize {
return self.dtype().sizeOf() * self.count();
}
/// Compares the two shapes described, ignoring tagging.
pub fn eql(self: Shape, other: Shape) bool {
return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype();
}
/// Compares the two shapes described, ignoring tagging and dtype.
pub fn eqlDims(self: Shape, other: Shape) bool {
return std.mem.eql(i64, self.dims(), other.dims());
}
/// Compares the two shapes described including tags.
pub fn eqlWithTags(self: Shape, other: Shape) bool {
return self.eql(other) and std.mem.eql(Tag, self.tags(), other.tags()) and self.dtype() == other.dtype();
}
/// Format the shape.
/// Default format: "Shape({.a=10, .b=20}, dtype=.f32)"
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
pub fn format(self: Shape, writer: *std.Io.Writer) !void {
_ = try writer.writeByte('{');
var need_comma = false;
for (self.dims(), 0..) |d, i| {
if (need_comma) try writer.writeByte(',');
const t = self.tag(i);
if (t != TagUnknown) {
try writer.print("{s}={d}", .{ t, d });
} else {
try writer.print("{d}", .{d});
}
if (self._sharding_info[i]) {
try writer.writeByte('!');
}
need_comma = true;
}
if (need_comma) try writer.writeByte(',');
_ = try writer.write(@tagName(self.dtype()));
_ = try writer.writeByte('}');
}
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
pub fn canBroadcastTo(self: Shape, other: Shape) bool {
// Already the right shape
if (std.mem.eql(i64, self.dims(), other.dims())) return true;
if (std.mem.eql(Tag, self.tags(), other.tags())) {
for (0..self.rank()) |i| {
if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false;
}
return true;
}
// Non ambiguous broadcasting
// TODO: broad is error prone because of this:
// it will happily broadcast .{ .a = 10, .b = 1 } to .{ .b = 10, .a = 5 }
if (self.rank() == 0 or self.rank() == other.rank()) {
for (0..self.rank()) |i| {
if (self.dim(i) != 1 and self.dim(i) != other.dim(i)) return false;
}
return true;
}
for (self.dims(), self.tags()) |d, t| {
// TODO this is also wrong when the axes are in different order in the two shapes.
const other_ax = other.hasTag(t) orelse return false;
if (d != 1 and d != other.dim(other_ax)) return false;
}
return true;
}
pub fn reshape(self: Shape, new_shape_: anytype) Shape {
var new_shape: Shape = .{ ._dtype = self.dtype() };
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
new_shape.inferMissingAxis(self.count());
stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {any} to {any}", .{ self.dims(), new_shape.dims() });
return new_shape;
}
fn inferMissingAxis(self: *Shape, n_: usize) void {
stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {f}", .{self.*});
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
// We can't use `self.count()` yet cause we have negative dims.
var tmp_count: i64 = 1;
for (self.dims()) |d| {
if (d > 0) {
tmp_count *= d;
}
}
const n: i64 = @intCast(n_);
// Abort, `reshape` will panic with more context.
if (@mod(n, tmp_count) != 0) {
return;
}
self._dims.set(inferred_ax, @divExact(n, tmp_count));
}
test reshape {
const x = Shape.init(.{ 2, 5, 3 }, .f32);
{
const res = x.reshape(.{ .auto, 3 });
try testing.expectEqualSlices(i64, &.{ 10, 3 }, res.dims());
}
{
const res = x.reshape(.{ 10, -1 });
try testing.expectEqualSlices(i64, &.{ 10, 3 }, res.dims());
}
{
const res = x.reshape(.{-1});
try testing.expectEqualSlices(i64, &.{30}, res.dims());
}
}
pub fn setDim(self: Shape, ax: anytype, d: i64) Shape {
var res = self;
res._dims.set(self.axis(ax), d);
return res;
}
pub const set = setDim;
fn isAutoDim(v: anytype) bool {
return toTag(v) == toTag(.auto);
}
fn isDynDim(v: anytype) bool {
return toTag(v) == toTag(.dyn);
}
/// Inserts one ore more axes with the given dimensions, before the given axis.
/// Negative axis is interpreted wrt the current shape.
/// `.last` axis can be used to insert at the end (ie to append).
/// ```
/// .{10, 11, 12}.insert(1, 2) -> .{10, 2, 11, 12}
/// .{10, 11, 12}.insert(-1, 2) -> .{10, 11, 2, 12}
/// .{10, 11, 12}.insert(.last, 2) -> .{10, 11, 12, 2}
/// ```
pub fn insert(self: Shape, axis_: anytype, dimz: anytype) Shape {
const dims_, const tags_ = parseDimensions(dimz);
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank()
else
self.axis(axis_);
var res = self;
res._dims.insertSlice(ax, dims_.constSlice()) catch unreachable;
res._tags.insertSlice(ax, tags_.constSlice()) catch unreachable;
return res;
}
test insert {
try testing.expectEqualSlices(i64, &.{ 10, 1, 11, 12 }, Shape.init(.{ 10, 11, 12 }, .f32).insert(1, .{1}).dims());
try testing.expectEqualSlices(i64, &.{ 10, 11, 12, 1, 13 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).insert(-1, .{1}).dims());
try testing.expectEqualSlices(i64, &.{ 10, 11, 12, 13, 1 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).insert(.last, .{1}).dims());
}
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {f}, it's already at max rank.", .{self});
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank()
else
self.axis(axis_);
var res = self;
res._dims.insert(ax, d) catch unreachable;
res._tags.insert(ax, toTag(tag_)) catch unreachable;
return res;
}
pub fn append(self: Shape, v: anytype) Shape {
var res = self;
const dims_, const tags_ = parseDimensions(v);
res._dims.appendSliceAssumeCapacity(dims_.constSlice());
res._tags.appendSliceAssumeCapacity(tags_.constSlice());
return res;
}
test append {
try testing.expectEqualSlices(
i64,
&.{ 10, 11, 12, 1 },
Shape.init(.{ 10, 11, 12 }, .f32).append(.{1}).dims(),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 11, .c = 12 }, .f32).eqlWithTags(
Shape.init(.{ .a = 10, .b = 11 }, .f32).append(.{ .c = 12 }),
),
);
}
pub fn appendDim(self: Shape, d: i64, tag_: ?Tag) Shape {
var res = self;
res._dims.appendAssumeCapacity(d);
res._tags.appendAssumeCapacity(if (tag_) |t| t else TagUnknown);
return res;
}
pub fn remove(self: Shape, axis_: anytype) Shape {
var res = self;
const a = self.axis(axis_);
_ = res._dims.orderedRemove(a);
_ = res._tags.orderedRemove(a);
return res;
}
pub const drop = remove;
test remove {
try std.testing.expectEqualSlices(i64, &.{ 10, 12 }, Shape.init(.{ 10, 11, 12 }, .f32).remove(1).dims());
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
}
pub fn removeMany(self: Shape, axes_: anytype) Shape {
var to_remove = self.axes(axes_);
if (to_remove.len == 0) return self;
std.mem.sort(u3, to_remove.slice(), {}, std.sort.asc(u3));
var sh: Shape = self;
const rk = self.rank();
var res_ax: u32 = 0;
for (0..rk) |ax| {
if (std.mem.indexOfScalar(u3, to_remove.constSlice(), @intCast(ax))) |_| {
continue;
}
sh._dims.buffer[res_ax] = self._dims.buffer[ax];
sh._tags.buffer[res_ax] = self._tags.buffer[ax];
res_ax += 1;
}
sh._dims.len = rk - to_remove.len;
sh._tags.len = rk - to_remove.len;
return sh;
}
test removeMany {
try std.testing.expectEqualSlices(
i64,
&.{12},
Shape.init(.{ 10, 11, 12 }, .f32).removeMany(.{ 0, 1 }).dims(),
);
try std.testing.expectEqualSlices(
i64,
&.{ 10, 11 },
Shape.init(.{ 10, 11, 12, 13 }, .f32).removeMany(.{ -1, -2 }).dims(),
);
}
pub fn transpose(self: Shape, permutations: anytype) Shape {
std.debug.assert(self.rank() == permutations.len);
const permutations_ = self.axes(permutations);
var res = self;
for (permutations_.constSlice(), 0..) |permutation, i| {
res._dims.set(i, self.dim(permutation));
res._tags.set(i, self.tag(permutation));
}
return res;
}
test transpose {
try testing.expect(
Shape.init(.{ 12, 11, 10 }, .f32).eql(
Shape.init(.{ 10, 11, 12 }, .f32).transpose(.{ 2, 1, 0 }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .c = 12, .b = 11, .d = 13 }, .f32).eqlWithTags(
Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13 }, .f32).transpose(.{ 0, 2, 1, 3 }),
),
);
}
/// Tag each ax of this shape with tags from a tuple.
pub fn withTags(self: Shape, tagz: anytype) Shape {
const T = @TypeOf(tagz);
if (T == Shape) {
return self.withTags(tagz.tags());
}
var res = self;
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {any}", .{ self, tagz });
for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {}", .{ self, tagz });
inline for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
stdx.debug.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
}
test withTags {
{
const tagged = Shape.init(.{ 0, 1 }, .f32).withTags(.{ .a, .b });
try testing.expectEqual(0, tagged.axis(.a));
try testing.expectEqual(1, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2 }, .f32).withTags(.{ ._, .a, .b });
try testing.expectEqual(1, tagged.axis(.a));
try testing.expectEqual(2, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3 }, .f32).withTags(.{ ._, ._, .a, .b });
try testing.expectEqual(2, tagged.axis(.a));
try testing.expectEqual(3, tagged.axis(.b));
}
}
/// Tag the last axes of this shape with tags from a tuple.
pub fn withPartialTags(self: Shape, tagz: anytype) Shape {
const T = @TypeOf(tagz);
if (T == Shape) {
return self.withPartialTags(tagz.tags());
}
var res = self;
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {any}", .{ self, tagz });
for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {}", .{ self, tagz });
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_));
}
return res;
}
stdx.debug.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)});
}
test withPartialTags {
{
const tagged = Shape.init(.{ 0, 1 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(0, tagged.axis(.a));
try testing.expectEqual(1, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(1, tagged.axis(.a));
try testing.expectEqual(2, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3, 4 }, .f32).withPartialTags(.{ .a, .b });
try testing.expectEqual(3, tagged.axis(.a));
try testing.expectEqual(4, tagged.axis(.b));
}
{
const tagged = Shape.init(.{ 0, 1, 2, 3, 4, 5, 6 }, .f32).withPartialTags(.{ .a, .b, .c });
try testing.expectEqual(4, tagged.axis(.a));
try testing.expectEqual(5, tagged.axis(.b));
try testing.expectEqual(6, tagged.axis(.c));
}
}
pub fn withDtype(self: Shape, dt: DataType) Shape {
var res = self;
res._dtype = dt;
return res;
}
pub fn withSharding(self: Shape, axes_: anytype) Shape {
var res = self;
// Reset sharding.
res._sharding_info = @splat(false);
for (self.axes(axes_).constSlice()) |ax| {
res._sharding_info[ax] = true;
}
return res;
}
/// Renames some of the tags in this shape.
/// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 };
pub fn rename(self: Shape, renames: anytype) Shape {
const T = @TypeOf(renames);
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
const new_field = @field(renames, field.name);
stdx.debug.assert(self.hasTag(new_field) == null, "{f}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
res._tags.set(self.axis(field), toTag(new_field));
}
return res;
}
test rename {
{
const tagged = Shape.init(.{ .a = 0, .b = 1 }, .f32).rename(.{ .a = .x, .b = .y });
try testing.expectEqual(0, tagged.dim(.x));
try testing.expectEqual(1, tagged.dim(.y));
}
{
const tagged = Shape.init(.{ .a = 0, .b = 1, .c = 2 }, .f32).rename(.{ .a = .x, .c = .z });
try testing.expectEqual(0, tagged.dim(.x));
try testing.expectEqual(1, tagged.dim(.b));
try testing.expectEqual(2, tagged.dim(.z));
}
}
pub fn computeStrides(self: Shape) stdx.BoundedArray(i64, MAX_RANK) {
const rk = self.rank();
var strides: stdx.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
if (rk == 0) return strides;
const V = @Vector(MAX_RANK, i64);
const rank_mask = std.simd.iota(u8, MAX_RANK) < @as(@Vector(MAX_RANK, u8), @splat(rk));
// For each axis compute the product of all following dimensions
// and the element size in bytes.
var d: V = @bitCast(self._dims.buffer);
d = @select(i64, rank_mask, d, @as(V, @splat(1)));
d = std.simd.shiftElementsLeft(d, 1, self.dtype().sizeOf());
d = std.simd.prefixScan(.Mul, -1, d);
strides.buffer = @bitCast(d);
return strides;
}
/// Returns the permutation needed to transpose this shape
/// so that the given axes are contiguous.
pub fn contiguousPerm(self: Shape, axes_: anytype) AxesArray {
const axes__, _ = self.parseAxes(axes_);
var perms = AxesArray.init(0) catch unreachable;
for (0..self.rank()) |i| {
if (std.mem.indexOfScalar(u3, axes__.constSlice(), @intCast(i))) |_| {
continue;
}
perms.appendAssumeCapacity(@intCast(i));
}
perms.appendSliceAssumeCapacity(axes__.constSlice());
return perms;
}
test contiguousPerm {
const abc = Shape.init(.{ .a = 10, .b = 11, .c = 12 }, .f32);
try testing.expect(
Shape.init(.{ .b = 11, .c = 12, .a = 10 }, .f32).eqlWithTags(
abc.transpose(abc.contiguousPerm(.{.a}).constSlice()),
),
);
try testing.expect(
Shape.init(.{ .c = 12, .b = 11, .a = 10 }, .f32).eqlWithTags(
abc.transpose(abc.contiguousPerm(.{ .b, .a }).constSlice()),
),
);
const abcd = Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13 }, .f32);
try testing.expect(
Shape.init(.{ .a = 10, .c = 12, .b = 11, .d = 13 }, .f32).eqlWithTags(
abcd.transpose(abcd.contiguousPerm(.{ .b, .d }).constSlice()),
),
);
const abcde = Shape.init(.{ .a = 10, .b = 11, .c = 12, .d = 13, .e = 14 }, .f32);
try testing.expect(
Shape.init(.{ .a = 10, .b = 11, .d = 13, .c = 12, .e = 14 }, .f32).eqlWithTags(
abcde.transpose(abcde.contiguousPerm(.{ .b, .d, .c, .e }).constSlice()),
),
);
}
/// Splits the given axis in several axes.
/// eg: `Shape.init(.{ .a = 10, .b = 3 }).split(.a, .{.a1 = 5, .a2 = 2}); -> .{.a1 = 5, .a2 = 2, .b = 3}`
/// The number of elements in the split shape must match the number of element
/// in the target axis.
pub fn splitAxis(self: Shape, axis_: anytype, split_shape_: anytype) Shape {
const ax = self.axis(axis_);
const dims_, const tags_ = parseDimensions(split_shape_);
var new_shape = self;
new_shape._dims.replaceRange(ax, 1, dims_.constSlice()) catch unreachable;
new_shape._tags.replaceRange(ax, 1, tags_.constSlice()) catch unreachable;
new_shape.inferMissingAxis(self.count());
return new_shape;
}
test splitAxis {
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxis(.a, .{ .a1 = 5, .a2 = 2 }),
),
);
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxis(.a, .{ .a1 = .auto, .a2 = 2 }),
),
);
}
pub fn splitAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_);
stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
res = res.splitAxis(field, @field(axes_, field.name));
}
return res;
}
test splitAxes {
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxes(.{ .a = .{ .a1 = 5, .a2 = .auto } }),
),
);
try testing.expect(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).eql(
Shape.init(.{ .a = 10, .b = 3 }, .f32).splitAxes(.{ .a = .{ .a1 = 5, .a2 = .auto } }),
),
);
}
/// Merge the given axes into one axis.
/// eg: `Shape.init(.{.a1 = 5, .a2 = 2, .b = 3}).merge(.{ .a = .{ .a1, .a2 }); -> .{ .a = 10, .b = 3 }`
pub fn mergeAxis(self: Shape, axis_: anytype, axes_: anytype) Shape {
const axes__ = self.axes(axes_);
const first_axis = axes__.get(0);
const last_axis = axes__.get(axes__.len - 1);
var new_dim: i64 = 1;
for (axes__.constSlice(), first_axis..) |ax, counter| {
new_dim *= self.dim(ax);
stdx.debug.assert(ax == counter, "Can't merge shape {f} along non-contiguous axes {any}", .{ self, axes_ });
}
var new_shape = self;
new_shape._dims.set(first_axis, new_dim);
new_shape._dims.replaceRange(first_axis + 1, self.dims()[first_axis + 1 ..].len, self.dims()[last_axis + 1 ..]) catch unreachable;
new_shape._tags.set(first_axis, if (comptime isTagConvertible(@TypeOf(axis_))) toTag(axis_) else TagUnknown);
new_shape._tags.replaceRange(first_axis + 1, self.dims()[first_axis + 1 ..].len, self.tags()[last_axis + 1 ..]) catch unreachable;
return new_shape;
}
test mergeAxis {
try testing.expect(
Shape.init(.{ .a = 10, .b = 3, .c = 4 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3, .c = 4 }, .f32).mergeAxis(.a, .{ .a1, .a2 }),
),
);
try testing.expect(
Shape.init(.{ .a = 5, .c = 6 }, .f32).eqlWithTags(
Shape.init(.{ .a = 5, .b1 = 2, .b2 = 3 }, .f32).mergeAxis(.c, .{ .b1, .b2 }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(.a, .{ toTag(.a1), toTag(.a2) }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(toTag(.a), @as([]const Tag, &.{ toTag(.a1), toTag(.a2) })),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxis(.a, @as([]const usize, &.{ 0, 1 })),
),
);
}
pub fn mergeAxes(self: Shape, axes_: anytype) Shape {
const T = @TypeOf(axes_);
stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T});
var res = self;
inline for (std.meta.fields(T)) |field| {
stdx.debug.assertComptime(stdx.meta.isTupleOfAny(field.type, isAxisConvertible) or stdx.meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type});
res = res.mergeAxis(field, @field(axes_, field.name));
}
return res;
}
test mergeAxes {
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ .a1, .a2 } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ toTag(.a1), toTag(.a2) } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = .{ 0, 1 } }),
),
);
try testing.expect(
Shape.init(.{ .a = 10, .b = 3 }, .f32).eqlWithTags(
Shape.init(.{ .a1 = 5, .a2 = 2, .b = 3 }, .f32).mergeAxes(.{ .a = @as([]const usize, &.{ 0, 1 }) }),
),
);
}
fn intersectTags(a: []const Tag, b: []const Tag) TagsArray {
var res = TagsArray.init(0) catch unreachable;
for (a) |tag_| {
if (std.mem.indexOfScalar(Tag, b, tag_)) {
res.appendAssumeCapacity(tag_);
}
}
return res;
}
pub fn parseStruct(T: type, v: anytype) struct { stdx.BoundedArray(T, MAX_RANK), TagsArray } {
const V = @TypeOf(v);
var vals_: stdx.BoundedArray(T, MAX_RANK) = .{};
var tags_: TagsArray = .{};
if (comptime stdx.meta.isSliceOf(V, T)) {
for (v) |d| {
vals_.appendAssumeCapacity(d);
}
return .{ vals_, tags_ };
}
if (comptime stdx.meta.isStruct(V)) {
const fields = std.meta.fields(V);
stdx.debug.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK });
inline for (fields) |field| {
const fv = @field(v, field.name);
vals_.appendAssumeCapacity(fv);
if (!comptime stdx.meta.isTuple(V)) {
tags_.appendAssumeCapacity(toTag(field));
}
}
return .{ vals_, tags_ };
}
stdx.debug.compileError("parseStruct expects struct or tuple, got {}", .{V});
}
test parseStruct {
const vals_, const tags_ = parseStruct(f32, .{ .a = 0.1, .b = 1.2 });
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2 }, vals_.constSlice());
try testing.expectEqualSlices(Tag, &.{ "a".ptr, "b".ptr }, tags_.constSlice());
}
/// Parses a struct literal into a list of options for each axes.
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) stdx.BoundedArray(T, MAX_RANK) {
const V = @TypeOf(options);
var res: stdx.BoundedArray(T, MAX_RANK) = .{};
if (comptime stdx.meta.isSliceOf(V, T)) {
stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
for (options) |d| {
res.appendAssumeCapacity(d);
}
}
if (comptime stdx.meta.isStruct(V)) {
for (0..self.rank()) |_| res.appendAssumeCapacity(default);
const fields = std.meta.fields(V);
stdx.debug.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len });
inline for (fields) |field| {
const a = self.axis(field);
res.buffer[a] = @field(options, field.name);
}
return res;
}
stdx.debug.compileError("parseStruct expects struct or tuple, got {}", .{V});
}
test parseAxesOptions {
const shape = Shape.init(.{ .a = 10, .b = 20, .c = 30 }, .u8);
const scaling = shape.parseAxesOptions(f32, .{ .b = 1.2, .a = 0.1 }, 1.0);
try testing.expectEqualSlices(f32, &.{ 0.1, 1.2, 1.0 }, scaling.constSlice());
}
test "comptimeShape" {
comptime {
const s = Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
try std.testing.expectEqual(3, s.rank());
try std.testing.expectEqual(4 * 5 * 6 * 7, s.byteSize());
try std.testing.expectEqual(1, s.axis(.b));
}
// comptime only the shape
{
const s = comptime Shape.init(.{ .a = 5, .b = 6, .c = 7 }, .f32);
try std.testing.expectEqual(3, s.rank());
try std.testing.expectEqual(4 * 5 * 6 * 7, s.byteSize());
try std.testing.expectEqual(1, s.axis(.b));
}
}
pub fn outer(self: Shape, other: Shape) Shape {
var res_shape = self;
var batching_axes: u8 = 0;
for (0..other.rank()) |ax| {
if (other.tag(ax) != Shape.TagUnknown) {
if (self.hasTag(other.tag(ax))) |batching_ax| {
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({f}, {f})", .{ self, other });
batching_axes += 1;
}
}
res_shape = res_shape.appendDim(other.dim(ax), other.tag(ax));
}
return res_shape;
}
};