stablehlo: fix forwarding of optional operand and result layout attributes in custom call
This commit is contained in:
parent
44933c9b89
commit
cbe6e730bd
@ -749,8 +749,8 @@ pub const CustomCallOpts = struct {
|
|||||||
call_target_name: [:0]const u8,
|
call_target_name: [:0]const u8,
|
||||||
has_side_effect: bool,
|
has_side_effect: bool,
|
||||||
backend_config: BackendConfig = .{ .string = &.{} },
|
backend_config: BackendConfig = .{ .string = &.{} },
|
||||||
operand_layouts: []const []const usize = &.{},
|
operand_layouts: ?[]const []const usize = null,
|
||||||
result_layouts: []const []const usize = &.{},
|
result_layouts: ?[]const []const usize = null,
|
||||||
output_operand_aliases: []const i64 = &.{},
|
output_operand_aliases: []const i64 = &.{},
|
||||||
addional_attributes: []const mlir.AttrTuple = &.{},
|
addional_attributes: []const mlir.AttrTuple = &.{},
|
||||||
api_version: ApiVersion,
|
api_version: ApiVersion,
|
||||||
@ -760,32 +760,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
const MAX_OPERANDS = 64;
|
const MAX_OPERANDS = 64;
|
||||||
const MAX_RESULTS = 16;
|
const MAX_RESULTS = 16;
|
||||||
|
|
||||||
const operand_layouts = blk: {
|
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
|
||||||
for (opts.operand_layouts) |ol| {
|
|
||||||
const tensor_type = mlir.RankedTensorType.init(
|
|
||||||
&.{@intCast(ol.len)},
|
|
||||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
|
||||||
).as(mlir.Type);
|
|
||||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
|
||||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
const result_layouts = blk: {
|
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
|
||||||
for (opts.result_layouts) |rl| {
|
|
||||||
const tensor_type = mlir.RankedTensorType.init(
|
|
||||||
&.{@intCast(rl.len)},
|
|
||||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
|
||||||
).as(mlir.Type);
|
|
||||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
|
||||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
const output_operand_aliases = blk: {
|
const output_operand_aliases = blk: {
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
for (opts.output_operand_aliases) |alias| {
|
for (opts.output_operand_aliases) |alias| {
|
||||||
@ -821,15 +795,49 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
};
|
};
|
||||||
|
|
||||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||||
|
|
||||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
||||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||||
.{ "backend_config", backend_config },
|
.{ "backend_config", backend_config },
|
||||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
||||||
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
|
|
||||||
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (opts.operand_layouts) |layouts| {
|
||||||
|
const operand_layouts = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
|
for (layouts) |ol| {
|
||||||
|
const tensor_type = mlir.RankedTensorType.init(
|
||||||
|
&.{@intCast(ol.len)},
|
||||||
|
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||||
|
).as(mlir.Type);
|
||||||
|
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||||
|
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) };
|
||||||
|
attrs.appendAssumeCapacity(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.result_layouts) |layouts| {
|
||||||
|
const result_layouts = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
|
for (layouts) |rl| {
|
||||||
|
const tensor_type = mlir.RankedTensorType.init(
|
||||||
|
&.{@intCast(rl.len)},
|
||||||
|
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||||
|
).as(mlir.Type);
|
||||||
|
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||||
|
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) };
|
||||||
|
attrs.appendAssumeCapacity(attr);
|
||||||
|
}
|
||||||
|
|
||||||
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||||
|
|||||||
@ -9,11 +9,19 @@ platforms:
|
|||||||
--cpu=darwin_arm64
|
--cpu=darwin_arm64
|
||||||
--apple_platform_type=macos
|
--apple_platform_type=macos
|
||||||
|
|
||||||
|
@zml//platforms:macos_amd64
|
||||||
|
--cpu=darwin_amd64
|
||||||
|
--apple_platform_type=macos
|
||||||
|
|
||||||
flags:
|
flags:
|
||||||
--cpu=darwin_arm64
|
--cpu=darwin_arm64
|
||||||
--apple_platform_type=macos
|
--apple_platform_type=macos
|
||||||
@zml//platforms:macos_arm64
|
@zml//platforms:macos_arm64
|
||||||
|
|
||||||
|
--cpu=darwin_amd64
|
||||||
|
--apple_platform_type=macos
|
||||||
|
@zml//platforms:macos_amd64
|
||||||
|
|
||||||
--cpu=k8
|
--cpu=k8
|
||||||
@zml//platforms:linux_amd64
|
@zml//platforms:linux_amd64
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user