stablehlo: fix forwarding of optional operand and result layout attributes in custom call
This commit is contained in:
parent
44933c9b89
commit
cbe6e730bd
@ -749,8 +749,8 @@ pub const CustomCallOpts = struct {
|
||||
call_target_name: [:0]const u8,
|
||||
has_side_effect: bool,
|
||||
backend_config: BackendConfig = .{ .string = &.{} },
|
||||
operand_layouts: []const []const usize = &.{},
|
||||
result_layouts: []const []const usize = &.{},
|
||||
operand_layouts: ?[]const []const usize = null,
|
||||
result_layouts: ?[]const []const usize = null,
|
||||
output_operand_aliases: []const i64 = &.{},
|
||||
addional_attributes: []const mlir.AttrTuple = &.{},
|
||||
api_version: ApiVersion,
|
||||
@ -760,32 +760,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
const MAX_OPERANDS = 64;
|
||||
const MAX_RESULTS = 16;
|
||||
|
||||
const operand_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||
for (opts.operand_layouts) |ol| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(ol.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const result_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.result_layouts) |rl| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(rl.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const output_operand_aliases = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.output_operand_aliases) |alias| {
|
||||
@ -821,15 +795,49 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
};
|
||||
|
||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||
|
||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||
.{ "backend_config", backend_config },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
||||
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
|
||||
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
|
||||
});
|
||||
|
||||
if (opts.operand_layouts) |layouts| {
|
||||
const operand_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||
for (layouts) |ol| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(ol.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) };
|
||||
attrs.appendAssumeCapacity(attr);
|
||||
}
|
||||
|
||||
if (opts.result_layouts) |layouts| {
|
||||
const result_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (layouts) |rl| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(rl.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) };
|
||||
attrs.appendAssumeCapacity(attr);
|
||||
}
|
||||
|
||||
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
||||
|
||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||
|
||||
@ -9,11 +9,19 @@ platforms:
|
||||
--cpu=darwin_arm64
|
||||
--apple_platform_type=macos
|
||||
|
||||
@zml//platforms:macos_amd64
|
||||
--cpu=darwin_amd64
|
||||
--apple_platform_type=macos
|
||||
|
||||
flags:
|
||||
--cpu=darwin_arm64
|
||||
--apple_platform_type=macos
|
||||
@zml//platforms:macos_arm64
|
||||
|
||||
--cpu=darwin_amd64
|
||||
--apple_platform_type=macos
|
||||
@zml//platforms:macos_amd64
|
||||
|
||||
--cpu=k8
|
||||
@zml//platforms:linux_amd64
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user