diff --git a/mlir/BUILD.bazel b/mlir/BUILD.bazel index 327fcc3..0bebfa7 100644 --- a/mlir/BUILD.bazel +++ b/mlir/BUILD.bazel @@ -1,16 +1,6 @@ load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:zig.bzl", "zig_cc_test") -cc_library( - name = "mlirx", - srcs = ["mlirx.cc"], - hdrs = ["mlirx.h"], - includes = ["."], - deps = [ - "@llvm-project//mlir:CAPIIR", - ], -) - cc_library( name = "c", hdrs = ["c.h"], @@ -30,7 +20,7 @@ zig_library( visibility = ["//visibility:public"], deps = [ ":c", - ":mlirx", + "//stdx", ], ) diff --git a/mlir/dialects/arith.zig b/mlir/dialects/arith.zig index 2c9e945..c005f6a 100644 --- a/mlir/dialects/arith.zig +++ b/mlir/dialects/arith.zig @@ -72,7 +72,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? }, + .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, }, .location = location, }); @@ -102,7 +102,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? }, + .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, }, .location = location, }); diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index d947fed..ec31e55 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -14,14 +14,14 @@ pub fn func( }, ) mlir.Operation { var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; - attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); - attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) }); if (args.arg_attrs.len > 0) { - attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) }); } if (args.res_attrs.len > 0) { - attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) }); } return mlir.Operation.make(ctx, "func.func", .{ @@ -36,7 +36,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r .variadic_operands = &.{values}, .results = results, .verify = true, - .attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }}, + .attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }}, .location = loc, }); } diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index d7d8673..4311de3 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -98,7 +98,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute).? }, + .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) }, }, .location = location, }); @@ -126,7 +126,7 @@ pub const DotPrecision = union(enum) { // When we specify the dot algorithm, we should not specify the precision. .algorithm => .DEFAULT, }); - return precision.as(mlir.Attribute).?; + return precision.as(mlir.Attribute); } pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute { @@ -156,14 +156,14 @@ pub const DotAlgorithm = struct { }; pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute { - const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input"); + const tensor_type = operand_type.as(mlir.RankedTensorType); const elem_type = tensor_type.getElementType(); return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( ctx.inner(), elem_type.inner(), elem_type.inner(), - self.accumulation.asType(ctx).?.inner(), + self.accumulation.asType(ctx).inner(), self.component_count, self.component_count, self.num_primitive_operations, @@ -196,7 +196,7 @@ pub fn dot_general( .rhs_batching_dimensions = opts.rhs_batching_dimensions, .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, - }).as(mlir.Attribute).?, + }).as(mlir.Attribute), }, .{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() }, // keep algorithm as the last attribute so we can omit it when it's not set. @@ -219,12 +219,12 @@ pub fn constant( location: mlir.Location, ) mlir.Operation { const attribute = switch (elem_type) { - inline else => |dt| mlir.DenseIntOrFPElementsAttribute(dt).init(result_type.as(mlir.Type).?, raw_bytes).as(mlir.Attribute).?, + inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute), }; return mlir.Operation.make(ctx, "stablehlo.constant", .{ .operands = &.{}, - .results = &.{result_type.as(mlir.Type).?}, + .results = &.{result_type.as(mlir.Type)}, .attributes = &.{.{ "value", attribute }}, .location = location, }); @@ -243,7 +243,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6 .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute).? }, + .{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) }, }, .location = location, }); @@ -254,7 +254,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l .operands = &.{value}, .results = &.{result_type}, .attributes = &.{ - .{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute).? }, + .{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) }, }, .location = location, }); @@ -265,9 +265,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64, .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute).? }, - .{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute).? }, - .{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute).? }, + .{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) }, + .{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) }, + .{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) }, }, .location = location, }); @@ -278,7 +278,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64 .operands = inputs, .result_type_inference = true, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, + .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, }, .location = location, }); @@ -287,7 +287,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64 pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.reshape", .{ .operands = &.{value}, - .results = &.{result_type.as(mlir.Type).?}, + .results = &.{result_type.as(mlir.Type)}, .location = location, }); } @@ -331,9 +331,9 @@ pub fn gather( args.start_indices_batching_dims, args.start_index_map, args.index_vector_dim, - ).as(mlir.Attribute).? }, - .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute).? }, - .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, + ).as(mlir.Attribute) }, + .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) }, + .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) }, }, .location = location, }, @@ -393,8 +393,8 @@ pub fn scatter( .blocks = &.{update_block}, .attributes = &.{ .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, - .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, - .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? }, + .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) }, + .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) }, }, .result_type_inference = true, .location = location, @@ -407,7 +407,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: .operands = &.{}, .results = &.{result_type}, .attributes = &.{ - .{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, + .{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, }, .location = location, }); @@ -419,7 +419,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? }, + .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) }, }, .location = location, }); @@ -430,8 +430,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d .operands = &.{ lhs, rhs }, .result_type_inference = true, .attributes = &.{ - .{ "comparison_direction", comparison_direction.as(mlir.Attribute).? }, - .{ "compare_type", compare_type.as(mlir.Attribute).? }, + .{ "comparison_direction", comparison_direction.as(mlir.Attribute) }, + .{ "compare_type", compare_type.as(mlir.Attribute) }, }, .location = location, }); @@ -452,7 +452,7 @@ pub fn reduce( const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; for (inputs, 0..) |input, i| { - const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; + const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type); reduce_elem_types[i] = arg_type; reduce_elem_types[inputs.len + i] = arg_type; } @@ -474,7 +474,7 @@ pub fn reduce( .result_type_inference = true, .block = block, .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? }, + .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) }, }, .location = location, }); @@ -494,7 +494,7 @@ pub fn sort( const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2]; var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined; for (inputs, 0..) |input, i| { - const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; + const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type); sort_elem_types[i * 2] = arg_type; sort_elem_types[i * 2 + 1] = arg_type; } @@ -511,8 +511,8 @@ pub fn sort( .result_type_inference = true, .block = block, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, - .{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute).? }, + .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, + .{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) }, }, .location = location, }); @@ -523,7 +523,7 @@ pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i6 .variadic_operands = &.{ &.{operand}, start_indices }, .result_type_inference = true, .attributes = &.{ - .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute).? }, + .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) }, }, .location = location, }); @@ -556,9 +556,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts .operands = &.{ value, padding_value }, .result_type_inference = true, .attributes = &.{ - .{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? }, - .{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? }, - .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? }, + .{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) }, + .{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) }, + .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) }, }, .location = location, }); @@ -576,10 +576,10 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value, .operands = &.{ value, other }, .result_type_inference = true, .attributes = &.{ - .{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute).? }, - .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute).? }, - .{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute).? }, - .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute).? }, + .{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) }, + .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) }, + .{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) }, + .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) }, }, .location = location, }); @@ -595,8 +595,8 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts: .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute).? }, - .{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute).? }, + .{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) }, + .{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) }, }, .location = location, }); @@ -607,7 +607,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r .operands = &.{ a, b, shape }, .result_type_inference = true, .attributes = &.{ - .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute).? }, + .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) }, }, .location = location, }); @@ -618,7 +618,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in .operands = &.{initial_state}, .results = &.{ res_state_type, res_type }, .attributes = &.{ - .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute).? }, + .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) }, }, .location = location, }); @@ -629,8 +629,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32 .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute).? }, - .{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute).? }, + .{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) }, + .{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) }, }, .location = location, }); @@ -657,7 +657,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64, .operands = &.{tuple_value}, .result_type_inference = true, .attributes = &.{ - .{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute).? }, + .{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) }, }, .location = location, }); @@ -694,23 +694,23 @@ pub fn convolution( ) mlir.Operation { var max_precisions: [2]mlir.Attribute = undefined; for (opts.precision_config, 0..) |p, i| { - max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?; + max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute); } var window_reversal: [3]i32 = undefined; for (opts.window_reversal, 0..) |w, i| { window_reversal[i] = @intCast(@intFromBool(w)); } - const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?; - const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type).?; + const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type); + const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type); return mlir.Operation.make(ctx, "stablehlo.convolution", .{ .operands = &.{ lhs, rhs }, .results = &.{res_type}, .attributes = &.{ - .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? }, - .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.pad_value)).as(mlir.Attribute).? }, - .{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute).? }, - .{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute).? }, - .{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute).? }, + .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) }, + .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) }, + .{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) }, + .{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) }, + .{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) }, .{ "dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{ .input_batch_dimension = opts.input_batch_dimension, @@ -722,11 +722,11 @@ pub fn convolution( .output_batch_dimension = opts.output_batch_dimension, .output_feature_dimension = opts.output_feature_dimension, .output_spatial_dimensions = opts.output_spatial_dimensions, - }).as(mlir.Attribute).?, + }).as(mlir.Attribute), }, - .{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute).? }, - .{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute).? }, - .{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute).? }, + .{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) }, + .{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) }, + .{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) }, }, .location = location, }); @@ -747,18 +747,18 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable; for (opts.output_operand_aliases, 0..) |alias, i| { - output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute).?; + output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute); } return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ .operands = inputs, .results = res_types, .attributes = &.{ - .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute).? }, - .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute).? }, - .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute).? }, - .{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute).? }, - .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute).? }, + .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) }, + .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) }, + .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) }, + .{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) }, + .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) }, }, .location = location, }); diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 35d60a7..1ccc766 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const stdx = @import("stdx"); const log = std.log.scoped(.mlir); const c = @import("c"); @@ -95,9 +96,10 @@ pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(O return false; } - pub inline fn as(self: OuterT, comptime OtherT: type) ?OtherT { + pub inline fn as(self: OuterT, comptime OtherT: type) OtherT { if (OtherT.Methods.is_a_fn) |is_a_fn| { - return if (is_a_fn(self.inner())) OtherT.wrap(self.inner()) else null; + stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT }); + return OtherT.wrap(self.inner()); } // if the other type doesn't have an is_a_fn, try. return OtherT.wrap(self.inner()); @@ -425,7 +427,7 @@ pub const BoolAttribute = struct { } pub fn asAttr(self: Self) Attribute { - return self.as(Attribute).?; + return self.as(Attribute); } }; @@ -446,7 +448,7 @@ pub const TypeAttribute = struct { } pub fn asAttr(self: TypeAttribute) Attribute { - return self.as(Attribute).?; + return self.as(Attribute); } }; @@ -591,10 +593,6 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type { .i64 => .i64, else => @compileError("DenseArrayAttribute: unreachable"), }); - - pub fn toElements(self: Attr) DenseArray { - return DenseArray.wrap(c.mlirDenseArrayToElements(self.inner())); - } }, else => struct {}, }; @@ -615,28 +613,32 @@ pub const DenseElementsAttributeTypes = enum { f16, f32, f64, + index, }; -pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { - const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) { - .bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue }, - .i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value }, - .i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value }, - .i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value }, - .i64 => .{ i64, i64, c.mlirDenseElementsAttrInt64Get, c.mlirDenseElementsAttrGetInt64Value }, - .u8 => .{ u8, u8, c.mlirDenseElementsAttrUInt8Get, c.mlirDenseElementsAttrGetUInt8Value }, - .u16 => .{ u16, u16, c.mlirDenseElementsAttrUInt16Get, c.mlirDenseElementsAttrGetUInt16Value }, - .u32 => .{ u32, u32, c.mlirDenseElementsAttrUInt32Get, c.mlirDenseElementsAttrGetUInt32Value }, - .u64 => .{ u64, u64, c.mlirDenseElementsAttrUInt64Get, c.mlirDenseElementsAttrGetUInt64Value }, - .bf16 => .{ u16, f32, c.mlirDenseElementsAttrBFloat16Get, c.mlirDenseElementsAttrGetFloatValue }, - .f16 => .{ f16, f32, c.mlirDenseElementsAttrFloat16Get, c.mlirDenseElementsAttrGetFloatValue }, - .f32 => .{ f32, f32, c.mlirDenseElementsAttrFloatGet, c.mlirDenseElementsAttrGetFloatValue }, - .f64 => .{ f64, f64, c.mlirDenseElementsAttrDoubleGet, c.mlirDenseElementsAttrGetDoubleValue }, +pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { + const ZigType = switch (dt) { + .bool => bool, + .i8 => i8, + .i16 => i16, + .i32 => i32, + .i64 => i64, + .u8 => u8, + .u16 => u16, + .u32 => u32, + .u64 => u64, + .bf16 => u16, + .f16 => f16, + .f32 => f32, + .f64 => f64, + .index => usize, }; return struct { _inner: c.MlirAttribute, + const Attr = @This(); + pub usingnamespace MlirHelpers(Attr, .{ .is_a_fn = c.mlirAttributeIsADenseElements, .is_null_fn = c.mlirAttributeIsNull, @@ -644,13 +646,29 @@ pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) t .equal_fn = c.mlirAttributeEqual, }); - pub fn init(shaped_type: Type, raw_values: []const u8) Attr { - const values = std.mem.bytesAsSlice(ZigInDataType, raw_values); - return Attr.wrap(initFn(shaped_type.inner(), @intCast(values.len), @ptrCast(@alignCast(values.ptr)))); + pub fn init(shaped_type: Type, slice: anytype) Attr { + const bytes = std.mem.sliceAsBytes(slice); + const v = Attr.wrapOr( + c.mlirDenseElementsAttrRawBufferGet( + shaped_type.inner(), + @intCast(bytes.len), + @ptrCast(bytes.ptr), + ), + ) orelse unreachable; + return v; } - pub fn get(self: Attr, pos: usize) ZigOutDataType { - return getValue(self.inner(), @intCast(pos)); + pub fn len(self: Attr) usize { + return @intCast(c.mlirElementsAttrGetNumElements(self.inner())); + } + + pub fn constSlice(self: Attr) []const ZigType { + const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable))); + return ptr[0..self.len()]; + } + + pub fn data(self: Attr) []const u8 { + return std.mem.sliceAsBytes(self.constSlice()); } }; } @@ -1338,12 +1356,9 @@ pub const FloatTypes = enum { f32, f64, - unknown, - - pub fn asType(self: FloatTypes, ctx: Context) ?Type { + pub fn asType(self: FloatTypes, ctx: Context) Type { return switch (self) { - .unknown => null, - inline else => |ft| FloatType(ft).init(ctx).asType(), + inline else => |ft| FloatType(ft).init(ctx).as(Type), }; } }; @@ -1359,48 +1374,23 @@ pub fn FloatType(comptime ft: FloatTypes) type { .f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet }, .f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet }, .f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet }, - .unknown => .{ null, null }, }; return struct { _inner: c.MlirType, - const Float = @This(); - pub usingnamespace MlirHelpers(Float, .{ - .is_a_fn = switch (ft) { - .unknown => typeIsAUnknownFloat, - else => Config[0], - }, + + const Self = @This(); + + pub usingnamespace MlirHelpers(Self, .{ + .is_a_fn = Config[0], .is_null_fn = c.mlirTypeIsNull, .dump_fn = c.mlirTypeDump, .equal_fn = c.mlirTypeEqual, }); - pub usingnamespace if (ft != .unknown) struct { - pub const FloatTypeType = ft; - - pub fn init(ctx: Context) Float { - const type_get = Config[1]; - return Float.wrap(type_get(ctx.inner())); - } - } else struct {}; - - fn typeIsAUnknownFloat(typ: c.MlirType) callconv(.C) bool { - const is_a_fns = .{ - c.mlirTypeIsABF16, - c.mlirTypeIsAF16, - c.mlirTypeIsAF32, - c.mlirTypeIsF64, - }; - inline for (is_a_fns) |is_a_fn| { - if (is_a_fn(typ)) { - return true; - } - } - return false; - } - - pub fn asType(self: Float) Type { - return self.as(Type).?; + pub fn init(ctx: Context) Self { + const type_get = Config[1]; + return Self.wrap(type_get(ctx.inner())); } }; } @@ -1545,10 +1535,6 @@ pub const RankedTensorType = struct { pub fn getDimension(self: RankedTensorType, dim: usize) i64 { return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)); } - - pub fn asType(self: RankedTensorType) Type { - return self.as(Type).?; - } }; pub const Dialect = struct { diff --git a/mlir/mlirx.cc b/mlir/mlirx.cc deleted file mode 100644 index f5fd45f..0000000 --- a/mlir/mlirx.cc +++ /dev/null @@ -1,27 +0,0 @@ -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" -#include "mlirx.h" - -namespace mlirx { - - static mlir::Attribute ArrayToElements(mlir::Attribute attr) { - if (auto array = attr.dyn_cast()) { - return mlir::DenseIntElementsAttr::get( - mlir::RankedTensorType::get(array.size(), array.getElementType()), - array.asArrayRef()); - } - if (auto array = attr.dyn_cast()) { - return mlir::DenseIntElementsAttr::get( - mlir::RankedTensorType::get(array.size(), array.getElementType()), - array.asArrayRef()); - } - return attr; - } - -} - -MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) { - return wrap(mlirx::ArrayToElements(unwrap(attr))); -} diff --git a/mlir/mlirx.h b/mlir/mlirx.h deleted file mode 100644 index 9946477..0000000 --- a/mlir/mlirx.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef MLIRX_CC_H -#define MLIRX_CC_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr); - -#ifdef __cplusplus -} -#endif - -#endif // MLIRX_CC_H diff --git a/zml/mlir.zig b/zml/mlir.zig index 38cb104..0709189 100644 --- a/zml/mlir.zig +++ b/zml/mlir.zig @@ -15,7 +15,7 @@ pub usingnamespace @import("mlir"); pub const ext = struct { pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type { - return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?; + return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type); } pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes { @@ -38,21 +38,21 @@ pub const ext = struct { } pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute { - const ranked_type_ = ranked_type.as(mlir.Type).?; + const ranked_type_ = ranked_type.as(mlir.Type); return switch (dt) { - .bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute).?, - .i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute).?, - .i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute).?, - .i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute).?, - .i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute).?, - .u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute).?, - .u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute).?, - .u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute).?, - .u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute).?, - .bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute).?, - .f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute).?, - .f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute).?, - .f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute).?, + .bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute), + .i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute), + .i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute), + .i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute), + .i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute), + .u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute), + .u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute), + .u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute), + .u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute), + .bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute), + .f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute), + .f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute), + .f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute), inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)), }; } @@ -66,28 +66,28 @@ pub const ext = struct { pub const Type = struct { pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { return switch (dt) { - .bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type).?, - .f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type).?, - .f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type).?, - .f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type).?, - .f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type).?, - .f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type).?, - .bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type).?, - .f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type).?, - .f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type).?, - .f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type).?, - .i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type).?, - .i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type).?, - .i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type).?, - .i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type).?, - .i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?, - .u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type).?, - .u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type).?, - .u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type).?, - .u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type).?, - .u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type).?, - .c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type).?, - .c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type).?, + .bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type), + .f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type), + .f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type), + .f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type), + .f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type), + .f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type), + .bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type), + .f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type), + .f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type), + .f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type), + .i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type), + .i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type), + .i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type), + .i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type), + .i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type), + .u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type), + .u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type), + .u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type), + .u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type), + .u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type), + .c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type), + .c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type), }; } @@ -123,7 +123,7 @@ pub const ext = struct { inline for (mapping) |entry| { const dt, const mlirT = entry; - if (mlir_type.as(mlirT)) |_| { + if (mlir_type.is_a(mlirT)) { return dt; } } @@ -136,39 +136,39 @@ pub const ext = struct { pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute { switch (data) { .bool => |val| { - return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute).?; + return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute); }, inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| { const float_type = @field(mlir.FloatTypes, @tagName(tag)); const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32()); - return float_attr.as(mlir.Attribute).?; + return float_attr.as(mlir.Attribute); }, inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| { const int_type = @field(mlir.IntegerTypes, @tagName(tag)); const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val)); - return int_attr.as(mlir.Attribute).?; + return int_attr.as(mlir.Attribute); }, inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), } } }; - pub const DenseIntOrFPElementsAttribute = struct { + pub const DenseElementsAttribute = struct { pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute { return switch (data.dtype()) { - .bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?, - .f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?, + .bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute), + .i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute), + .i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute), + .i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute), + .i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute), + .u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute), + .u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute), + .u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute), + .u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute), + .bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute), + .f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute), + .f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute), + .f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute), inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), }; } diff --git a/zml/module.zig b/zml/module.zig index 2377449..0b284c7 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -123,7 +123,7 @@ pub const CompilationContext = struct { const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const module = mlir.Module.init(loc); - module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?); + module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute)); var canonicalizer = try mlir.PassManager.init(mlir_ctx); { @@ -492,7 +492,7 @@ pub const CompilationContext = struct { attributes[a].appendAssumeCapacity( mlir.NamedAttribute.init( mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), - mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?, + mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute), ), ); // log.debug("attribute: {}", .{attributes[a].constSlice()}); diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index b1960bb..94e8e2e 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -132,7 +132,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { }, &.{ mlir.ext.mlirType(mlir_ctx, q.shape()), - mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(), + mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type), }, loc, ); diff --git a/zml/ops.zig b/zml/ops.zig index 99c7ace..852ec34 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -148,7 +148,7 @@ pub fn reduce( .result_type_inference = true, .blocks = &.{body_block}, .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute).? }, + .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) }, }, // We can't verify right away, cause the weights captured by the reduce haven't been added yet. .verify = false, @@ -197,7 +197,7 @@ pub fn reduce( mlir_ctx, val, inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced], - mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type).?, + mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type), inner_ctx.loc, ); tensor.* = Tensor._result(reduced_shape, broad_val.result(0)); @@ -240,17 +240,17 @@ pub fn reduceWindow( const pad_shape = mlir.RankedTensorType.init( &.{ @intCast(opts.padding.len), 2 }, mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64), - ).as(mlir.Type).?; + ).as(mlir.Type); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{ .variadic_operands = &.{ input_values[0..], init_values[0..] }, .result_type_inference = true, .blocks = &.{body_block}, .attributes = &.{ - .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute).? }, - .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? }, - .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? }, - .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? }, - .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? }, + .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) }, + .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) }, + .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) }, + .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) }, + .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) }, }, .location = loc, }); @@ -611,8 +611,8 @@ pub fn sort( .result_type_inference = true, .blocks = &.{block}, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute).? }, - .{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute).? }, + .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) }, + .{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) }, }, .location = loc, }); diff --git a/zml/tensor.zig b/zml/tensor.zig index a9da569..81578a2 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -110,7 +110,7 @@ pub const Tensor = struct { /// /// The shape is derived from the type of the mlir.Value. pub fn fromMlirValue(val: mlir.Value) Tensor { - const ranked_tensor = val.getType().as(mlir.RankedTensorType).?; + const ranked_tensor = val.getType().as(mlir.RankedTensorType); const n = ranked_tensor.getRank(); stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); @@ -281,7 +281,7 @@ pub const Tensor = struct { const op = dialect.stablehlo.bitcast_convert( self.getContext().mlirCtx(), self.value(), - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?, + mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type), loc, ); @@ -830,7 +830,7 @@ pub const Tensor = struct { self.value(), other.value(), used_opts, - mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type).?, + mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type), loc, ); @@ -1010,7 +1010,7 @@ pub const Tensor = struct { return self; } - const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?; + const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type); const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc); @@ -1520,7 +1520,7 @@ pub const Tensor = struct { const mlir_ctx = self.getContext().mlirCtx(); const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices}); - const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?; + const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type); const slice_op = dialect.stablehlo.slice( mlir_ctx, self.value(), @@ -1785,7 +1785,12 @@ pub const Tensor = struct { const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a }); const mlir_ctx = ctx.mlirCtx(); - var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); + var op = dialect.stablehlo.iota( + mlir_ctx, + a, + mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type), + loc, + ); return _result(res_shape, op.result(0)); } @@ -1857,7 +1862,7 @@ pub const Tensor = struct { }; if (sh.rank() > 0) { - constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc); + constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc); } return _result(sh, constant_op.result(0)).convert(val.dtype()); } @@ -1904,7 +1909,7 @@ pub const Tensor = struct { return _result(res_shape, self.value()); } const ctx = self.getContext(); - const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?; + const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type); const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);