mlir: rework DenseElementsAttribute to correctly slice inputs and modify .as() to return a concrete value instead of an optional

This commit is contained in:
Tarry Singh 2024-07-15 12:32:24 +00:00
parent 201f5245c1
commit aec1d96e6d
12 changed files with 206 additions and 268 deletions

View File

@ -1,16 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
cc_library(
name = "mlirx",
srcs = ["mlirx.cc"],
hdrs = ["mlirx.h"],
includes = ["."],
deps = [
"@llvm-project//mlir:CAPIIR",
],
)
cc_library(
name = "c",
hdrs = ["c.h"],
@ -30,7 +20,7 @@ zig_library(
visibility = ["//visibility:public"],
deps = [
":c",
":mlirx",
"//stdx",
],
)

View File

@ -72,7 +72,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
},
.location = location,
});
@ -102,7 +102,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
},
.location = location,
});

View File

@ -14,14 +14,14 @@ pub fn func(
},
) mlir.Operation {
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) });
if (args.arg_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) });
}
if (args.res_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) });
}
return mlir.Operation.make(ctx, "func.func", .{
@ -36,7 +36,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r
.variadic_operands = &.{values},
.results = results,
.verify = true,
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }},
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }},
.location = loc,
});
}

View File

@ -98,7 +98,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute).? },
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) },
},
.location = location,
});
@ -126,7 +126,7 @@ pub const DotPrecision = union(enum) {
// When we specify the dot algorithm, we should not specify the precision.
.algorithm => .DEFAULT,
});
return precision.as(mlir.Attribute).?;
return precision.as(mlir.Attribute);
}
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
@ -156,14 +156,14 @@ pub const DotAlgorithm = struct {
};
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input");
const tensor_type = operand_type.as(mlir.RankedTensorType);
const elem_type = tensor_type.getElementType();
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
ctx.inner(),
elem_type.inner(),
elem_type.inner(),
self.accumulation.asType(ctx).?.inner(),
self.accumulation.asType(ctx).inner(),
self.component_count,
self.component_count,
self.num_primitive_operations,
@ -196,7 +196,7 @@ pub fn dot_general(
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute).?,
}).as(mlir.Attribute),
},
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
// keep algorithm as the last attribute so we can omit it when it's not set.
@ -219,12 +219,12 @@ pub fn constant(
location: mlir.Location,
) mlir.Operation {
const attribute = switch (elem_type) {
inline else => |dt| mlir.DenseIntOrFPElementsAttribute(dt).init(result_type.as(mlir.Type).?, raw_bytes).as(mlir.Attribute).?,
inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
};
return mlir.Operation.make(ctx, "stablehlo.constant", .{
.operands = &.{},
.results = &.{result_type.as(mlir.Type).?},
.results = &.{result_type.as(mlir.Type)},
.attributes = &.{.{ "value", attribute }},
.location = location,
});
@ -243,7 +243,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute).? },
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) },
},
.location = location,
});
@ -254,7 +254,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l
.operands = &.{value},
.results = &.{result_type},
.attributes = &.{
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute).? },
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) },
},
.location = location,
});
@ -265,9 +265,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64,
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute).? },
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute).? },
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute).? },
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) },
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) },
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) },
},
.location = location,
});
@ -278,7 +278,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
.operands = inputs,
.result_type_inference = true,
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
},
.location = location,
});
@ -287,7 +287,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.reshape", .{
.operands = &.{value},
.results = &.{result_type.as(mlir.Type).?},
.results = &.{result_type.as(mlir.Type)},
.location = location,
});
}
@ -331,9 +331,9 @@ pub fn gather(
args.start_indices_batching_dims,
args.start_index_map,
args.index_vector_dim,
).as(mlir.Attribute).? },
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute).? },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
).as(mlir.Attribute) },
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
},
.location = location,
},
@ -393,8 +393,8 @@ pub fn scatter(
.blocks = &.{update_block},
.attributes = &.{
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) },
},
.result_type_inference = true,
.location = location,
@ -407,7 +407,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location:
.operands = &.{},
.results = &.{result_type},
.attributes = &.{
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
},
.location = location,
});
@ -419,7 +419,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
},
.location = location,
});
@ -430,8 +430,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "comparison_direction", comparison_direction.as(mlir.Attribute).? },
.{ "compare_type", compare_type.as(mlir.Attribute).? },
.{ "comparison_direction", comparison_direction.as(mlir.Attribute) },
.{ "compare_type", compare_type.as(mlir.Attribute) },
},
.location = location,
});
@ -452,7 +452,7 @@ pub fn reduce(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
reduce_elem_types[i] = arg_type;
reduce_elem_types[inputs.len + i] = arg_type;
}
@ -474,7 +474,7 @@ pub fn reduce(
.result_type_inference = true,
.block = block,
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
},
.location = location,
});
@ -494,7 +494,7 @@ pub fn sort(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
sort_elem_types[i * 2] = arg_type;
sort_elem_types[i * 2 + 1] = arg_type;
}
@ -511,8 +511,8 @@ pub fn sort(
.result_type_inference = true,
.block = block,
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute).? },
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) },
},
.location = location,
});
@ -523,7 +523,7 @@ pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i6
.variadic_operands = &.{ &.{operand}, start_indices },
.result_type_inference = true,
.attributes = &.{
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute).? },
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) },
},
.location = location,
});
@ -556,9 +556,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
.operands = &.{ value, padding_value },
.result_type_inference = true,
.attributes = &.{
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? },
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? },
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) },
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) },
},
.location = location,
});
@ -576,10 +576,10 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
.operands = &.{ value, other },
.result_type_inference = true,
.attributes = &.{
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute).? },
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute).? },
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute).? },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute).? },
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) },
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) },
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
},
.location = location,
});
@ -595,8 +595,8 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute).? },
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute).? },
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) },
},
.location = location,
});
@ -607,7 +607,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
.operands = &.{ a, b, shape },
.result_type_inference = true,
.attributes = &.{
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute).? },
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) },
},
.location = location,
});
@ -618,7 +618,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
.operands = &.{initial_state},
.results = &.{ res_state_type, res_type },
.attributes = &.{
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute).? },
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) },
},
.location = location,
});
@ -629,8 +629,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute).? },
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute).? },
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) },
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) },
},
.location = location,
});
@ -657,7 +657,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64,
.operands = &.{tuple_value},
.result_type_inference = true,
.attributes = &.{
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute).? },
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) },
},
.location = location,
});
@ -694,23 +694,23 @@ pub fn convolution(
) mlir.Operation {
var max_precisions: [2]mlir.Attribute = undefined;
for (opts.precision_config, 0..) |p, i| {
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?;
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute);
}
var window_reversal: [3]i32 = undefined;
for (opts.window_reversal, 0..) |w, i| {
window_reversal[i] = @intCast(@intFromBool(w));
}
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?;
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type).?;
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type);
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type);
return mlir.Operation.make(ctx, "stablehlo.convolution", .{
.operands = &.{ lhs, rhs },
.results = &.{res_type},
.attributes = &.{
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.pad_value)).as(mlir.Attribute).? },
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute).? },
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute).? },
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute).? },
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) },
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) },
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) },
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) },
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) },
.{
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
.input_batch_dimension = opts.input_batch_dimension,
@ -722,11 +722,11 @@ pub fn convolution(
.output_batch_dimension = opts.output_batch_dimension,
.output_feature_dimension = opts.output_feature_dimension,
.output_spatial_dimensions = opts.output_spatial_dimensions,
}).as(mlir.Attribute).?,
}).as(mlir.Attribute),
},
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute).? },
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute).? },
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute).? },
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) },
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) },
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) },
},
.location = location,
});
@ -747,18 +747,18 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
for (opts.output_operand_aliases, 0..) |alias, i| {
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute).?;
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute);
}
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs,
.results = res_types,
.attributes = &.{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute).? },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute).? },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute).? },
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute).? },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute).? },
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
},
.location = location,
});

