From 344e07fb6eab43d44af214223cb36d7153a4fd6e Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 7 Jun 2023 11:20:25 +0000 Subject: [PATCH] stablehlo: extend dot_general API to include DotAlgorithm support by merging precision and algorithm attributes into a union, aligning with spec requirements. Currently not exposed to users due to limited algorithm support. --- mlir/dialects/stablehlo.zig | 113 ++++++++++++++++++++++++++++-------- mlir/mlir.zig | 27 ++++++--- zml/tensor.zig | 9 +-- 3 files changed, 115 insertions(+), 34 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 0c54835..bb8946d 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -112,34 +112,101 @@ pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Va }); } -/// General matrix multiplication "a la Einstein sum" -/// Note: stablehlo doesn't do type inference for dot_general -pub fn dot_general(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct { - lhs_batching_dimensions: []const i64, - rhs_batching_dimensions: []const i64, - lhs_contracting_dimensions: []const i64, - rhs_contracting_dimensions: []const i64, - precision: []const PrecisionAttribute.Precision, -}) mlir.Operation { - var maxPrecisions: [10]mlir.Attribute = undefined; - for (opts.precision, 0..) |p, i| { - maxPrecisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?; +pub const DotPrecision = union(enum) { + fast, + high, + highest, + algorithm: DotAlgorithm, + + pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute { + const precision = PrecisionAttribute.init(ctx, switch (self) { + .fast => .DEFAULT, + .high => .HIGH, + .highest => .HIGHEST, + // When we specify the dot algorithm, we should not specify the precision. + .algorithm => .DEFAULT, + }); + return precision.as(mlir.Attribute).?; } + pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute { + return switch (self) { + .algorithm => |algo| algo.asAttr(ctx, operand_type), + else => null, + }; + } +}; + +pub const DotAlgorithm = struct { + accumulation: mlir.FloatTypes, + // Note stablehlo distinguish between left/right component_count + // but all the supported algorithm have the same component_count on both side. + component_count: u8 = 1, + num_primitive_operations: u8 = 1, + allow_imprecise_accumulation: bool = false, + + // bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32. + // not sure where this is available. + pub const bf16_6x: DotAlgorithm = .{ + .operand = .bf16, + .accumulation = .f32, + .component_count = 1, + .num_primitive_operations = 6, + .allow_imprecise_accumulation = false, + }; + + pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute { + const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input"); + const elem_type = tensor_type.getElementType(); + + return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( + ctx.inner(), + elem_type.inner(), + elem_type.inner(), + self.accumulation.asType(ctx).?.inner(), + self.component_count, + self.component_count, + self.num_primitive_operations, + self.allow_imprecise_accumulation, + )); + } +}; + +/// General matrix multiplication "a la Einstein sum" +/// Note: stablehlo doesn't do type inference for dot_general +pub fn dot_general( + ctx: mlir.Context, + lhs: mlir.Value, + rhs: mlir.Value, + result_type: mlir.Type, + location: mlir.Location, + opts: struct { + lhs_batching_dimensions: []const i64, + rhs_batching_dimensions: []const i64, + lhs_contracting_dimensions: []const i64, + rhs_contracting_dimensions: []const i64, + precision: DotPrecision, + }, +) mlir.Operation { + const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2; + const attributes = [3]mlir.Operation.AttrTuple{ + .{ + "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ + .lhs_batching_dimensions = opts.lhs_batching_dimensions, + .rhs_batching_dimensions = opts.rhs_batching_dimensions, + .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, + .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, + }).as(mlir.Attribute).?, + }, + .{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() }, + // keep algorithm as the last attribute so we can omit it when it's not set. + .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined }, + }; + const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1; return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ .operands = &.{ lhs, rhs }, .results = &.{result_type}, - .attributes = &.{ - .{ - "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ - .lhs_batching_dimensions = opts.lhs_batching_dimensions, - .rhs_batching_dimensions = opts.rhs_batching_dimensions, - .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, - .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, - }).as(mlir.Attribute).?, - }, - .{ "precision_config", mlir.ArrayAttribute.init(ctx, maxPrecisions[0..opts.precision.len]).as(mlir.Attribute).? }, - }, + .attributes = attributes[0..n_attributes], .location = location, }); } diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 8383c67..de5087a 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -434,15 +434,17 @@ pub const TypeAttribute = struct { .dump_fn = c.mlirAttributeDump, .equal_fn = c.mlirAttributeEqual, }); - const Self = TypeAttribute; - - pub fn init(type_: Type) Self { - return Self.wrap(c.mlirTypeAttrGet(type_.inner())); + pub fn init(type_: Type) TypeAttribute { + return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner())); } - pub fn typ(self: Self) Type { + pub fn typ(self: TypeAttribute) Type { return Type.wrap(c.mlirAttributeGetType(self.inner())); } + + pub fn asAttr(self: TypeAttribute) Attribute { + return self.as(Attribute).?; + } }; pub const ArrayAttribute = struct { @@ -788,9 +790,9 @@ pub const Operation = struct { ) orelse Error.InvalidMlir; } - pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { - pub const AttrTuple = struct { [:0]const u8, Attribute }; + pub const AttrTuple = struct { [:0]const u8, Attribute }; + pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { operands: ?[]const Value = null, variadic_operands: ?[]const []const Value = null, results: ?[]const Type = null, @@ -1301,6 +1303,13 @@ pub const FloatTypes = enum { f64, unknown, + + pub fn asType(self: FloatTypes, ctx: Context) ?Type { + return switch (self) { + .unknown => null, + inline else => |ft| FloatType(ft).init(ctx).asType(), + }; + } }; pub fn FloatType(comptime ft: FloatTypes) type { @@ -1353,6 +1362,10 @@ pub fn FloatType(comptime ft: FloatTypes) type { } return false; } + + pub fn asType(self: Float) Type { + return self.as(Type).?; + } }; } diff --git a/zml/tensor.zig b/zml/tensor.zig index 82e5b6b..554aa4b 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1163,19 +1163,20 @@ pub const Tensor = struct { res_shape = res_shape.appendDim(rhs._shape.dim(r), rhs._shape.tag(r)); } - const loc = lhs.getContext().mlirCtx().location(@src()); + const mlir_ctx = lhs.getContext().mlirCtx(); + const loc = mlir_ctx.location(@src()); const op = dialect.stablehlo.dot_general( - lhs.getContext().mlirCtx(), + mlir_ctx, lhs.value(), rhs.value(), - mlir.ext.mlirType(lhs.getContext().mlirCtx(), res_shape), + mlir.ext.mlirType(mlir_ctx, res_shape), loc, .{ .lhs_batching_dimensions = lhs_batching_axes.constSlice(), .rhs_batching_dimensions = rhs_batching_axes.constSlice(), .lhs_contracting_dimensions = lhs_contracting_axes.constSlice(), .rhs_contracting_dimensions = rhs_contracting_axes.constSlice(), - .precision = &.{ .DEFAULT, .DEFAULT }, + .precision = .fast, }, ); return _result(res_shape, op.result(0));