stablehlo: extend dot_general API to include DotAlgorithm support by merging precision and algorithm attributes into a union, aligning with spec requirements. Currently not exposed to users due to limited algorithm support.

This commit is contained in:
Tarry Singh 2023-06-07 11:20:25 +00:00
parent 6d720126ac
commit 344e07fb6e
3 changed files with 115 additions and 34 deletions

View File

@ -112,34 +112,101 @@ pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Va
}); });
} }
/// General matrix multiplication "a la Einstein sum" pub const DotPrecision = union(enum) {
/// Note: stablehlo doesn't do type inference for dot_general fast,
pub fn dot_general(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct { high,
lhs_batching_dimensions: []const i64, highest,
rhs_batching_dimensions: []const i64, algorithm: DotAlgorithm,
lhs_contracting_dimensions: []const i64,
rhs_contracting_dimensions: []const i64, pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute {
precision: []const PrecisionAttribute.Precision, const precision = PrecisionAttribute.init(ctx, switch (self) {
}) mlir.Operation { .fast => .DEFAULT,
var maxPrecisions: [10]mlir.Attribute = undefined; .high => .HIGH,
for (opts.precision, 0..) |p, i| { .highest => .HIGHEST,
maxPrecisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?; // When we specify the dot algorithm, we should not specify the precision.
.algorithm => .DEFAULT,
});
return precision.as(mlir.Attribute).?;
} }
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
return switch (self) {
.algorithm => |algo| algo.asAttr(ctx, operand_type),
else => null,
};
}
};
pub const DotAlgorithm = struct {
accumulation: mlir.FloatTypes,
// Note stablehlo distinguish between left/right component_count
// but all the supported algorithm have the same component_count on both side.
component_count: u8 = 1,
num_primitive_operations: u8 = 1,
allow_imprecise_accumulation: bool = false,
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
// not sure where this is available.
pub const bf16_6x: DotAlgorithm = .{
.operand = .bf16,
.accumulation = .f32,
.component_count = 1,
.num_primitive_operations = 6,
.allow_imprecise_accumulation = false,
};
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input");
const elem_type = tensor_type.getElementType();
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
ctx.inner(),
elem_type.inner(),
elem_type.inner(),
self.accumulation.asType(ctx).?.inner(),
self.component_count,
self.component_count,
self.num_primitive_operations,
self.allow_imprecise_accumulation,
));
}
};
/// General matrix multiplication "a la Einstein sum"
/// Note: stablehlo doesn't do type inference for dot_general
pub fn dot_general(
ctx: mlir.Context,
lhs: mlir.Value,
rhs: mlir.Value,
result_type: mlir.Type,
location: mlir.Location,
opts: struct {
lhs_batching_dimensions: []const i64,
rhs_batching_dimensions: []const i64,
lhs_contracting_dimensions: []const i64,
rhs_contracting_dimensions: []const i64,
precision: DotPrecision,
},
) mlir.Operation {
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
const attributes = [3]mlir.Operation.AttrTuple{
.{
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute).?,
},
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
// keep algorithm as the last attribute so we can omit it when it's not set.
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
};
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = attributes[0..n_attributes],
.{
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute).?,
},
.{ "precision_config", mlir.ArrayAttribute.init(ctx, maxPrecisions[0..opts.precision.len]).as(mlir.Attribute).? },
},
.location = location, .location = location,
}); });
} }

View File

@ -434,15 +434,17 @@ pub const TypeAttribute = struct {
.dump_fn = c.mlirAttributeDump, .dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual, .equal_fn = c.mlirAttributeEqual,
}); });
const Self = TypeAttribute; pub fn init(type_: Type) TypeAttribute {
return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner()));
pub fn init(type_: Type) Self {
return Self.wrap(c.mlirTypeAttrGet(type_.inner()));
} }
pub fn typ(self: Self) Type { pub fn typ(self: TypeAttribute) Type {
return Type.wrap(c.mlirAttributeGetType(self.inner())); return Type.wrap(c.mlirAttributeGetType(self.inner()));
} }
pub fn asAttr(self: TypeAttribute) Attribute {
return self.as(Attribute).?;
}
}; };
pub const ArrayAttribute = struct { pub const ArrayAttribute = struct {
@ -788,9 +790,9 @@ pub const Operation = struct {
) orelse Error.InvalidMlir; ) orelse Error.InvalidMlir;
} }
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { pub const AttrTuple = struct { [:0]const u8, Attribute };
pub const AttrTuple = struct { [:0]const u8, Attribute };
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
operands: ?[]const Value = null, operands: ?[]const Value = null,
variadic_operands: ?[]const []const Value = null, variadic_operands: ?[]const []const Value = null,
results: ?[]const Type = null, results: ?[]const Type = null,
@ -1301,6 +1303,13 @@ pub const FloatTypes = enum {
f64, f64,
unknown, unknown,
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
return switch (self) {
.unknown => null,
inline else => |ft| FloatType(ft).init(ctx).asType(),
};
}
}; };
pub fn FloatType(comptime ft: FloatTypes) type { pub fn FloatType(comptime ft: FloatTypes) type {
@ -1353,6 +1362,10 @@ pub fn FloatType(comptime ft: FloatTypes) type {
} }
return false; return false;
} }
pub fn asType(self: Float) Type {
return self.as(Type).?;
}
}; };
} }

View File

@ -1163,19 +1163,20 @@ pub const Tensor = struct {
res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r)); res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r));
} }
const loc = lhs.getContext().mlirCtx().location(@src()); const mlir_ctx = lhs.getContext().mlirCtx();
const loc = mlir_ctx.location(@src());
const op = dialect.stablehlo.dot_general( const op = dialect.stablehlo.dot_general(
lhs.getContext().mlirCtx(), mlir_ctx,
lhs.value(), lhs.value(),
rhs.value(), rhs.value(),
mlir.ext.mlirType(lhs.getContext().mlirCtx(), res_shape), mlir.ext.mlirType(mlir_ctx, res_shape),
loc, loc,
.{ .{
.lhs_batching_dimensions = lhs_batching_axes.constSlice(), .lhs_batching_dimensions = lhs_batching_axes.constSlice(),
.rhs_batching_dimensions = rhs_batching_axes.constSlice(), .rhs_batching_dimensions = rhs_batching_axes.constSlice(),
.lhs_contracting_dimensions = lhs_contracting_axes.constSlice(), .lhs_contracting_dimensions = lhs_contracting_axes.constSlice(),
.rhs_contracting_dimensions = rhs_contracting_axes.constSlice(), .rhs_contracting_dimensions = rhs_contracting_axes.constSlice(),
.precision = &.{ .DEFAULT, .DEFAULT }, .precision = .fast,
}, },
); );
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));