zml/ops: add wiring for operand output alias in zml.ops.triton

This commit is contained in:
Tarry Singh 2024-09-09 15:00:28 +00:00
parent 7e0fcecfc9
commit 1f5ff96c10

View File

@ -819,6 +819,7 @@ pub const TritonOps = struct {
grid: [3]i32, grid: [3]i32,
num_stages: i32, num_stages: i32,
num_warps: i32, num_warps: i32,
output_operand_aliases: []const i64 = &.{},
}; };
/// Generate an MLIR call to the given member function with the given tensors. /// Generate an MLIR call to the given member function with the given tensors.
@ -877,6 +878,7 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
.api_version = .typed_ffi, .api_version = .typed_ffi,
.operand_layouts = &operands_layouts, .operand_layouts = &operands_layouts,
.result_layouts = &results_layouts, .result_layouts = &results_layouts,
.output_operand_aliases = opts.output_operand_aliases,
}, },
&res_types, &res_types,
ctx.mlirCtx().location(@src()), ctx.mlirCtx().location(@src()),