stablehlo: extend dot_general API to include DotAlgorithm support by merging precision and algorithm attributes into a union, aligning with spec requirements. Currently not exposed to users due to limited algorithm support.
This commit is contained in:
parent
6d720126ac
commit
344e07fb6e
@ -112,24 +112,84 @@ pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Va
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const DotPrecision = union(enum) {
|
||||||
|
fast,
|
||||||
|
high,
|
||||||
|
highest,
|
||||||
|
algorithm: DotAlgorithm,
|
||||||
|
|
||||||
|
pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute {
|
||||||
|
const precision = PrecisionAttribute.init(ctx, switch (self) {
|
||||||
|
.fast => .DEFAULT,
|
||||||
|
.high => .HIGH,
|
||||||
|
.highest => .HIGHEST,
|
||||||
|
// When we specify the dot algorithm, we should not specify the precision.
|
||||||
|
.algorithm => .DEFAULT,
|
||||||
|
});
|
||||||
|
return precision.as(mlir.Attribute).?;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
|
||||||
|
return switch (self) {
|
||||||
|
.algorithm => |algo| algo.asAttr(ctx, operand_type),
|
||||||
|
else => null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const DotAlgorithm = struct {
|
||||||
|
accumulation: mlir.FloatTypes,
|
||||||
|
// Note stablehlo distinguish between left/right component_count
|
||||||
|
// but all the supported algorithm have the same component_count on both side.
|
||||||
|
component_count: u8 = 1,
|
||||||
|
num_primitive_operations: u8 = 1,
|
||||||
|
allow_imprecise_accumulation: bool = false,
|
||||||
|
|
||||||
|
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
|
||||||
|
// not sure where this is available.
|
||||||
|
pub const bf16_6x: DotAlgorithm = .{
|
||||||
|
.operand = .bf16,
|
||||||
|
.accumulation = .f32,
|
||||||
|
.component_count = 1,
|
||||||
|
.num_primitive_operations = 6,
|
||||||
|
.allow_imprecise_accumulation = false,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
|
||||||
|
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input");
|
||||||
|
const elem_type = tensor_type.getElementType();
|
||||||
|
|
||||||
|
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
|
||||||
|
ctx.inner(),
|
||||||
|
elem_type.inner(),
|
||||||
|
elem_type.inner(),
|
||||||
|
self.accumulation.asType(ctx).?.inner(),
|
||||||
|
self.component_count,
|
||||||
|
self.component_count,
|
||||||
|
self.num_primitive_operations,
|
||||||
|
self.allow_imprecise_accumulation,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// General matrix multiplication "a la Einstein sum"
|
/// General matrix multiplication "a la Einstein sum"
|
||||||
/// Note: stablehlo doesn't do type inference for dot_general
|
/// Note: stablehlo doesn't do type inference for dot_general
|
||||||
pub fn dot_general(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct {
|
pub fn dot_general(
|
||||||
|
ctx: mlir.Context,
|
||||||
|
lhs: mlir.Value,
|
||||||
|
rhs: mlir.Value,
|
||||||
|
result_type: mlir.Type,
|
||||||
|
location: mlir.Location,
|
||||||
|
opts: struct {
|
||||||
lhs_batching_dimensions: []const i64,
|
lhs_batching_dimensions: []const i64,
|
||||||
rhs_batching_dimensions: []const i64,
|
rhs_batching_dimensions: []const i64,
|
||||||
lhs_contracting_dimensions: []const i64,
|
lhs_contracting_dimensions: []const i64,
|
||||||
rhs_contracting_dimensions: []const i64,
|
rhs_contracting_dimensions: []const i64,
|
||||||
precision: []const PrecisionAttribute.Precision,
|
precision: DotPrecision,
|
||||||
}) mlir.Operation {
|
},
|
||||||
var maxPrecisions: [10]mlir.Attribute = undefined;
|
) mlir.Operation {
|
||||||
for (opts.precision, 0..) |p, i| {
|
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
|
||||||
maxPrecisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?;
|
const attributes = [3]mlir.Operation.AttrTuple{
|
||||||
}
|
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
|
|
||||||
.operands = &.{ lhs, rhs },
|
|
||||||
.results = &.{result_type},
|
|
||||||
.attributes = &.{
|
|
||||||
.{
|
.{
|
||||||
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
||||||
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
|
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
|
||||||
@ -138,8 +198,15 @@ pub fn dot_general(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, result_t
|
|||||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||||
}).as(mlir.Attribute).?,
|
}).as(mlir.Attribute).?,
|
||||||
},
|
},
|
||||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, maxPrecisions[0..opts.precision.len]).as(mlir.Attribute).? },
|
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
|
||||||
},
|
// keep algorithm as the last attribute so we can omit it when it's not set.
|
||||||
|
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
|
||||||
|
};
|
||||||
|
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
|
||||||
|
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
|
||||||
|
.operands = &.{ lhs, rhs },
|
||||||
|
.results = &.{result_type},
|
||||||
|
.attributes = attributes[0..n_attributes],
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -434,15 +434,17 @@ pub const TypeAttribute = struct {
|
|||||||
.dump_fn = c.mlirAttributeDump,
|
.dump_fn = c.mlirAttributeDump,
|
||||||
.equal_fn = c.mlirAttributeEqual,
|
.equal_fn = c.mlirAttributeEqual,
|
||||||
});
|
});
|
||||||
const Self = TypeAttribute;
|
pub fn init(type_: Type) TypeAttribute {
|
||||||
|
return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner()));
|
||||||
pub fn init(type_: Type) Self {
|
|
||||||
return Self.wrap(c.mlirTypeAttrGet(type_.inner()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn typ(self: Self) Type {
|
pub fn typ(self: TypeAttribute) Type {
|
||||||
return Type.wrap(c.mlirAttributeGetType(self.inner()));
|
return Type.wrap(c.mlirAttributeGetType(self.inner()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asAttr(self: TypeAttribute) Attribute {
|
||||||
|
return self.as(Attribute).?;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const ArrayAttribute = struct {
|
pub const ArrayAttribute = struct {
|
||||||
@ -788,9 +790,9 @@ pub const Operation = struct {
|
|||||||
) orelse Error.InvalidMlir;
|
) orelse Error.InvalidMlir;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
|
||||||
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
||||||
|
|
||||||
|
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
||||||
operands: ?[]const Value = null,
|
operands: ?[]const Value = null,
|
||||||
variadic_operands: ?[]const []const Value = null,
|
variadic_operands: ?[]const []const Value = null,
|
||||||
results: ?[]const Type = null,
|
results: ?[]const Type = null,
|
||||||
@ -1301,6 +1303,13 @@ pub const FloatTypes = enum {
|
|||||||
f64,
|
f64,
|
||||||
|
|
||||||
unknown,
|
unknown,
|
||||||
|
|
||||||
|
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
|
||||||
|
return switch (self) {
|
||||||
|
.unknown => null,
|
||||||
|
inline else => |ft| FloatType(ft).init(ctx).asType(),
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn FloatType(comptime ft: FloatTypes) type {
|
pub fn FloatType(comptime ft: FloatTypes) type {
|
||||||
@ -1353,6 +1362,10 @@ pub fn FloatType(comptime ft: FloatTypes) type {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asType(self: Float) Type {
|
||||||
|
return self.as(Type).?;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1163,19 +1163,20 @@ pub const Tensor = struct {
|
|||||||
res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r));
|
res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r));
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = lhs.getContext().mlirCtx().location(@src());
|
const mlir_ctx = lhs.getContext().mlirCtx();
|
||||||
|
const loc = mlir_ctx.location(@src());
|
||||||
const op = dialect.stablehlo.dot_general(
|
const op = dialect.stablehlo.dot_general(
|
||||||
lhs.getContext().mlirCtx(),
|
mlir_ctx,
|
||||||
lhs.value(),
|
lhs.value(),
|
||||||
rhs.value(),
|
rhs.value(),
|
||||||
mlir.ext.mlirType(lhs.getContext().mlirCtx(), res_shape),
|
mlir.ext.mlirType(mlir_ctx, res_shape),
|
||||||
loc,
|
loc,
|
||||||
.{
|
.{
|
||||||
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
|
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
|
||||||
.rhs_batching_dimensions = rhs_batching_axes.constSlice(),
|
.rhs_batching_dimensions = rhs_batching_axes.constSlice(),
|
||||||
.lhs_contracting_dimensions = lhs_contracting_axes.constSlice(),
|
.lhs_contracting_dimensions = lhs_contracting_axes.constSlice(),
|
||||||
.rhs_contracting_dimensions = rhs_contracting_axes.constSlice(),
|
.rhs_contracting_dimensions = rhs_contracting_axes.constSlice(),
|
||||||
.precision = &.{ .DEFAULT, .DEFAULT },
|
.precision = .fast,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user