stablehlo: fix forwarding of optional operand and result layout attributes in custom call

This commit is contained in:
Tarry Singh 2024-07-31 17:53:18 +00:00
parent 44933c9b89
commit cbe6e730bd
2 changed files with 46 additions and 30 deletions

View File

@ -749,8 +749,8 @@ pub const CustomCallOpts = struct {
call_target_name: [:0]const u8,
has_side_effect: bool,
backend_config: BackendConfig = .{ .string = &.{} },
operand_layouts: []const []const usize = &.{},
result_layouts: []const []const usize = &.{},
operand_layouts: ?[]const []const usize = null,
result_layouts: ?[]const []const usize = null,
output_operand_aliases: []const i64 = &.{},
addional_attributes: []const mlir.AttrTuple = &.{},
api_version: ApiVersion,
@ -760,32 +760,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const MAX_OPERANDS = 64;
const MAX_RESULTS = 16;
const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (opts.operand_layouts) |ol| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(ol.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.result_layouts) |rl| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(rl.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const output_operand_aliases = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| {
@ -821,15 +795,49 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
};
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "backend_config", backend_config },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
});
if (opts.operand_layouts) |layouts| {
const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (layouts) |ol| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(ol.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) };
attrs.appendAssumeCapacity(attr);
}
if (opts.result_layouts) |layouts| {
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (layouts) |rl| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(rl.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) };
attrs.appendAssumeCapacity(attr);
}
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{

View File

@ -9,11 +9,19 @@ platforms:
--cpu=darwin_arm64
--apple_platform_type=macos
@zml//platforms:macos_amd64
--cpu=darwin_amd64
--apple_platform_type=macos
flags:
--cpu=darwin_arm64
--apple_platform_type=macos
@zml//platforms:macos_arm64
--cpu=darwin_amd64
--apple_platform_type=macos
@zml//platforms:macos_amd64
--cpu=k8
@zml//platforms:linux_amd64