Radix/mlir/dialects/scf.zig

62 lines
1.8 KiB
Zig
Raw Permalink Normal View History

const std = @import("std");
const mlir = @import("mlir");
pub fn ForBody(ExtraArgs: type) type {
return fn (mlir.Context, mlir.Block, ExtraArgs) mlir.Operation;
}
pub const ForRange = struct {
start: mlir.Value,
end: mlir.Value,
step: mlir.Value,
};
pub fn @"for"(
ExtraArgs: type,
ctx: mlir.Context,
range: ForRange,
init_values: []const mlir.Value,
body: ForBody(ExtraArgs),
extra_args: ExtraArgs,
loc: mlir.Location,
) mlir.Operation {
const n_args = init_values.len;
var init_types_buf: [32]mlir.Type = undefined;
var locs_buf: [32]mlir.Location = undefined;
// The first block argument is the for loop induction variable,
// followed then by all the loop-carried variables.
const init_types = init_types_buf[0 .. n_args + 1];
const locs = locs_buf[0 .. n_args + 1];
init_types[0] = range.start.getType();
locs[0] = loc;
for (1.., init_values) |i, val| {
init_types[i] = val.getType();
locs[i] = loc;
}
const block = mlir.Block.init(init_types, locs) catch unreachable;
const yield_op = @call(.auto, body, .{ ctx, block, extra_args });
std.debug.assert(std.mem.eql(u8, "scf.yield", yield_op.name().str()));
block.appendOperationRecursive(yield_op, .open);
const for_op = mlir.Operation.make(ctx, "scf.for", .{
.variadic_operands = &.{ &.{ range.start, range.end, range.step }, init_values },
.results = init_types[1..],
.blocks = &.{block},
.location = loc,
.verify = false,
});
return for_op;
}
pub fn yield(ctx: mlir.Context, res: []const mlir.Value, loc: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "scf.yield", .{
.variadic_operands = &.{res},
.results = &.{},
.location = loc,
.verify = false,
});
}