Set default values for operand_layouts and result_layouts in StableHLO dialect.

This commit is contained in:
Tarry Singh 2024-12-26 09:29:45 +00:00
parent e6286b6097
commit c961d705f1

View File

@ -786,12 +786,32 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
} }
const MINOR_TO_MAJOR = blk: {
const MAX_RANK = 8;
var ret: [MAX_RANK]usize = undefined;
for (0..MAX_RANK) |i| {
ret[i] = @intCast(MAX_RANK - i - 1);
}
break :blk ret;
};
if (opts.operand_layouts) |layouts| { if (opts.operand_layouts) |layouts| {
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (layouts) |ol| { for (layouts) |ol| {
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
} }
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
} else {
const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (inputs) |input| {
const ranked_type = input.getType().as(mlir.RankedTensorType);
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
}
break :blk ret;
};
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
} }
if (opts.result_layouts) |layouts| { if (opts.result_layouts) |layouts| {
@ -800,6 +820,17 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
} }
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
} else {
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (res_types) |t| {
const ranked_t = t.as(mlir.RankedTensorType);
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
}
break :blk ret;
};
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
} }
attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes"); attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");