From cbe6e730bd34d33cafdb7d21e38b67665e435dc2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 31 Jul 2024 17:53:18 +0000 Subject: [PATCH] stablehlo: fix forwarding of optional operand and result layout attributes in custom call --- mlir/dialects/stablehlo.zig | 68 +++++++++++++++++++++---------------- platform_mappings | 8 +++++ 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 58ba5eb..117f809 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -749,8 +749,8 @@ pub const CustomCallOpts = struct { call_target_name: [:0]const u8, has_side_effect: bool, backend_config: BackendConfig = .{ .string = &.{} }, - operand_layouts: []const []const usize = &.{}, - result_layouts: []const []const usize = &.{}, + operand_layouts: ?[]const []const usize = null, + result_layouts: ?[]const []const usize = null, output_operand_aliases: []const i64 = &.{}, addional_attributes: []const mlir.AttrTuple = &.{}, api_version: ApiVersion, @@ -760,32 +760,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const MAX_OPERANDS = 64; const MAX_RESULTS = 16; - const operand_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; - for (opts.operand_layouts) |ol| { - const tensor_type = mlir.RankedTensorType.init( - &.{@intCast(ol.len)}, - mlir.IndexType.init(ctx).as(mlir.Type), - ).as(mlir.Type); - const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol); - ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); - } - break :blk ret; - }; - - const result_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (opts.result_layouts) |rl| { - const tensor_type = mlir.RankedTensorType.init( - &.{@intCast(rl.len)}, - mlir.IndexType.init(ctx).as(mlir.Type), - ).as(mlir.Type); - const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl); - ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); - } - break :blk ret; - }; - const output_operand_aliases = blk: { var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (opts.output_operand_aliases) |alias| { @@ -821,15 +795,49 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa }; var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; + attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) }, .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) }, .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) }, .{ "backend_config", backend_config }, .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) }, - .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) }, - .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) }, }); + + if (opts.operand_layouts) |layouts| { + const operand_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (layouts) |ol| { + const tensor_type = mlir.RankedTensorType.init( + &.{@intCast(ol.len)}, + mlir.IndexType.init(ctx).as(mlir.Type), + ).as(mlir.Type); + const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol); + ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + } + break :blk ret; + }; + const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) }; + attrs.appendAssumeCapacity(attr); + } + + if (opts.result_layouts) |layouts| { + const result_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (layouts) |rl| { + const tensor_type = mlir.RankedTensorType.init( + &.{@intCast(rl.len)}, + mlir.IndexType.init(ctx).as(mlir.Type), + ).as(mlir.Type); + const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl); + ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + } + break :blk ret; + }; + const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) }; + attrs.appendAssumeCapacity(attr); + } + attrs.appendSliceAssumeCapacity(opts.addional_attributes); return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ diff --git a/platform_mappings b/platform_mappings index 6fd7ceb..ed02329 100644 --- a/platform_mappings +++ b/platform_mappings @@ -9,11 +9,19 @@ platforms: --cpu=darwin_arm64 --apple_platform_type=macos + @zml//platforms:macos_amd64 + --cpu=darwin_amd64 + --apple_platform_type=macos + flags: --cpu=darwin_arm64 --apple_platform_type=macos @zml//platforms:macos_arm64 + --cpu=darwin_amd64 + --apple_platform_type=macos + @zml//platforms:macos_amd64 + --cpu=k8 @zml//platforms:linux_amd64