View File

@ -1,5 +1,6 @@
const builtin = @import("builtin");
const std = @import("std");
const stdx = @import("stdx");
const log = std.log.scoped(.mlir);
const c = @import("c");
@ -95,9 +96,10 @@ pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(O
return false;
}
pub inline fn as(self: OuterT, comptime OtherT: type) ?OtherT {
pub inline fn as(self: OuterT, comptime OtherT: type) OtherT {
if (OtherT.Methods.is_a_fn) |is_a_fn| {
return if (is_a_fn(self.inner())) OtherT.wrap(self.inner()) else null;
stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT });
return OtherT.wrap(self.inner());
}
// if the other type doesn't have an is_a_fn, try.
return OtherT.wrap(self.inner());
@ -425,7 +427,7 @@ pub const BoolAttribute = struct {
}
pub fn asAttr(self: Self) Attribute {
return self.as(Attribute).?;
return self.as(Attribute);
}
};
@ -446,7 +448,7 @@ pub const TypeAttribute = struct {
}
pub fn asAttr(self: TypeAttribute) Attribute {
return self.as(Attribute).?;
return self.as(Attribute);
}
};
@ -591,10 +593,6 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
.i64 => .i64,
else => @compileError("DenseArrayAttribute: unreachable"),
});
pub fn toElements(self: Attr) DenseArray {
return DenseArray.wrap(c.mlirDenseArrayToElements(self.inner()));
}
},
else => struct {},
};
@ -615,28 +613,32 @@ pub const DenseElementsAttributeTypes = enum {
f16,
f32,
f64,
index,
};
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
.bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },
.i64 => .{ i64, i64, c.mlirDenseElementsAttrInt64Get, c.mlirDenseElementsAttrGetInt64Value },
.u8 => .{ u8, u8, c.mlirDenseElementsAttrUInt8Get, c.mlirDenseElementsAttrGetUInt8Value },
.u16 => .{ u16, u16, c.mlirDenseElementsAttrUInt16Get, c.mlirDenseElementsAttrGetUInt16Value },
.u32 => .{ u32, u32, c.mlirDenseElementsAttrUInt32Get, c.mlirDenseElementsAttrGetUInt32Value },
.u64 => .{ u64, u64, c.mlirDenseElementsAttrUInt64Get, c.mlirDenseElementsAttrGetUInt64Value },
.bf16 => .{ u16, f32, c.mlirDenseElementsAttrBFloat16Get, c.mlirDenseElementsAttrGetFloatValue },
.f16 => .{ f16, f32, c.mlirDenseElementsAttrFloat16Get, c.mlirDenseElementsAttrGetFloatValue },
.f32 => .{ f32, f32, c.mlirDenseElementsAttrFloatGet, c.mlirDenseElementsAttrGetFloatValue },
.f64 => .{ f64, f64, c.mlirDenseElementsAttrDoubleGet, c.mlirDenseElementsAttrGetDoubleValue },
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
const ZigType = switch (dt) {
.bool => bool,
.i8 => i8,
.i16 => i16,
.i32 => i32,
.i64 => i64,
.u8 => u8,
.u16 => u16,
.u32 => u32,
.u64 => u64,
.bf16 => u16,
.f16 => f16,
.f32 => f32,
.f64 => f64,
.index => usize,
};
return struct {
_inner: c.MlirAttribute,
const Attr = @This();
pub usingnamespace MlirHelpers(Attr, .{
.is_a_fn = c.mlirAttributeIsADenseElements,
.is_null_fn = c.mlirAttributeIsNull,
@ -644,13 +646,29 @@ pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) t
.equal_fn = c.mlirAttributeEqual,
});
pub fn init(shaped_type: Type, raw_values: []const u8) Attr {
const values = std.mem.bytesAsSlice(ZigInDataType, raw_values);
return Attr.wrap(initFn(shaped_type.inner(), @intCast(values.len), @ptrCast(@alignCast(values.ptr))));
pub fn init(shaped_type: Type, slice: anytype) Attr {
const bytes = std.mem.sliceAsBytes(slice);
const v = Attr.wrapOr(
c.mlirDenseElementsAttrRawBufferGet(
shaped_type.inner(),
@intCast(bytes.len),
@ptrCast(bytes.ptr),
),
) orelse unreachable;
return v;
}
pub fn get(self: Attr, pos: usize) ZigOutDataType {
return getValue(self.inner(), @intCast(pos));
pub fn len(self: Attr) usize {
return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
}
pub fn constSlice(self: Attr) []const ZigType {
const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
return ptr[0..self.len()];
}
pub fn data(self: Attr) []const u8 {
return std.mem.sliceAsBytes(self.constSlice());
}
};
}
@ -1338,12 +1356,9 @@ pub const FloatTypes = enum {
f32,
f64,
unknown,
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
pub fn asType(self: FloatTypes, ctx: Context) Type {
return switch (self) {
.unknown => null,
inline else => |ft| FloatType(ft).init(ctx).asType(),
inline else => |ft| FloatType(ft).init(ctx).as(Type),
};
}
};
@ -1359,48 +1374,23 @@ pub fn FloatType(comptime ft: FloatTypes) type {
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
.unknown => .{ null, null },
};
return struct {
_inner: c.MlirType,
const Float = @This();
pub usingnamespace MlirHelpers(Float, .{
.is_a_fn = switch (ft) {
.unknown => typeIsAUnknownFloat,
else => Config[0],
},
const Self = @This();
pub usingnamespace MlirHelpers(Self, .{
.is_a_fn = Config[0],
.is_null_fn = c.mlirTypeIsNull,
.dump_fn = c.mlirTypeDump,
.equal_fn = c.mlirTypeEqual,
});
pub usingnamespace if (ft != .unknown) struct {
pub const FloatTypeType = ft;
pub fn init(ctx: Context) Float {
pub fn init(ctx: Context) Self {
const type_get = Config[1];
return Float.wrap(type_get(ctx.inner()));
}
} else struct {};
fn typeIsAUnknownFloat(typ: c.MlirType) callconv(.C) bool {
const is_a_fns = .{
c.mlirTypeIsABF16,
c.mlirTypeIsAF16,
c.mlirTypeIsAF32,
c.mlirTypeIsF64,
};
inline for (is_a_fns) |is_a_fn| {
if (is_a_fn(typ)) {
return true;
}
}
return false;
}
pub fn asType(self: Float) Type {
return self.as(Type).?;
return Self.wrap(type_get(ctx.inner()));
}
};
}
@ -1545,10 +1535,6 @@ pub const RankedTensorType = struct {
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
}
pub fn asType(self: RankedTensorType) Type {
return self.as(Type).?;
}
};
pub const Dialect = struct {

View File

@ -1,27 +0,0 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlirx.h"
namespace mlirx {
static mlir::Attribute ArrayToElements(mlir::Attribute attr) {
if (auto array = attr.dyn_cast<mlir::DenseI64ArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
if (auto array = attr.dyn_cast<mlir::DenseBoolArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
return attr;
}
}
MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) {
return wrap(mlirx::ArrayToElements(unwrap(attr)));
}

View File

@ -1,16 +0,0 @@
#ifndef MLIRX_CC_H
#define MLIRX_CC_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr);
#ifdef __cplusplus
}
#endif
#endif // MLIRX_CC_H

View File

@ -15,7 +15,7 @@ pub usingnamespace @import("mlir");
pub const ext = struct {
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?;
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
}
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
@ -38,21 +38,21 @@ pub const ext = struct {
}
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
const ranked_type_ = ranked_type.as(mlir.Type).?;
const ranked_type_ = ranked_type.as(mlir.Type);
return switch (dt) {
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute).?,
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute).?,
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute).?,
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute).?,
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute).?,
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute).?,
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute).?,
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute).?,
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute).?,
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute).?,
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute).?,
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute).?,
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute).?,
.bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
};
}
@ -66,28 +66,28 @@ pub const ext = struct {
pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type).?,
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type).?,
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type).?,
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type).?,
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type).?,
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type).?,
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type).?,
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type).?,
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type).?,
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type).?,
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type).?,
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type).?,
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type).?,
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type).?,
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?,
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type).?,
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type).?,
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type).?,
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type).?,
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type).?,
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type).?,
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type).?,
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
};
}
@ -123,7 +123,7 @@ pub const ext = struct {
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlir_type.as(mlirT)) |_| {
if (mlir_type.is_a(mlirT)) {
return dt;
}
}
@ -136,39 +136,39 @@ pub const ext = struct {
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
switch (data) {
.bool => |val| {
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute).?;
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
},
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
const float_type = @field(mlir.FloatTypes, @tagName(tag));
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
return float_attr.as(mlir.Attribute).?;
return float_attr.as(mlir.Attribute);
},
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
return int_attr.as(mlir.Attribute).?;
return int_attr.as(mlir.Attribute);
},
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
}
}
};
pub const DenseIntOrFPElementsAttribute = struct {
pub const DenseElementsAttribute = struct {
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
return switch (data.dtype()) {
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
.bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
};
}

View File

@ -123,7 +123,7 @@ pub const CompilationContext = struct {
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute));
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
{
@ -492,7 +492,7 @@ pub const CompilationContext = struct {
attributes[a].appendAssumeCapacity(
mlir.NamedAttribute.init(
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?,
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute),
),
);
// log.debug("attribute: {}", .{attributes[a].constSlice()});

View File

@ -132,7 +132,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
},
&.{
mlir.ext.mlirType(mlir_ctx, q.shape()),
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(),
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type),
},
loc,
);

