Set default values for operand_layouts and result_layouts in StableHLO dialect.
This commit is contained in:
parent
e6286b6097
commit
c961d705f1
@ -786,12 +786,32 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MINOR_TO_MAJOR = blk: {
|
||||||
|
const MAX_RANK = 8;
|
||||||
|
var ret: [MAX_RANK]usize = undefined;
|
||||||
|
for (0..MAX_RANK) |i| {
|
||||||
|
ret[i] = @intCast(MAX_RANK - i - 1);
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
if (opts.operand_layouts) |layouts| {
|
if (opts.operand_layouts) |layouts| {
|
||||||
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
for (layouts) |ol| {
|
for (layouts) |ol| {
|
||||||
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||||
}
|
}
|
||||||
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
||||||
|
} else {
|
||||||
|
const operand_layouts = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
|
for (inputs) |input| {
|
||||||
|
const ranked_type = input.getType().as(mlir.RankedTensorType);
|
||||||
|
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
|
||||||
|
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opts.result_layouts) |layouts| {
|
if (opts.result_layouts) |layouts| {
|
||||||
@ -800,6 +820,17 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||||
}
|
}
|
||||||
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
||||||
|
} else {
|
||||||
|
const result_layouts = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
|
for (res_types) |t| {
|
||||||
|
const ranked_t = t.as(mlir.RankedTensorType);
|
||||||
|
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
|
||||||
|
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
||||||
}
|
}
|
||||||
|
|
||||||
attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");
|
attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user