1776 lines
58 KiB
Zig
Executable File
1776 lines
58 KiB
Zig
Executable File
const std = @import("std");
|
|
const builtin = @import("builtin");
|
|
|
|
const c = @import("c");
|
|
const stdx = @import("stdx");
|
|
|
|
const log = std.log.scoped(.mlir);
|
|
|
|
test {
|
|
std.testing.refAllDecls(@This());
|
|
|
|
_ = try Context.init();
|
|
}
|
|
|
|
const Error = error{
|
|
/// Invalid Mlir was created.
|
|
InvalidMlir,
|
|
/// Another Mlir error. Check the log for more context.
|
|
MlirUnexpected,
|
|
/// A resource/executor was not found.
|
|
NotFound,
|
|
OutOfMemory,
|
|
};
|
|
|
|
pub inline fn stringRef(str: []const u8) c.MlirStringRef {
|
|
return .{ .data = str.ptr, .length = str.len };
|
|
}
|
|
|
|
pub inline fn fromStringRef(str: c.MlirStringRef) []const u8 {
|
|
// Note: mlir.StringRef need not to be null terminated.
|
|
return str.data[0..str.length];
|
|
}
|
|
|
|
pub fn registerPasses(comptime passes: []const u8) void {
|
|
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
|
}
|
|
|
|
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
|
return if (res.value == 0) err;
|
|
}
|
|
|
|
/// Alternative to MlirWrapperType
|
|
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
|
|
|
|
pub const Registry = struct {
|
|
_inner: c.MlirDialectRegistry,
|
|
|
|
pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy);
|
|
|
|
pub fn init() !Registry {
|
|
return helpers.init(Registry, c.mlirDialectRegistryCreate(), c.mlirDialectRegistryIsNull) orelse Error.MlirUnexpected;
|
|
}
|
|
};
|
|
|
|
pub const Context = struct {
|
|
_inner: c.MlirContext,
|
|
const Self = Context;
|
|
pub const deinit = helpers.deinit(Context, c.mlirContextDestroy);
|
|
pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull);
|
|
|
|
pub fn init() !Self {
|
|
return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn setMultiThreading(self: *Self, enabled: bool) void {
|
|
c.mlirContextEnableMultithreading(self._inner, enabled);
|
|
}
|
|
|
|
pub fn appendDialectRegistry(self: *Self, registry: Registry) void {
|
|
c.mlirContextAppendDialectRegistry(self._inner, registry._inner);
|
|
}
|
|
|
|
pub fn loadAllAvailableDialects(self: *Self) void {
|
|
c.mlirContextLoadAllAvailableDialects(self._inner);
|
|
}
|
|
|
|
pub fn numRegisteredDialects(self: Self) usize {
|
|
return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner));
|
|
}
|
|
|
|
pub fn numLoadedDialects(self: Self) usize {
|
|
return @intCast(c.mlirContextGetNumLoadedDialects(self._inner));
|
|
}
|
|
|
|
pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool {
|
|
return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op));
|
|
}
|
|
|
|
pub fn location(self: Self, src: std.builtin.SourceLocation) Location {
|
|
return Location.fromSrc(self, src);
|
|
}
|
|
};
|
|
|
|
pub const Module = struct {
|
|
_inner: c.MlirModule,
|
|
|
|
pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy);
|
|
pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull);
|
|
|
|
const Self = Module;
|
|
|
|
pub fn init(loc: Location) Self {
|
|
return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) };
|
|
}
|
|
|
|
pub fn parse(ctx: Context, source: [:0]const u8) !Module {
|
|
return Module.wrapOr(
|
|
c.mlirModuleCreateParse(ctx._inner, stringRef(source)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn fromOperation(operation: Operation) Module {
|
|
return .{ ._inner = c.mlirModuleFromOperation(operation._inner) };
|
|
}
|
|
|
|
pub fn context(self: Module) Context {
|
|
return .{ ._inner = c.mlirModuleGetContext(self._inner) };
|
|
}
|
|
|
|
pub fn getBody(self: Module) Block {
|
|
return .{ ._inner = c.mlirModuleGetBody(self._inner) };
|
|
}
|
|
|
|
pub fn op(self: Module) Operation {
|
|
return .{ ._inner = c.mlirModuleGetOperation(self._inner) };
|
|
}
|
|
|
|
pub fn hash(self: Module, hasher: *std.hash.XxHash64) void {
|
|
return self.op().hash(hasher);
|
|
}
|
|
};
|
|
|
|
pub const PassManager = struct {
|
|
_inner: c.MlirPassManager,
|
|
|
|
pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy);
|
|
pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull);
|
|
|
|
const Self = PassManager;
|
|
|
|
pub fn init(ctx: Context) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirPassManagerCreate(ctx._inner),
|
|
) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)),
|
|
) orelse Error.MlirUnexpected;
|
|
}
|
|
|
|
pub fn asOpPassManager(self: Self) OpPassManager {
|
|
return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) };
|
|
}
|
|
|
|
pub fn enableIRPrinting(self: *Self) void {
|
|
c.mlirPassManagerEnableIRPrinting(self._inner);
|
|
}
|
|
|
|
pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void {
|
|
if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) {
|
|
return Error.InvalidMlir;
|
|
}
|
|
}
|
|
};
|
|
|
|
fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.c) void {
|
|
_ = ctx;
|
|
std.debug.print(">>ERROR: {s}\n", .{err.data});
|
|
}
|
|
|
|
pub const OpPassManager = struct {
|
|
_inner: c.MlirOpPassManager,
|
|
|
|
pub fn addPipeline(self: *OpPassManager, pipeline: [:0]const u8) error{OutOfMemory}!void {
|
|
if (c.mlirOpPassManagerAddPipeline(
|
|
self._inner,
|
|
stringRef(pipeline),
|
|
&_mlir_passpipeline_error,
|
|
null,
|
|
).value == 0) {
|
|
return Error.OutOfMemory;
|
|
}
|
|
}
|
|
};
|
|
|
|
pub const Identifier = struct {
|
|
_inner: c.MlirIdentifier,
|
|
const Self = Identifier;
|
|
|
|
pub fn get(ctx: Context, str_: [:0]const u8) Self {
|
|
return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) };
|
|
}
|
|
|
|
pub fn context(self: Self) Context {
|
|
return .{ ._inner = c.mlirIdentifierGetContext(self._inner) };
|
|
}
|
|
|
|
pub fn str(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirIdentifierStr(self._inner));
|
|
}
|
|
|
|
pub fn equals(self: Self, other: Self) bool {
|
|
return c.mlirIdentifierEqual(self._inner, other._inner);
|
|
}
|
|
};
|
|
|
|
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
|
|
|
pub const Attribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
|
|
pub const dump = helpers.dump(Attribute, c.mlirAttributeDump);
|
|
pub const eql = helpers.eql(Attribute, c.mlirAttributeEqual);
|
|
pub const format = helpers.format(Attribute, c.mlirAttributePrint);
|
|
pub const wrapOr = helpers.wrapOr(Attribute, c.mlirAttributeIsNull);
|
|
|
|
pub fn wrap(c_attr: c.MlirAttribute) Attribute {
|
|
return .{ ._inner = c_attr };
|
|
}
|
|
|
|
pub fn parse(ctx: Context, attr: [:0]const u8) !Attribute {
|
|
return Attribute.wrapOr(
|
|
c.mlirAttributeParseGet(ctx._inner, stringRef(attr)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn fromAny(SpecificAttr: type) fn (x: SpecificAttr) Attribute {
|
|
return struct {
|
|
fn cast(x: SpecificAttr) Attribute {
|
|
return .{ ._inner = x._inner };
|
|
}
|
|
}.cast;
|
|
}
|
|
|
|
pub fn isA(self: Attribute, SpecificAttr: type) bool {
|
|
return SpecificAttr.is_a_fn(self._inner);
|
|
}
|
|
|
|
// utilities function to built common attributes.
|
|
// All attributes are upcasted to the Attribute type, making it easier to chain construct,
|
|
// but losing type information.
|
|
|
|
pub fn null_() Attribute {
|
|
return .wrap(c.mlirAttributeGetNull());
|
|
}
|
|
|
|
pub fn string(ctx: Context, str: []const u8) Attribute {
|
|
return StringAttribute.init(ctx, str).asAttr();
|
|
}
|
|
|
|
pub fn type_(t: Type) Attribute {
|
|
return TypeAttribute.init(t).asAttr();
|
|
}
|
|
|
|
pub fn unit(ctx: Context) Attribute {
|
|
return .wrap(c.mlirUnitAttrGet(ctx._inner));
|
|
}
|
|
|
|
pub fn boolean(ctx: Context, value: bool) Attribute {
|
|
return BoolAttribute.init(ctx, value).asAttr();
|
|
}
|
|
|
|
pub fn i1FromBool(ctx: Context, value: bool) Attribute {
|
|
return IntegerAttribute(.i1).init(ctx, @intFromBool(value)).asAttr();
|
|
}
|
|
|
|
pub fn int(ctx: Context, comptime int_type: IntegerTypes, value: i64) Attribute {
|
|
return IntegerAttribute(int_type).init(ctx, value).asAttr();
|
|
}
|
|
|
|
pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute {
|
|
return FloatAttribute(float_type).init(ctx, value).asAttr();
|
|
}
|
|
|
|
pub fn array(ctx: Context, attrs: []const Attribute) Attribute {
|
|
return ArrayAttribute.init(ctx, attrs).asAttr();
|
|
}
|
|
|
|
pub fn dense(ctx: Context, comptime dt: DenseArrayTypes, values: []const dt.ZigType()) Attribute {
|
|
return DenseArrayAttribute(dt).init(ctx, values).asAttr();
|
|
}
|
|
|
|
/// Use a tensor as an attribute.
|
|
/// The tensor is specified by dims, dtype and a flat slice of values.
|
|
pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: []const dt.ZigType()) Attribute {
|
|
return DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).asAttr();
|
|
}
|
|
|
|
pub fn denseElementsFromBytes(ctx: Context, dims: []const i64, dt: DenseElementsAttributeTypes, raw_bytes: []const u8) Attribute {
|
|
const shape: Type = .tensor(dims, dt.mlirType(ctx));
|
|
return .{ ._inner = c.mlirDenseElementsAttrRawBufferGet(
|
|
shape._inner,
|
|
@intCast(raw_bytes.len),
|
|
raw_bytes.ptr,
|
|
) };
|
|
}
|
|
|
|
pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute {
|
|
return FlatSymbolRefAttribute.init(ctx, flat_name).asAttr();
|
|
}
|
|
|
|
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
|
|
return NamedAttribute.named(ctx, name, attr);
|
|
}
|
|
|
|
pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute {
|
|
var attr_buf: [32]NamedAttribute = undefined;
|
|
stdx.debug.assert(named_attrs.len <= attr_buf.len, ".dict helper only support up to {} attribute, got {}. You will need to call mlir.DictionaryAttribute directly", .{ attr_buf.len, named_attrs.len });
|
|
|
|
const attrs = attr_buf[0..named_attrs.len];
|
|
for (attrs, named_attrs) |*attr, tuple| {
|
|
attr.* = .named(ctx, tuple[0], tuple[1]);
|
|
}
|
|
|
|
return DictionaryAttribute.init(ctx, attrs).asAttr();
|
|
}
|
|
};
|
|
|
|
pub const NamedAttribute = extern struct {
|
|
_inner: c.MlirNamedAttribute,
|
|
|
|
pub fn wrap(c_named_attribute: c.MlirNamedAttribute) NamedAttribute {
|
|
return @bitCast(c_named_attribute);
|
|
}
|
|
|
|
pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute {
|
|
return .{ ._inner = .{
|
|
.name = c.mlirIdentifierGet(ctx._inner, stringRef(name)),
|
|
.attribute = attr._inner,
|
|
} };
|
|
}
|
|
|
|
pub fn init(name: Identifier, attr: Attribute) NamedAttribute {
|
|
return .{ ._inner = .{
|
|
.name = name._inner,
|
|
.attribute = attr._inner,
|
|
} };
|
|
}
|
|
};
|
|
|
|
pub const StringAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAString;
|
|
const Self = StringAttribute;
|
|
pub const asAttr = Attribute.fromAny(Self);
|
|
pub const eql = Attribute.eqlAny(Self);
|
|
|
|
pub fn init(ctx: Context, str: []const u8) Self {
|
|
return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) };
|
|
}
|
|
|
|
pub fn value(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirStringAttrGetValue(self._inner));
|
|
}
|
|
};
|
|
|
|
pub const BoolAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsABool;
|
|
const Self = BoolAttribute;
|
|
pub const asAttr = Attribute.fromAny(Self);
|
|
pub const eql = Attribute.eqlAny(Self);
|
|
|
|
pub fn init(ctx: Context, value_: bool) Self {
|
|
return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) };
|
|
}
|
|
|
|
pub fn value(self: Self) bool {
|
|
return c.mlirBoolAttrGetValue(self._inner);
|
|
}
|
|
};
|
|
|
|
pub const TypeAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAType;
|
|
pub const eql = Attribute.eqlAny(TypeAttribute);
|
|
|
|
pub fn init(type_: Type) TypeAttribute {
|
|
return .{ ._inner = c.mlirTypeAttrGet(type_._inner) };
|
|
}
|
|
|
|
pub fn typ(self: TypeAttribute) Type {
|
|
return .{ ._inner = c.mlirAttributeGetType(self._inner) };
|
|
}
|
|
|
|
pub const asAttr = Attribute.fromAny(TypeAttribute);
|
|
};
|
|
|
|
pub const ArrayAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAArray;
|
|
const Self = ArrayAttribute;
|
|
pub const asAttr = Attribute.fromAny(Self);
|
|
pub const eql = Attribute.eqlAny(Self);
|
|
|
|
pub fn init(ctx: Context, attrs: []const Attribute) Self {
|
|
return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) };
|
|
}
|
|
|
|
pub fn size(self: Self) usize {
|
|
return @intCast(c.mlirArrayAttrGetNumElements(self._inner));
|
|
}
|
|
|
|
pub fn get(self: Self, index: usize) Attribute {
|
|
return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) };
|
|
}
|
|
};
|
|
|
|
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
|
const ZigType, const getter = comptime switch (it) {
|
|
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
|
|
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
|
|
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
|
|
.unknown => @compileError("IntegerAttribute(unknown)"),
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAInteger;
|
|
|
|
pub const IntegerTypeType = IntegerType(it);
|
|
const IntAttr = @This();
|
|
|
|
pub const asAttr = Attribute.fromAny(IntAttr);
|
|
pub const eql = Attribute.eqlAny(IntAttr);
|
|
|
|
pub fn init(ctx: Context, value: i64) IntAttr {
|
|
return .{ ._inner = c.mlirIntegerAttrGet(
|
|
IntegerType(it).init(ctx)._inner,
|
|
value,
|
|
) };
|
|
}
|
|
|
|
pub fn get(value: IntAttr) ZigType {
|
|
return @intCast(getter(value._inner));
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn FloatAttribute(comptime ft: FloatTypes) type {
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAFloat;
|
|
const FloatAttr = @This();
|
|
pub const asAttr = Attribute.fromAny(FloatAttr);
|
|
|
|
pub fn init(ctx: Context, value: f64) FloatAttr {
|
|
return .{ ._inner = c.mlirFloatAttrDoubleGet(
|
|
ctx._inner,
|
|
FloatType(ft).init(ctx)._inner,
|
|
value,
|
|
) };
|
|
}
|
|
|
|
pub fn get(value: FloatAttr) f64 {
|
|
return c.mlirFloatAttrGetValueDouble(value._inner);
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const DenseArrayTypes = enum {
|
|
bool,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
f32,
|
|
f64,
|
|
|
|
pub fn ZigType(comptime dt: DenseArrayTypes) type {
|
|
return switch (dt) {
|
|
.bool => i32,
|
|
.i8 => i8,
|
|
.i16 => i16,
|
|
.i32 => i32,
|
|
.i64 => i64,
|
|
.f32 => f32,
|
|
.f64 => f64,
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
|
|
const _is_a_fn, const get_fn, const get_element_fn = switch (dt) {
|
|
.bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
|
|
.i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
|
|
.i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
|
|
.i32 => .{ c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
|
|
.i64 => .{ c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
|
|
.f32 => .{ c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
|
|
.f64 => .{ c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
const Attr = @This();
|
|
const ElementType = dt;
|
|
const ElementTypeZig = dt.ZigType();
|
|
|
|
pub const asAttr = Attribute.fromAny(Attr);
|
|
pub const eql = Attribute.eqlAny(Attr);
|
|
pub const is_a_fn = _is_a_fn;
|
|
|
|
pub fn init(ctx: Context, values: []const ElementTypeZig) Attr {
|
|
return .{ ._inner = get_fn(ctx._inner, @intCast(values.len), @ptrCast(values.ptr)) };
|
|
}
|
|
|
|
pub fn get(self: Attr, pos: usize) ElementTypeZig {
|
|
return get_element_fn(self._inner, @intCast(pos));
|
|
}
|
|
|
|
pub fn len(self: Attr) usize {
|
|
return @intCast(c.mlirDenseArrayGetNumElements(self._inner));
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const DenseElementsAttributeTypes = enum {
|
|
bool,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
u8,
|
|
u16,
|
|
u32,
|
|
u64,
|
|
bf16,
|
|
f16,
|
|
f32,
|
|
f64,
|
|
index,
|
|
|
|
pub fn ZigType(comptime dt: DenseElementsAttributeTypes) type {
|
|
return switch (dt) {
|
|
.bool => bool,
|
|
.i8 => i8,
|
|
.i16 => i16,
|
|
.i32 => i32,
|
|
.i64 => i64,
|
|
.u8 => u8,
|
|
.u16 => u16,
|
|
.u32 => u32,
|
|
.u64 => u64,
|
|
.bf16 => u16,
|
|
.f16 => f16,
|
|
.f32 => f32,
|
|
.f64 => f64,
|
|
.index => usize,
|
|
};
|
|
}
|
|
|
|
pub fn mlirType(dt: DenseElementsAttributeTypes, ctx: Context) Type {
|
|
return switch (dt) {
|
|
.bool => .int(ctx, .i1),
|
|
.i8 => .int(ctx, .i8),
|
|
.i16 => .int(ctx, .i16),
|
|
.i32 => .int(ctx, .i32),
|
|
.i64 => .int(ctx, .i64),
|
|
.u8 => .int(ctx, .u8),
|
|
.u16 => .int(ctx, .u16),
|
|
.u32 => .int(ctx, .u32),
|
|
.u64 => .int(ctx, .u64),
|
|
.bf16 => .float(ctx, .bf16),
|
|
.f16 => .float(ctx, .f16),
|
|
.f32 => .float(ctx, .f32),
|
|
.f64 => .float(ctx, .f64),
|
|
.index => .index(ctx),
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
|
return struct {
|
|
_inner: c.MlirAttribute,
|
|
|
|
const Attr = @This();
|
|
|
|
pub const is_a_fn = c.mlirAttributeIsADenseElements;
|
|
pub const asAttr = Attribute.fromAny(Attr);
|
|
pub const eql = Attribute.eqlAny(Attr);
|
|
|
|
pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr {
|
|
const raw_bytes = std.mem.sliceAsBytes(slice);
|
|
const res: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet(
|
|
shaped_type._inner,
|
|
@intCast(raw_bytes.len),
|
|
@ptrCast(raw_bytes.ptr),
|
|
) };
|
|
std.debug.assert(Attribute.wrapOr(res._inner) != null);
|
|
return res;
|
|
}
|
|
|
|
pub fn len(self: Attr) usize {
|
|
return @intCast(c.mlirElementsAttrGetNumElements(self._inner));
|
|
}
|
|
|
|
pub fn items(self: Attr) []const dt.ZigType() {
|
|
const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable;
|
|
const ptr: [*]const dt.ZigType() = @ptrCast(@alignCast(raw_bytes));
|
|
// Note the mlir API returns us the number of elements, not the number of bytes,
|
|
// that's why we track the element type at comptime to allow items to work.
|
|
return ptr[0..self.len()];
|
|
}
|
|
|
|
pub fn bytes(self: Attr) []const u8 {
|
|
return std.mem.sliceAsBytes(self.items());
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const FlatSymbolRefAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsAFlatSymbolRef;
|
|
const Self = FlatSymbolRefAttribute;
|
|
pub const eql = Attribute.eqlAny(Self);
|
|
|
|
pub fn init(ctx: Context, str: [:0]const u8) Self {
|
|
return .{ ._inner = c.mlirFlatSymbolRefAttrGet(ctx._inner, stringRef(str)) };
|
|
}
|
|
|
|
pub fn value(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirFlatSymbolRefAttrGetValue(self._inner));
|
|
}
|
|
|
|
pub const asAttr = Attribute.fromAny(Self);
|
|
};
|
|
|
|
pub const OperationState = struct {
|
|
_inner: c.MlirOperationState,
|
|
|
|
const Self = OperationState;
|
|
|
|
pub fn init(name: [:0]const u8, loc: Location) Self {
|
|
return .{ ._inner = c.mlirOperationStateGet(stringRef(name), loc._inner) };
|
|
}
|
|
|
|
pub fn addResult(self: *Self, type_: Type) void {
|
|
c.mlirOperationStateAddResults(&self._inner, 1, &[_]c.MlirType{type_._inner});
|
|
}
|
|
|
|
pub fn addResults(self: *Self, types: []const Type) void {
|
|
c.mlirOperationStateAddResults(&self._inner, @intCast(types.len), @ptrCast(types.ptr));
|
|
}
|
|
|
|
pub fn addOperand(self: *Self, value: Value) void {
|
|
c.mlirOperationStateAddOperands(&self._inner, 1, &[_]c.MlirValue{value._inner});
|
|
}
|
|
|
|
pub fn addOperands(self: *Self, values: []const Value) void {
|
|
c.mlirOperationStateAddOperands(&self._inner, @intCast(values.len), @ptrCast(values.ptr));
|
|
}
|
|
|
|
pub fn addRegion(self: *Self, region: *Region) void {
|
|
c.mlirOperationStateAddOwnedRegions(&self._inner, 1, &[_]c.MlirRegion{region._inner});
|
|
}
|
|
|
|
pub fn addRegions(self: *Self, regions: []const Region) void {
|
|
c.mlirOperationStateAddOwnedRegions(&self._inner, @intCast(regions.len), @ptrCast(regions.ptr));
|
|
}
|
|
|
|
pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void {
|
|
c.mlirOperationStateAddAttributes(&self._inner, 1, @ptrCast(&.{
|
|
.{
|
|
.name = Identifier.get(ctx, name)._inner,
|
|
.attribute = attr._inner,
|
|
},
|
|
}));
|
|
}
|
|
|
|
pub fn addAttributeRaw(self: *Self, name: Identifier, attr: Attribute) void {
|
|
c.mlirOperationStateAddAttributes(&self._inner, 1, @ptrCast(&.{
|
|
.{
|
|
.name = name._inner,
|
|
.attribute = attr._inner,
|
|
},
|
|
}));
|
|
}
|
|
|
|
pub fn addAttributes(self: *Self, attributes: []const NamedAttribute) void {
|
|
c.mlirOperationStateAddAttributes(&self._inner, @intCast(attributes.len), @ptrCast(attributes.ptr));
|
|
}
|
|
|
|
pub fn resultTypeInference(self: *Self, enabled: bool) void {
|
|
self._inner.enableResultTypeInference = enabled;
|
|
}
|
|
};
|
|
|
|
pub const DictionaryAttribute = struct {
|
|
_inner: c.MlirAttribute,
|
|
pub const is_a_fn = c.mlirAttributeIsADictionary;
|
|
pub const asAttr = Attribute.fromAny(DictionaryAttribute);
|
|
pub const eql = Attribute.eqlAny(DictionaryAttribute);
|
|
|
|
pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute {
|
|
return .{ ._inner = c.mlirDictionaryAttrGet(
|
|
ctx._inner,
|
|
@intCast(attributes.len),
|
|
@ptrCast(attributes.ptr),
|
|
) };
|
|
}
|
|
|
|
pub fn size(self: DictionaryAttribute) usize {
|
|
return @intCast(c.mlirDictionaryAttrGetNumElements(self._inner));
|
|
}
|
|
|
|
pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute {
|
|
return .wrap(c.mlirDictionaryAttrGetElement(self._inner, @bitCast(pos)));
|
|
}
|
|
|
|
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute {
|
|
return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, name));
|
|
}
|
|
};
|
|
|
|
pub const Operation = struct {
|
|
const Self = Operation;
|
|
_inner: c.MlirOperation,
|
|
|
|
pub const dump = helpers.dump(Operation, c.mlirOperationDestroy);
|
|
pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy);
|
|
pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull);
|
|
|
|
pub const eql = Attribute.eqlAny(Self);
|
|
|
|
pub fn init(state: *OperationState) !Self {
|
|
return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
|
operands: ?[]const Value = null,
|
|
variadic_operands: ?[]const []const Value = null,
|
|
tt_variadic_operands: ?[]const []const Value = null,
|
|
results: ?[]const Type = null,
|
|
variadic_results: ?[]const []const Type = null,
|
|
result_type_inference: ?bool = null,
|
|
n_regions: usize = 0,
|
|
attributes: ?[]const AttrTuple = null,
|
|
blocks: ?[]const Block = null,
|
|
verify: bool = true,
|
|
location: Location,
|
|
}) Self {
|
|
var state = OperationState.init(op_name, args.location);
|
|
std.debug.assert(!(args.operands != null and args.variadic_operands != null));
|
|
if (args.operands) |operands| {
|
|
state.addOperands(operands);
|
|
} else if (args.variadic_operands) |operands_segments| {
|
|
const MAX_SEGMENTS = 32;
|
|
var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
|
|
|
for (operands_segments) |operands| {
|
|
state.addOperands(operands);
|
|
segments.appendAssumeCapacity(@intCast(operands.len));
|
|
}
|
|
state.addAttribute(ctx, "operandSegmentSizes", .denseElements(ctx, &.{@intCast(segments.len)}, .i32, segments.constSlice()));
|
|
} else if (args.tt_variadic_operands) |operands_segments| {
|
|
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
|
|
const MAX_SEGMENTS = 32;
|
|
var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
|
|
|
for (operands_segments) |operands| {
|
|
state.addOperands(operands);
|
|
segments.appendAssumeCapacity(@intCast(operands.len));
|
|
}
|
|
state.addAttribute(ctx, "operandSegmentSizes", .dense(ctx, .i32, segments.constSlice()));
|
|
}
|
|
if (args.result_type_inference) |enable| {
|
|
state.resultTypeInference(enable);
|
|
}
|
|
std.debug.assert(!(args.results != null and args.variadic_results != null));
|
|
if (args.results) |results| {
|
|
state.addResults(results);
|
|
} else if (args.variadic_results) |result_segments| {
|
|
for (result_segments) |results| {
|
|
state.addResults(results);
|
|
}
|
|
}
|
|
for (0..args.n_regions) |_| {
|
|
var region_ = Region.init() catch {
|
|
@panic("Failed to create MLIR region");
|
|
};
|
|
state.addRegion(®ion_);
|
|
}
|
|
if (args.attributes) |attrs| {
|
|
for (attrs) |attr| {
|
|
state.addAttributeRaw(
|
|
Identifier.get(ctx, attr[0]),
|
|
attr[1],
|
|
);
|
|
}
|
|
}
|
|
if (args.blocks) |blocks_| {
|
|
for (blocks_) |block_| {
|
|
var region_ = Region.init() catch {
|
|
@panic("Failed to create MLIR region");
|
|
};
|
|
region_.appendBlock(block_);
|
|
state.addRegion(®ion_);
|
|
}
|
|
}
|
|
|
|
const new_op = Operation.init(&state) catch {
|
|
@panic("Failed to create MLIR operation");
|
|
};
|
|
if (args.verify and new_op.verify() == false) {
|
|
log.err("Failed to verify MLIR operation:\n{f}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
|
@panic("Failed to verify MLIR operation");
|
|
}
|
|
return new_op;
|
|
}
|
|
|
|
pub fn initParse(ctx: Context, str: [:0]const u8) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirOperationCreateParse(ctx._inner, stringRef(str), stringRef("pouet")),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn clone(self: Self) !Self {
|
|
return Self.wrapOr(
|
|
c.mlirOperationClone(self._inner),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn name(self: Self) Identifier {
|
|
return .{ ._inner = c.mlirOperationGetName(self._inner) };
|
|
}
|
|
|
|
pub fn removeFromParent(self: *Self) void {
|
|
c.mlirOperationRemoveFromParent(self._inner);
|
|
}
|
|
|
|
pub fn numOperands(self: Self) usize {
|
|
return @intCast(c.mlirOperationGetNumOperands(self._inner));
|
|
}
|
|
|
|
pub fn operand(self: Self, index: usize) Value {
|
|
return .{ ._inner = c.mlirOperationGetOperand(self._inner, @intCast(index)) };
|
|
}
|
|
|
|
pub fn setOperand(self: *Self, index: usize, value: Value) void {
|
|
c.mlirOperationSetOperand(self._inner, @intCast(index), value._inner);
|
|
}
|
|
|
|
pub fn numResults(self: Self) usize {
|
|
return @intCast(c.mlirOperationGetNumResults(self._inner));
|
|
}
|
|
|
|
pub fn result(self: Self, index: usize) Value {
|
|
return .{ ._inner = c.mlirOperationGetResult(self._inner, @intCast(index)) };
|
|
}
|
|
|
|
pub fn nextInBlock(self: Self) Self {
|
|
return .{ ._inner = c.mlirOperationGetNextInBlock(self._inner) };
|
|
}
|
|
|
|
// pub fn previousInBlock(self: Self) Self {
|
|
// return .{ ._inner = c.mlirOperationGetPrevInBlock(self._inner) };
|
|
// }
|
|
|
|
pub fn block(self: Self) ?Block {
|
|
return Block.wrapOr(c.mlirOperationGetBlock(self._inner));
|
|
}
|
|
|
|
pub fn parent(self: Self) ?Self {
|
|
return Self.wrapOr(c.mlirOperationGetParentOperation(self._inner));
|
|
}
|
|
|
|
pub fn region(self: Self, index: usize) Region {
|
|
return .{ ._inner = c.mlirOperationGetRegion(self._inner, @intCast(index)) };
|
|
}
|
|
|
|
pub fn context(self: Self) Context {
|
|
return .{ ._inner = c.mlirOperationGetContext(self._inner) };
|
|
}
|
|
|
|
pub fn writeBytecode(self: Self, writer: anytype) void {
|
|
var writer_context = .{ .writer = writer };
|
|
const WriterContext = @TypeOf(writer_context);
|
|
|
|
c.mlirOperationWriteBytecode(
|
|
self._inner,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
|
}
|
|
}).callback,
|
|
&writer_context,
|
|
);
|
|
}
|
|
|
|
pub fn writeBytecodeWithConfig(self: Self, writer: anytype, config: struct {
|
|
desiredEmitedVersion: ?i64 = null,
|
|
}) !void {
|
|
const cfg = c.mlirBytecodeWriterConfigCreate();
|
|
defer c.mlirBytecodeWriterConfigDestroy(cfg);
|
|
if (config.desiredEmitedVersion) |v| {
|
|
c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v);
|
|
}
|
|
|
|
const WriterContext = struct {
|
|
writer: @TypeOf(writer),
|
|
write_error: ?@TypeOf(writer).Error = null,
|
|
};
|
|
var writer_context: WriterContext = .{ .writer = writer };
|
|
|
|
try successOr(c.mlirOperationWriteBytecodeWithConfig(
|
|
self._inner,
|
|
cfg,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
|
|
inner_writer_context.write_error = err;
|
|
};
|
|
}
|
|
}).callback,
|
|
&writer_context,
|
|
), error.InvalidMlirBytecodeVersion);
|
|
|
|
if (writer_context.write_error) |err| return err;
|
|
}
|
|
|
|
/// Enable a full dump of the IR.
|
|
/// Usage `std.log.debug("{}", .{ module.op().mlirFormatter(.{}) });
|
|
pub fn mlirFormatter(self: Operation, flags: OpPrintingFlags) MlirFormatter {
|
|
return .{ .op = self, .flags = flags };
|
|
}
|
|
|
|
const MlirFormatter = struct {
|
|
op: Operation,
|
|
flags: OpPrintingFlags,
|
|
|
|
pub fn format(self: @This(), writer: anytype) !void {
|
|
self.op.print(writer, self.flags);
|
|
}
|
|
};
|
|
|
|
pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void {
|
|
const pflags = flags.create();
|
|
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
|
|
|
c.mlirOperationPrintWithFlags(
|
|
self._inner,
|
|
pflags,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
|
const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_));
|
|
_writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed");
|
|
}
|
|
}).callback,
|
|
writer,
|
|
);
|
|
}
|
|
|
|
pub fn verify(self: Self) bool {
|
|
return c.mlirOperationVerify(self._inner);
|
|
}
|
|
|
|
pub fn getLocation(self: Self) Location {
|
|
return .{ ._inner = c.mlirOperationGetLocation(self._inner) };
|
|
}
|
|
|
|
pub const WalkOrder = enum(c.MlirWalkOrder) {
|
|
pre_order = c.MlirWalkPreOrder,
|
|
post_order = c.MlirWalkPostOrder,
|
|
};
|
|
|
|
pub const WalkResult = enum(c.MlirWalkResult) {
|
|
advance = c.MlirWalkResultAdvance,
|
|
interrupt = c.MlirWalkResultInterrupt,
|
|
skip = c.MlirWalkResultSkip,
|
|
};
|
|
|
|
pub fn walk(self: Self, order: WalkOrder, ctx: anytype, walkfn: fn (ctx: anytype, op: Operation) WalkResult) void {
|
|
var inner_ctx = .{ .ctx = ctx };
|
|
const ContextType = @TypeOf(inner_ctx);
|
|
|
|
c.mlirOperationWalk(
|
|
self._inner,
|
|
(struct {
|
|
pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.c) c.MlirWalkResult {
|
|
const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_));
|
|
return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op }));
|
|
}
|
|
}).callback,
|
|
&inner_ctx,
|
|
@intFromEnum(order),
|
|
);
|
|
}
|
|
|
|
pub fn getAttribute(self: Self, pos: usize) NamedAttribute {
|
|
return .{ ._inner = c.mlirOperationGetAttribute(self._inner, @intCast(pos)) };
|
|
}
|
|
|
|
pub fn getAttributeByName(self: Self, name_: [:0]const u8) ?Attribute {
|
|
return Attribute.wrapOr(c.mlirOperationGetAttributeByName(self._inner, stringRef(name_)));
|
|
}
|
|
|
|
pub fn setAttributeByName(self: Self, name_: [:0]const u8, attr: Attribute) void {
|
|
c.mlirOperationSetAttributeByName(self._inner, stringRef(name_), attr._inner);
|
|
}
|
|
|
|
pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool {
|
|
return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_));
|
|
}
|
|
|
|
/// Hash the canonicalized IR, without debug information that can change across builds.
|
|
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
|
|
// Note: before we where using op.writeBytecode(writer),
|
|
// but it crashes on some inputs, notably for unused variables.
|
|
// So we use the text representation of the mlir.
|
|
// See https://github.com/zml/zml/issues/97.
|
|
const flags = OpPrintingFlags.create(.{ .debug_info = false });
|
|
defer c.mlirOpPrintingFlagsDestroy(flags);
|
|
|
|
c.mlirOperationPrintWithFlags(
|
|
op._inner,
|
|
flags,
|
|
(struct {
|
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
|
const _hasher: *std.hash.XxHash64 = @ptrCast(@alignCast(ctx_));
|
|
_hasher.update(str.data[0..str.length]);
|
|
}
|
|
}).callback,
|
|
hasher,
|
|
);
|
|
}
|
|
};
|
|
|
|
pub const OpPrintingFlags = struct {
|
|
elide_large_elements_attrs: ?usize = null,
|
|
debug_info: bool = false,
|
|
debug_info_pretty_form: bool = true,
|
|
print_generic_op_form: bool = false,
|
|
use_local_scope: bool = false,
|
|
assume_verified: bool = false,
|
|
|
|
pub fn create(self: OpPrintingFlags) c.MlirOpPrintingFlags {
|
|
const pflags = c.mlirOpPrintingFlagsCreate();
|
|
if (self.elide_large_elements_attrs) |v| {
|
|
c.mlirOpPrintingFlagsElideLargeElementsAttrs(pflags, @intCast(v));
|
|
}
|
|
c.mlirOpPrintingFlagsEnableDebugInfo(pflags, self.debug_info, self.debug_info_pretty_form);
|
|
if (self.print_generic_op_form) {
|
|
c.mlirOpPrintingFlagsPrintGenericOpForm(pflags);
|
|
}
|
|
if (self.use_local_scope) {
|
|
c.mlirOpPrintingFlagsUseLocalScope(pflags);
|
|
}
|
|
if (self.assume_verified) {
|
|
c.mlirOpPrintingFlagsAssumeVerified(pflags);
|
|
}
|
|
return pflags;
|
|
}
|
|
};
|
|
|
|
pub const OpOperand = struct {
|
|
_inner: c.MlirOpOperand,
|
|
const Self = OpOperand;
|
|
|
|
pub fn owner(self: Self) Operation {
|
|
return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) };
|
|
}
|
|
|
|
pub fn number(self: Self) usize {
|
|
return @intCast(c.mlirOpOperandGetOperandNumber(self._inner));
|
|
}
|
|
|
|
pub fn nextUse(self: Self) ?Self {
|
|
return Self.wrapOr(
|
|
c.mlirOpOperandGetNextUse(self._inner),
|
|
);
|
|
}
|
|
};
|
|
|
|
pub const Region = struct {
|
|
_inner: c.MlirRegion,
|
|
|
|
pub const eql = helpers.eql(Region, c.mlirBlockEqual);
|
|
pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy);
|
|
pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull);
|
|
|
|
const Self = Region;
|
|
|
|
pub fn init() !Self {
|
|
return Self.wrapOr(c.mlirRegionCreate()) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn appendBlock(self: *Self, block: Block) void {
|
|
c.mlirRegionAppendOwnedBlock(self._inner, block._inner);
|
|
}
|
|
|
|
pub fn insertBlock(self: *Self, index: isize, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlock(self._inner, index, block._inner);
|
|
}
|
|
|
|
pub fn insertBlockBefore(self: *Self, reference: Block, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlockBefore(self._inner, reference._inner, block._inner);
|
|
}
|
|
|
|
pub fn insertBlockAfter(self: *Self, reference: Block, block: Block) void {
|
|
c.mlirRegionInsertOwnedBlockAfter(self._inner, reference._inner, block._inner);
|
|
}
|
|
|
|
pub fn firstBlock(self: Self) Block {
|
|
return .{ ._inner = c.mlirRegionGetFirstBlock(self._inner) };
|
|
}
|
|
};
|
|
|
|
pub const Value = struct {
|
|
_inner: c.MlirValue,
|
|
|
|
pub const dump = helpers.dump(Value, c.mlirValueDump);
|
|
pub const eql = helpers.eql(Value, c.mlirValueEqual);
|
|
pub const format = helpers.format(Value, c.mlirValuePrint).format;
|
|
pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull);
|
|
|
|
pub fn getType(val: Value) Type {
|
|
return .{ ._inner = c.mlirValueGetType(val._inner) };
|
|
}
|
|
|
|
pub fn setType(val: *Value, typ: Type) void {
|
|
c.mlirValueSetType(val._inner, typ._inner);
|
|
}
|
|
|
|
pub fn firstUse(val: Value) OpOperand {
|
|
return .{ ._inner = c.mlirValueGetFirstUse(val._inner) };
|
|
}
|
|
|
|
pub fn replaceAllUsesWith(val: Value, with: Value) void {
|
|
c.mlirValueReplaceAllUsesOfWith(val._inner, with._inner);
|
|
}
|
|
|
|
pub fn owner(val: Value) Operation {
|
|
return .{ ._inner = c.mlirOpResultGetOwner(val._inner) };
|
|
}
|
|
|
|
pub fn isABlockArgument(val: Value) bool {
|
|
return c.mlirValueIsABlockArgument(val._inner);
|
|
}
|
|
|
|
pub fn isAOpResult(val: Value) bool {
|
|
return c.mlirValueIsAOpResult(val._inner);
|
|
}
|
|
|
|
pub const Kind = union(enum) {
|
|
block_argument: BlockArgument,
|
|
op_result: Operation,
|
|
null,
|
|
};
|
|
|
|
pub fn kind(val: Value) Kind {
|
|
if (val.isAOpResult()) {
|
|
return .{ .op_result = val.owner() };
|
|
}
|
|
if (val.isABlockArgument()) {
|
|
return .{ .block_argument = .{ ._inner = val._inner } };
|
|
}
|
|
// From MLIR docs:
|
|
// https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#details
|
|
// > An SSA value is either a BlockArgument or the result of an operation.
|
|
return .null;
|
|
}
|
|
};
|
|
|
|
pub const BlockArgument = struct {
|
|
_inner: c.MlirValue,
|
|
|
|
pub fn block(arg: BlockArgument) Block {
|
|
return .{ ._inner = c.mlirBlockArgumentGetOwner(arg._inner) };
|
|
}
|
|
|
|
pub fn number(arg: BlockArgument) usize {
|
|
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
|
}
|
|
|
|
pub fn format(self: BlockArgument, writer: anytype) !void {
|
|
const value = Value{ ._inner = self._inner };
|
|
return value.format(writer);
|
|
}
|
|
};
|
|
|
|
pub const Type = struct {
|
|
_inner: c.MlirType,
|
|
|
|
pub const dump = helpers.eql(Type, c.mlirTypeDump);
|
|
pub const eql = helpers.eql(Type, c.mlirTypeEqual);
|
|
pub const format = helpers.format(Type, c.mlirTypePrint);
|
|
pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull);
|
|
|
|
pub fn parse(ctx: Context, str: [:0]const u8) !Type {
|
|
return Type.wrapOr(
|
|
c.mlirTypeParseGet(ctx._inner, stringRef(str)),
|
|
) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn as(generic: Type, SpecificType: type) ?SpecificType {
|
|
if (@hasDecl(SpecificType, "is_a_fn")) {
|
|
return if (SpecificType.is_a_fn(generic._inner))
|
|
.{ ._inner = generic._inner }
|
|
else
|
|
null;
|
|
}
|
|
@compileError("Mlir subclass of type need `is_a_fn` attribute: " ++ @typeName(SpecificType));
|
|
}
|
|
|
|
pub fn fromAny(SpecificType: type) fn (x: SpecificType) Type {
|
|
stdx.debug.assertComptime(@hasDecl(SpecificType, "asType"), "Type.fromAny expects a type subclass, got: {}. Missing `asAttr` declaration.", .{SpecificType});
|
|
return struct {
|
|
fn cast(x: SpecificType) Type {
|
|
return .{ ._inner = x._inner };
|
|
}
|
|
}.cast;
|
|
}
|
|
|
|
pub fn eqlAny(SpecificType: type) fn (SpecificType, SpecificType) bool {
|
|
return struct {
|
|
fn eql(a: SpecificType, b: SpecificType) bool {
|
|
return a.asType().eql(b.asType());
|
|
}
|
|
}.eql;
|
|
}
|
|
|
|
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
|
|
return struct {
|
|
pub fn format(self: SpecificType, writer: anytype) !void {
|
|
return try Type.format(self.asType(), writer);
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn index(ctx: Context) Type {
|
|
return IndexType.init(ctx).asType();
|
|
}
|
|
|
|
pub fn int(ctx: Context, int_type: IntegerTypes) Type {
|
|
return switch (int_type) {
|
|
.unknown => @panic("Unknown integer type"),
|
|
inline else => |t| IntegerType(t).init(ctx).asType(),
|
|
};
|
|
}
|
|
|
|
pub fn float(ctx: Context, float_type: FloatTypes) Type {
|
|
return switch (float_type) {
|
|
inline else => |t| FloatType(t).init(ctx).asType(),
|
|
};
|
|
}
|
|
|
|
pub fn complex(ctx: Context, complex_type: ComplexTypes) Type {
|
|
return switch (complex_type) {
|
|
.unknown => @panic("Unknown complex type can't be created like this"), // What's the point ?
|
|
inline else => |t| ComplexType(t).init(ctx).asType(),
|
|
};
|
|
}
|
|
|
|
pub fn tuple(ctx: Context, types: []const Type) Type {
|
|
return (TupleType.init(ctx, types) catch unreachable).asType();
|
|
}
|
|
|
|
pub fn function(ctx: Context, args: []const Type, results: []const Type) Type {
|
|
return (FunctionType.init(ctx, args, results) catch unreachable).asType();
|
|
}
|
|
|
|
pub fn tensor(dimensions: []const i64, elem_type: Type) Type {
|
|
return RankedTensorType.init(dimensions, elem_type).asType();
|
|
}
|
|
};
|
|
|
|
pub const IndexType = struct {
|
|
_inner: c.MlirType,
|
|
|
|
pub const asType = Type.fromAny(IndexType);
|
|
pub const eql = Type.eqlAny(IndexType);
|
|
pub const format = Type.formatAny(IndexType).format;
|
|
|
|
pub fn init(ctx: Context) IndexType {
|
|
return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) };
|
|
}
|
|
};
|
|
|
|
pub const IntegerTypes = enum {
|
|
i1,
|
|
i4,
|
|
i8,
|
|
i16,
|
|
i32,
|
|
i64,
|
|
si4,
|
|
si8,
|
|
si16,
|
|
si32,
|
|
si64,
|
|
u4,
|
|
u8,
|
|
u16,
|
|
u32,
|
|
u64,
|
|
|
|
unknown,
|
|
};
|
|
|
|
pub fn IntegerType(comptime it: IntegerTypes) type {
|
|
const Config = switch (it) {
|
|
.i1 => .{ 1, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i4 => .{ 4, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i8 => .{ 8, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i16 => .{ 16, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i32 => .{ 32, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.i64 => .{ 64, c.mlirIntegerTypeGet, c.mlirIntegerTypeIsSignless },
|
|
.si4 => .{ 4, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si8 => .{ 8, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si16 => .{ 16, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si32 => .{ 32, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.si64 => .{ 64, c.mlirIntegerTypeSignedGet, c.mlirIntegerTypeIsSigned },
|
|
.u4 => .{ 4, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u8 => .{ 8, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u16 => .{ 16, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u32 => .{ 32, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.u64 => .{ 64, c.mlirIntegerTypeUnsignedGet, c.mlirIntegerTypeIsUnsigned },
|
|
.unknown => .{ 0, null, null },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
|
|
const Int = @This();
|
|
pub const is_a_fn = switch (it) {
|
|
.unknown => c.mlirTypeIsAInteger,
|
|
else => typeIsAIntegerExact,
|
|
};
|
|
|
|
pub const asType = Type.fromAny(Int);
|
|
pub const eql = Type.eqlAny(Int);
|
|
pub const format = helpers.format(Int, c.mlirTypePrint);
|
|
|
|
fn typeIsAIntegerExact(typ: c.MlirType) callconv(.c) bool {
|
|
const bit_width = Config[0];
|
|
const is_sign = Config[2];
|
|
return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ);
|
|
}
|
|
|
|
pub const BitWidth = Config[0];
|
|
|
|
pub const init = if (it != .unknown) struct {
|
|
pub fn init(ctx: Context) Int {
|
|
const type_get = Config[1];
|
|
return .{ ._inner = type_get(ctx._inner, BitWidth) };
|
|
}
|
|
}.init else {};
|
|
};
|
|
}
|
|
|
|
pub const FloatTypes = enum {
|
|
f8e4m3b11fnuz,
|
|
f8e4m3fn,
|
|
f8e4m3fnuz,
|
|
f8e5m2,
|
|
f8e5m2fnuz,
|
|
bf16,
|
|
f16,
|
|
f32,
|
|
f64,
|
|
|
|
pub fn asType(self: FloatTypes, ctx: Context) Type {
|
|
return switch (self) {
|
|
inline else => |ft| FloatType(ft).init(ctx).asType(),
|
|
};
|
|
}
|
|
};
|
|
|
|
pub fn FloatType(comptime ft: FloatTypes) type {
|
|
const Config = switch (ft) {
|
|
.f8e4m3b11fnuz => .{ c.mlirTypeIsAFloat8E4M3B11FNUZ, c.mlirFloat8E4M3B11FNUZTypeGet },
|
|
.f8e4m3fn => .{ c.mlirTypeIsAFloat8E4M3FN, c.mlirFloat8E4M3FNTypeGet },
|
|
.f8e4m3fnuz => .{ c.mlirTypeIsAFloat8E4M3FNUZ, c.mlirFloat8E4M3FNUZTypeGet },
|
|
.f8e5m2 => .{ c.mlirTypeIsAFloat8E5M2, c.mlirFloat8E5M2TypeGet },
|
|
.f8e5m2fnuz => .{ c.mlirTypeIsAFloat8E5M2FNUZ, c.mlirFloat8E5M2FNUZTypeGet },
|
|
.bf16 => .{ c.mlirTypeIsABF16, c.mlirBF16TypeGet },
|
|
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
|
|
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
|
|
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
|
|
};
|
|
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
|
|
const Self = @This();
|
|
|
|
pub const is_a_fn = Config[0];
|
|
|
|
pub const asType = Type.fromAny(Self);
|
|
pub const eql = Type.eqlAny(Self);
|
|
pub const format = helpers.format(Self, c.mlirTypePrint);
|
|
|
|
pub fn init(ctx: Context) Self {
|
|
const type_get = Config[1];
|
|
return .{ ._inner = type_get(ctx._inner) };
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const ComplexTypes = enum {
|
|
c64,
|
|
c128,
|
|
|
|
unknown,
|
|
};
|
|
|
|
pub fn ComplexType(comptime ct: ComplexTypes) type {
|
|
return struct {
|
|
_inner: c.MlirType,
|
|
const Complex = @This();
|
|
|
|
fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
|
|
return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx));
|
|
}
|
|
|
|
fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
|
|
return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx));
|
|
}
|
|
|
|
fn mlirTypeIsAC64(typ: c.MlirType) callconv(.c) bool {
|
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
|
return c.mlirTypeIsAF32(element_type);
|
|
}
|
|
|
|
fn mlirTypeIsAC128(typ: c.MlirType) callconv(.c) bool {
|
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
|
return c.mlirTypeIsAF64(element_type);
|
|
}
|
|
|
|
const Config = switch (ct) {
|
|
.c64 => .{ mlirTypeIsAC64, mlirC64TypeGet },
|
|
.c128 => .{ mlirTypeIsAC128, mlirC128TypeGet },
|
|
.unknown => .{ c.mlirTypeIsAComplex, null },
|
|
};
|
|
|
|
fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.c) bool {
|
|
return c.mlirTypeIsAComplex(typ);
|
|
}
|
|
|
|
pub const is_a_fn = Config[0];
|
|
|
|
pub const asType = Type.fromAny(Complex);
|
|
pub const eql = Type.eqlAny(Complex);
|
|
pub const format = Type.formatAny(Complex).format;
|
|
pub const ComplexTypeType: ComplexTypes = ct;
|
|
|
|
pub const init = if (ct != .unknown) struct {
|
|
pub fn init(ctx: Context) Complex {
|
|
const type_get = Config[1];
|
|
return .{ ._inner = type_get(ctx._inner) };
|
|
}
|
|
}.init else {};
|
|
};
|
|
}
|
|
|
|
pub const TupleType = struct {
|
|
_inner: c.MlirType,
|
|
pub const is_a_fn = c.mlirTypeIsATuple;
|
|
|
|
const Self = TupleType;
|
|
|
|
pub fn init(ctx: Context, elements: []const Type) !Self {
|
|
return Self.wrapOr(c.mlirTupleTypeGet(
|
|
ctx._inner,
|
|
@intCast(elements.len),
|
|
@ptrCast(elements.ptr),
|
|
)) orelse Error.InvalidMlir;
|
|
}
|
|
|
|
pub fn getNumTypes(self: Self) usize {
|
|
return @intCast(c.mlirTupleTypeGetNumTypes(self._inner));
|
|
}
|
|
|
|
pub fn getElementType(self: Self, index: usize) Type {
|
|
return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) };
|
|
}
|
|
|
|
pub const asType = Type.fromAny(Self);
|
|
};
|
|
|
|
pub const FunctionType = struct {
|
|
_inner: c.MlirType,
|
|
pub const is_a_fn = c.mlirTypeIsAFunction;
|
|
const Self = FunctionType;
|
|
pub const asType = Type.fromAny(Self);
|
|
pub const eql = Type.eqlAny(IndexType);
|
|
|
|
pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self {
|
|
const func = Type.wrapOr(c.mlirFunctionTypeGet(
|
|
ctx._inner,
|
|
@intCast(args.len),
|
|
@ptrCast(args.ptr),
|
|
@intCast(results.len),
|
|
@ptrCast(results.ptr),
|
|
)) orelse return Error.InvalidMlir;
|
|
return func.as(Self).?;
|
|
}
|
|
};
|
|
|
|
pub const RankedTensorType = struct {
|
|
_inner: c.MlirType,
|
|
pub const is_a_fn = c.mlirTypeIsARankedTensor;
|
|
pub const eql = Type.eqlAny(RankedTensorType);
|
|
pub const format = helpers.format(Type, c.mlirTypePrint);
|
|
|
|
pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType {
|
|
return .{ ._inner = c.mlirRankedTensorTypeGet(
|
|
@intCast(dimensions.len),
|
|
@ptrCast(dimensions.ptr),
|
|
elemType._inner,
|
|
c.mlirAttributeGetNull(),
|
|
) };
|
|
}
|
|
|
|
pub fn getElementType(self: RankedTensorType) Type {
|
|
return .{ ._inner = c.mlirShapedTypeGetElementType(self._inner) };
|
|
}
|
|
|
|
pub fn getRank(self: RankedTensorType) usize {
|
|
return @intCast(c.mlirShapedTypeGetRank(self._inner));
|
|
}
|
|
|
|
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
|
return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim));
|
|
}
|
|
|
|
pub const asType = Type.fromAny(RankedTensorType);
|
|
};
|
|
|
|
pub const Dialect = struct {
|
|
_inner: c.MlirDialect,
|
|
|
|
const Self = Dialect;
|
|
|
|
pub fn getContext(self: Self) Context {
|
|
return .{ ._inner = c.mlirDialectGetContext(self._inner) };
|
|
}
|
|
|
|
pub fn getNamespace(self: Self) []const u8 {
|
|
return fromStringRef(c.mlirDialectGetNamespace(self._inner));
|
|
}
|
|
};
|
|
|
|
pub const DialectHandle = struct {
|
|
_inner: c.MlirDialectHandle,
|
|
|
|
pub fn getNamespace(self: DialectHandle) []const u8 {
|
|
return fromStringRef(c.mlirDialectHandleGetNamespace(self._inner));
|
|
}
|
|
|
|
pub fn insertDialect(self: DialectHandle, registry: Registry) void {
|
|
c.mlirDialectHandleInsertDialect(self._inner, registry._inner);
|
|
}
|
|
|
|
pub fn registerDialect(self: DialectHandle, ctx: Context) void {
|
|
c.mlirDialectHandleRegisterDialect(self._inner, ctx._inner);
|
|
}
|
|
|
|
pub fn loadDialect(self: DialectHandle, ctx: Context) Dialect {
|
|
return .{ ._inner = c.mlirDialectHandleLoadDialect(self._inner, ctx._inner) };
|
|
}
|
|
|
|
pub fn fromString(comptime namespace: []const u8) DialectHandle {
|
|
return .{ ._inner = @field(c, "mlirGetDialectHandle__" ++ namespace ++ "__")() };
|
|
}
|
|
};
|
|
|
|
pub const Location = struct {
|
|
_inner: c.MlirLocation,
|
|
|
|
pub const eql = helpers.eql(Type, c.mlirLocationEqual);
|
|
pub const format = helpers.format(Location, c.mlirLocationPrint);
|
|
|
|
pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location {
|
|
return .{ ._inner = c.mlirLocationFileLineColGet(
|
|
ctx._inner,
|
|
stringRef(src.file),
|
|
@intCast(src.line),
|
|
@intCast(src.column),
|
|
) };
|
|
}
|
|
|
|
pub fn fileLineCol(ctx: Context, file: []const u8, line: usize, column: usize) Location {
|
|
return .{ ._inner = c.mlirLocationFileLineColGet(
|
|
ctx._inner,
|
|
stringRef(file),
|
|
@intCast(line),
|
|
@intCast(column),
|
|
) };
|
|
}
|
|
|
|
pub fn callSite(callee: Location, caller: Location) Location {
|
|
return .{ ._inner = c.mlirLocationCallSiteGet(callee._inner, caller._inner) };
|
|
}
|
|
|
|
pub fn fused(ctx: Context, locations: []const Location, metadata: Attribute) Location {
|
|
return .{ ._inner = c.mlirLocationFusedGet(
|
|
ctx._inner,
|
|
@intCast(locations.len),
|
|
@ptrCast(locations.ptr),
|
|
metadata._inner,
|
|
) };
|
|
}
|
|
|
|
pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location {
|
|
return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) };
|
|
}
|
|
|
|
pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location {
|
|
var buf: [256]u8 = undefined;
|
|
var stream = std.io.fixedBufferStream(&buf);
|
|
std.fmt.format(stream.writer(), fmt, args) catch {
|
|
buf[256 - 3 ..].* = "...".*;
|
|
};
|
|
return loc.named(ctx, @ptrCast(stream.getWritten()));
|
|
}
|
|
|
|
pub fn unknown(ctx: Context) Location {
|
|
return .{ ._inner = c.mlirLocationUnknownGet(ctx._inner) };
|
|
}
|
|
};
|
|
|
|
pub const Block = struct {
|
|
_inner: c.MlirBlock,
|
|
|
|
pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull);
|
|
pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy);
|
|
|
|
pub const eql = helpers.eql(Block, c.mlirBlockEqual);
|
|
|
|
pub fn init(args: []const Type, locs: []const Location) !Block {
|
|
const block = Block.wrapOr(
|
|
c.mlirBlockCreate(@intCast(args.len), @ptrCast(args.ptr), @ptrCast(locs.ptr)),
|
|
);
|
|
return block orelse error.InvalidMlir;
|
|
}
|
|
|
|
pub fn argument(self: Block, index: usize) Value {
|
|
return .{ ._inner = c.mlirBlockGetArgument(self._inner, @intCast(index)) };
|
|
}
|
|
|
|
pub fn numArguments(self: Block) usize {
|
|
return @intCast(c.mlirBlockGetNumArguments(self._inner));
|
|
}
|
|
|
|
pub fn addArgument(self: *Block, typ: Type, loc: Location) Value {
|
|
return .{ ._inner = c.mlirBlockAddArgument(self._inner, typ._inner, loc._inner) };
|
|
}
|
|
|
|
pub fn insertArgument(self: *Block, index: usize, typ: Type, loc: Location) Value {
|
|
return .{ ._inner = c.mlirBlockInsertArgument(self._inner, @intCast(index), typ._inner, loc._inner) };
|
|
}
|
|
|
|
pub fn equals(self: Block, other: Block) bool {
|
|
return c.mlirBlockEqual(self._inner, other._inner);
|
|
}
|
|
|
|
pub fn appendOperation(self: Block, op: Operation) void {
|
|
c.mlirBlockAppendOwnedOperation(self._inner, op._inner);
|
|
}
|
|
|
|
pub fn appendOperations(self: *Block, ops: []const Operation) void {
|
|
for (ops) |op| {
|
|
c.mlirBlockAppendOwnedOperation(self._inner, op._inner);
|
|
}
|
|
}
|
|
|
|
pub const RecursiveOpts = enum { open, hermetic };
|
|
|
|
pub fn appendValueRecursive(self: Block, value: Value, opt: RecursiveOpts) void {
|
|
switch (value.kind()) {
|
|
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
|
|
.block_argument => |arg| {
|
|
// Hermetic blocks are not allowed to use arguments from other blocks.
|
|
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {f} from {*} block to {*} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
|
|
},
|
|
.null => @panic("InvalidMlir"),
|
|
}
|
|
}
|
|
|
|
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
|
|
if (op.block()) |prev_block| {
|
|
// Hermetic blocks are not allowed to reference values from other blocks.
|
|
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {*} block to {*} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
|
|
return;
|
|
}
|
|
for (0..op.numOperands()) |i| {
|
|
self.appendValueRecursive(op.operand(i), opt);
|
|
}
|
|
self.appendOperation(op);
|
|
}
|
|
};
|
|
|
|
pub const helpers = struct {
|
|
pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.c) bool) fn (T, T) bool {
|
|
return struct {
|
|
fn eql(a: T, b: T) bool {
|
|
return equal_fn(a._inner, b._inner);
|
|
}
|
|
}.eql;
|
|
}
|
|
|
|
pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (*T) void {
|
|
return struct {
|
|
fn deinit(a: *T) void {
|
|
deinit_fn(a._inner);
|
|
a.* = undefined;
|
|
}
|
|
}.deinit;
|
|
}
|
|
|
|
pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (T) void {
|
|
return struct {
|
|
fn dump(a: T) void {
|
|
return dump_fn(a._inner);
|
|
}
|
|
}.dump;
|
|
}
|
|
|
|
pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (T) bool {
|
|
return struct {
|
|
fn isNull(a: T) bool {
|
|
return is_null_fn(a._inner);
|
|
}
|
|
}.isNull;
|
|
}
|
|
|
|
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type {
|
|
return struct {
|
|
pub fn format(self: Any, writer: *std.Io.Writer) !void {
|
|
const WriterWithErr = struct {
|
|
writer: *std.Io.Writer,
|
|
err: ?std.Io.Writer.Error = null,
|
|
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
|
var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
|
|
if (ctx.err) |_| return;
|
|
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
|
ctx.err = err;
|
|
return;
|
|
};
|
|
}
|
|
};
|
|
|
|
var context: WriterWithErr = .{ .writer = writer };
|
|
print_fn(self._inner, &WriterWithErr.printCallback, &context);
|
|
if (context.err) |err| return err;
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
|
|
return struct {
|
|
fn wrapOr(inner: @FieldType(T, "_inner")) ?T {
|
|
if (is_null_fn(inner)) return null;
|
|
return .{ ._inner = inner };
|
|
}
|
|
}.wrapOr;
|
|
}
|
|
|
|
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T {
|
|
if (is_null_fn(inner)) return null;
|
|
return .{ ._inner = inner };
|
|
}
|
|
};
|