View File

@ -148,7 +148,7 @@ pub fn reduce(
.result_type_inference = true,
.blocks = &.{body_block},
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute).? },
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) },
},
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
.verify = false,
@ -197,7 +197,7 @@ pub fn reduce(
mlir_ctx,
val,
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type).?,
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type),
inner_ctx.loc,
);
tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
@ -240,17 +240,17 @@ pub fn reduceWindow(
const pad_shape = mlir.RankedTensorType.init(
&.{ @intCast(opts.padding.len), 2 },
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
).as(mlir.Type).?;
).as(mlir.Type);
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
.variadic_operands = &.{ input_values[0..], init_values[0..] },
.result_type_inference = true,
.blocks = &.{body_block},
.attributes = &.{
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute).? },
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? },
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? },
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? },
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) },
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) },
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) },
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) },
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) },
},
.location = loc,
});
@ -611,8 +611,8 @@ pub fn sort(
.result_type_inference = true,
.blocks = &.{block},
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute).? },
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute).? },
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) },
},
.location = loc,
});

View File

@ -110,7 +110,7 @@ pub const Tensor = struct {
///
/// The shape is derived from the type of the mlir.Value.
pub fn fromMlirValue(val: mlir.Value) Tensor {
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
const ranked_tensor = val.getType().as(mlir.RankedTensorType);
const n = ranked_tensor.getRank();
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
@ -281,7 +281,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.bitcast_convert(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type),
loc,
);
@ -830,7 +830,7 @@ pub const Tensor = struct {
self.value(),
other.value(),
used_opts,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type).?,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type),
loc,
);
@ -1010,7 +1010,7 @@ pub const Tensor = struct {
return self;
}
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?;
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type);
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
@ -1520,7 +1520,7 @@ pub const Tensor = struct {
const mlir_ctx = self.getContext().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type);
const slice_op = dialect.stablehlo.slice(
mlir_ctx,
self.value(),
@ -1785,7 +1785,12 @@ pub const Tensor = struct {
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
const mlir_ctx = ctx.mlirCtx();
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
var op = dialect.stablehlo.iota(
mlir_ctx,
a,
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type),
loc,
);
return _result(res_shape, op.result(0));
}
@ -1857,7 +1862,7 @@ pub const Tensor = struct {
};
if (sh.rank() > 0) {
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc);
}
return _result(sh, constant_op.result(0)).convert(val.dtype());
}
@ -1904,7 +1909,7 @@ pub const Tensor = struct {
return _result(res_shape, self.value());
}
const ctx = self.getContext();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?;
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type);
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);