stablehlo: extend dot_general API to include DotAlgorithm support by merging precision and algorithm attributes into a union, aligning with spec requirements. Currently not exposed to users due to limited algorithm support.
This commit is contained in:
parent
6d720126ac
commit
344e07fb6e
@ -112,34 +112,101 @@ pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Va
|
||||
});
|
||||
}
|
||||
|
||||
/// General matrix multiplication "a la Einstein sum"
|
||||
/// Note: stablehlo doesn't do type inference for dot_general
|
||||
pub fn dot_general(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct {
|
||||
lhs_batching_dimensions: []const i64,
|
||||
rhs_batching_dimensions: []const i64,
|
||||
lhs_contracting_dimensions: []const i64,
|
||||
rhs_contracting_dimensions: []const i64,
|
||||
precision: []const PrecisionAttribute.Precision,
|
||||
}) mlir.Operation {
|
||||
var maxPrecisions: [10]mlir.Attribute = undefined;
|
||||
for (opts.precision, 0..) |p, i| {
|
||||
maxPrecisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?;
|
||||
pub const DotPrecision = union(enum) {
|
||||
fast,
|
||||
high,
|
||||
highest,
|
||||
algorithm: DotAlgorithm,
|
||||
|
||||
pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute {
|
||||
const precision = PrecisionAttribute.init(ctx, switch (self) {
|
||||
.fast => .DEFAULT,
|
||||
.high => .HIGH,
|
||||
.highest => .HIGHEST,
|
||||
// When we specify the dot algorithm, we should not specify the precision.
|
||||
.algorithm => .DEFAULT,
|
||||
});
|
||||
return precision.as(mlir.Attribute).?;
|
||||
}
|
||||
|
||||
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
|
||||
return switch (self) {
|
||||
.algorithm => |algo| algo.asAttr(ctx, operand_type),
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const DotAlgorithm = struct {
|
||||
accumulation: mlir.FloatTypes,
|
||||
// Note stablehlo distinguish between left/right component_count
|
||||
// but all the supported algorithm have the same component_count on both side.
|
||||
component_count: u8 = 1,
|
||||
num_primitive_operations: u8 = 1,
|
||||
allow_imprecise_accumulation: bool = false,
|
||||
|
||||
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
|
||||
// not sure where this is available.
|
||||
pub const bf16_6x: DotAlgorithm = .{
|
||||
.operand = .bf16,
|
||||
.accumulation = .f32,
|
||||
.component_count = 1,
|
||||
.num_primitive_operations = 6,
|
||||
.allow_imprecise_accumulation = false,
|
||||
};
|
||||
|
||||
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
|
||||
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input");
|
||||
const elem_type = tensor_type.getElementType();
|
||||
|
||||
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
|
||||
ctx.inner(),
|
||||
elem_type.inner(),
|
||||
elem_type.inner(),
|
||||
self.accumulation.asType(ctx).?.inner(),
|
||||
self.component_count,
|
||||
self.component_count,
|
||||
self.num_primitive_operations,
|
||||
self.allow_imprecise_accumulation,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
/// General matrix multiplication "a la Einstein sum"
|
||||
/// Note: stablehlo doesn't do type inference for dot_general
|
||||
pub fn dot_general(
|
||||
ctx: mlir.Context,
|
||||
lhs: mlir.Value,
|
||||
rhs: mlir.Value,
|
||||
result_type: mlir.Type,
|
||||
location: mlir.Location,
|
||||
opts: struct {
|
||||
lhs_batching_dimensions: []const i64,
|
||||
rhs_batching_dimensions: []const i64,
|
||||
lhs_contracting_dimensions: []const i64,
|
||||
rhs_contracting_dimensions: []const i64,
|
||||
precision: DotPrecision,
|
||||
},
|
||||
) mlir.Operation {
|
||||
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
|
||||
const attributes = [3]mlir.Operation.AttrTuple{
|
||||
.{
|
||||
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
||||
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
|
||||
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
|
||||
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
|
||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||
}).as(mlir.Attribute).?,
|
||||
},
|
||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
|
||||
// keep algorithm as the last attribute so we can omit it when it's not set.
|
||||
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
|
||||
};
|
||||
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
|
||||
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{
|
||||
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
||||
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
|
||||
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
|
||||
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
|
||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||
}).as(mlir.Attribute).?,
|
||||
},
|
||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, maxPrecisions[0..opts.precision.len]).as(mlir.Attribute).? },
|
||||
},
|
||||
.attributes = attributes[0..n_attributes],
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
@ -434,15 +434,17 @@ pub const TypeAttribute = struct {
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
const Self = TypeAttribute;
|
||||
|
||||
pub fn init(type_: Type) Self {
|
||||
return Self.wrap(c.mlirTypeAttrGet(type_.inner()));
|
||||
pub fn init(type_: Type) TypeAttribute {
|
||||
return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner()));
|
||||
}
|
||||
|
||||
pub fn typ(self: Self) Type {
|
||||
pub fn typ(self: TypeAttribute) Type {
|
||||
return Type.wrap(c.mlirAttributeGetType(self.inner()));
|
||||
}
|
||||
|
||||
pub fn asAttr(self: TypeAttribute) Attribute {
|
||||
return self.as(Attribute).?;
|
||||
}
|
||||
};
|
||||
|
||||
pub const ArrayAttribute = struct {
|
||||
@ -788,9 +790,9 @@ pub const Operation = struct {
|
||||
) orelse Error.InvalidMlir;
|
||||
}
|
||||
|
||||
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
||||
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
||||
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
||||
|
||||
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
||||
operands: ?[]const Value = null,
|
||||
variadic_operands: ?[]const []const Value = null,
|
||||
results: ?[]const Type = null,
|
||||
@ -1301,6 +1303,13 @@ pub const FloatTypes = enum {
|
||||
f64,
|
||||
|
||||
unknown,
|
||||
|
||||
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
|
||||
return switch (self) {
|
||||
.unknown => null,
|
||||
inline else => |ft| FloatType(ft).init(ctx).asType(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn FloatType(comptime ft: FloatTypes) type {
|
||||
@ -1353,6 +1362,10 @@ pub fn FloatType(comptime ft: FloatTypes) type {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn asType(self: Float) Type {
|
||||
return self.as(Type).?;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1163,19 +1163,20 @@ pub const Tensor = struct {
|
||||
res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r));
|
||||
}
|
||||
|
||||
const loc = lhs.getContext().mlirCtx().location(@src());
|
||||
const mlir_ctx = lhs.getContext().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src());
|
||||
const op = dialect.stablehlo.dot_general(
|
||||
lhs.getContext().mlirCtx(),
|
||||
mlir_ctx,
|
||||
lhs.value(),
|
||||
rhs.value(),
|
||||
mlir.ext.mlirType(lhs.getContext().mlirCtx(), res_shape),
|
||||
mlir.ext.mlirType(mlir_ctx, res_shape),
|
||||
loc,
|
||||
.{
|
||||
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
|
||||
.rhs_batching_dimensions = rhs_batching_axes.constSlice(),
|
||||
.lhs_contracting_dimensions = lhs_contracting_axes.constSlice(),
|
||||
.rhs_contracting_dimensions = rhs_contracting_axes.constSlice(),
|
||||
.precision = &.{ .DEFAULT, .DEFAULT },
|
||||
.precision = .fast,
|
||||
},
|
||||
);
|
||||
return _result(res_shape, op.result(0));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user