From c961d705f1eeab32b271ccb5c691bf243b44f3de Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 26 Dec 2024 09:29:45 +0000 Subject: [PATCH] Set default values for operand_layouts and result_layouts in StableHLO dialect. --- mlir/dialects/stablehlo.zig | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 6328321..24466dd 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -786,12 +786,32 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); } + const MINOR_TO_MAJOR = blk: { + const MAX_RANK = 8; + var ret: [MAX_RANK]usize = undefined; + for (0..MAX_RANK) |i| { + ret[i] = @intCast(MAX_RANK - i - 1); + } + break :blk ret; + }; + if (opts.operand_layouts) |layouts| { var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (layouts) |ol| { operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); } attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); + } else { + const operand_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (inputs) |input| { + const ranked_type = input.getType().as(mlir.RankedTensorType); + const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); + } + break :blk ret; + }; + attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); } if (opts.result_layouts) |layouts| { @@ -800,6 +820,17 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); } attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); + } else { + const result_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (res_types) |t| { + const ranked_t = t.as(mlir.RankedTensorType); + const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); + } + break :blk ret; + }; + attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); } attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");