From 0a2ab7c8cb40307728f13fa81c3bf73beee47ad4 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 28 Jan 2025 09:35:58 +0000 Subject: [PATCH] Remove usingnamespace from MLIR. --- mlir/dialects/arith.zig | 4 +- mlir/dialects/stablehlo.zig | 310 +++++------ mlir/mlir.zig | 1045 +++++++++++++++-------------------- zml/mlir.zig | 176 ------ zml/mlirx.zig | 101 ++++ zml/module.zig | 15 +- zml/nn/cuda.zig | 4 +- zml/ops.zig | 47 +- zml/tensor.zig | 104 ++-- zml/zml.zig | 2 +- 10 files changed, 777 insertions(+), 1031 deletions(-) delete mode 100644 zml/mlir.zig create mode 100644 zml/mlirx.zig diff --git a/mlir/dialects/arith.zig b/mlir/dialects/arith.zig index f6d86d5..39cbc2a 100644 --- a/mlir/dialects/arith.zig +++ b/mlir/dialects/arith.zig @@ -73,7 +73,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, + .{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) }, }, .location = location, }); @@ -103,7 +103,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, + .{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) }, }, .location = location, }); diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 24466dd..8010985 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -127,10 +127,10 @@ pub const DotPrecision = union(enum) { // When we specify the dot algorithm, we should not specify the precision. .algorithm => .DEFAULT, }); - return precision.as(mlir.Attribute); + return precision.asAttr(); } - pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute { + pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute { return switch (self) { .algorithm => |algo| algo.asAttr(ctx, operand_type), else => null, @@ -156,15 +156,14 @@ pub const DotAlgorithm = struct { .allow_imprecise_accumulation = false, }; - pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute { - const tensor_type = operand_type.as(mlir.RankedTensorType); + pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute { const elem_type = tensor_type.getElementType(); return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( - ctx.inner(), - elem_type.inner(), - elem_type.inner(), - self.accumulation.asType(ctx).inner(), + ctx._inner, + elem_type._inner, + elem_type._inner, + self.accumulation.asType(ctx)._inner, self.component_count, self.component_count, self.num_primitive_operations, @@ -197,11 +196,11 @@ pub fn dot_general( .rhs_batching_dimensions = opts.rhs_batching_dimensions, .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, - }).as(mlir.Attribute), + }).asAttr(), }, .{ "precision_config", .array(ctx, &precisions) }, // keep algorithm as the last attribute so we can omit it when it's not set. - .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined }, + .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined }, }; const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1; return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ @@ -214,19 +213,15 @@ pub fn dot_general( pub fn constant( ctx: mlir.Context, - result_type: mlir.RankedTensorType, + dims: []const i64, elem_type: mlir.DenseElementsAttributeTypes, raw_bytes: []const u8, location: mlir.Location, ) mlir.Operation { - const attribute = switch (elem_type) { - inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute), - }; - return mlir.Operation.make(ctx, "stablehlo.constant", .{ .operands = &.{}, - .results = &.{result_type.as(mlir.Type)}, - .attributes = &.{.{ "value", attribute }}, + .results = &.{.tensor(dims, elem_type.mlirType(ctx))}, + .attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }}, .location = location, }); } @@ -285,10 +280,10 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64 }); } -pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation { +pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.reshape", .{ .operands = &.{value}, - .results = &.{result_type.as(mlir.Type)}, + .results = &.{result_type}, .location = location, }); } @@ -332,7 +327,7 @@ pub fn gather( args.start_indices_batching_dims, args.start_index_map, args.index_vector_dim, - ).as(mlir.Attribute) }, + ).asAttr() }, .{ "slice_sizes", .dense(ctx, .i64, slice_sizes) }, .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, }, @@ -358,22 +353,20 @@ pub const ScatterArgs = struct { unique_indices: bool = false, pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute { - return mlir.Attribute.wrap( - c.stablehloScatterDimensionNumbersGet( - ctx.inner(), - @intCast(self.update_window_dims.len), - self.update_window_dims.ptr, - @intCast(self.inserted_window_dims.len), - self.inserted_window_dims.ptr, - @intCast(self.input_batching_dims.len), - self.input_batching_dims.ptr, - @intCast(self.scatter_indices_batching_dims.len), - self.scatter_indices_batching_dims.ptr, - @intCast(self.scatter_dims_to_operand_dims.len), - self.scatter_dims_to_operand_dims.ptr, - self.index_vector_dim, - ), - ); + return .{ ._inner = c.stablehloScatterDimensionNumbersGet( + ctx._inner, + @intCast(self.update_window_dims.len), + self.update_window_dims.ptr, + @intCast(self.inserted_window_dims.len), + self.inserted_window_dims.ptr, + @intCast(self.input_batching_dims.len), + self.input_batching_dims.ptr, + @intCast(self.scatter_indices_batching_dims.len), + self.scatter_indices_batching_dims.ptr, + @intCast(self.scatter_dims_to_operand_dims.len), + self.scatter_dims_to_operand_dims.ptr, + self.index_vector_dim, + ) }; } }; @@ -431,8 +424,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "comparison_direction", comparison_direction.as(mlir.Attribute) }, - .{ "compare_type", compare_type.as(mlir.Attribute) }, + .{ "comparison_direction", comparison_direction.asAttr() }, + .{ "compare_type", compare_type.asAttr() }, }, .location = location, }); @@ -580,7 +573,7 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value, .{ "left_side", .i1FromBool(ctx, opts.left_side) }, .{ "lower", .i1FromBool(ctx, opts.lower) }, .{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) }, - .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) }, + .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() }, }, .location = location, }); @@ -596,7 +589,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts: .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) }, + .{ "fft_type", FftType.init(ctx, opts.kind).asAttr() }, .{ "fft_length", .dense(ctx, .i64, opts.length) }, }, .location = location, @@ -608,7 +601,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r .operands = &.{ a, b, shape }, .result_type_inference = true, .attributes = &.{ - .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) }, + .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() }, }, .location = location, }); @@ -619,7 +612,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in .operands = &.{initial_state}, .results = &.{ res_state_type, res_type }, .attributes = &.{ - .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) }, + .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() }, }, .location = location, }); @@ -695,7 +688,7 @@ pub fn convolution( ) mlir.Operation { var max_precisions: [2]mlir.Attribute = undefined; for (opts.precision_config, 0..) |p, i| { - max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute); + max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr(); } var window_reversal: [3]i32 = undefined; for (opts.window_reversal, 0..) |w, i| { @@ -721,7 +714,7 @@ pub fn convolution( .output_batch_dimension = opts.output_batch_dimension, .output_feature_dimension = opts.output_feature_dimension, .output_spatial_dimensions = opts.output_spatial_dimensions, - }).as(mlir.Attribute), + }).asAttr(), }, .{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) }, .{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) }, @@ -756,13 +749,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, ""); if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { stdx.debug.assert( - backend_config.is_a(mlir.StringAttribute), + backend_config.isA(mlir.StringAttribute), "API version < 4 requires a string as backend_config, got {}", .{backend_config}, ); } else { stdx.debug.assert( - backend_config.is_a(mlir.DictionaryAttribute), + backend_config.isA(mlir.DictionaryAttribute), "API version >= 4 requires a dictionary as backend_config, got {}", .{backend_config}, ); @@ -780,7 +773,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (opts.output_operand_aliases) |alias| { output_operand_aliases.appendAssumeCapacity( - OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), + OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(), ); } attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); @@ -805,7 +798,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const operand_layouts = blk: { var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (inputs) |input| { - const ranked_type = input.getType().as(mlir.RankedTensorType); + const ranked_type = input.getType().as(mlir.RankedTensorType).?; const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); } @@ -824,7 +817,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const result_layouts = blk: { var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (res_types) |t| { - const ranked_t = t.as(mlir.RankedTensorType); + const ranked_t = t.as(mlir.RankedTensorType).?; const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); } @@ -846,13 +839,10 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa pub const DotDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(DotDimensionNumbersAttribute, .{ - .is_a_fn = c.stablehloAttributeIsADotDimensionNumbers, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers; const Self = DotDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub fn init(ctx: mlir.Context, args: struct { lhs_batching_dimensions: []const i64, @@ -860,9 +850,9 @@ pub const DotDimensionNumbersAttribute = struct { lhs_contracting_dimensions: []const i64, rhs_contracting_dimensions: []const i64, }) Self { - return Self.wrap( - c.stablehloDotDimensionNumbersGet( - ctx.inner(), + return .{ + ._inner = c.stablehloDotDimensionNumbersGet( + ctx._inner, @intCast(args.lhs_batching_dimensions.len), args.lhs_batching_dimensions.ptr, @intCast(args.rhs_batching_dimensions.len), @@ -872,52 +862,49 @@ pub const DotDimensionNumbersAttribute = struct { @intCast(args.rhs_contracting_dimensions.len), args.rhs_contracting_dimensions.ptr, ), - ); + }; } pub fn getLhsBatchingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self.inner())); + return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner)); } pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos)); } pub fn getRhsBatchingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self.inner())); + return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner)); } pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos)); } pub fn getLhsContractingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self.inner())); + return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner)); } pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos)); } pub fn getRhsContractingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self.inner())); + return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner)); } pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos)); } }; pub const GatherDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(GatherDimensionNumbersAttribute, .{ - .is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers; const Self = GatherDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub fn init( ctx: mlir.Context, @@ -928,9 +915,9 @@ pub const GatherDimensionNumbersAttribute = struct { start_index_map: []const i64, index_vector_dim: i64, ) Self { - return Self.wrap( - c.stablehloGatherDimensionNumbersGet( - ctx.inner(), + return .{ + ._inner = c.stablehloGatherDimensionNumbersGet( + ctx._inner, @intCast(offset_dims.len), offset_dims.ptr, @intCast(collapsed_slice_dims.len), @@ -943,64 +930,61 @@ pub const GatherDimensionNumbersAttribute = struct { start_index_map.ptr, index_vector_dim, ), - ); + }; } pub fn getOffsetDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner)); } pub fn getOffsetDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self.inner(), @intCast(pos)); + return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos)); } pub fn getCollapsedSliceDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner)); } pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self.inner(), @intCast(pos)); + return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos)); } pub fn getStartIndexMapSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner)); } pub fn getOperandBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner)); } pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self.inner(), @intCast(pos)); + return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos)); } pub fn getStartIndicesBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner)); } pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self.inner(), @intCast(pos)); + return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos)); } pub fn getStartIndexMapElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self.inner(), @intCast(pos)); + return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos)); } pub fn getIndexVectorDim(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self.inner())); + return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner)); } }; pub const ConvDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(ConvDimensionNumbersAttribute, .{ - .is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers; const Self = ConvDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub fn init(ctx: mlir.Context, args: struct { input_batch_dimension: i64, @@ -1013,9 +997,9 @@ pub const ConvDimensionNumbersAttribute = struct { output_feature_dimension: i64, output_spatial_dimensions: []const i64, }) Self { - return Self.wrap( - c.stablehloConvDimensionNumbersGet( - ctx.inner(), + return .{ + ._inner = c.stablehloConvDimensionNumbersGet( + ctx._inner, args.input_batch_dimension, args.input_feature_dimension, @intCast(args.input_spatial_dimensions.len), @@ -1029,67 +1013,64 @@ pub const ConvDimensionNumbersAttribute = struct { @intCast(args.output_spatial_dimensions.len), args.output_spatial_dimensions.ptr, ), - ); + }; } pub fn getInputBatchDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetInputBatchDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner); } pub fn getInputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner); } pub fn getInputSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self.inner())); + return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner)); } pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos)); } pub fn getKernelInputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner); } pub fn getKernelOutputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner); } pub fn getKernelSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self.inner())); + return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner)); } pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos)); } pub fn getOutputBatchDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner); } pub fn getOutputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self.inner()); + return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner); } pub fn getOutputSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self.inner())); + return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner)); } pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self.inner(), @intCast(pos)); + return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos)); } }; pub const OutputOperandAliasAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(OutputOperandAliasAttribute, .{ - .is_a_fn = c.stablehloAttributeIsAOutputOperandAlias, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias; + pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute); + pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute); pub fn init( ctx: mlir.Context, @@ -1097,27 +1078,24 @@ pub const OutputOperandAliasAttribute = struct { operand_index: i64, operand_tuple_indices: []const i64, ) OutputOperandAliasAttribute { - return OutputOperandAliasAttribute.wrap(c.stablehloOutputOperandAliasGet( - ctx.inner(), + return .{ ._inner = c.stablehloOutputOperandAliasGet( + ctx._inner, @intCast(output_tuple_indices.len), output_tuple_indices.ptr, @intCast(operand_index), @intCast(operand_tuple_indices.len), operand_tuple_indices.ptr, - )); + ) }; } }; pub const PrecisionAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(PrecisionAttribute, .{ - .is_a_fn = c.stablehloAttributeIsAPrecisionAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr; const Self = PrecisionAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Precision = enum { DEFAULT, @@ -1126,11 +1104,11 @@ pub const PrecisionAttribute = struct { }; pub fn init(ctx: mlir.Context, value: Precision) Self { - return Self.wrap(c.stablehloPrecisionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Precision { - const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner)); return std.meta.stringToEnum(Precision, value) orelse unreachable; } }; @@ -1138,13 +1116,10 @@ pub const PrecisionAttribute = struct { pub const ComparisonDirection = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(ComparisonDirection, .{ - .is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr; const Self = ComparisonDirection; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Direction = enum { EQ, @@ -1156,11 +1131,11 @@ pub const ComparisonDirection = struct { }; pub fn init(ctx: mlir.Context, value: Direction) Self { - return Self.wrap(c.stablehloComparisonDirectionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Direction { - const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner)); return std.meta.stringToEnum(Direction, value) orelse unreachable; } }; @@ -1168,13 +1143,10 @@ pub const ComparisonDirection = struct { pub const CompareType = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(CompareType, .{ - .is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr; const Self = CompareType; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Type = enum { SIGNED, @@ -1184,11 +1156,11 @@ pub const CompareType = struct { }; pub fn init(ctx: mlir.Context, value: Type) Self { - return Self.wrap(c.stablehloComparisonTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner)); return std.meta.stringToEnum(Type, value) orelse unreachable; } }; @@ -1196,13 +1168,10 @@ pub const CompareType = struct { pub const Transpose = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(Transpose, .{ - .is_a_fn = c.stablehloAttributeIsATransposeAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsATransposeAttr; const Self = Transpose; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Type = enum { NO_TRANSPOSE, @@ -1211,11 +1180,11 @@ pub const Transpose = struct { }; pub fn init(ctx: mlir.Context, value: Type) Self { - return Self.wrap(c.stablehloTransposeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner)); return std.meta.stringToEnum(Type, value) orelse unreachable; } }; @@ -1223,13 +1192,10 @@ pub const Transpose = struct { pub const FftType = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(FftType, .{ - .is_a_fn = c.stablehloAttributeIsAFftTypeAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr; const Self = FftType; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Type = enum { FFT, @@ -1239,11 +1205,11 @@ pub const FftType = struct { }; pub fn init(ctx: mlir.Context, value: Type) Self { - return Self.wrap(c.stablehloFftTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner)); return std.meta.stringToEnum(Type, value) orelse unreachable; } }; @@ -1251,13 +1217,10 @@ pub const FftType = struct { pub const RngDistribution = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(RngDistribution, .{ - .is_a_fn = c.stablehloAttributeIsARngDistributionAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr; const Self = RngDistribution; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Type = enum { UNIFORM, @@ -1265,11 +1228,11 @@ pub const RngDistribution = struct { }; pub fn init(ctx: mlir.Context, value: Type) Self { - return Self.wrap(c.stablehloRngDistributionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner)); return std.meta.stringToEnum(Type, value) orelse unreachable; } }; @@ -1277,13 +1240,10 @@ pub const RngDistribution = struct { pub const RngAlgorithm = struct { _inner: c.MlirAttribute, - pub usingnamespace mlir.MlirHelpers(RngAlgorithm, .{ - .is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr; const Self = RngAlgorithm; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); pub const Type = enum { DEFAULT, @@ -1292,11 +1252,11 @@ pub const RngAlgorithm = struct { }; pub fn init(ctx: mlir.Context, value: Type) Self { - return Self.wrap(c.stablehloRngAlgorithmAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); + return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; } pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self.inner())); + const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner)); return std.meta.stringToEnum(Type, value) orelse unreachable; } }; diff --git a/mlir/mlir.zig b/mlir/mlir.zig index a4f378c..918c304 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -37,151 +37,24 @@ pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void { return if (res.value == 0) err; } -pub fn MlirTypeMethods(comptime InnerT: type) type { - return struct { - is_null_fn: ?fn (InnerT) callconv(.C) bool = null, - is_a_fn: ?fn (InnerT) callconv(.C) bool = null, - equal_fn: ?fn (InnerT, InnerT) callconv(.C) bool = null, - dump_fn: ?fn (InnerT) callconv(.C) void = null, - deinit_fn: ?fn (InnerT) callconv(.C) void = null, - }; -} - /// Alternative to MlirWrapperType pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void; -fn MlirHelpersMethods(OuterT: type) type { - switch (@typeInfo(OuterT)) { - .@"struct" => |info| { - if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT)); - }, - else => @compileError("MlirHelpersMethods is only available on an Mlir wrapper struct. Received: " ++ @typeName(OuterT)), - } - - return struct { - pub const InnerT = std.meta.FieldType(OuterT, ._inner); - comptime { - std.debug.assert(@sizeOf(InnerT) == @sizeOf(OuterT)); - } - - is_null_fn: ?fn (InnerT) callconv(.C) bool = null, - is_a_fn: ?fn (InnerT) callconv(.C) bool = null, - equal_fn: ?fn (InnerT, InnerT) callconv(.C) bool = null, - deinit_fn: ?fn (InnerT) callconv(.C) void = null, - dump_fn: ?fn (InnerT) callconv(.C) void = null, - print_fn: ?fn (InnerT, ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void = null, - }; -} - -pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(OuterT)) type { - const InnerT = @TypeOf(methods).InnerT; - return struct { - pub const Methods = methods; - - pub inline fn wrap(raw: InnerT) OuterT { - return .{ ._inner = raw }; - } - - pub inline fn inner(self: OuterT) InnerT { - return self._inner; - } - - pub inline fn innerPtr(self: *OuterT) *InnerT { - return &self._inner; - } - - pub inline fn is_a(self: OuterT, comptime otherT: type) bool { - if (otherT.Methods.is_a_fn) |is_a_fn| { - return is_a_fn(self.inner()); - } - return false; - } - - pub inline fn as(self: OuterT, comptime OtherT: type) OtherT { - if (OtherT.Methods.is_a_fn) |is_a_fn| { - stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT }); - return OtherT.wrap(self.inner()); - } - // if the other type doesn't have an is_a_fn, try. - return OtherT.wrap(self.inner()); - } - - pub usingnamespace if (Methods.is_null_fn) |is_null| struct { - pub inline fn wrapOr(raw: InnerT) ?OuterT { - return if (is_null(raw)) null else OuterT.wrap(raw); - } - } else struct {}; - - pub usingnamespace if (Methods.equal_fn) |equal| struct { - pub inline fn eql(self: OuterT, other: OuterT) bool { - return equal(self.inner(), other.inner()); - } - } else struct {}; - - pub usingnamespace if (Methods.deinit_fn) |_deinit| struct { - pub inline fn deinit(self: *OuterT) void { - _deinit(self.inner()); - self.* = undefined; - } - } else struct {}; - - pub usingnamespace if (Methods.dump_fn) |_dump| struct { - pub inline fn dump(self: OuterT) void { - return _dump(self.inner()); - } - } else struct {}; - - pub usingnamespace if (Methods.print_fn) |print| struct { - pub fn format( - self: OuterT, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; - - const Writer = struct { - writer: @TypeOf(writer), - err: ?@TypeOf(writer).Error = null, - fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void { - var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); - if (ctx.err) |_| return; - _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { - ctx.err = err; - return; - }; - } - }; - - var context: Writer = .{ .writer = writer }; - print(self.inner(), &Writer.printCallback, &context); - if (context.err) |err| return err; - } - } else struct {}; - }; -} - pub const Registry = struct { _inner: c.MlirDialectRegistry, - pub usingnamespace MlirHelpers(Registry, .{ - .is_null_fn = c.mlirDialectRegistryIsNull, - .deinit_fn = c.mlirDialectRegistryDestroy, - }); - const Self = Registry; - pub fn init() !Self { - return Self.wrapOr(c.mlirDialectRegistryCreate()) orelse Error.MlirUnexpected; + pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy); + + pub fn init() !Registry { + return helpers.init(Registry, c.mlirDialectRegistryCreate(), c.mlirDialectRegistryIsNull) orelse Error.MlirUnexpected; } }; pub const Context = struct { _inner: c.MlirContext, - pub usingnamespace MlirHelpers(Context, .{ - .is_null_fn = c.mlirContextIsNull, - .deinit_fn = c.mlirContextDestroy, - }); const Self = Context; + pub const deinit = helpers.deinit(Context, c.mlirContextDestroy); + pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull); pub fn init() !Self { return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected; @@ -189,32 +62,32 @@ pub const Context = struct { pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self { return Self.wrapOr( - c.mlirContextCreateWithRegistry(registry.inner(), threadingEnabled), + c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled), ) orelse Error.InvalidMlir; } pub fn setMultiThreading(self: *Self, enabled: bool) void { - c.mlirContextEnableMultithreading(self.inner(), enabled); + c.mlirContextEnableMultithreading(self._inner, enabled); } pub fn appendDialectRegistry(self: *Self, registry: Registry) void { - c.mlirContextAppendDialectRegistry(self.inner(), registry.inner()); + c.mlirContextAppendDialectRegistry(self._inner, registry._inner); } pub fn loadAllAvailableDialects(self: *Self) void { - c.mlirContextLoadAllAvailableDialects(self.inner()); + c.mlirContextLoadAllAvailableDialects(self._inner); } pub fn numRegisteredDialects(self: Self) usize { - return @intCast(c.mlirContextGetNumRegisteredDialects(self.inner())); + return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner)); } pub fn numLoadedDialects(self: Self) usize { - return @intCast(c.mlirContextGetNumLoadedDialects(self.inner())); + return @intCast(c.mlirContextGetNumLoadedDialects(self._inner)); } pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool { - return c.mlirContextIsRegisteredOperation(self.inner(), stringRef(op)); + return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op)); } pub fn location(self: Self, src: std.builtin.SourceLocation) Location { @@ -224,36 +97,36 @@ pub const Context = struct { pub const Module = struct { _inner: c.MlirModule, - pub usingnamespace MlirHelpers(Module, .{ - .is_null_fn = c.mlirModuleIsNull, - .deinit_fn = c.mlirModuleDestroy, - }); + + pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy); + pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull); + const Self = Module; pub fn init(loc: Location) Self { - return Self.wrap(c.mlirModuleCreateEmpty(loc.inner())); + return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) }; } pub fn parse(ctx: Context, source: [:0]const u8) !Module { return Module.wrapOr( - c.mlirModuleCreateParse(ctx.inner(), stringRef(source)), + c.mlirModuleCreateParse(ctx._inner, stringRef(source)), ) orelse Error.InvalidMlir; } pub fn fromOperation(operation: Operation) Module { - return Module.wrap(c.mlirModuleFromOperation(operation.inner())); + return .{ ._inner = c.mlirModuleFromOperation(operation._inner) }; } pub fn context(self: Module) Context { - return Context.wrap(c.mlirModuleGetContext(self.inner())); + return .{ ._inner = c.mlirModuleGetContext(self._inner) }; } pub fn getBody(self: Module) Block { - return Block.wrap(c.mlirModuleGetBody(self.inner())); + return .{ ._inner = c.mlirModuleGetBody(self._inner) }; } pub fn op(self: Module) Operation { - return Operation.wrap(c.mlirModuleGetOperation(self.inner())); + return .{ ._inner = c.mlirModuleGetOperation(self._inner) }; } pub fn hash(self: Module, hasher: *std.hash.XxHash64) void { @@ -264,34 +137,33 @@ pub const Module = struct { pub const PassManager = struct { _inner: c.MlirPassManager, - pub usingnamespace MlirHelpers(PassManager, .{ - .is_null_fn = c.mlirPassManagerIsNull, - .deinit_fn = c.mlirPassManagerDestroy, - }); + pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy); + pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull); + const Self = PassManager; pub fn init(ctx: Context) !Self { return Self.wrapOr( - c.mlirPassManagerCreate(ctx.inner()), + c.mlirPassManagerCreate(ctx._inner), ) orelse Error.MlirUnexpected; } pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self { return Self.wrapOr( - c.mlirPassManagerCreateOnOperation(ctx.inner(), stringRef(op)), + c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)), ) orelse Error.MlirUnexpected; } pub fn asOpPassManager(self: Self) OpPassManager { - return OpPassManager.wrap(c.mlirPassManagerGetAsOpPassManager(self.inner())); + return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) }; } pub fn enableIRPrinting(self: *Self) void { - c.mlirPassManagerEnableIRPrinting(self.inner()); + c.mlirPassManagerEnableIRPrinting(self._inner); } pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void { - if (c.mlirPassManagerRunOnOp(self.inner(), op.inner()).value == 0) { + if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) { return Error.InvalidMlir; } } @@ -304,11 +176,10 @@ fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.C) pub const OpPassManager = struct { _inner: c.MlirOpPassManager, - pub usingnamespace MlirHelpers(OpPassManager, .{}); pub fn addPipeline(self: *OpPassManager, pipeline: [:0]const u8) error{OutOfMemory}!void { if (c.mlirOpPassManagerAddPipeline( - self.inner(), + self._inner, stringRef(pipeline), &_mlir_passpipeline_error, null, @@ -320,23 +191,22 @@ pub const OpPassManager = struct { pub const Identifier = struct { _inner: c.MlirIdentifier, - pub usingnamespace MlirHelpers(Identifier, .{}); const Self = Identifier; pub fn get(ctx: Context, str_: [:0]const u8) Self { - return Self.wrap(c.mlirIdentifierGet(ctx.inner(), stringRef(str_))); + return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) }; } pub fn context(self: Self) Context { - return Context.wrap(c.mlirIdentifierGetContext(self.inner())); + return .{ ._inner = c.mlirIdentifierGetContext(self._inner) }; } pub fn str(self: Self) []const u8 { - return fromStringRef(c.mlirIdentifierStr(self.inner())); + return fromStringRef(c.mlirIdentifierStr(self._inner)); } pub fn equals(self: Self, other: Self) bool { - return c.mlirIdentifierEqual(self.inner(), other.inner()); + return c.mlirIdentifierEqual(self._inner, other._inner); } }; @@ -344,19 +214,34 @@ pub const AttrTuple = struct { [:0]const u8, Attribute }; pub const Attribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(Attribute, .{ - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); - const Self = Attribute; + + pub const dump = helpers.dump(Attribute, c.mlirAttributeDump); + pub const eql = helpers.eql(Attribute, c.mlirAttributeEqual); + pub const format = helpers.format(Attribute, c.mlirAttributePrint); + pub const wrapOr = helpers.wrapOr(Attribute, c.mlirAttributeIsNull); + + pub fn wrap(c_attr: c.MlirAttribute) Attribute { + return .{ ._inner = c_attr }; + } pub fn parse(ctx: Context, attr: [:0]const u8) !Attribute { return Attribute.wrapOr( - c.mlirAttributeParseGet(ctx.inner(), stringRef(attr)), + c.mlirAttributeParseGet(ctx._inner, stringRef(attr)), ) orelse Error.InvalidMlir; } + pub fn fromAny(SpecificAttr: type) fn (x: SpecificAttr) Attribute { + return struct { + fn cast(x: SpecificAttr) Attribute { + return .{ ._inner = x._inner }; + } + }.cast; + } + + pub fn isA(self: Attribute, SpecificAttr: type) bool { + return SpecificAttr.is_a_fn(self._inner); + } + // utilities function to built common attributes. // All attributes are upcasted to the Attribute type, making it easier to chain construct, // but losing type information. @@ -374,7 +259,7 @@ pub const Attribute = struct { } pub fn unit(ctx: Context) Attribute { - return .wrap(c.mlirUnitAttrGet(ctx.inner())); + return .wrap(c.mlirUnitAttrGet(ctx._inner)); } pub fn boolean(ctx: Context, value: bool) Attribute { @@ -390,7 +275,7 @@ pub const Attribute = struct { } pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute { - return .wrap(FloatAttribute(float_type).init(ctx, value)._inner); + return FloatAttribute(float_type).init(ctx, value).asAttr(); } pub fn array(ctx: Context, attrs: []const Attribute) Attribute { @@ -403,16 +288,25 @@ pub const Attribute = struct { /// Use a tensor as an attribute. /// The tensor is specified by dims, dtype and a flat slice of values. - pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: anytype) Attribute { - return .wrap(DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).inner()); + pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: []const dt.ZigType()) Attribute { + return DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).asAttr(); + } + + pub fn denseElementsFromBytes(ctx: Context, dims: []const i64, dt: DenseElementsAttributeTypes, raw_bytes: []const u8) Attribute { + const shape: Type = .tensor(dims, dt.mlirType(ctx)); + return .{ ._inner = c.mlirDenseElementsAttrRawBufferGet( + shape._inner, + @intCast(raw_bytes.len), + raw_bytes.ptr, + ) }; } pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute { - return .wrap(FlatSymbolRefAttribute.init(ctx, flat_name).inner()); + return FlatSymbolRefAttribute.init(ctx, flat_name).asAttr(); } pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute { - return NamedAttribute.init(Identifier.get(ctx, name), attr); + return NamedAttribute.named(ctx, name, attr); } pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute { @@ -429,115 +323,92 @@ pub const Attribute = struct { }; pub const NamedAttribute = extern struct { - name: c.MlirIdentifier, - attribute: c.MlirAttribute, + _inner: c.MlirNamedAttribute, + + pub fn wrap(c_named_attribute: c.MlirNamedAttribute) NamedAttribute { + return @bitCast(c_named_attribute); + } pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute { - return .{ + return .{ ._inner = .{ .name = c.mlirIdentifierGet(ctx._inner, stringRef(name)), - .attribute = attr.inner(), - }; + .attribute = attr._inner, + } }; } pub fn init(name: Identifier, attr: Attribute) NamedAttribute { - return .{ - .name = name.inner(), - .attribute = attr.inner(), - }; + return .{ ._inner = .{ + .name = name._inner, + .attribute = attr._inner, + } }; } }; pub const StringAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(StringAttribute, .{ - .is_a_fn = c.mlirAttributeIsAString, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsAString; const Self = StringAttribute; + pub const asAttr = Attribute.fromAny(Self); + pub const eql = Attribute.eqlAny(Self); pub fn init(ctx: Context, str: []const u8) Self { - return Self.wrap(c.mlirStringAttrGet(ctx.inner(), stringRef(str))); + return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) }; } pub fn value(self: Self) []const u8 { - return fromStringRef(c.mlirStringAttrGetValue(self.inner())); - } - - pub fn asAttr(self: StringAttribute) Attribute { - return .{ ._inner = self._inner }; + return fromStringRef(c.mlirStringAttrGetValue(self._inner)); } }; pub const BoolAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(BoolAttribute, .{ - .is_a_fn = c.mlirAttributeIsABool, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsABool; const Self = BoolAttribute; + pub const asAttr = Attribute.fromAny(Self); + pub const eql = Attribute.eqlAny(Self); pub fn init(ctx: Context, value_: bool) Self { - return Self.wrap(c.mlirBoolAttrGet(ctx.inner(), if (value_) 1 else 0)); + return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) }; } pub fn value(self: Self) bool { - return c.mlirBoolAttrGetValue(self.inner()); - } - - pub fn asAttr(self: Self) Attribute { - return self.as(Attribute); + return c.mlirBoolAttrGetValue(self._inner); } }; pub const TypeAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(TypeAttribute, .{ - .is_a_fn = c.mlirAttributeIsAType, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsAType; + pub const eql = Attribute.eqlAny(TypeAttribute); + pub fn init(type_: Type) TypeAttribute { - return TypeAttribute.wrap(c.mlirTypeAttrGet(type_.inner())); + return .{ ._inner = c.mlirTypeAttrGet(type_._inner) }; } pub fn typ(self: TypeAttribute) Type { - return Type.wrap(c.mlirAttributeGetType(self.inner())); + return .{ ._inner = c.mlirAttributeGetType(self._inner) }; } - pub fn asAttr(self: TypeAttribute) Attribute { - return self.as(Attribute); - } + pub const asAttr = Attribute.fromAny(TypeAttribute); }; pub const ArrayAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(ArrayAttribute, .{ - .is_a_fn = c.mlirAttributeIsAArray, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsAArray; const Self = ArrayAttribute; + pub const asAttr = Attribute.fromAny(Self); + pub const eql = Attribute.eqlAny(Self); pub fn init(ctx: Context, attrs: []const Attribute) Self { - return Self.wrap(c.mlirArrayAttrGet(ctx.inner(), @intCast(attrs.len), @ptrCast(attrs.ptr))); + return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) }; } pub fn size(self: Self) usize { - return @intCast(c.mlirArrayAttrGetNumElements(self.inner())); + return @intCast(c.mlirArrayAttrGetNumElements(self._inner)); } pub fn get(self: Self, index: usize) Attribute { - return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index))); - } - - pub fn asAttr(self: Self) Attribute { - return .{ ._inner = self._inner }; + return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) }; } }; @@ -551,28 +422,23 @@ pub fn IntegerAttribute(comptime it: IntegerTypes) type { return struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(@This(), .{ - .is_a_fn = c.mlirAttributeIsAInteger, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsAInteger; + pub const IntegerTypeType = IntegerType(it); const IntAttr = @This(); + pub const asAttr = Attribute.fromAny(IntAttr); + pub const eql = Attribute.eqlAny(IntAttr); + pub fn init(ctx: Context, value: i64) IntAttr { - return IntAttr.wrap(c.mlirIntegerAttrGet( - IntegerType(it).init(ctx).inner(), + return .{ ._inner = c.mlirIntegerAttrGet( + IntegerType(it).init(ctx)._inner, value, - )); + ) }; } pub fn get(value: IntAttr) ZigType { - return @intCast(getter(value.inner())); - } - - pub fn asAttr(self: IntAttr) Attribute { - return .{ ._inner = self._inner }; + return @intCast(getter(value._inner)); } }; } @@ -580,23 +446,20 @@ pub fn IntegerAttribute(comptime it: IntegerTypes) type { pub fn FloatAttribute(comptime ft: FloatTypes) type { return struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(@This(), .{ - .is_a_fn = c.mlirAttributeIsAFloat, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsAFloat; const FloatAttr = @This(); + pub const asAttr = Attribute.fromAny(FloatAttr); + pub fn init(ctx: Context, value: f64) FloatAttr { - return FloatAttr.wrap(c.mlirFloatAttrDoubleGet( - ctx.inner(), - FloatType(ft).init(ctx).inner(), + return .{ ._inner = c.mlirFloatAttrDoubleGet( + ctx._inner, + FloatType(ft).init(ctx)._inner, value, - )); + ) }; } pub fn get(value: FloatAttr) f64 { - return c.mlirFloatAttrGetValueDouble(value.inner()); + return c.mlirFloatAttrGetValueDouble(value._inner); } }; } @@ -624,7 +487,7 @@ pub const DenseArrayTypes = enum { }; pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type { - const is_a_fn, const get_fn, const get_element_fn = switch (dt) { + const _is_a_fn, const get_fn, const get_element_fn = switch (dt) { .bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement }, .i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement }, .i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement }, @@ -636,42 +499,25 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type { return struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(@This(), .{ - .is_a_fn = is_a_fn, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); const Attr = @This(); const ElementType = dt; const ElementTypeZig = dt.ZigType(); + pub const asAttr = Attribute.fromAny(Attr); + pub const eql = Attribute.eqlAny(Attr); + pub const is_a_fn = _is_a_fn; + pub fn init(ctx: Context, values: []const ElementTypeZig) Attr { - return Attr.wrap(get_fn(ctx.inner(), @intCast(values.len), @ptrCast(values.ptr))); + return .{ ._inner = get_fn(ctx._inner, @intCast(values.len), @ptrCast(values.ptr)) }; } pub fn get(self: Attr, pos: usize) ElementTypeZig { - return get_element_fn(self.inner(), @intCast(pos)); + return get_element_fn(self._inner, @intCast(pos)); } pub fn len(self: Attr) usize { - return @intCast(c.mlirDenseArrayGetNumElements(self.inner())); + return @intCast(c.mlirDenseArrayGetNumElements(self._inner)); } - - pub fn asAttr(self: Attr) Attribute { - return Attribute.wrap(self._inner); - } - - pub usingnamespace switch (dt) { - .bool, .i64 => struct { - const DenseArray = DenseArrayAttribute(switch (dt) { - .bool => .bool, - .i64 => .i64, - else => @compileError("DenseArrayAttribute: unreachable"), - }); - }, - else => struct {}, - }; }; } @@ -736,155 +582,140 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { const Attr = @This(); - pub usingnamespace MlirHelpers(Attr, .{ - .is_a_fn = c.mlirAttributeIsADenseElements, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsADenseElements; + pub const asAttr = Attribute.fromAny(Attr); + pub const eql = Attribute.eqlAny(Attr); - pub fn init(shaped_type: Type, slice: anytype) Attr { - const bytes = std.mem.sliceAsBytes(slice); - const v = Attr.wrapOr( - c.mlirDenseElementsAttrRawBufferGet( - shaped_type.inner(), - @intCast(bytes.len), - @ptrCast(bytes.ptr), - ), - ) orelse unreachable; - return v; + pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr { + const raw_bytes = std.mem.sliceAsBytes(slice); + const res: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet( + shaped_type._inner, + @intCast(raw_bytes.len), + @ptrCast(raw_bytes.ptr), + ) }; + std.debug.assert(Attribute.wrapOr(res._inner) != null); + return res; } pub fn len(self: Attr) usize { - return @intCast(c.mlirElementsAttrGetNumElements(self.inner())); + return @intCast(c.mlirElementsAttrGetNumElements(self._inner)); } - pub fn constSlice(self: Attr) []const dt.ZigType() { - const ptr: [*]const dt.ZigType() = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable))); + pub fn items(self: Attr) []const dt.ZigType() { + const raw_bytes: [*]const u8 = c.mlirDenseElementsAttrGetRawData(self._inner) orelse unreachable; + const ptr: [*]const dt.ZigType() = @alignCast(@ptrCast(raw_bytes)); + // Note the mlir API returns us the number of elements, not the number of bytes, + // that's why we track the element type at comptime to allow items to work. return ptr[0..self.len()]; } - pub fn data(self: Attr) []const u8 { - return std.mem.sliceAsBytes(self.constSlice()); + pub fn bytes(self: Attr) []const u8 { + return std.mem.sliceAsBytes(self.items()); } }; } pub const FlatSymbolRefAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(FlatSymbolRefAttribute, .{ - .is_a_fn = c.mlirAttributeIsAFlatSymbolRef, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); - + pub const is_a_fn = c.mlirAttributeIsAFlatSymbolRef; const Self = FlatSymbolRefAttribute; + pub const eql = Attribute.eqlAny(Self); pub fn init(ctx: Context, str: [:0]const u8) Self { - return Self.wrap(c.mlirFlatSymbolRefAttrGet(ctx.inner(), stringRef(str))); + return .{ ._inner = c.mlirFlatSymbolRefAttrGet(ctx._inner, stringRef(str)) }; } pub fn value(self: Self) []const u8 { - return fromStringRef(c.mlirFlatSymbolRefAttrGetValue(self.inner())); + return fromStringRef(c.mlirFlatSymbolRefAttrGetValue(self._inner)); } + + pub const asAttr = Attribute.fromAny(Self); }; pub const OperationState = struct { _inner: c.MlirOperationState, - pub usingnamespace MlirHelpers( - OperationState, - .{}, - ); const Self = OperationState; pub fn init(name: [:0]const u8, loc: Location) Self { - return Self.wrap(c.mlirOperationStateGet(stringRef(name), loc.inner())); + return .{ ._inner = c.mlirOperationStateGet(stringRef(name), loc._inner) }; } pub fn addResult(self: *Self, type_: Type) void { - c.mlirOperationStateAddResults(self.innerPtr(), 1, &[_]c.MlirType{type_.inner()}); + c.mlirOperationStateAddResults(&self._inner, 1, &[_]c.MlirType{type_._inner}); } pub fn addResults(self: *Self, types: []const Type) void { - c.mlirOperationStateAddResults(self.innerPtr(), @intCast(types.len), @ptrCast(types.ptr)); + c.mlirOperationStateAddResults(&self._inner, @intCast(types.len), @ptrCast(types.ptr)); } pub fn addOperand(self: *Self, value: Value) void { - c.mlirOperationStateAddOperands(self.innerPtr(), 1, &[_]c.MlirValue{value.inner()}); + c.mlirOperationStateAddOperands(&self._inner, 1, &[_]c.MlirValue{value._inner}); } pub fn addOperands(self: *Self, values: []const Value) void { - c.mlirOperationStateAddOperands(self.innerPtr(), @intCast(values.len), @ptrCast(values.ptr)); + c.mlirOperationStateAddOperands(&self._inner, @intCast(values.len), @ptrCast(values.ptr)); } pub fn addRegion(self: *Self, region: *Region) void { - c.mlirOperationStateAddOwnedRegions(self.innerPtr(), 1, &[_]c.MlirRegion{region.inner()}); + c.mlirOperationStateAddOwnedRegions(&self._inner, 1, &[_]c.MlirRegion{region._inner}); } pub fn addRegions(self: *Self, regions: []const Region) void { - c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr)); + c.mlirOperationStateAddOwnedRegions(&self._inner, @intCast(regions.len), @ptrCast(regions.ptr)); } pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void { - c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{ + c.mlirOperationStateAddAttributes(&self._inner, 1, @ptrCast(&.{ .{ - .name = Identifier.get(ctx, name).inner(), - .attribute = attr.inner(), + .name = Identifier.get(ctx, name)._inner, + .attribute = attr._inner, }, })); } pub fn addAttributeRaw(self: *Self, name: Identifier, attr: Attribute) void { - c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{ + c.mlirOperationStateAddAttributes(&self._inner, 1, @ptrCast(&.{ .{ - .name = name.inner(), - .attribute = attr.inner(), + .name = name._inner, + .attribute = attr._inner, }, })); } pub fn addAttributes(self: *Self, attributes: []const NamedAttribute) void { - c.mlirOperationStateAddAttributes(self.innerPtr(), @intCast(attributes.len), @ptrCast(attributes.ptr)); + c.mlirOperationStateAddAttributes(&self._inner, @intCast(attributes.len), @ptrCast(attributes.ptr)); } pub fn resultTypeInference(self: *Self, enabled: bool) void { - self.innerPtr().enableResultTypeInference = enabled; + self._inner.enableResultTypeInference = enabled; } }; pub const DictionaryAttribute = struct { _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(DictionaryAttribute, .{ - .is_a_fn = c.mlirAttributeIsADictionary, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); + pub const is_a_fn = c.mlirAttributeIsADictionary; + pub const asAttr = Attribute.fromAny(DictionaryAttribute); + pub const eql = Attribute.eqlAny(DictionaryAttribute); pub fn init(ctx: Context, attributes: []const NamedAttribute) DictionaryAttribute { - return DictionaryAttribute.wrap(c.mlirDictionaryAttrGet( - ctx.inner(), + return .{ ._inner = c.mlirDictionaryAttrGet( + ctx._inner, @intCast(attributes.len), @ptrCast(attributes.ptr), - )); + ) }; } pub fn size(self: DictionaryAttribute) usize { - return @intCast(c.mlirDictionaryAttrGetNumElements(self.inner())); + return @intCast(c.mlirDictionaryAttrGetNumElements(self._inner)); } pub fn get(self: DictionaryAttribute, pos: usize) NamedAttribute { - return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos))); + return .wrap(c.mlirDictionaryAttrGetElement(self._inner, @bitCast(pos))); } - pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?NamedAttribute { - return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name)); - } - - pub fn asAttr(self: DictionaryAttribute) Attribute { - return .{ ._inner = self._inner }; + pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute { + return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, name)); } }; @@ -892,20 +723,14 @@ pub const Operation = struct { const Self = Operation; _inner: c.MlirOperation, - pub usingnamespace MlirHelpers( - Operation, - .{ - .is_null_fn = c.mlirOperationIsNull, - .deinit_fn = c.mlirOperationDestroy, - .dump_fn = c.mlirOperationDump, - .equal_fn = c.mlirOperationEqual, - }, - ); + pub const dump = helpers.dump(Operation, c.mlirOperationDestroy); + pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy); + pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull); + + pub const eql = Attribute.eqlAny(Self); pub fn init(state: *OperationState) !Self { - return Self.wrapOr( - c.mlirOperationCreate(state.innerPtr()), - ) orelse Error.InvalidMlir; + return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir; } pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { @@ -992,66 +817,66 @@ pub const Operation = struct { pub fn initParse(ctx: Context, str: [:0]const u8) !Self { return Self.wrapOr( - c.mlirOperationCreateParse(ctx.inner(), stringRef(str), stringRef("pouet")), + c.mlirOperationCreateParse(ctx._inner, stringRef(str), stringRef("pouet")), ) orelse Error.InvalidMlir; } pub fn clone(self: Self) !Self { return Self.wrapOr( - c.mlirOperationClone(self.inner()), + c.mlirOperationClone(self._inner), ) orelse Error.InvalidMlir; } pub fn name(self: Self) Identifier { - return Identifier.wrap(c.mlirOperationGetName(self.inner())); + return .{ ._inner = c.mlirOperationGetName(self._inner) }; } pub fn removeFromParent(self: *Self) void { - c.mlirOperationRemoveFromParent(self.inner()); + c.mlirOperationRemoveFromParent(self._inner); } pub fn numOperands(self: Self) usize { - return @intCast(c.mlirOperationGetNumOperands(self.inner())); + return @intCast(c.mlirOperationGetNumOperands(self._inner)); } pub fn operand(self: Self, index: usize) Value { - return Value.wrap(c.mlirOperationGetOperand(self.inner(), @intCast(index))); + return .{ ._inner = c.mlirOperationGetOperand(self._inner, @intCast(index)) }; } pub fn setOperand(self: *Self, index: usize, value: Value) void { - c.mlirOperationSetOperand(self.inner(), @intCast(index), value.inner()); + c.mlirOperationSetOperand(self._inner, @intCast(index), value._inner); } pub fn numResults(self: Self) usize { - return @intCast(c.mlirOperationGetNumResults(self.inner())); + return @intCast(c.mlirOperationGetNumResults(self._inner)); } pub fn result(self: Self, index: usize) Value { - return Value.wrap(c.mlirOperationGetResult(self.inner(), @intCast(index))); + return .{ ._inner = c.mlirOperationGetResult(self._inner, @intCast(index)) }; } pub fn nextInBlock(self: Self) Self { - return Self.wrap(c.mlirOperationGetNextInBlock(self.inner())); + return .{ ._inner = c.mlirOperationGetNextInBlock(self._inner) }; } // pub fn previousInBlock(self: Self) Self { - // return Self.wrap(c.mlirOperationGetPrevInBlock(self.inner())); + // return .{ ._inner = c.mlirOperationGetPrevInBlock(self._inner) }; // } pub fn block(self: Self) ?Block { - return Block.wrapOr(c.mlirOperationGetBlock(self.inner())); + return Block.wrapOr(c.mlirOperationGetBlock(self._inner)); } pub fn parent(self: Self) ?Self { - return Self.wrapOr(c.mlirOperationGetParentOperation(self.inner())); + return Self.wrapOr(c.mlirOperationGetParentOperation(self._inner)); } pub fn region(self: Self, index: usize) Region { - return Region.wrap(c.mlirOperationGetRegion(self.inner(), @intCast(index))); + return .{ ._inner = c.mlirOperationGetRegion(self._inner, @intCast(index)) }; } pub fn context(self: Self) Context { - return Context.wrap(c.mlirOperationGetContext(self.inner())); + return .{ ._inner = c.mlirOperationGetContext(self._inner) }; } pub fn writeBytecode(self: Self, writer: anytype) void { @@ -1059,7 +884,7 @@ pub const Operation = struct { const WriterContext = @TypeOf(writer_context); c.mlirOperationWriteBytecode( - self.inner(), + self._inner, (struct { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); @@ -1086,7 +911,7 @@ pub const Operation = struct { var writer_context: WriterContext = .{ .writer = writer }; try successOr(c.mlirOperationWriteBytecodeWithConfig( - self.inner(), + self._inner, cfg, (struct { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { @@ -1126,7 +951,7 @@ pub const Operation = struct { var writer_context = .{ .writer = writer }; const WriterContext = @TypeOf(writer_context); c.mlirOperationPrintWithFlags( - self.inner(), + self._inner, pflags, (struct { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { @@ -1139,11 +964,11 @@ pub const Operation = struct { } pub fn verify(self: Self) bool { - return c.mlirOperationVerify(self.inner()); + return c.mlirOperationVerify(self._inner); } pub fn getLocation(self: Self) Location { - return Location.wrap(c.mlirOperationGetLocation(self.inner())); + return .{ ._inner = c.mlirOperationGetLocation(self._inner) }; } pub const WalkOrder = enum(c.MlirWalkOrder) { @@ -1162,11 +987,11 @@ pub const Operation = struct { const ContextType = @TypeOf(inner_ctx); c.mlirOperationWalk( - self.inner(), + self._inner, (struct { pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.C) c.MlirWalkResult { const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_)); - return @intFromEnum(walkfn(inner_ctx_.ctx, Operation.wrap(op))); + return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op })); } }).callback, &inner_ctx, @@ -1175,19 +1000,19 @@ pub const Operation = struct { } pub fn getAttribute(self: Self, pos: usize) NamedAttribute { - return NamedAttribute.wrap(c.mlirOperationGetAttribute(self.inner(), @intCast(pos))); + return .{ ._inner = c.mlirOperationGetAttribute(self._inner, @intCast(pos)) }; } pub fn getAttributeByName(self: Self, name_: [:0]const u8) ?Attribute { - return Attribute.wrapOr(c.mlirOperationGetAttributeByName(self.inner(), stringRef(name_))); + return Attribute.wrapOr(c.mlirOperationGetAttributeByName(self._inner, stringRef(name_))); } pub fn setAttributeByName(self: Self, name_: [:0]const u8, attr: Attribute) void { - c.mlirOperationSetAttributeByName(self.inner(), stringRef(name_), attr.inner()); + c.mlirOperationSetAttributeByName(self._inner, stringRef(name_), attr._inner); } pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool { - return c.mlirOperationRemoveAttributeByName(self.inner(), stringRef(name_)); + return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_)); } pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void { @@ -1240,100 +1065,91 @@ pub const OpPrintingFlags = struct { pub const OpOperand = struct { _inner: c.MlirOpOperand, - pub usingnamespace MlirHelpers(OpOperand, .{ - .is_null_fn = c.mlirOpOperandIsNull, - }); - const Self = OpOperand; pub fn owner(self: Self) Operation { - return Operation.wrap(c.mlirOpOperandGetOwner(self.inner())); + return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) }; } pub fn number(self: Self) usize { - return @intCast(c.mlirOpOperandGetOperandNumber(self.inner())); + return @intCast(c.mlirOpOperandGetOperandNumber(self._inner)); } pub fn nextUse(self: Self) ?Self { return Self.wrapOr( - c.mlirOpOperandGetNextUse(self.inner()), + c.mlirOpOperandGetNextUse(self._inner), ); } }; pub const Region = struct { _inner: c.MlirRegion, - pub usingnamespace MlirHelpers(Region, .{ - .is_null_fn = c.mlirRegionIsNull, - .deinit_fn = c.mlirRegionDestroy, - .equal_fn = c.mlirRegionEqual, - }); + + pub const eql = helpers.eql(Region, c.mlirBlockEqual); + pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy); + pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull); const Self = Region; pub fn init() !Self { - return Self.wrapOr( - c.mlirRegionCreate(), - ) orelse Error.InvalidMlir; + return Self.wrapOr(c.mlirRegionCreate()) orelse Error.InvalidMlir; } pub fn appendBlock(self: *Self, block: Block) void { - c.mlirRegionAppendOwnedBlock(self.inner(), block.inner()); + c.mlirRegionAppendOwnedBlock(self._inner, block._inner); } pub fn insertBlock(self: *Self, index: isize, block: Block) void { - c.mlirRegionInsertOwnedBlock(self.inner(), index, block.inner()); + c.mlirRegionInsertOwnedBlock(self._inner, index, block._inner); } pub fn insertBlockBefore(self: *Self, reference: Block, block: Block) void { - c.mlirRegionInsertOwnedBlockBefore(self.inner(), reference.inner(), block.inner()); + c.mlirRegionInsertOwnedBlockBefore(self._inner, reference._inner, block._inner); } pub fn insertBlockAfter(self: *Self, reference: Block, block: Block) void { - c.mlirRegionInsertOwnedBlockAfter(self.inner(), reference.inner(), block.inner()); + c.mlirRegionInsertOwnedBlockAfter(self._inner, reference._inner, block._inner); } pub fn firstBlock(self: Self) Block { - return Block.wrap(c.mlirRegionGetFirstBlock(self.inner())); + return .{ ._inner = c.mlirRegionGetFirstBlock(self._inner) }; } }; pub const Value = struct { _inner: c.MlirValue, - pub usingnamespace MlirHelpers(Value, .{ - .is_null_fn = c.mlirValueIsNull, - .equal_fn = c.mlirValueEqual, - .dump_fn = c.mlirValueDump, - .print_fn = c.mlirValuePrint, - }); + pub const dump = helpers.dump(Value, c.mlirValueDump); + pub const eql = helpers.eql(Value, c.mlirValueEqual); + pub const format = helpers.format(Value, c.mlirValuePrint).format; + pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull); pub fn getType(val: Value) Type { - return Type.wrap(c.mlirValueGetType(val.inner())); + return .{ ._inner = c.mlirValueGetType(val._inner) }; } pub fn setType(val: *Value, typ: Type) void { - c.mlirValueSetType(val.inner(), typ.inner()); + c.mlirValueSetType(val._inner, typ._inner); } pub fn firstUse(val: Value) OpOperand { - return OpOperand.wrap(c.mlirValueGetFirstUse(val.inner())); + return .{ ._inner = c.mlirValueGetFirstUse(val._inner) }; } pub fn replaceAllUsesWith(val: Value, with: Value) void { - c.mlirValueReplaceAllUsesOfWith(val.inner(), with.inner()); + c.mlirValueReplaceAllUsesOfWith(val._inner, with._inner); } pub fn owner(val: Value) Operation { - return Operation.wrap(c.mlirOpResultGetOwner(val.inner())); + return .{ ._inner = c.mlirOpResultGetOwner(val._inner) }; } pub fn isABlockArgument(val: Value) bool { - return c.mlirValueIsABlockArgument(val.inner()); + return c.mlirValueIsABlockArgument(val._inner); } pub fn isAOpResult(val: Value) bool { - return c.mlirValueIsAOpResult(val.inner()); + return c.mlirValueIsAOpResult(val._inner); } pub const Kind = union(enum) { @@ -1360,7 +1176,7 @@ pub const BlockArgument = struct { _inner: c.MlirValue, pub fn block(arg: BlockArgument) Block { - return Block.wrap(c.mlirBlockArgumentGetOwner(arg._inner)); + return .{ ._inner = c.mlirBlockArgumentGetOwner(arg._inner) }; } pub fn number(arg: BlockArgument) usize { @@ -1376,67 +1192,98 @@ pub const BlockArgument = struct { pub const Type = struct { _inner: c.MlirType, - pub usingnamespace MlirHelpers(Type, .{ - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - .print_fn = c.mlirTypePrint, - }); + pub const dump = helpers.eql(Type, c.mlirTypeDump); + pub const eql = helpers.eql(Type, c.mlirTypeEqual); + pub const format = helpers.format(Type, c.mlirTypePrint); + pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull); pub fn parse(ctx: Context, str: [:0]const u8) !Type { return Type.wrapOr( - c.mlirTypeParseGet(ctx.inner(), stringRef(str)), + c.mlirTypeParseGet(ctx._inner, stringRef(str)), ) orelse Error.InvalidMlir; } + pub fn as(generic: Type, SpecificType: type) ?SpecificType { + if (@hasDecl(SpecificType, "is_a_fn")) { + return if (SpecificType.is_a_fn(generic._inner)) + .{ ._inner = generic._inner } + else + null; + } + @compileError("Mlir subclass of type need `is_a_fn` attribute: " ++ @typeName(SpecificType)); + } + + pub fn fromAny(SpecificType: type) fn (x: SpecificType) Type { + stdx.debug.assertComptime(@hasDecl(SpecificType, "asType"), "Type.fromAny expects a type subclass, got: {}. Missing `asAttr` declaration.", .{SpecificType}); + return struct { + fn cast(x: SpecificType) Type { + return .{ ._inner = x._inner }; + } + }.cast; + } + + pub fn eqlAny(SpecificType: type) fn (SpecificType, SpecificType) bool { + return struct { + fn eql(a: SpecificType, b: SpecificType) bool { + return a.asType().eql(b.asType()); + } + }.eql; + } + + pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type { + return struct { + pub fn format(self: SpecificType, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + return try Type.format(self.asType(), fmt, options, writer); + } + }; + } + pub fn index(ctx: Context) Type { - return IndexType.init(ctx).as(Type); + return IndexType.init(ctx).asType(); } pub fn int(ctx: Context, int_type: IntegerTypes) Type { return switch (int_type) { .unknown => @panic("Unknown integer type"), - inline else => |t| IntegerType(t).init(ctx).as(Type), + inline else => |t| IntegerType(t).init(ctx).asType(), }; } pub fn float(ctx: Context, float_type: FloatTypes) Type { return switch (float_type) { - inline else => |t| FloatType(t).init(ctx).as(Type), + inline else => |t| FloatType(t).init(ctx).asType(), }; } pub fn complex(ctx: Context, complex_type: ComplexTypes) Type { return switch (complex_type) { - inline else => |t| ComplexType(t).init(ctx).as(Type), + .unknown => @panic("Unknown complex type can't be created like this"), // What's the point ? + inline else => |t| ComplexType(t).init(ctx).asType(), }; } pub fn tuple(ctx: Context, types: []const Type) Type { - return (TupleType.init(ctx, types) catch unreachable).as(Type); + return (TupleType.init(ctx, types) catch unreachable).asType(); } pub fn function(ctx: Context, args: []const Type, results: []const Type) Type { - return (FunctionType.init(ctx, args, results) catch unreachable).as(Type); + return (FunctionType.init(ctx, args, results) catch unreachable).asType(); } pub fn tensor(dimensions: []const i64, elem_type: Type) Type { - return RankedTensorType.init(dimensions, elem_type).as(Type); + return RankedTensorType.init(dimensions, elem_type).asType(); } }; pub const IndexType = struct { _inner: c.MlirType, - pub usingnamespace MlirHelpers(IndexType, .{ - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - .print_fn = c.mlirTypePrint, - }); + pub const asType = Type.fromAny(IndexType); + pub const eql = Type.eqlAny(IndexType); + pub const format = Type.formatAny(IndexType).format; pub fn init(ctx: Context) IndexType { - return IndexType.wrap(c.mlirIndexTypeGet(ctx.inner())); + return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) }; } }; @@ -1486,30 +1333,29 @@ pub fn IntegerType(comptime it: IntegerTypes) type { _inner: c.MlirType, const Int = @This(); - pub usingnamespace MlirHelpers(Int, .{ - .is_a_fn = switch (it) { - .unknown => c.mlirTypeIsAInteger, - else => typeIsAIntegerExact, - }, - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); - const IntegerTypeType = it; + pub const is_a_fn = switch (it) { + .unknown => c.mlirTypeIsAInteger, + else => typeIsAIntegerExact, + }; + + pub const asType = Type.fromAny(Int); + pub const eql = Type.eqlAny(Int); + pub const format = helpers.format(Int, c.mlirTypePrint); fn typeIsAIntegerExact(typ: c.MlirType) callconv(.C) bool { const bit_width = Config[0]; const is_sign = Config[2]; return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ); } - pub usingnamespace if (it != .unknown) struct { - pub const BitWidth = Config[0]; + pub const BitWidth = Config[0]; + + pub const init = if (it != .unknown) struct { pub fn init(ctx: Context) Int { const type_get = Config[1]; - return Int.wrap(type_get(ctx.inner(), BitWidth)); + return .{ ._inner = type_get(ctx._inner, BitWidth) }; } - } else struct {}; + }.init else {}; }; } @@ -1526,7 +1372,7 @@ pub const FloatTypes = enum { pub fn asType(self: FloatTypes, ctx: Context) Type { return switch (self) { - inline else => |ft| FloatType(ft).init(ctx).as(Type), + inline else => |ft| FloatType(ft).init(ctx).asType(), }; } }; @@ -1549,16 +1395,15 @@ pub fn FloatType(comptime ft: FloatTypes) type { const Self = @This(); - pub usingnamespace MlirHelpers(Self, .{ - .is_a_fn = Config[0], - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); + pub const is_a_fn = Config[0]; + + pub const asType = Type.fromAny(Self); + pub const eql = Type.eqlAny(Self); + pub const format = helpers.format(Self, c.mlirTypePrint); pub fn init(ctx: Context) Self { const type_get = Config[1]; - return Self.wrap(type_get(ctx.inner())); + return .{ ._inner = type_get(ctx._inner) }; } }; } @@ -1603,219 +1448,173 @@ pub fn ComplexType(comptime ct: ComplexTypes) type { return c.mlirTypeIsAComplex(typ); } - pub usingnamespace MlirHelpers(@This(), .{ - .is_a_fn = Config[0], - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); + pub const is_a_fn = Config[0]; - pub usingnamespace if (ct != .unknown) struct { - pub const ComplexTypeType = ct; + pub const asType = Type.fromAny(Complex); + pub const eql = Type.eqlAny(Complex); + pub const format = Type.formatAny(Complex).format; + pub const ComplexTypeType: ComplexTypes = ct; + pub const init = if (ct != .unknown) struct { pub fn init(ctx: Context) Complex { const type_get = Config[1]; - return Complex.wrap(type_get(ctx.inner())); + return .{ ._inner = type_get(ctx._inner) }; } - } else struct {}; + }.init else {}; }; } pub const TupleType = struct { _inner: c.MlirType, - pub usingnamespace MlirHelpers(TupleType, .{ - .is_a_fn = c.mlirTypeIsATuple, - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); + pub const is_a_fn = c.mlirTypeIsATuple; const Self = TupleType; pub fn init(ctx: Context, elements: []const Type) !Self { return Self.wrapOr(c.mlirTupleTypeGet( - ctx.inner(), + ctx._inner, @intCast(elements.len), @ptrCast(elements.ptr), )) orelse Error.InvalidMlir; } pub fn getNumTypes(self: Self) usize { - return @intCast(c.mlirTupleTypeGetNumTypes(self.inner())); + return @intCast(c.mlirTupleTypeGetNumTypes(self._inner)); } pub fn getElementType(self: Self, index: usize) Type { - return Type.wrap(c.mlirTupleTypeGetType(self.inner(), @intCast(index))); + return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) }; } + + pub const asType = Type.fromAny(Self); }; pub const FunctionType = struct { _inner: c.MlirType, - pub usingnamespace MlirHelpers(FunctionType, .{ - .is_a_fn = c.mlirTypeIsAFunction, - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); - + pub const is_a_fn = c.mlirTypeIsAFunction; const Self = FunctionType; + pub const asType = Type.fromAny(Self); + pub const eql = Type.eqlAny(IndexType); pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self { - return Self.wrapOr(c.mlirFunctionTypeGet( - ctx.inner(), + const func = Type.wrapOr(c.mlirFunctionTypeGet( + ctx._inner, @intCast(args.len), @ptrCast(args.ptr), @intCast(results.len), @ptrCast(results.ptr), - )) orelse Error.InvalidMlir; + )) orelse return Error.InvalidMlir; + return func.as(Self).?; } }; pub const RankedTensorType = struct { _inner: c.MlirType, - pub usingnamespace MlirHelpers(RankedTensorType, .{ - .is_a_fn = c.mlirTypeIsARankedTensor, - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - .print_fn = c.mlirTypePrint, - }); + pub const is_a_fn = c.mlirTypeIsARankedTensor; + pub const eql = Type.eqlAny(RankedTensorType); + pub const format = helpers.format(Type, c.mlirTypePrint); pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType { - return RankedTensorType.wrap( - c.mlirRankedTensorTypeGet( - @intCast(dimensions.len), - @ptrCast(dimensions.ptr), - elemType.inner(), - c.mlirAttributeGetNull(), - ), - ); + return .{ ._inner = c.mlirRankedTensorTypeGet( + @intCast(dimensions.len), + @ptrCast(dimensions.ptr), + elemType._inner, + c.mlirAttributeGetNull(), + ) }; } pub fn getElementType(self: RankedTensorType) Type { - return Type.wrap(c.mlirShapedTypeGetElementType(self.inner())); + return .{ ._inner = c.mlirShapedTypeGetElementType(self._inner) }; } pub fn getRank(self: RankedTensorType) usize { - return @intCast(c.mlirShapedTypeGetRank(self.inner())); + return @intCast(c.mlirShapedTypeGetRank(self._inner)); } pub fn getDimension(self: RankedTensorType, dim: usize) i64 { - return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)); + return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim)); } + + pub const asType = Type.fromAny(RankedTensorType); }; pub const Dialect = struct { _inner: c.MlirDialect, - pub usingnamespace MlirHelpers(Dialect, .{ - .equal_fn = c.mlirDialectEqual, - .is_null_fn = c.mlirDialectIsNull, - }); const Self = Dialect; pub fn getContext(self: Self) Context { - return Context.wrap(c.mlirDialectGetContext(self.inner())); + return .{ ._inner = c.mlirDialectGetContext(self._inner) }; } pub fn getNamespace(self: Self) []const u8 { - return fromStringRef(c.mlirDialectGetNamespace(self.inner())); + return fromStringRef(c.mlirDialectGetNamespace(self._inner)); } }; pub const DialectHandle = struct { _inner: c.MlirDialectHandle, - pub usingnamespace MlirHelpers( - DialectHandle, - .{}, - ); pub fn getNamespace(self: DialectHandle) []const u8 { - return fromStringRef(c.mlirDialectHandleGetNamespace(self.inner())); + return fromStringRef(c.mlirDialectHandleGetNamespace(self._inner)); } pub fn insertDialect(self: DialectHandle, registry: Registry) void { - c.mlirDialectHandleInsertDialect(self.inner(), registry.inner()); + c.mlirDialectHandleInsertDialect(self._inner, registry._inner); } pub fn registerDialect(self: DialectHandle, ctx: Context) void { - c.mlirDialectHandleRegisterDialect(self.inner(), ctx.inner()); + c.mlirDialectHandleRegisterDialect(self._inner, ctx._inner); } pub fn loadDialect(self: DialectHandle, ctx: Context) Dialect { - return Dialect.wrap(c.mlirDialectHandleLoadDialect(self.inner(), ctx.inner())); + return .{ ._inner = c.mlirDialectHandleLoadDialect(self._inner, ctx._inner) }; } pub fn fromString(comptime namespace: []const u8) DialectHandle { - return DialectHandle.wrap(@field(c, "mlirGetDialectHandle__" ++ namespace ++ "__")()); - } -}; - -pub const ShapedType = struct { - _inner: c.MlirType, - pub usingnamespace MlirHelpers(ShapedType, .{ - .is_a_fn = c.mlirTypeIsAShaped, - .is_null_fn = c.mlirTypeIsNull, - .dump_fn = c.mlirTypeDump, - .equal_fn = c.mlirTypeEqual, - }); - const Self = ShapedType; - - pub fn rank(self: Self) usize { - return @intCast(c.mlirShapedTypeGetRank(self.inner())); - } - - pub fn elementType(self: Self) Type { - return Type.wrap(c.mlirShapedTypeGetElementType(self.inner())); - } - - pub fn dimension(self: Self, dim: usize) usize { - return @intCast(c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim))); + return .{ ._inner = @field(c, "mlirGetDialectHandle__" ++ namespace ++ "__")() }; } }; pub const Location = struct { _inner: c.MlirLocation, - pub usingnamespace MlirHelpers(Location, .{ - .is_null_fn = c.mlirLocationIsNull, - .equal_fn = c.mlirLocationEqual, - .print_fn = c.mlirLocationPrint, - }); + pub const eql = helpers.eql(Type, c.mlirLocationEqual); + pub const format = helpers.format(Location, c.mlirLocationPrint); pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location { - return Location.wrap(c.mlirLocationFileLineColGet( - ctx.inner(), + return .{ ._inner = c.mlirLocationFileLineColGet( + ctx._inner, stringRef(src.file), @intCast(src.line), @intCast(src.column), - )); + ) }; } pub fn fileLineCol(ctx: Context, file: []const u8, line: usize, column: usize) Location { - return Location.wrap(c.mlirLocationFileLineColGet( - ctx.inner(), + return .{ ._inner = c.mlirLocationFileLineColGet( + ctx._inner, stringRef(file), @intCast(line), @intCast(column), - )); + ) }; } pub fn callSite(callee: Location, caller: Location) Location { - return Location.wrap(c.mlirLocationCallSiteGet(callee.inner(), caller.inner())); + return .{ ._inner = c.mlirLocationCallSiteGet(callee._inner, caller._inner) }; } pub fn fused(ctx: Context, locations: []const Location, metadata: Attribute) Location { - return Location.wrap(c.mlirLocationFusedGet( - ctx.inner(), + return .{ ._inner = c.mlirLocationFusedGet( + ctx._inner, @intCast(locations.len), @ptrCast(locations.ptr), - metadata.inner(), - )); + metadata._inner, + ) }; } pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location { - return Location.wrap(c.mlirLocationNameGet(ctx.inner(), stringRef(loc_name), loc.inner())); + return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) }; } pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location { @@ -1828,17 +1627,17 @@ pub const Location = struct { } pub fn unknown(ctx: Context) Location { - return Location.wrap(c.mlirLocationUnknownGet(ctx.inner())); + return .{ ._inner = c.mlirLocationUnknownGet(ctx._inner) }; } }; pub const Block = struct { _inner: c.MlirBlock, - pub usingnamespace MlirHelpers(Block, .{ - .is_null_fn = c.mlirBlockIsNull, - .deinit_fn = c.mlirBlockDestroy, - .equal_fn = c.mlirBlockEqual, - }); + + pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull); + pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy); + + pub const eql = helpers.eql(Block, c.mlirBlockEqual); pub fn init(args: []const Type, locs: []const Location) !Block { const block = Block.wrapOr( @@ -1848,34 +1647,32 @@ pub const Block = struct { } pub fn argument(self: Block, index: usize) Value { - const arg = c.mlirBlockGetArgument(self.inner(), @intCast(index)); - stdx.debug.assert(!Value.Methods.is_null_fn.?(arg), "Block doesn't have argument {}, only got {}", .{ index, self.numArguments() }); - return Value.wrap(arg); + return .{ ._inner = c.mlirBlockGetArgument(self._inner, @intCast(index)) }; } pub fn numArguments(self: Block) usize { - return @intCast(c.mlirBlockGetNumArguments(self.inner())); + return @intCast(c.mlirBlockGetNumArguments(self._inner)); } pub fn addArgument(self: *Block, typ: Type, loc: Location) Value { - return Value.wrap(c.mlirBlockAddArgument(self.inner(), typ.inner(), loc.inner())); + return .{ ._inner = c.mlirBlockAddArgument(self._inner, typ._inner, loc._inner) }; } pub fn insertArgument(self: *Block, index: usize, typ: Type, loc: Location) Value { - return Value.wrap(c.mlirBlockInsertArgument(self.inner(), @intCast(index), typ.inner(), loc.inner())); + return .{ ._inner = c.mlirBlockInsertArgument(self._inner, @intCast(index), typ._inner, loc._inner) }; } pub fn equals(self: Block, other: Block) bool { - return c.mlirBlockEqual(self.inner(), other.inner()); + return c.mlirBlockEqual(self._inner, other._inner); } pub fn appendOperation(self: Block, op: Operation) void { - c.mlirBlockAppendOwnedOperation(self.inner(), op.inner()); + c.mlirBlockAppendOwnedOperation(self._inner, op._inner); } pub fn appendOperations(self: *Block, ops: []const Operation) void { for (ops) |op| { - c.mlirBlockAppendOwnedOperation(self.inner(), op.inner()); + c.mlirBlockAppendOwnedOperation(self._inner, op._inner); } } @@ -1904,3 +1701,83 @@ pub const Block = struct { self.appendOperation(op); } }; + +pub const helpers = struct { + pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.C) bool) fn (T, T) bool { + return struct { + fn eql(a: T, b: T) bool { + return equal_fn(a._inner, b._inner); + } + }.eql; + } + + pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (*T) void { + return struct { + fn deinit(a: *T) void { + deinit_fn(a._inner); + a.* = undefined; + } + }.deinit; + } + + pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (T) void { + return struct { + fn dump(a: T) void { + return dump_fn(a._inner); + } + }.dump; + } + + pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (T) bool { + return struct { + fn isNull(a: T) bool { + return is_null_fn(a._inner); + } + }.isNull; + } + + pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void) type { + return struct { + pub fn format( + self: Any, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + + const Writer = struct { + writer: @TypeOf(writer), + err: ?@TypeOf(writer).Error = null, + fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void { + var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); + if (ctx.err) |_| return; + _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { + ctx.err = err; + return; + }; + } + }; + + var context: Writer = .{ .writer = writer }; + print_fn(self._inner, &Writer.printCallback, &context); + if (context.err) |err| return err; + } + }; + } + + pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (@FieldType(T, "_inner")) ?T { + return struct { + fn wrapOr(inner: @FieldType(T, "_inner")) ?T { + if (is_null_fn(inner)) return null; + return .{ ._inner = inner }; + } + }.wrapOr; + } + + pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) ?T { + if (is_null_fn(inner)) return null; + return .{ ._inner = inner }; + } +}; diff --git a/zml/mlir.zig b/zml/mlir.zig deleted file mode 100644 index 0709189..0000000 --- a/zml/mlir.zig +++ /dev/null @@ -1,176 +0,0 @@ -const mlir = @This(); - -const builtin = @import("builtin"); -const std = @import("std"); -const stdx = @import("stdx"); - -const dtype = @import("dtype.zig"); - -const Shape = @import("shape.zig").Shape; -const Tensor = @import("tensor.zig").Tensor; - -const log = std.log.scoped(.@"zml/mlir"); - -pub usingnamespace @import("mlir"); - -pub const ext = struct { - pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type { - return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type); - } - - pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes { - return switch (dt) { - .bool => .bool, - .i8 => .i8, - .i16 => .i16, - .i32 => .i32, - .i64 => .i64, - .u8 => .u8, - .u16 => .u16, - .u32 => .u32, - .u64 => .u64, - .bf16 => .bf16, - .f16 => .f16, - .f32 => .f32, - .f64 => .f64, - else => null, - }; - } - - pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute { - const ranked_type_ = ranked_type.as(mlir.Type); - return switch (dt) { - .bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute), - .i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute), - .i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute), - .i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute), - .i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute), - .u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute), - .u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute), - .u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute), - .u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute), - .bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute), - .f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute), - .f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute), - .f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute), - inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)), - }; - } - - pub const RankedTensorType = struct { - pub fn fromShape(ctx: mlir.Context, sh: Shape) mlir.RankedTensorType { - return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())); - } - }; - - pub const Type = struct { - pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { - return switch (dt) { - .bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type), - .f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type), - .f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type), - .f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type), - .f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type), - .f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type), - .bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type), - .f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type), - .f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type), - .f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type), - .i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type), - .i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type), - .i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type), - .i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type), - .i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type), - .u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type), - .u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type), - .u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type), - .u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type), - .u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type), - .c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type), - .c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type), - }; - } - - pub fn toDType(mlir_type: mlir.Type) dtype.DataType { - const mapping = .{ - .{ .bool, mlir.IntegerType(.i1) }, - - .{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) }, - .{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) }, - .{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) }, - .{ .f8e5m2, mlir.FloatType(.f8e5m2) }, - .{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) }, - .{ .bf16, mlir.FloatType(.bf16) }, - .{ .f16, mlir.FloatType(.f16) }, - .{ .f32, mlir.FloatType(.f32) }, - .{ .f64, mlir.FloatType(.f64) }, - - .{ .i4, mlir.IntegerType(.i4) }, - .{ .i8, mlir.IntegerType(.i8) }, - .{ .i16, mlir.IntegerType(.i16) }, - .{ .i32, mlir.IntegerType(.i32) }, - .{ .i64, mlir.IntegerType(.i64) }, - - .{ .u4, mlir.IntegerType(.u4) }, - .{ .u8, mlir.IntegerType(.u8) }, - .{ .u16, mlir.IntegerType(.u16) }, - .{ .u32, mlir.IntegerType(.u32) }, - .{ .u64, mlir.IntegerType(.u64) }, - - .{ .c64, mlir.ComplexType(.c64) }, - .{ .c128, mlir.ComplexType(.c128) }, - }; - - inline for (mapping) |entry| { - const dt, const mlirT = entry; - if (mlir_type.is_a(mlirT)) { - return dt; - } - } - - stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); - } - }; - - pub const Attribute = struct { - pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute { - switch (data) { - .bool => |val| { - return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute); - }, - inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| { - const float_type = @field(mlir.FloatTypes, @tagName(tag)); - const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32()); - return float_attr.as(mlir.Attribute); - }, - inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| { - const int_type = @field(mlir.IntegerTypes, @tagName(tag)); - const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val)); - return int_attr.as(mlir.Attribute); - }, - inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), - } - } - }; - - pub const DenseElementsAttribute = struct { - pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute { - return switch (data.dtype()) { - .bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute), - .i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute), - .i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute), - .i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute), - .i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute), - .u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute), - .u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute), - .u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute), - .u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute), - .bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute), - .f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute), - .f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute), - .f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute), - inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), - }; - } - }; -}; diff --git a/zml/mlirx.zig b/zml/mlirx.zig new file mode 100644 index 0000000..28f4038 --- /dev/null +++ b/zml/mlirx.zig @@ -0,0 +1,101 @@ +const std = @import("std"); + +const mlir = @import("mlir"); + +const dtype = @import("dtype.zig"); +const Shape = @import("shape.zig").Shape; + +const mlirx = @This(); + +/// Returns the mlir.Type corresponding to a given zml.Shape. +pub fn tensorType(ctx: mlir.Context, sh: Shape) mlir.Type { + return .tensor(sh.dims(), mlirx.Type.fromDType(ctx, sh.dtype())); +} + +pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes { + return switch (dt) { + .bool => .bool, + .i8 => .i8, + .i16 => .i16, + .i32 => .i32, + .i64 => .i64, + .u8 => .u8, + .u16 => .u16, + .u32 => .u32, + .u64 => .u64, + .bf16 => .bf16, + .f16 => .f16, + .f32 => .f32, + .f64 => .f64, + else => null, + }; +} + +pub const Type = struct { + pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { + return switch (dt) { + .bool => .int(ctx, .i1), + .f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz), + .f8e4m3fn => .float(ctx, .f8e4m3fn), + .f8e4m3fnuz => .float(ctx, .f8e4m3fnuz), + .f8e5m2 => .float(ctx, .f8e5m2), + .f8e5m2fnuz => .float(ctx, .f8e5m2fnuz), + .bf16 => .float(ctx, .bf16), + .f16 => .float(ctx, .f16), + .f32 => .float(ctx, .f32), + .f64 => .float(ctx, .f64), + .i4 => .int(ctx, .i4), + .i8 => .int(ctx, .i8), + .i16 => .int(ctx, .i16), + .i32 => .int(ctx, .i32), + .i64 => .int(ctx, .i64), + .u4 => .int(ctx, .u4), + .u8 => .int(ctx, .u8), + .u16 => .int(ctx, .u16), + .u32 => .int(ctx, .u32), + .u64 => .int(ctx, .u64), + .c64 => .complex(ctx, .c64), + .c128 => .complex(ctx, .c128), + }; + } + + pub fn toDType(mlir_type: mlir.Type) dtype.DataType { + const mapping = .{ + .{ .bool, mlir.IntegerType(.i1) }, + + .{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) }, + .{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) }, + .{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) }, + .{ .f8e5m2, mlir.FloatType(.f8e5m2) }, + .{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) }, + .{ .bf16, mlir.FloatType(.bf16) }, + .{ .f16, mlir.FloatType(.f16) }, + .{ .f32, mlir.FloatType(.f32) }, + .{ .f64, mlir.FloatType(.f64) }, + + .{ .i4, mlir.IntegerType(.i4) }, + .{ .i8, mlir.IntegerType(.i8) }, + .{ .i16, mlir.IntegerType(.i16) }, + .{ .i32, mlir.IntegerType(.i32) }, + .{ .i64, mlir.IntegerType(.i64) }, + + .{ .u4, mlir.IntegerType(.u4) }, + .{ .u8, mlir.IntegerType(.u8) }, + .{ .u16, mlir.IntegerType(.u16) }, + .{ .u32, mlir.IntegerType(.u32) }, + .{ .u64, mlir.IntegerType(.u64) }, + + .{ .c64, mlir.ComplexType(.c64) }, + .{ .c128, mlir.ComplexType(.c128) }, + }; + + inline for (mapping) |entry| { + const dt, const mlirT = entry; + if (mlirT.is_a_fn(mlir_type._inner)) { + return dt; + } + } + + std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); + } +}; diff --git a/zml/module.zig b/zml/module.zig index 9650010..c6824b4 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -2,21 +2,18 @@ const std = @import("std"); const asynk = @import("async"); const dialect = @import("mlir/dialects"); -const runfiles = @import("runfiles"); +const mlir = @import("mlir"); const stdx = @import("stdx"); const xla_pb = @import("//xla:xla_proto"); const BaseExe = @import("exe.zig").BaseExe; const Buffer = @import("buffer.zig").Buffer; -const Bufferized = @import("tensor.zig").Bufferized; const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const Location = mlir.Location; +const mlirx = @import("mlirx.zig"); const ops = @import("ops.zig"); const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; -const ShapeOf = @import("tensor.zig").ShapeOf; const Target = @import("platform.zig").Target; const Tensor = @import("tensor.zig").Tensor; const Tracer = @import("tools/tracer.zig").Tracer; @@ -170,8 +167,8 @@ pub const CompilationContext = struct { const sharding = self._platform.sharding(); const mlir_ctx = self._mlir_ctx; - module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); - module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); + module.op().setAttributeByName("mhlo.num_replicas", .int(mlir_ctx, .i32, sharding.num_replicas)); + module.op().setAttributeByName("mhlo.num_partitions", .int(mlir_ctx, .i32, sharding.num_partitions)); const module_hash = computeModuleHash(self._platform, module); var module_dir: ?[]const u8 = null; @@ -346,7 +343,7 @@ pub const CompilationContext = struct { stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); const input_types = try arena.alloc(mlir.Type, tensor_count); - for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); + for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh); const og_block_args = self._block_args; defer { @@ -947,7 +944,7 @@ pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) voi var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types }; meta.visit((struct { fn cb(inner_context: *LocalContext, tensor: *const Tensor) void { - inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape()); + inner_context.types[inner_context.index] = mlirx.tensorType(inner_context.mlir_ctx, tensor.shape()); inner_context.index += 1; } }).cb, &context, v); diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 8652f9b..d0de654 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -5,7 +5,7 @@ const dialect = @import("mlir/dialects"); const Context = @import("../context.zig").Context; const DataType = @import("../dtype.zig").DataType; const Data = @import("../dtype.zig").Data; -const mlir = @import("../mlir.zig"); +const mlirx = @import("../mlirx.zig"); const module = @import("../module.zig"); const CompilationContext = module.CompilationContext; const SdpaOpts = @import("../nn.zig").SdpaOpts; @@ -130,7 +130,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { .api_version = .original, }, &.{ - mlir.ext.mlirType(mlir_ctx, q.shape()), + mlirx.tensorType(mlir_ctx, q.shape()), .tensor(&.{0}, .int(mlir_ctx, .u8)), }, loc, diff --git a/zml/ops.zig b/zml/ops.zig index 81fb0a8..e0af64e 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -1,24 +1,16 @@ const std = @import("std"); -const assert = std.debug.assert; +const mlir = @import("mlir"); const stdx = @import("stdx"); const _collectAxes = @import("tensor.zig")._collectAxes; -const buffer = @import("buffer.zig"); -const Buffer = buffer.Buffer; -const Bufferized = @import("tensor.zig").Bufferized; +const Buffer = @import("buffer.zig").Buffer; +const CompilationContext = @import("module.zig").CompilationContext; const Context = @import("context.zig").Context; -const Data = @import("dtype.zig").Data; -const DataType = @import("dtype.zig").DataType; -const helpers = @import("helpers.zig"); -const HostBuffer = @import("hostbuffer.zig").HostBuffer; const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const module = @import("module.zig"); -const CompilationContext = module.CompilationContext; +const mlirx = @import("mlirx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; -const ShapeOf = @import("tensor.zig").ShapeOf; const Tensor = @import("tensor.zig").Tensor; const EnumLiteral = @TypeOf(.enum_literal); @@ -200,14 +192,14 @@ pub fn reduce( mlir_ctx, val, inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced], - mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type), + mlirx.tensorType(mlir_ctx, reduced_shape), inner_ctx.loc, ); tensor.* = Tensor._result(reduced_shape, broad_val.result(0)); inner_ctx.index += 1; } }).cb, &local_context, &res); - assert(local_context.index == op.numResults()); + std.debug.assert(local_context.index == op.numResults()); return res; } @@ -248,7 +240,8 @@ pub fn reduceWindow( .{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) }, .{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) }, .{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) }, - .{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) }, + // Cast the [][2]i64 to []i64 (safe) + .{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, @ptrCast(opts.padding)) }, }, .location = loc, }); @@ -609,8 +602,8 @@ pub fn sort( .result_type_inference = true, .blocks = &.{block}, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) }, - .{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) }, + .{ "dimension", .int(ctx.mlirCtx(), .i64, dimension) }, + .{ "is_stable", .boolean(ctx.mlirCtx(), is_stable) }, }, .location = loc, }); @@ -767,7 +760,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base inner_ctx.index += 1; } }).cb, &context, &res); - assert(context.index == op.numResults()); + std.debug.assert(context.index == op.numResults()); return res; } @@ -817,7 +810,7 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T var res_types: [outputs.len]mlir.Type = undefined; inline for (outputs, 0..) |output, i| { - res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); + res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output); } const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{ @@ -1031,7 +1024,7 @@ pub fn scatter( inner_ctx.index += 1; } }).cb, &local_context, &res); - assert(local_context.index == op.numResults()); + std.debug.assert(local_context.index == op.numResults()); return res; } @@ -1327,30 +1320,30 @@ pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype, } fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor { - const ctx = module.CompilationContext.current(); + const ctx = CompilationContext.current(); const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable; ctx.extractValues(inputs, values); const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable; for (outputs, 0..) |output, i| { - res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); + res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output); } const metadata_type_info = @typeInfo(@TypeOf(metadata)); var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined; inline for (metadata_type_info.@"struct".fields, 0..) |field, i| { const attribute: mlir.Attribute = switch (@typeInfo(field.type)) { - .int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))), + .int, .comptime_int => .int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))), else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)), }; metadata_attributes_tuple[i] = .{ field.name, attribute }; } - const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{ - .{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) }, - .{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) }, - } ++ metadata_attributes_tuple)); + const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(metadata_attributes_tuple ++ [_]mlir.AttrTuple{ + .{ "pjrt_api", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) }, + .{ "pjrt_client", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) }, + })); const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable; for (inputs, 0..) |input, i| { diff --git a/zml/tensor.zig b/zml/tensor.zig index 02d7386..74bf177 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1,20 +1,17 @@ const std = @import("std"); -const assert = std.debug.assert; -const testing = std.testing; const builtin = @import("builtin"); +const mlir = @import("mlir"); const stdx = @import("stdx"); const Buffer = @import("buffer.zig").Buffer; +const CompilationContext = @import("module.zig").CompilationContext; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; const HostBuffer = @import("hostbuffer.zig").HostBuffer; const Memory = @import("buffer.zig").Buffer.Memory; const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const Location = mlir.Location; -const module = @import("module.zig"); -const CompilationContext = module.CompilationContext; +const mlirx = @import("mlirx.zig"); const ops = @import("ops.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; @@ -112,12 +109,12 @@ pub const Tensor = struct { /// /// The shape is derived from the type of the mlir.Value. pub fn fromMlirValue(val: mlir.Value) Tensor { - const ranked_tensor = val.getType().as(mlir.RankedTensorType); + const ranked_tensor = val.getType().as(mlir.RankedTensorType).?; const n = ranked_tensor.getRank(); stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); - var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) }; + var sh: Shape = .{ ._dtype = mlirx.Type.toDType(ranked_tensor.getElementType()) }; for (0..n) |i| { sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i)); } @@ -322,7 +319,7 @@ pub const Tensor = struct { const op = dialect.stablehlo.bitcast_convert( self.getContext().mlirCtx(), self.value(), - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type), + mlirx.tensorType(self.getContext().mlirCtx(), res_shape), loc, ); @@ -559,8 +556,8 @@ pub const Tensor = struct { ctx.mlirCtx(), self.algorithm, self._state.value(), - mlir.ext.mlirType(ctx.mlirCtx(), self._state._shape), - mlir.ext.mlirType(ctx.mlirCtx(), sh), + mlirx.tensorType(ctx.mlirCtx(), self._state._shape), + mlirx.tensorType(ctx.mlirCtx(), sh), loc, ); return .{ self.update(op.result(0)), _result(sh, op.result(1)) }; @@ -870,7 +867,7 @@ pub const Tensor = struct { self.value(), other.value(), used_opts, - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type), + mlirx.tensorType(self.getContext().mlirCtx(), new_shape), loc, ); @@ -1052,7 +1049,7 @@ pub const Tensor = struct { const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const mlir_ctx = self.getContext().mlirCtx(); - const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to)); + const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to)); const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc); return _result(self._shape.withDtype(to), op.result(0)); } @@ -1217,7 +1214,7 @@ pub const Tensor = struct { mlir_ctx, lhs.value(), rhs.value(), - mlir.ext.mlirType(mlir_ctx, res_shape), + mlirx.tensorType(mlir_ctx, res_shape), loc, .{ .lhs_batching_dimensions = lhs_batching_axes.constSlice(), @@ -1392,7 +1389,7 @@ pub const Tensor = struct { [2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } }, ); const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x}); - try testing.expectEqual( + try std.testing.expectEqual( [2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } }, try res.getValue([2][5]f32), ); @@ -1424,7 +1421,7 @@ pub const Tensor = struct { const op = dialect.stablehlo.transpose( self.getContext().mlirCtx(), self.value(), - mlir.ext.mlirType(self.getContext().mlirCtx(), res_shape), + mlirx.tensorType(self.getContext().mlirCtx(), res_shape), loc, .{ .permutation = toI64(permutation) }, ); @@ -1457,7 +1454,7 @@ pub const Tensor = struct { const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), + mlirx.tensorType(self.getContext().mlirCtx(), new_shape), loc, ); return _result(new_shape, reshaped_val.result(0)); @@ -1474,7 +1471,7 @@ pub const Tensor = struct { const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), + mlirx.tensorType(self.getContext().mlirCtx(), new_shape), loc, ); return _result(new_shape, reshaped_val.result(0)); @@ -1512,7 +1509,7 @@ pub const Tensor = struct { const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), + mlirx.tensorType(self.getContext().mlirCtx(), new_shape), loc, ); // log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] }); @@ -1586,7 +1583,7 @@ pub const Tensor = struct { const mlir_ctx = self.getContext().mlirCtx(); const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices}); - const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type); + const result_type = mlirx.tensorType(mlir_ctx, res_shape); const slice_op = dialect.stablehlo.slice( mlir_ctx, self.value(), @@ -1620,15 +1617,15 @@ pub const Tensor = struct { { const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } }); - try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); + try std.testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); } { const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } }); - try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); + try std.testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); } { const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } }); - try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); + try std.testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); } } @@ -1838,7 +1835,7 @@ pub const Tensor = struct { const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const sh = Shape.init(.{n_steps}, dt); - var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc); + var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc); var res = _result(sh, op.result(0)); if (args.step != 1) { @@ -1868,7 +1865,7 @@ pub const Tensor = struct { var op = dialect.stablehlo.iota( mlir_ctx, a, - mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type), + mlirx.tensorType(mlir_ctx, res_shape), loc, ); return _result(res_shape, op.result(0)); @@ -1890,7 +1887,7 @@ pub const Tensor = struct { const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt }); const sh = Shape.init(.{args.steps}, dt); - var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc); + var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc); var res = _result(sh, iota_op.result(0)); if (args.steps != 1) { @@ -1933,21 +1930,19 @@ pub const Tensor = struct { /// Returns a constant Tensor with the given value. pub fn constant(dimz: anytype, val: Data) Tensor { const sh = Shape.init(dimz, val.dtype()); - const singleton_sh = Shape.init(.{}, val.dtype()); const ctx = CompilationContext.current().mlirCtx(); const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val }); - const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); - var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type| - dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc) + var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type| + dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc) else blk: { // Not all dtype can be serialized in the IR. If that's not possible, use f32. const val_f32 = val.as(f32); - break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc); + break :blk dialect.stablehlo.constant(ctx, &.{}, .f32, std.mem.asBytes(&val_f32), loc); }; if (sh.rank() > 0) { - constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc); + constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlirx.tensorType(ctx, sh), loc); } return _result(sh, constant_op.result(0)).convert(val.dtype()); } @@ -1955,10 +1950,9 @@ pub const Tensor = struct { /// Embeds a buffer with concrete values into an Mlir program. pub fn constantTensor(val: HostBuffer) Tensor { const ctx = CompilationContext.current().mlirCtx(); - const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape()); const loc = ctx.location(@src()); - const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); - const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc); + const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); + const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc); return _result(val.shape(), constant_op.result(0)); } @@ -1994,7 +1988,7 @@ pub const Tensor = struct { return _result(res_shape, self.value()); } const ctx = self.getContext(); - const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type); + const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape); const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); @@ -2052,7 +2046,7 @@ pub const Tensor = struct { /// Reshapes the input Tensor with the given shape. pub fn reshape(self: Tensor, output_shape_: anytype) Tensor { const output_shape = self._shape.reshape(output_shape_); - const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape); + const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape); const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape}); const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc); return _result(output_shape, reshape_value.result(0)); @@ -2846,9 +2840,9 @@ pub const Tensor = struct { const res = argmax.call(.{x}); const max_ = res.values.getValue(f32); const max_idx = res.indices.getValue(i32); - try testing.expectEqual(max_, 7.9); + try std.testing.expectEqual(max_, 7.9); // We should always return the first max found. - try testing.expectEqual(max_idx, 2); + try std.testing.expectEqual(max_idx, 2); } // Test with Nan @@ -2857,8 +2851,8 @@ pub const Tensor = struct { const res = argmax.call(.{x}); const max_ = try res.values.getValue(f32); const max_idx = try res.indices.getValue(i32); - try testing.expect(std.math.isNan(max_)); - try testing.expectEqual(max_idx, 1); + try std.testing.expect(std.math.isNan(max_)); + try std.testing.expectEqual(max_idx, 1); } } @@ -2907,7 +2901,7 @@ pub const Tensor = struct { const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 }); const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} }); const res_cpu = try res.toHostAlloc(allocator); - try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32)); + try std.testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32)); } // 3D Tensor, dim = 1, descending { @@ -2920,7 +2914,7 @@ pub const Tensor = struct { }); const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } }); const res = try res_dev.toHostAlloc(allocator); - try testing.expectEqualSlices(i32, &.{ + try std.testing.expectEqualSlices(i32, &.{ 4, 1, 1, 2, 0, 2, 0, 0, 3, 4, 2, 0, 4, 4, 1, 3, 4, 4, 1, 0, 1, 4, 2, 0, 2, 4, 2, 2, 0, 3, @@ -2942,7 +2936,7 @@ pub const Tensor = struct { }); const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} }); const res = try res_dev.toHostAlloc(allocator); - try testing.expectEqualSlices(i32, &.{ + try std.testing.expectEqualSlices(i32, &.{ 2, 1, 3, 0, 2, 3, 1, 0, 3, 2, 0, 1, @@ -3262,7 +3256,7 @@ pub const Tensor = struct { const z = try zml.Buffer.scalar(platform, 4, .i32); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } }); - try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T)); + try std.testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T)); } { @@ -3271,7 +3265,7 @@ pub const Tensor = struct { const z = try zml.Buffer.scalar(platform, 3, .i32); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } }); - try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T)); + try std.testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T)); } } @@ -3389,7 +3383,7 @@ pub const Tensor = struct { }._fwd, .{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) }, ); - try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32)); + try std.testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32)); } { @@ -3407,7 +3401,7 @@ pub const Tensor = struct { }._fwd, .{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) }, ); - try testing.expectEqualDeep( + try std.testing.expectEqualDeep( [2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } }, try res.getValue([2][5]f32), ); @@ -3427,7 +3421,7 @@ pub const Tensor = struct { }._fwd, .{ x, idx, y }, ); - try testing.expectEqualDeep( + try std.testing.expectEqualDeep( [2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } }, res.getValue([2][5]f32), ); @@ -3448,7 +3442,7 @@ pub const Tensor = struct { }._fwd, .{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) }, ); - try testing.expectEqualDeep( + try std.testing.expectEqualDeep( [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, res.getValue([2][5]f32), ); @@ -3466,7 +3460,7 @@ pub const Tensor = struct { } }; const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y }); - try testing.expectEqualDeep( + try std.testing.expectEqualDeep( [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, res.getValue([2][5]f32), ); @@ -3531,7 +3525,7 @@ pub const Tensor = struct { const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } }); { const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x}); - try testing.expectEqual( + try std.testing.expectEqual( [2][2][2]u8{ .{ .{ 1, 0 }, .{ 0, 2 }, @@ -3582,7 +3576,7 @@ pub const Tensor = struct { }); { const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 }); - try testing.expectEqual( + try std.testing.expectEqual( [3][3]u8{ .{ 1, 0, 0 }, .{ 1, 1, 0 }, @@ -3593,7 +3587,7 @@ pub const Tensor = struct { } { const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 }); - try testing.expectEqual( + try std.testing.expectEqual( [3][3]u8{ .{ 1, 1, 0 }, .{ 1, 1, 1 }, @@ -3604,7 +3598,7 @@ pub const Tensor = struct { } { const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 }); - try testing.expectEqual( + try std.testing.expectEqual( [3][3]u8{ .{ 0, 0, 0 }, .{ 1, 0, 0 }, diff --git a/zml/zml.zig b/zml/zml.zig index 18282b9..bcb86db 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -25,7 +25,7 @@ pub const nn = @import("nn.zig"); pub const module = @import("module.zig"); pub const meta = @import("meta.zig"); pub const platform = @import("platform.zig"); -pub const mlir = @import("mlir.zig"); +pub const mlir = @import("mlirx.zig"); pub const pjrt = @import("pjrtx.zig"); pub const testing = @import("testing.zig"); pub const torch = @import("torch.zig");