1907 lines
62 KiB
Zig
1907 lines
62 KiB
Zig
const std = @import("std");
|
|
const builtin = @import("builtin");
|
|
|
|
const c = @import("c");
|
|
const stdx = @import("stdx");
|
|
|
|
const log = std.log.scoped(.mlir);
|
|
|
|
test {
|
|
std.testing.refAllDecls(@This());
|
|
}
|
|
|
|
const Error = error{
|
|
/// Invalid Mlir was created.
|
|
InvalidMlir,
|
|
/// Another Mlir error. Check the log for more context.
|
|
MlirUnexpected,
|
|
/// A resource/executor was not found.
|
|
NotFound,
|
|
OutOfMemory,
|
|
};
|
|
|
|
pub inline fn stringRef(str: []const u8) c.MlirStringRef {
|
|
return .{ .data = str.ptr, .length = str.len };
|
|
}
|
|
|
|
pub inline fn fromStringRef(str: c.MlirStringRef) []const u8 {
|
|
// Note: mlir.StringRef need not to be null terminated.
|
|
return str.data[0..str.length];
|
|
}
|
|
|
|
pub fn registerPasses(comptime passes: []const u8) void {
|
|
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
|
}
|
|
|
|
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
|
return if (res.value == 0) err;
|
|
}
|
|
|
|
pub fn MlirTypeMethods(comptime InnerT: type) type {
|
|
return struct {
|
|
is_null_fn: ?fn (InnerT) callconv(.C) bool = null,
|
|
is_a_fn: ?fn (InnerT) callconv(.C) bool = null,
|
|
equal_fn: ?fn (InnerT, InnerT) callconv(.C) bool = null,
|
|
dump_fn: ?fn (InnerT) callconv(.C) void = null,
|
|
deinit_fn: ?fn (InnerT) callconv(.C) void = null,
|
|
};
|
|
}
|
|
|
|
/// Alternative to MlirWrapperType
|
|
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
|
|
|
|
fn MlirHelpersMethods(OuterT: type) type {
|
|
switch (@typeInfo(OuterT)) {
|
|
.@"struct" => |info| {
|
|
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
|
|
},
|
|
else => @compileError("MlirHelpersMethods is only available on an Mlir wrapper struct. Received: " ++ @typeName(OuterT)),
|
|
}
|
|
|
|
return struct {
|
|
pub const InnerT = std.meta.FieldType(OuterT, ._inner);
|
|
comptime {
|
|
std.debug.assert(@sizeOf(InnerT) == @sizeOf(OuterT));
|
|
}
|
|
|
|
is_null_fn: ?fn (InnerT) callconv(.C) bool = null,
|
|
is_a_fn: ?fn (InnerT) callconv(.C) bool = null,
|
|
equal_fn: ?fn (InnerT, InnerT) callconv(.C) bool = null,
|
|
deinit_fn: ?fn (InnerT) callconv(.C) void = null,
|
|
dump_fn: ?fn (InnerT) callconv(.C) void = null,
|
|
print_fn: ?fn (InnerT, ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void = null,
|
|
};
|
|
}
|
|
|
|
pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(OuterT)) type {
|
|
const InnerT = @TypeOf(methods).InnerT;
|
|
return struct {
|
|
pub const Methods = methods;
|
|
|
|
pub inline fn wrap(raw: InnerT) OuterT {
|
|
return .{ ._inner = raw };
|
|
}
|
|
|
|
pub inline fn inner(self: OuterT) InnerT {
|
|
return self._inner;
|
|
}
|
|
|
|
pub inline fn innerPtr(self: *OuterT) *InnerT {
|
|
return &self._inner;
|
|
}
|
|
|
|
pub inline fn is_a(self: OuterT, comptime otherT: type) bool {
|
|
if (otherT.Methods.is_a_fn) |is_a_fn| {
|
|
return is_a_fn(self.inner());
|
|
}
|
|
return false;
|
|
}
|
|
|
|
pub inline fn as(self: OuterT, comptime OtherT: type) OtherT {
|
|
if (OtherT.Methods.is_a_fn) |is_a_fn| {
|
|
stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT });
|
|
return OtherT.wrap(self.inner());
|
|
}
|
|
// if the other type doesn't have an is_a_fn, try.
|
|
return OtherT.wrap(self.inner());
|
|
}
|
|
|
|
pub usingnamespace if (Methods.is_null_fn) |is_null| struct {
|
|
pub inline fn wrapOr(raw: InnerT) ?OuterT {
|
|
return if (is_null(raw)) null else OuterT.wrap(raw);
|
|
}
|
|
} else struct {};
|
|
|
|
pub usingnamespace if (Methods.equal_fn) |equal| struct {
|
|
pub inline fn eql(self: OuterT, other: OuterT) bool {
|
|
return equal(self.inner(), other.inner());
|
|
}
|
|
} else struct {};
|
|
|
|
pub usingnamespace if (Methods.deinit_fn) |_deinit| struct {
|
|
pub inline fn deinit(self: *OuterT) void {
|
|
_deinit(self.inner());
|
|
self.* = undefined;
|
|
}
|
|
} else struct {};
|
|
|
|
pub usingnamespace if (Methods.dump_fn) |_dump| struct {
|
|
pub inline fn dump(self: OuterT) void {
|
|
return _dump(self.inner());
|
|
}
|
|
} else struct {};
|
|
|
|
pub usingnamespace if (Methods.print_fn) |print| struct {
|
|
pub fn format(
|
|
self: OuterT,
|
|
comptime fmt: []const u8,
|
|
options: std.fmt.FormatOptions,
|
|
writer: anytype,
|
|
) !void {
|
|
_ = fmt;
|
|
_ = options;
|
|
|
|
const Writer = struct {
|
|
writer: @TypeOf(writer),
|
|
err: ?@TypeOf(writer).Error = null,
|
|
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void {
|
|
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
|
|
if (ctx.err) |_| return;
|
|
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
|
ctx.err = err;
|
|
return;
|
|
};
|
|
}
|
|
};
|
|
|
|
var context: Writer = .{ .writer = writer };
|
|
print(self.inner(), &Writer.printCallback, &context);
|
|
if (context.err) |err| return err;
|
|
}
|
|
} else struct {};
|
|
};
|
|
}
|
|
|
|
pub const Registry = struct {
|
|
_inner: c.MlirDialectRegistry,
|
|
pub usingnamespace MlirHelpers(Registry, .{
|
|
.is_null_fn = c.mlirDialectRegistryIsNull,
|
|
.deinit_fn = c.mlirDialectRegistryDestroy,
|
|
});
|
|
const Self = Registry;
|
|
|
|
pub fn init() !Self {
|
|
return Self.wrapOr(c.mlirDialectRegistryCreate()) orelse Error.MlirUnexpected;
|
|
}
|
|
};
|
|
|
|
pub const Context = struct {
|
|
_inner: c.MlirContext,
|
|
pub usingnamespace MlirHelpers(Context, .{
|
|
.is_null_fn = c.mlirContextIsNull,
|
|
.deinit_fn = c.mlirContextDestroy,
|
|
});
|
|
const Self = Context;
|
|
|
|
pub fn init() !Self {
|
|
return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirContextCreateWithRegistry(registry.inner(), threadingEnabled),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn setMultiThreading(self: *Self, enabled: bool) void {
|
|
c.mlirContextEnableMultithreading(self.inner(), enabled);
|
|
}
|
|
|
|
pub fn appendDialectRegistry(self: *Self, registry: Registry) void {
|
|
c.mlirContextAppendDialectRegistry(self.inner(), registry.inner());
|
|
}
|
|
|
|
pub fn loadAllAvailableDialects(self: *Self) void {
|
|
c.mlirContextLoadAllAvailableDialects(self.inner());
|
|
}
|
|
|
|
pub fn numRegisteredDialects(self: Self) usize {
|
|
return @intCast(c.mlirContextGetNumRegisteredDialects(self.inner()));
|
|
}
|
|
|
|
pub fn numLoadedDialects(self: Self) usize {
|
|
return @intCast(c.mlirContextGetNumLoadedDialects(self.inner()));
|
|
}
|
|
|
|
pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool {
|
|
return c.mlirContextIsRegisteredOperation(self.inner(), stringRef(op));
|
|
}
|
|
|
|
pub fn location(self: Self, src: std.builtin.SourceLocation) Location {
|
|
return Location.fromSrc(self, src);
|
|
}
|
|
};
|
|
|
|
pub const Module = struct {
|
|
_inner: c.MlirModule,
|
|
pub usingnamespace MlirHelpers(Module, .{
|
|
.is_null_fn = c.mlirModuleIsNull,
|
|
.deinit_fn = c.mlirModuleDestroy,
|
|
});
|
|
const Self = Module;
|
|
|
|
pub fn init(loc: Location) Self {
|
|
return Self.wrap(c.mlirModuleCreateEmpty(loc.inner()));
|
|
}
|
|
|
|
pub fn parse(ctx: Context, source: [:0]const u8) !Module {
|
|
return Module.wrapOr(
|
|
c.mlirModuleCreateParse(ctx.inner(), stringRef(source)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn fromOperation(operation: Operation) Module {
|
|
return Module.wrap(c.mlirModuleFromOperation(operation.inner()));
|
|
}
|
|
|
|
pub fn context(self: Module) Context {
|
|
return Context.wrap(c.mlirModuleGetContext(self.inner()));
|
|
}
|
|
|
|
pub fn getBody(self: Module) Block {
|
|
return Block.wrap(c.mlirModuleGetBody(self.inner()));
|
|
}
|
|
|
|
pub fn op(self: Module) Operation {
|
|
return Operation.wrap(c.mlirModuleGetOperation(self.inner()));
|
|
}
|
|
|
|
pub fn hash(self: Module, hasher: *std.hash.XxHash64) void {
|
|
return self.op().hash(hasher);
|
|
}
|
|
};
|
|
|
|
pub const PassManager = struct {
|
|
_inner: c.MlirPassManager,
|
|
|
|
pub usingnamespace MlirHelpers(PassManager, .{
|
|
.is_null_fn = c.mlirPassManagerIsNull,
|
|
.deinit_fn = c.mlirPassManagerDestroy,
|
|
});
|
|
const Self = PassManager;
|
|
|
|
pub fn init(ctx: Context) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirPassManagerCreate(ctx.inner()),
|
|
) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirPassManagerCreateOnOperation(ctx.inner(), stringRef(op)),
|
|
) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn asOpPassManager(self: Self) OpPassManager {
|
|
return OpPassManager.wrap(c.mlirPassManagerGetAsOpPassManager(self.inner()));
|
|
}
|
|
|
|
pub fn enableIRPrinting(self: *Self) void {
|
|
c.mlirPassManagerEnableIRPrinting(self.inner());
|
|
}
|
|
|
|
pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void {
|
|
if (c.mlirPassManagerRunOnOp(self.inner(), op.inner()).value == 0) {
|
|
return Error.InvalidMlir;
|
|
}
|
|
}
|
|
};
|
|
|
|
fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.C) void {
|
|
_ = ctx;
|
|
std.debug.print(">>ERROR: {s}\n", .{err.data});
|
|
}
|
|
|
|
pub const OpPassManager = struct {
|
|
_inner: c.MlirOpPassManager,
|
|
pub usingnamespace MlirHelpers(OpPassManager, .{});
|
|
|
|
pub fn addPipeline(self: *OpPassManager, pipeline: [:0]const u8) error{OutOfMemory}!void {
|
|
if (c.mlirOpPassManagerAddPipeline(
|
|
self.inner(),
|
|
stringRef(pipeline),
|
|
&_mlir_passpipeline_error,
|
|
null,
|
|
).value == 0) {
|
|
return Error.OutOfMemory;
|
|
}
|
|
}
|
|
};
|
|
|
|
pub const Identifier = struct {
|
|
_inner: c.MlirIdentifier,
|
|
pub usingnamespace MlirHelpers(Identifier, .{});
|
|
const Self = Identifier;
|
|
|
|
pub fn get(ctx: Context, str_: [:0]const u8) Self {
|
|
return Self.wrap(c.mlirIdentifierGet(ctx.inner(), stringRef(str_)));
|
|
}
|
|
|
|
pub fn context(self: Self) Context {
|
|
return Context.wrap(c.mlirIdentifierGetContext(self.inner()));
|
|
}
|
|
|
|
pub fn str(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirIdentifierStr(self.inner()));
|
|
}
|
|
|
|
pub fn equals(self: Self, other: Self) bool {
|
|
return c.mlirIdentifierEqual(self.inner(), other.inner());
|
|
}
|
|
};
|
|
|
|
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
|
|
|
pub const Attribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(Attribute, .{
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const Self = Attribute;
|
|
|
|
pub fn parse(ctx: Context, attr: [:0]const u8) !Attribute {
|
|
return Attribute.wrapOr(
|
|
c.mlirAttributeParseGet(ctx.inner(), stringRef(attr)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
// utilities function to built common attributes.
|
|
// All attributes are upcasted to the Attribute type, making it easier to chain construct,
|
|
// but losing type information.
|
|
|
|
pub fn null_() Attribute {
|
|
return .wrap(c.mlirAttributeGetNull());
|
|
}
|
|
|
|
pub fn string(ctx: Context, str: []const u8) Attribute {
|
|
return StringAttribute.init(ctx, str).asAttr();
|
|
}
|
|
|
|
pub fn type_(t: Type) Attribute {
|
|
return TypeAttribute.init(t).asAttr();
|
|
}
|
|
|
|
pub fn unit(ctx: Context) Attribute {
|
|
return .wrap(c.mlirUnitAttrGet(ctx.inner()));
|
|
}
|
|
|
|
pub fn boolean(ctx: Context, value: bool) Attribute {
|
|
return BoolAttribute.init(ctx, value).asAttr();
|
|
}
|
|
|
|
pub fn i1FromBool(ctx: Context, value: bool) Attribute {
|
|
return IntegerAttribute(.i1).init(ctx, @intFromBool(value)).asAttr();
|
|
}
|
|
|
|
pub fn int(ctx: Context, comptime int_type: IntegerTypes, value: i64) Attribute {
|
|
return IntegerAttribute(int_type).init(ctx, value).asAttr();
|
|
}
|
|
|
|
pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute {
|
|
return .wrap(FloatAttribute(float_type).init(ctx, value)._inner);
|
|
}
|
|
|
|
pub fn array(ctx: Context, attrs: []const Attribute) Attribute {
|
|
return ArrayAttribute.init(ctx, attrs).asAttr();
|
|
}
|
|
|
|
pub fn dense(ctx: Context, comptime dt: DenseArrayTypes, values: []const dt.ZigType()) Attribute {
|
|
return DenseArrayAttribute(dt).init(ctx, values).asAttr();
|
|
}
|
|
|
|
/// Use a tensor as an attribute.
|
|
/// The tensor is specified by dims, dtype and a flat slice of values.
|
|
pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: anytype) Attribute {
|
|
return .wrap(DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).inner());
|
|
}
|
|
|
|
pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute {
|
|
return .wrap(FlatSymbolRefAttribute.init(ctx, flat_name).inner());
|
|
}
|
|
|
|
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
|
|
return NamedAttribute.init(Identifier.get(ctx, name), attr);
|
|
}
|
|
|
|
pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute {
|
|
var attr_buf: [32]NamedAttribute = undefined;
|
|
stdx.debug.assert(named_attrs.len <= attr_buf.len, ".dict helper only support up to {} attribute, got {}. You will need to call mlir.DictionaryAttribute directly", .{ attr_buf.len, named_attrs.len });
|
|
|
|
const attrs = attr_buf[0..named_attrs.len];
|
|
for (attrs, named_attrs) |*attr, tuple| {
|
|
attr.* = .named(ctx, tuple[0], tuple[1]);
|
|
}
|
|
|
|
return DictionaryAttribute.init(ctx, attrs).asAttr();
|
|
}
|
|
};
|
|
|
|
pub const NamedAttribute = extern struct {
|
|
name: c.MlirIdentifier,
|
|
attribute: c.MlirAttribute,
|
|
|
|
pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute {
|
|
return .{
|
|
.name = c.mlirIdentifierGet(ctx._inner, stringRef(name)),
|
|
.attribute = attr.inner(),
|
|
};
|
|
}
|
|
|
|
pub fn init(name: Identifier, attr: Attribute) NamedAttribute {
|
|
return .{
|
|
.name = name.inner(),
|
|
.attribute = attr.inner(),
|
|
};
|
|
}
|
|
};
|
|
|
|
pub const StringAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(StringAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsAString,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const Self = StringAttribute;
|
|
|
|
pub fn init(ctx: Context, str: []const u8) Self {
|
|
return Self.wrap(c.mlirStringAttrGet(ctx.inner(), stringRef(str)));
|
|
}
|
|
|
|
pub fn value(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirStringAttrGetValue(self.inner()));
|
|
}
|
|
|
|
pub fn asAttr(self: StringAttribute) Attribute {
|
|
return .{ ._inner = self._inner };
|
|
}
|
|
};
|
|
|
|
pub const BoolAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(BoolAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsABool,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const Self = BoolAttribute;
|
|
|
|
pub fn init(ctx: Context, value_: bool) Self {
|
|
return Self.wrap(c.mlirBoolAttrGet(ctx.inner(), if (value_) 1 else 0));
|
|
}
|
|
|
|
pub fn value(self: Self) bool {
|
|
return c.mlirBoolAttrGetValue(self.inner());
|
|
}
|
|
|
|
pub fn asAttr(self: Self) Attribute {
|
|
return self.as(Attribute);
|
|
}
|
|
};
|
|
|
|
pub const TypeAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(TypeAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsAType,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
pub fn init(type_: Type) TypeAttribute {
|
|
return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner()));
|
|
}
|
|
|
|
pub fn typ(self: TypeAttribute) Type {
|
|
return Type.wrap(c.mlirAttributeGetType(self.inner()));
|
|
}
|
|
|
|
pub fn asAttr(self: TypeAttribute) Attribute {
|
|
return self.as(Attribute);
|
|
}
|
|
};
|
|
|
|
pub const ArrayAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(ArrayAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsAArray,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const Self = ArrayAttribute;
|
|
|
|
pub fn init(ctx: Context, attrs: []const Attribute) Self {
|
|
return Self.wrap(c.mlirArrayAttrGet(ctx.inner(), @intCast(attrs.len), @ptrCast(attrs.ptr)));
|
|
}
|
|
|
|
pub fn size(self: Self) usize {
|
|
return @intCast(c.mlirArrayAttrGetNumElements(self.inner()));
|
|
}
|
|
|
|
pub fn get(self: Self, index: usize) Attribute {
|
|
return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index)));
|
|
}
|
|
|
|
pub fn asAttr(self: Self) Attribute {
|
|
return .{ ._inner = self._inner };
|
|
}
|
|
};
|
|
|
|
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
|
const ZigType, const getter = comptime switch (it) {
|
|
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
|
|
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
|
|
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
|
|
.unknown => @compileError("IntegerAttribute(unknown)"),
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(@This(), .{
|
|
.is_a_fn = c.mlirAttributeIsAInteger,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
pub const IntegerTypeType = IntegerType(it);
|
|
const IntAttr = @This();
|
|
|
|
pub fn init(ctx: Context, value: i64) IntAttr {
|
|
return IntAttr.wrap(c.mlirIntegerAttrGet(
|
|
IntegerType(it).init(ctx).inner(),
|
|
value,
|
|
));
|
|
}
|
|
|
|
pub fn get(value: IntAttr) ZigType {
|
|
return @intCast(getter(value.inner()));
|
|
}
|
|
|
|
pub fn asAttr(self: IntAttr) Attribute {
|
|
return .{ ._inner = self._inner };
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn FloatAttribute(comptime ft: FloatTypes) type {
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(@This(), .{
|
|
.is_a_fn = c.mlirAttributeIsAFloat,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const FloatAttr = @This();
|
|
pub fn init(ctx: Context, value: f64) FloatAttr {
|
|
return FloatAttr.wrap(c.mlirFloatAttrDoubleGet(
|
|
ctx.inner(),
|
|
FloatType(ft).init(ctx).inner(),
|
|
value,
|
|
));
|
|
}
|
|
|
|
pub fn get(value: FloatAttr) f64 {
|
|
return c.mlirFloatAttrGetValueDouble(value.inner());
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const DenseArrayTypes = enum {
|
|
bool,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
f32,
|
|
f64,
|
|
|
|
pub fn ZigType(comptime dt: DenseArrayTypes) type {
|
|
return switch (dt) {
|
|
.bool => i32,
|
|
.i8 => i8,
|
|
.i16 => i16,
|
|
.i32 => i32,
|
|
.i64 => i64,
|
|
.f32 => f32,
|
|
.f64 => f64,
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
|
|
const is_a_fn, const get_fn, const get_element_fn = switch (dt) {
|
|
.bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
|
|
.i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
|
|
.i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
|
|
.i32 => .{ c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
|
|
.i64 => .{ c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
|
|
.f32 => .{ c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
|
|
.f64 => .{ c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(@This(), .{
|
|
.is_a_fn = is_a_fn,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
const Attr = @This();
|
|
const ElementType = dt;
|
|
const ElementTypeZig = dt.ZigType();
|
|
|
|
pub fn init(ctx: Context, values: []const ElementTypeZig) Attr {
|
|
return Attr.wrap(get_fn(ctx.inner(), @intCast(values.len), @ptrCast(values.ptr)));
|
|
}
|
|
|
|
pub fn get(self: Attr, pos: usize) ElementTypeZig {
|
|
return get_element_fn(self.inner(), @intCast(pos));
|
|
}
|
|
|
|
pub fn len(self: Attr) usize {
|
|
return @intCast(c.mlirDenseArrayGetNumElements(self.inner()));
|
|
}
|
|
|
|
pub fn asAttr(self: Attr) Attribute {
|
|
return Attribute.wrap(self._inner);
|
|
}
|
|
|
|
pub usingnamespace switch (dt) {
|
|
.bool, .i64 => struct {
|
|
const DenseArray = DenseArrayAttribute(switch (dt) {
|
|
.bool => .bool,
|
|
.i64 => .i64,
|
|
else => @compileError("DenseArrayAttribute: unreachable"),
|
|
});
|
|
},
|
|
else => struct {},
|
|
};
|
|
};
|
|
}
|
|
|
|
pub const DenseElementsAttributeTypes = enum {
|
|
bool,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
u8,
|
|
u16,
|
|
u32,
|
|
u64,
|
|
bf16,
|
|
f16,
|
|
f32,
|
|
f64,
|
|
index,
|
|
|
|
pub fn ZigType(comptime dt: DenseElementsAttributeTypes) type {
|
|
return switch (dt) {
|
|
.bool => bool,
|
|
.i8 => i8,
|
|
.i16 => i16,
|
|
.i32 => i32,
|
|
.i64 => i64,
|
|
.u8 => u8,
|
|
.u16 => u16,
|
|
.u32 => u32,
|
|
.u64 => u64,
|
|
.bf16 => u16,
|
|
.f16 => f16,
|
|
.f32 => f32,
|
|
.f64 => f64,
|
|
.index => usize,
|
|
};
|
|
}
|
|
|
|
pub fn mlirType(dt: DenseElementsAttributeTypes, ctx: Context) Type {
|
|
return switch (dt) {
|
|
.bool => .int(ctx, .i1),
|
|
.i8 => .int(ctx, .i8),
|
|
.i16 => .int(ctx, .i16),
|
|
.i32 => .int(ctx, .i32),
|
|
.i64 => .int(ctx, .i64),
|
|
.u8 => .int(ctx, .u8),
|
|
.u16 => .int(ctx, .u16),
|
|
.u32 => .int(ctx, .u32),
|
|
.u64 => .int(ctx, .u64),
|
|
.bf16 => .float(ctx, .bf16),
|
|
.f16 => .float(ctx, .f16),
|
|
.f32 => .float(ctx, .f32),
|
|
.f64 => .float(ctx, .f64),
|
|
.index => .index(ctx),
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
|
|
const Attr = @This();
|
|
|
|
pub usingnamespace MlirHelpers(Attr, .{
|
|
.is_a_fn = c.mlirAttributeIsADenseElements,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
|
|
pub fn init(shaped_type: Type, slice: anytype) Attr {
|
|
const bytes = std.mem.sliceAsBytes(slice);
|
|
const v = Attr.wrapOr(
|
|
c.mlirDenseElementsAttrRawBufferGet(
|
|
shaped_type.inner(),
|
|
@intCast(bytes.len),
|
|
@ptrCast(bytes.ptr),
|
|
),
|
|
) orelse unreachable;
|
|
return v;
|
|
}
|
|
|
|
pub fn len(self: Attr) usize {
|
|
return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
|
|
}
|
|
|
|
pub fn constSlice(self: Attr) []const dt.ZigType() {
|
|
const ptr: [*]const dt.ZigType() = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
|
|
return ptr[0..self.len()];
|
|
}
|
|
|
|
pub fn data(self: Attr) []const u8 {
|
|
return std.mem.sliceAsBytes(self.constSlice());
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const FlatSymbolRefAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(FlatSymbolRefAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsAFlatSymbolRef,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
|
|
const Self = FlatSymbolRefAttribute;
|
|
|
|
pub fn init(ctx: Context, str: [:0]const u8) Self {
|
|
return Self.wrap(c.mlirFlatSymbolRefAttrGet(ctx.inner(), stringRef(str)));
|
|
}
|
|
|
|
pub fn value(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirFlatSymbolRefAttrGetValue(self.inner()));
|
|
}
|
|
};
|
|
|
|
pub const OperationState = struct {
|
|
_inner: c.MlirOperationState,
|
|
pub usingnamespace MlirHelpers(
|
|
OperationState,
|
|
.{},
|
|
);
|
|
|
|
const Self = OperationState;
|
|
|
|
pub fn init(name: [:0]const u8, loc: Location) Self {
|
|
return Self.wrap(c.mlirOperationStateGet(stringRef(name), loc.inner()));
|
|
}
|
|
|
|
pub fn addResult(self: *Self, type_: Type) void {
|
|
c.mlirOperationStateAddResults(self.innerPtr(), 1, &[_]c.MlirType{type_.inner()});
|
|
}
|
|
|
|
pub fn addResults(self: *Self, types: []const Type) void {
|
|
c.mlirOperationStateAddResults(self.innerPtr(), @intCast(types.len), @ptrCast(types.ptr));
|
|
}
|
|
|
|
pub fn addOperand(self: *Self, value: Value) void {
|
|
c.mlirOperationStateAddOperands(self.innerPtr(), 1, &[_]c.MlirValue{value.inner()});
|
|
}
|
|
|
|
pub fn addOperands(self: *Self, values: []const Value) void {
|
|
c.mlirOperationStateAddOperands(self.innerPtr(), @intCast(values.len), @ptrCast(values.ptr));
|
|
}
|
|
|
|
pub fn addRegion(self: *Self, region: *Region) void {
|
|
c.mlirOperationStateAddOwnedRegions(self.innerPtr(), 1, &[_]c.MlirRegion{region.inner()});
|
|
}
|
|
|
|
pub fn addRegions(self: *Self, regions: []const Region) void {
|
|
c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr));
|
|
}
|
|
|
|
pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void {
|
|
c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{
|
|
.{
|
|
.name = Identifier.get(ctx, name).inner(),
|
|
.attribute = attr.inner(),
|
|
},
|
|
}));
|
|
}
|
|
|
|
pub fn addAttributeRaw(self: *Self, name: Identifier, attr: Attribute) void {
|
|
c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{
|
|
.{
|
|
.name = name.inner(),
|
|
.attribute = attr.inner(),
|
|
},
|
|
}));
|
|
}
|
|
|
|
pub fn addAttributes(self: *Self, attributes: []const NamedAttribute) void {
|
|
c.mlirOperationStateAddAttributes(self.innerPtr(), @intCast(attributes.len), @ptrCast(attributes.ptr));
|
|
}
|
|
|
|
pub fn resultTypeInference(self: *Self, enabled: bool) void {
|
|
self.innerPtr().enableResultTypeInference = enabled;
|
|
}
|
|
};
|
|
|
|
pub const DictionaryAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub usingnamespace MlirHelpers(DictionaryAttribute, .{
|
|
.is_a_fn = c.mlirAttributeIsADictionary,
|
|
.is_null_fn = c.mlirAttributeIsNull,
|
|
.dump_fn = c.mlirAttributeDump,
|
|
.equal_fn = c.mlirAttributeEqual,
|
|
});
|
|
|
|
pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute {
|
|
return DictionaryAttribute.wrap(c.mlirDictionaryAttrGet(
|
|
ctx.inner(),
|
|
@intCast(attributes.len),
|
|
@ptrCast(attributes.ptr),
|
|
));
|
|
}
|
|
|
|
pub fn size(self: DictionaryAttribute) usize {
|
|
return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner()));
|
|
}
|
|
|
|
pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute {
|
|
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
|
|
}
|
|
|
|
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?NamedAttribute {
|
|
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
|
|
}
|
|
|
|
pub fn asAttr(self: DictionaryAttribute) Attribute {
|
|
return .{ ._inner = self._inner };
|
|
}
|
|
};
|
|
|
|
pub const Operation = struct {
|
|
const Self = Operation;
|
|
_inner: c.MlirOperation,
|
|
|
|
pub usingnamespace MlirHelpers(
|
|
Operation,
|
|
.{
|
|
.is_null_fn = c.mlirOperationIsNull,
|
|
.deinit_fn = c.mlirOperationDestroy,
|
|
.dump_fn = c.mlirOperationDump,
|
|
.equal_fn = c.mlirOperationEqual,
|
|
},
|
|
);
|
|
|
|
pub fn init(state: *OperationState) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirOperationCreate(state.innerPtr()),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
|
operands: ?[]const Value = null,
|
|
variadic_operands: ?[]const []const Value = null,
|
|
tt_variadic_operands: ?[]const []const Value = null,
|
|
results: ?[]const Type = null,
|
|
variadic_results: ?[]const []const Type = null,
|
|
result_type_inference: ?bool = null,
|
|
n_regions: usize = 0,
|
|
attributes: ?[]const AttrTuple = null,
|
|
blocks: ?[]const Block = null,
|
|
verify: bool = true,
|
|
location: Location,
|
|
}) Self {
|
|
var state = OperationState.init(op_name, args.location);
|
|
std.debug.assert(!(args.operands != null and args.variadic_operands != null));
|
|
if (args.operands) |operands| {
|
|
state.addOperands(operands);
|
|
} else if (args.variadic_operands) |operands_segments| {
|
|
const MAX_SEGMENTS = 32;
|
|
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
|
|
|
for (operands_segments) |operands| {
|
|
state.addOperands(operands);
|
|
segments.appendAssumeCapacity(@intCast(operands.len));
|
|
}
|
|
state.addAttribute(ctx, "operandSegmentSizes", .denseElements(ctx, &.{@intCast(segments.len)}, .i32, segments.constSlice()));
|
|
} else if (args.tt_variadic_operands) |operands_segments| {
|
|
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
|
|
const MAX_SEGMENTS = 32;
|
|
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
|
|
|
for (operands_segments) |operands| {
|
|
state.addOperands(operands);
|
|
segments.appendAssumeCapacity(@intCast(operands.len));
|
|
}
|
|
state.addAttribute(ctx, "operandSegmentSizes", .dense(ctx, .i32, segments.constSlice()));
|
|
}
|
|
if (args.result_type_inference) |enable| {
|
|
state.resultTypeInference(enable);
|
|
}
|
|
std.debug.assert(!(args.results != null and args.variadic_results != null));
|
|
if (args.results) |results| {
|
|
state.addResults(results);
|
|
} else if (args.variadic_results) |result_segments| {
|
|
for (result_segments) |results| {
|
|
state.addResults(results);
|
|
}
|
|
}
|
|
for (0..args.n_regions) |_| {
|
|
var region_ = Region.init() catch {
|
|
@panic("Failed to create MLIR region");
|
|
};
|
|
state.addRegion(®ion_);
|
|
}
|
|
if (args.attributes) |attrs| {
|
|
for (attrs) |attr| {
|
|
state.addAttributeRaw(
|
|
Identifier.get(ctx, attr[0]),
|
|
attr[1],
|
|
);
|
|
}
|
|
}
|
|
if (args.blocks) |blocks_| {
|
|
for (blocks_) |block_| {
|
|
var region_ = Region.init() catch {
|
|
@panic("Failed to create MLIR region");
|
|
};
|
|
region_.appendBlock(block_);
|
|
state.addRegion(®ion_);
|
|
}
|
|
}
|
|
|
|
const new_op = Operation.init(&state) catch {
|
|
@panic("Failed to create MLIR operation");
|
|
};
|
|
if (args.verify and new_op.verify() == false) {
|
|
log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
|
@panic("Failed to verify MLIR operation");
|
|
}
|
|
return new_op;
|
|
}
|
|
|
|
pub fn initParse(ctx: Context, str: [:0]const u8) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirOperationCreateParse(ctx.inner(), stringRef(str), stringRef("pouet")),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn clone(self: Self) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirOperationClone(self.inner()),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn name(self: Self) Identifier {
|
|
return Identifier.wrap(c.mlirOperationGetName(self.inner()));
|
|
}
|
|
|
|
pub fn removeFromParent(self: *Self) void {
|
|
c.mlirOperationRemoveFromParent(self.inner());
|
|
}
|
|
|
|
pub fn numOperands(self: Self) usize {
|
|
return @intCast(c.mlirOperationGetNumOperands(self.inner()));
|
|
}
|
|
|
|
pub fn operand(self: Self, index: usize) Value {
|
|
return Value.wrap(c.mlirOperationGetOperand(self.inner(), @intCast(index)));
|
|
}
|
|
|
|
pub fn setOperand(self: *Self, index: usize, value: Value) void {
|
|
c.mlirOperationSetOperand(self.inner(), @intCast(index), value.inner());
|
|
}
|
|
|
|
pub fn numResults(self: Self) usize {
|
|
return @intCast(c.mlirOperationGetNumResults(self.inner()));
|
|
}
|
|
|
|
pub fn result(self: Self, index: usize) Value {
|
|
return Value.wrap(c.mlirOperationGetResult(self.inner(), @intCast(index)));
|
|
}
|
|
|
|
pub fn nextInBlock(self: Self) Self {
|
|
return Self.wrap(c.mlirOperationGetNextInBlock(self.inner()));
|
|
}
|
|
|
|
// pub fn previousInBlock(self: Self) Self {
|
|
// return Self.wrap(c.mlirOperationGetPrevInBlock(self.inner()));
|
|
// }
|
|
|
|
pub fn block(self: Self) ?Block {
|
|
return Block.wrapOr(c.mlirOperationGetBlock(self.inner()));
|
|
}
|
|
|
|
pub fn parent(self: Self) ?Self {
|
|
return Self.wrapOr(c.mlirOperationGetParentOperation(self.inner()));
|
|
}
|
|
|
|
pub fn region(self: Self, index: usize) Region {
|
|
return Region.wrap(c.mlirOperationGetRegion(self.inner(), @intCast(index)));
|
|
}
|
|
|
|
pub fn context(self: Self) Context {
|
|
return Context.wrap(c.mlirOperationGetContext(self.inner()));
|
|
}
|
|
|
|
pub fn writeBytecode(self: Self, writer: anytype) void {
|
|
var writer_context = .{ .writer = writer };
|
|
const WriterContext = @TypeOf(writer_context);
|
|
|
|
c.mlirOperationWriteBytecode(
|
|
self.inner(),
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
|
}
|
|
}).callback,
|
|
&writer_context,
|
|
);
|
|
}
|
|
|
|
pub fn writeBytecodeWithConfig(self: Self, writer: anytype, config: struct {
|
|
desiredEmitedVersion: ?i64 = null,
|
|
}) !void {
|
|
const cfg = c.mlirBytecodeWriterConfigCreate();
|
|
defer c.mlirBytecodeWriterConfigDestroy(cfg);
|
|
if (config.desiredEmitedVersion) |v| {
|
|
c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v);
|
|
}
|
|
|
|
const WriterContext = struct {
|
|
writer: @TypeOf(writer),
|
|
write_error: ?@TypeOf(writer).Error = null,
|
|
};
|
|
var writer_context: WriterContext = .{ .writer = writer };
|
|
|
|
try successOr(c.mlirOperationWriteBytecodeWithConfig(
|
|
self.inner(),
|
|
cfg,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
|
|
inner_writer_context.write_error = err;
|
|
};
|
|
}
|
|
}).callback,
|
|
&writer_context,
|
|
), error.InvalidMlirBytecodeVersion);
|
|
|
|
if (writer_context.write_error) |err| return err;
|
|
}
|
|
|
|
/// Enable a full dump of the IR.
|
|
/// Usage `std.log.debug("{}", .{ module.op().mlirFormatter(.{}) });
|
|
pub fn mlirFormatter(self: Operation, flags: OpPrintingFlags) MlirFormatter {
|
|
return .{ .op = self, .flags = flags };
|
|
}
|
|
|
|
const MlirFormatter = struct {
|
|
op: Operation,
|
|
flags: OpPrintingFlags,
|
|
|
|
pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
|
_ = fmt;
|
|
_ = options;
|
|
self.op.print(writer, self.flags);
|
|
}
|
|
};
|
|
|
|
pub fn print(self: Self, writer: anytype, flags: OpPrintingFlags) void {
|
|
const pflags = flags.create();
|
|
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
|
|
|
var writer_context = .{ .writer = writer };
|
|
const WriterContext = @TypeOf(writer_context);
|
|
c.mlirOperationPrintWithFlags(
|
|
self.inner(),
|
|
pflags,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
|
}
|
|
}).callback,
|
|
&writer_context,
|
|
);
|
|
}
|
|
|
|
pub fn verify(self: Self) bool {
|
|
return c.mlirOperationVerify(self.inner());
|
|
}
|
|
|
|
pub fn getLocation(self: Self) Location {
|
|
return Location.wrap(c.mlirOperationGetLocation(self.inner()));
|
|
}
|
|
|
|
pub const WalkOrder = enum(c.MlirWalkOrder) {
|
|
pre_order = c.MlirWalkPreOrder,
|
|
post_order = c.MlirWalkPostOrder,
|
|
};
|
|
|
|
pub const WalkResult = enum(c.MlirWalkResult) {
|
|
advance = c.MlirWalkResultAdvance,
|
|
interrupt = c.MlirWalkResultInterrupt,
|
|
skip = c.MlirWalkResultSkip,
|
|
};
|
|
|
|
pub fn walk(self: Self, order: WalkOrder, ctx: anytype, walkfn: fn (ctx: anytype, op: Operation) WalkResult) void {
|
|
var inner_ctx = .{ .ctx = ctx };
|
|
const ContextType = @TypeOf(inner_ctx);
|
|
|
|
c.mlirOperationWalk(
|
|
self.inner(),
|
|
(struct {
|
|
pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.C) c.MlirWalkResult {
|
|
const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_));
|
|
return @intFromEnum(walkfn(inner_ctx_.ctx, Operation.wrap(op)));
|
|
}
|
|
}).callback,
|
|
&inner_ctx,
|
|
@intFromEnum(order),
|
|
);
|
|
}
|
|
|
|
pub fn getAttribute(self: Self, pos: usize) NamedAttribute {
|
|
return NamedAttribute.wrap(c.mlirOperationGetAttribute(self.inner(), @intCast(pos)));
|
|
}
|
|
|
|
pub fn getAttributeByName(self: Self, name_: [:0]const u8) ?Attribute {
|
|
return Attribute.wrapOr(c.mlirOperationGetAttributeByName(self.inner(), stringRef(name_)));
|
|
}
|
|
|
|
pub fn setAttributeByName(self: Self, name_: [:0]const u8, attr: Attribute) void {
|
|
c.mlirOperationSetAttributeByName(self.inner(), stringRef(name_), attr.inner());
|
|
}
|
|
|
|
pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool {
|
|
return c.mlirOperationRemoveAttributeByName(self.inner(), stringRef(name_));
|
|
}
|
|
|
|
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
|
|
const NoError = error{};
|
|
const write = struct {
|
|
fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize {
|
|
hasher_.update(bytes);
|
|
return bytes.len;
|
|
}
|
|
}.write;
|
|
const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write);
|
|
const writer: HashWriter = .{ .context = hasher };
|
|
|
|
// Hash the canonicalized IR, without debug information that can change across builds.
|
|
// Note: before we where using op.writeBytecode(writer),
|
|
// but it crashes on some inputs, notably for unused variables.
|
|
// So we use the text representation of the mlir.
|
|
// See https://github.com/zml/zml/issues/97.
|
|
// Writes can't fail because we are writing to a hasher.
|
|
op.print(writer, .{ .debug_info = false });
|
|
}
|
|
};
|
|
|
|
pub const OpPrintingFlags = struct {
|
|
elide_large_elements_attrs: ?usize = null,
|
|
debug_info: bool = false,
|
|
debug_info_pretty_form: bool = true,
|
|
print_generic_op_form: bool = false,
|
|
use_local_scope: bool = false,
|
|
assume_verified: bool = false,
|
|
|
|
pub fn create(self: OpPrintingFlags) c.MlirOpPrintingFlags {
|
|
const pflags = c.mlirOpPrintingFlagsCreate();
|
|
if (self.elide_large_elements_attrs) |v| {
|
|
c.mlirOpPrintingFlagsElideLargeElementsAttrs(pflags, @intCast(v));
|
|
}
|
|
c.mlirOpPrintingFlagsEnableDebugInfo(pflags, self.debug_info, self.debug_info_pretty_form);
|
|
if (self.print_generic_op_form) {
|
|
c.mlirOpPrintingFlagsPrintGenericOpForm(pflags);
|
|
}
|
|
if (self.use_local_scope) {
|
|
c.mlirOpPrintingFlagsUseLocalScope(pflags);
|
|
}
|
|
if (self.assume_verified) {
|
|
c.mlirOpPrintingFlagsAssumeVerified(pflags);
|
|
}
|
|
return pflags;
|
|
}
|
|
};
|
|
|
|
pub const OpOperand = struct {
|
|
_inner: c.MlirOpOperand,
|
|
pub usingnamespace MlirHelpers(OpOperand, .{
|
|
.is_null_fn = c.mlirOpOperandIsNull,
|
|
});
|
|
|
|
const Self = OpOperand;
|
|
|
|
pub fn owner(self: Self) Operation {
|
|
return Operation.wrap(c.mlirOpOperandGetOwner(self.inner()));
|
|
}
|
|
|
|
pub fn number(self: Self) usize {
|
|
return @intCast(c.mlirOpOperandGetOperandNumber(self.inner()));
|
|
}
|
|
|
|
pub fn nextUse(self: Self) ?Self {
|
|
return Self.wrapOr(
|
|
c.mlirOpOperandGetNextUse(self.inner()),
|
|
);
|
|
}
|
|
};
|
|
|
|
pub const Region = struct {
|
|
_inner: c.MlirRegion,
|
|
pub usingnamespace MlirHelpers(Region, .{
|
|
.is_null_fn = c.mlirRegionIsNull,
|
|
.deinit_fn = c.mlirRegionDestroy,
|
|
.equal_fn = c.mlirRegionEqual,
|
|
});
|
|
|
|
const Self = Region;
|
|
|
|
pub fn init() !Self {
|
|
return Self.wrapOr(
|
|
c.mlirRegionCreate(),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn appendBlock(self: *Self, block: Block) void {
|
|
c.mlirRegionAppendOwnedBlock(self.inner(), block.inner());
|
|
}
|
|
|
|
pub fn insertBlock(self: *Self, index: isize, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlock(self.inner(), index, block.inner());
|
|
}
|
|
|
|
pub fn insertBlockBefore(self: *Self, reference: Block, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlockBefore(self.inner(), reference.inner(), block.inner());
|
|
}
|
|
|
|
pub fn insertBlockAfter(self: *Self, reference: Block, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlockAfter(self.inner(), reference.inner(), block.inner());
|
|
}
|
|
|
|
pub fn firstBlock(self: Self) Block {
|
|
return Block.wrap(c.mlirRegionGetFirstBlock(self.inner()));
|
|
}
|
|
};
|
|
|
|
pub const Value = struct {
|
|
_inner: c.MlirValue,
|
|
|
|
pub usingnamespace MlirHelpers(Value, .{
|
|
.is_null_fn = c.mlirValueIsNull,
|
|
.equal_fn = c.mlirValueEqual,
|
|
.dump_fn = c.mlirValueDump,
|
|
.print_fn = c.mlirValuePrint,
|
|
});
|
|
|
|
pub fn getType(val: Value) Type {
|
|
return Type.wrap(c.mlirValueGetType(val.inner()));
|
|
}
|
|
|
|
pub fn setType(val: *Value, typ: Type) void {
|
|
c.mlirValueSetType(val.inner(), typ.inner());
|
|
}
|
|
|
|
pub fn firstUse(val: Value) OpOperand {
|
|
return OpOperand.wrap(c.mlirValueGetFirstUse(val.inner()));
|
|
}
|
|
|
|
pub fn replaceAllUsesWith(val: Value, with: Value) void {
|
|
c.mlirValueReplaceAllUsesOfWith(val.inner(), with.inner());
|
|
}
|
|
|
|
pub fn owner(val: Value) Operation {
|
|
return Operation.wrap(c.mlirOpResultGetOwner(val.inner()));
|
|
}
|
|
|
|
pub fn isABlockArgument(val: Value) bool {
|
|
return c.mlirValueIsABlockArgument(val.inner());
|
|
}
|
|
|
|
pub fn isAOpResult(val: Value) bool {
|
|
return c.mlirValueIsAOpResult(val.inner());
|
|
}
|
|
|
|
pub const Kind = union(enum) {
|
|
block_argument: BlockArgument,
|
|
op_result: Operation,
|
|
null,
|
|
};
|
|
|
|
pub fn kind(val: Value) Kind {
|
|
if (val.isAOpResult()) {
|
|
return .{ .op_result = val.owner() };
|
|
}
|
|
if (val.isABlockArgument()) {
|
|
return .{ .block_argument = .{ ._inner = val._inner } };
|
|
}
|
|
// From MLIR docs:
|
|
// https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#details
|
|
// > An SSA value is either a BlockArgument or the result of an operation.
|
|
return .null;
|
|
}
|
|
};
|
|
|
|
pub const BlockArgument = struct {
|
|
_inner: c.MlirValue,
|
|
|
|
pub fn block(arg: BlockArgument) Block {
|
|
return Block.wrap(c.mlirBlockArgumentGetOwner(arg._inner));
|
|
}
|
|
|
|
pub fn number(arg: BlockArgument) usize {
|
|
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
|
}
|
|
|
|
pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
|
const value = Value{ ._inner = self._inner };
|
|
return value.format(fmt, options, writer);
|
|
}
|
|
};
|
|
|
|
pub const Type = struct {
|
|
_inner: c.MlirType,
|
|
|
|
pub usingnamespace MlirHelpers(Type, .{
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
.print_fn = c.mlirTypePrint,
|
|
});
|
|
|
|
pub fn parse(ctx: Context, str: [:0]const u8) !Type {
|
|
return Type.wrapOr(
|
|
c.mlirTypeParseGet(ctx.inner(), stringRef(str)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn index(ctx: Context) Type {
|
|
return IndexType.init(ctx).as(Type);
|
|
}
|
|
|
|
pub fn int(ctx: Context, int_type: IntegerTypes) Type {
|
|
return switch (int_type) {
|
|
.unknown => @panic("Unknown integer type"),
|
|
inline else => |t| IntegerType(t).init(ctx).as(Type),
|
|
};
|
|
}
|
|
|
|
pub fn float(ctx: Context, float_type: FloatTypes) Type {
|
|
return switch (float_type) {
|
|
inline else => |t| FloatType(t).init(ctx).as(Type),
|
|
};
|
|
}
|
|
|
|
pub fn complex(ctx: Context, complex_type: ComplexTypes) Type {
|
|
return switch (complex_type) {
|
|
inline else => |t| ComplexType(t).init(ctx).as(Type),
|
|
};
|
|
}
|
|
|
|
pub fn tuple(ctx: Context, types: []const Type) Type {
|
|
return (TupleType.init(ctx, types) catch unreachable).as(Type);
|
|
}
|
|
|
|
pub fn function(ctx: Context, args: []const Type, results: []const Type) Type {
|
|
return (FunctionType.init(ctx, args, results) catch unreachable).as(Type);
|
|
}
|
|
|
|
pub fn tensor(dimensions: []const i64, elem_type: Type) Type {
|
|
return RankedTensorType.init(dimensions, elem_type).as(Type);
|
|
}
|
|
};
|
|
|
|
pub const IndexType = struct {
|
|
_inner: c.MlirType,
|
|
|
|
pub usingnamespace MlirHelpers(IndexType, .{
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
.print_fn = c.mlirTypePrint,
|
|
});
|
|
|
|
pub fn init(ctx: Context) IndexType {
|
|
return IndexType.wrap(c.mlirIndexTypeGet(ctx.inner()));
|
|
}
|
|
};
|
|
|
|
pub const IntegerTypes = enum {
|
|
i1,
|
|
i4,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
si4,
|
|
si8,
|
|
si16,
|
|
si32,
|
|
si64,
|
|
u4,
|
|
u8,
|
|
u16,
|
|
u32,
|
|
u64,
|
|
|
|
unknown,
|
|
};
|
|
|
|
pub fn IntegerType(comptime it: IntegerTypes) type {
|
|
const Config = switch (it) {
|
|
.i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i32 => .{ 32, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i64 => .{ 64, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.si4 => .{ 4, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si8 => .{ 8, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u32 => .{ 32, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u64 => .{ 64, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.unknown => .{ 0, null, null },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
|
|
const Int = @This();
|
|
pub usingnamespace MlirHelpers(Int, .{
|
|
.is_a_fn = switch (it) {
|
|
.unknown => c.mlirTypeIsAInteger,
|
|
else => typeIsAIntegerExact,
|
|
},
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
const IntegerTypeType = it;
|
|
|
|
fn typeIsAIntegerExact(typ: c.MlirType) callconv(.C) bool {
|
|
const bit_width = Config[0];
|
|
const is_sign = Config[2];
|
|
return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ);
|
|
}
|
|
pub usingnamespace if (it != .unknown) struct {
|
|
pub const BitWidth = Config[0];
|
|
|
|
pub fn init(ctx: Context) Int {
|
|
const type_get = Config[1];
|
|
return Int.wrap(type_get(ctx.inner(), BitWidth));
|
|
}
|
|
} else struct {};
|
|
};
|
|
}
|
|
|
|
pub const FloatTypes = enum {
|
|
f8e4m3b11fnuz,
|
|
f8e4m3fn,
|
|
f8e4m3fnuz,
|
|
f8e5m2,
|
|
f8e5m2fnuz,
|
|
bf16,
|
|
f16,
|
|
f32,
|
|
f64,
|
|
|
|
pub fn asType(self: FloatTypes, ctx: Context) Type {
|
|
return switch (self) {
|
|
inline else => |ft| FloatType(ft).init(ctx).as(Type),
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn FloatType(comptime ft: FloatTypes) type {
|
|
const Config = switch (ft) {
|
|
.f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet },
|
|
.f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet },
|
|
.f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet },
|
|
.f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet },
|
|
.f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet },
|
|
.bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet },
|
|
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
|
|
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
|
|
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
|
|
const Self = @This();
|
|
|
|
pub usingnamespace MlirHelpers(Self, .{
|
|
.is_a_fn = Config[0],
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
|
|
pub fn init(ctx: Context) Self {
|
|
const type_get = Config[1];
|
|
return Self.wrap(type_get(ctx.inner()));
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const ComplexTypes = enum {
|
|
c64,
|
|
c128,
|
|
|
|
unknown,
|
|
};
|
|
|
|
pub fn ComplexType(comptime ct: ComplexTypes) type {
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
const Complex = @This();
|
|
|
|
fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType {
|
|
return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx));
|
|
}
|
|
|
|
fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType {
|
|
return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx));
|
|
}
|
|
|
|
fn mlirTypeIsAC64(typ: c.MlirType) callconv(.C) bool {
|
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
|
return c.mlirTypeIsAF32(element_type);
|
|
}
|
|
|
|
fn mlirTypeIsAC128(typ: c.MlirType) callconv(.C) bool {
|
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
|
return c.mlirTypeIsAF64(element_type);
|
|
}
|
|
|
|
const Config = switch (ct) {
|
|
.c64 => .{ mlirTypeIsAC64, mlirC64TypeGet },
|
|
.c128 => .{ mlirTypeIsAC128, mlirC128TypeGet },
|
|
.unknown => .{ c.mlirTypeIsAComplex, null },
|
|
};
|
|
|
|
fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.C) bool {
|
|
return c.mlirTypeIsAComplex(typ);
|
|
}
|
|
|
|
pub usingnamespace MlirHelpers(@This(), .{
|
|
.is_a_fn = Config[0],
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
|
|
pub usingnamespace if (ct != .unknown) struct {
|
|
pub const ComplexTypeType = ct;
|
|
|
|
pub fn init(ctx: Context) Complex {
|
|
const type_get = Config[1];
|
|
return Complex.wrap(type_get(ctx.inner()));
|
|
}
|
|
} else struct {};
|
|
};
|
|
}
|
|
|
|
pub const TupleType = struct {
|
|
_inner: c.MlirType,
|
|
pub usingnamespace MlirHelpers(TupleType, .{
|
|
.is_a_fn = c.mlirTypeIsATuple,
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
|
|
const Self = TupleType;
|
|
|
|
pub fn init(ctx: Context, elements: []const Type) !Self {
|
|
return Self.wrapOr(c.mlirTupleTypeGet(
|
|
ctx.inner(),
|
|
@intCast(elements.len),
|
|
@ptrCast(elements.ptr),
|
|
)) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn getNumTypes(self: Self) usize {
|
|
return @intCast(c.mlirTupleTypeGetNumTypes(self.inner()));
|
|
}
|
|
|
|
pub fn getElementType(self: Self, index: usize) Type {
|
|
return Type.wrap(c.mlirTupleTypeGetType(self.inner(), @intCast(index)));
|
|
}
|
|
};
|
|
|
|
pub const FunctionType = struct {
|
|
_inner: c.MlirType,
|
|
pub usingnamespace MlirHelpers(FunctionType, .{
|
|
.is_a_fn = c.mlirTypeIsAFunction,
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
|
|
const Self = FunctionType;
|
|
|
|
pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self {
|
|
return Self.wrapOr(c.mlirFunctionTypeGet(
|
|
ctx.inner(),
|
|
@intCast(args.len),
|
|
@ptrCast(args.ptr),
|
|
@intCast(results.len),
|
|
@ptrCast(results.ptr),
|
|
)) orelse Error.InvalidMlir;
|
|
}
|
|
};
|
|
|
|
pub const RankedTensorType = struct {
|
|
_inner: c.MlirType,
|
|
pub usingnamespace MlirHelpers(RankedTensorType, .{
|
|
.is_a_fn = c.mlirTypeIsARankedTensor,
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
.print_fn = c.mlirTypePrint,
|
|
});
|
|
|
|
pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType {
|
|
return RankedTensorType.wrap(
|
|
c.mlirRankedTensorTypeGet(
|
|
@intCast(dimensions.len),
|
|
@ptrCast(dimensions.ptr),
|
|
elemType.inner(),
|
|
c.mlirAttributeGetNull(),
|
|
),
|
|
);
|
|
}
|
|
|
|
pub fn getElementType(self: RankedTensorType) Type {
|
|
return Type.wrap(c.mlirShapedTypeGetElementType(self.inner()));
|
|
}
|
|
|
|
pub fn getRank(self: RankedTensorType) usize {
|
|
return @intCast(c.mlirShapedTypeGetRank(self.inner()));
|
|
}
|
|
|
|
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
|
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
|
|
}
|
|
};
|
|
|
|
pub const Dialect = struct {
|
|
_inner: c.MlirDialect,
|
|
pub usingnamespace MlirHelpers(Dialect, .{
|
|
.equal_fn = c.mlirDialectEqual,
|
|
.is_null_fn = c.mlirDialectIsNull,
|
|
});
|
|
|
|
const Self = Dialect;
|
|
|
|
pub fn getContext(self: Self) Context {
|
|
return Context.wrap(c.mlirDialectGetContext(self.inner()));
|
|
}
|
|
|
|
pub fn getNamespace(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirDialectGetNamespace(self.inner()));
|
|
}
|
|
};
|
|
|
|
pub const DialectHandle = struct {
|
|
_inner: c.MlirDialectHandle,
|
|
pub usingnamespace MlirHelpers(
|
|
DialectHandle,
|
|
.{},
|
|
);
|
|
|
|
pub fn getNamespace(self: DialectHandle) []const u8 {
|
|
return fromStringRef(c.mlirDialectHandleGetNamespace(self.inner()));
|
|
}
|
|
|
|
pub fn insertDialect(self: DialectHandle, registry: Registry) void {
|
|
c.mlirDialectHandleInsertDialect(self.inner(), registry.inner());
|
|
}
|
|
|
|
pub fn registerDialect(self: DialectHandle, ctx: Context) void {
|
|
c.mlirDialectHandleRegisterDialect(self.inner(), ctx.inner());
|
|
}
|
|
|
|
pub fn loadDialect(self: DialectHandle, ctx: Context) Dialect {
|
|
return Dialect.wrap(c.mlirDialectHandleLoadDialect(self.inner(), ctx.inner()));
|
|
}
|
|
|
|
pub fn fromString(comptime namespace: []const u8) DialectHandle {
|
|
return DialectHandle.wrap(@field(c, "mlirGetDialectHandle__" ++ namespace ++ "__")());
|
|
}
|
|
};
|
|
|
|
pub const ShapedType = struct {
|
|
_inner: c.MlirType,
|
|
pub usingnamespace MlirHelpers(ShapedType, .{
|
|
.is_a_fn = c.mlirTypeIsAShaped,
|
|
.is_null_fn = c.mlirTypeIsNull,
|
|
.dump_fn = c.mlirTypeDump,
|
|
.equal_fn = c.mlirTypeEqual,
|
|
});
|
|
const Self = ShapedType;
|
|
|
|
pub fn rank(self: Self) usize {
|
|
return @intCast(c.mlirShapedTypeGetRank(self.inner()));
|
|
}
|
|
|
|
pub fn elementType(self: Self) Type {
|
|
return Type.wrap(c.mlirShapedTypeGetElementType(self.inner()));
|
|
}
|
|
|
|
pub fn dimension(self: Self, dim: usize) usize {
|
|
return @intCast(c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)));
|
|
}
|
|
};
|
|
|
|
pub const Location = struct {
|
|
_inner: c.MlirLocation,
|
|
|
|
pub usingnamespace MlirHelpers(Location, .{
|
|
.is_null_fn = c.mlirLocationIsNull,
|
|
.equal_fn = c.mlirLocationEqual,
|
|
.print_fn = c.mlirLocationPrint,
|
|
});
|
|
|
|
pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location {
|
|
return Location.wrap(c.mlirLocationFileLineColGet(
|
|
ctx.inner(),
|
|
stringRef(src.file),
|
|
@intCast(src.line),
|
|
@intCast(src.column),
|
|
));
|
|
}
|
|
|
|
pub fn fileLineCol(ctx: Context, file: []const u8, line: usize, column: usize) Location {
|
|
return Location.wrap(c.mlirLocationFileLineColGet(
|
|
ctx.inner(),
|
|
stringRef(file),
|
|
@intCast(line),
|
|
@intCast(column),
|
|
));
|
|
}
|
|
|
|
pub fn callSite(callee: Location, caller: Location) Location {
|
|
return Location.wrap(c.mlirLocationCallSiteGet(callee.inner(), caller.inner()));
|
|
}
|
|
|
|
pub fn fused(ctx: Context, locations: []const Location, metadata: Attribute) Location {
|
|
return Location.wrap(c.mlirLocationFusedGet(
|
|
ctx.inner(),
|
|
@intCast(locations.len),
|
|
@ptrCast(locations.ptr),
|
|
metadata.inner(),
|
|
));
|
|
}
|
|
|
|
pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location {
|
|
return Location.wrap(c.mlirLocationNameGet(ctx.inner(), stringRef(loc_name), loc.inner()));
|
|
}
|
|
|
|
pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location {
|
|
var buf: [256]u8 = undefined;
|
|
var stream = std.io.fixedBufferStream(&buf);
|
|
std.fmt.format(stream.writer(), fmt, args) catch {
|
|
buf[256 - 3 ..].* = "...".*;
|
|
};
|
|
return loc.named(ctx, @ptrCast(stream.getWritten()));
|
|
}
|
|
|
|
pub fn unknown(ctx: Context) Location {
|
|
return Location.wrap(c.mlirLocationUnknownGet(ctx.inner()));
|
|
}
|
|
};
|
|
|
|
pub const Block = struct {
|
|
_inner: c.MlirBlock,
|
|
pub usingnamespace MlirHelpers(Block, .{
|
|
.is_null_fn = c.mlirBlockIsNull,
|
|
.deinit_fn = c.mlirBlockDestroy,
|
|
.equal_fn = c.mlirBlockEqual,
|
|
});
|
|
|
|
pub fn init(args: []const Type, locs: []const Location) !Block {
|
|
const block = Block.wrapOr(
|
|
c.mlirBlockCreate(@intCast(args.len), @ptrCast(args.ptr), @ptrCast(locs.ptr)),
|
|
);
|
|
return block orelse error.InvalidMlir;
|
|
}
|
|
|
|
pub fn argument(self: Block, index: usize) Value {
|
|
const arg = c.mlirBlockGetArgument(self.inner(), @intCast(index));
|
|
stdx.debug.assert(!Value.Methods.is_null_fn.?(arg), "Block doesn't have argument {}, only got {}", .{ index, self.numArguments() });
|
|
return Value.wrap(arg);
|
|
}
|
|
|
|
pub fn numArguments(self: Block) usize {
|
|
return @intCast(c.mlirBlockGetNumArguments(self.inner()));
|
|
}
|
|
|
|
pub fn addArgument(self: *Block, typ: Type, loc: Location) Value {
|
|
return Value.wrap(c.mlirBlockAddArgument(self.inner(), typ.inner(), loc.inner()));
|
|
}
|
|
|
|
pub fn insertArgument(self: *Block, index: usize, typ: Type, loc: Location) Value {
|
|
return Value.wrap(c.mlirBlockInsertArgument(self.inner(), @intCast(index), typ.inner(), loc.inner()));
|
|
}
|
|
|
|
pub fn equals(self: Block, other: Block) bool {
|
|
return c.mlirBlockEqual(self.inner(), other.inner());
|
|
}
|
|
|
|
pub fn appendOperation(self: Block, op: Operation) void {
|
|
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
|
|
}
|
|
|
|
pub fn appendOperations(self: *Block, ops: []const Operation) void {
|
|
for (ops) |op| {
|
|
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
|
|
}
|
|
}
|
|
|
|
pub const RecursiveOpts = enum { open, hermetic };
|
|
|
|
pub fn appendValueRecursive(self: Block, value: Value, opt: RecursiveOpts) void {
|
|
switch (value.kind()) {
|
|
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
|
|
.block_argument => |arg| {
|
|
// Hermetic blocks are not allowed to use arguments from other blocks.
|
|
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
|
|
},
|
|
.null => @panic("InvalidMlir"),
|
|
}
|
|
}
|
|
|
|
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
|
|
if (op.block()) |prev_block| {
|
|
// Hermetic blocks are not allowed to reference values from other blocks.
|
|
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
|
|
return;
|
|
}
|
|
for (0..op.numOperands()) |i| {
|
|
self.appendValueRecursive(op.operand(i), opt);
|
|
}
|
|
self.appendOperation(op);
|
|
}
|
|
};
|