mlir: rework DenseElementsAttribute to correctly slice inputs and modify .as() to return a concrete value instead of an optional

This commit is contained in:
Tarry Singh 2024-07-15 12:32:24 +00:00
parent 201f5245c1
commit aec1d96e6d
12 changed files with 206 additions and 268 deletions

View File

@ -1,16 +1,6 @@
load("@rules_zig//zig:defs.bzl", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test") load("//bazel:zig.bzl", "zig_cc_test")
cc_library(
name = "mlirx",
srcs = ["mlirx.cc"],
hdrs = ["mlirx.h"],
includes = ["."],
deps = [
"@llvm-project//mlir:CAPIIR",
],
)
cc_library( cc_library(
name = "c", name = "c",
hdrs = ["c.h"], hdrs = ["c.h"],
@ -30,7 +20,7 @@ zig_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":c", ":c",
":mlirx", "//stdx",
], ],
) )

View File

@ -72,7 +72,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? }, .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -102,7 +102,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? }, .{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });

View File

@ -14,14 +14,14 @@ pub fn func(
}, },
) mlir.Operation { ) mlir.Operation {
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) });
if (args.arg_attrs.len > 0) { if (args.arg_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) });
} }
if (args.res_attrs.len > 0) { if (args.res_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) });
} }
return mlir.Operation.make(ctx, "func.func", .{ return mlir.Operation.make(ctx, "func.func", .{
@ -36,7 +36,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r
.variadic_operands = &.{values}, .variadic_operands = &.{values},
.results = results, .results = results,
.verify = true, .verify = true,
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }}, .attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }},
.location = loc, .location = loc,
}); });
} }

View File

@ -98,7 +98,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli
.operands = &.{value}, .operands = &.{value},
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute).? }, .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -126,7 +126,7 @@ pub const DotPrecision = union(enum) {
// When we specify the dot algorithm, we should not specify the precision. // When we specify the dot algorithm, we should not specify the precision.
.algorithm => .DEFAULT, .algorithm => .DEFAULT,
}); });
return precision.as(mlir.Attribute).?; return precision.as(mlir.Attribute);
} }
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute { pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
@ -156,14 +156,14 @@ pub const DotAlgorithm = struct {
}; };
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute { pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input"); const tensor_type = operand_type.as(mlir.RankedTensorType);
const elem_type = tensor_type.getElementType(); const elem_type = tensor_type.getElementType();
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
ctx.inner(), ctx.inner(),
elem_type.inner(), elem_type.inner(),
elem_type.inner(), elem_type.inner(),
self.accumulation.asType(ctx).?.inner(), self.accumulation.asType(ctx).inner(),
self.component_count, self.component_count,
self.component_count, self.component_count,
self.num_primitive_operations, self.num_primitive_operations,
@ -196,7 +196,7 @@ pub fn dot_general(
.rhs_batching_dimensions = opts.rhs_batching_dimensions, .rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions, .lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions, .rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute).?, }).as(mlir.Attribute),
}, },
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() }, .{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
// keep algorithm as the last attribute so we can omit it when it's not set. // keep algorithm as the last attribute so we can omit it when it's not set.
@ -219,12 +219,12 @@ pub fn constant(
location: mlir.Location, location: mlir.Location,
) mlir.Operation { ) mlir.Operation {
const attribute = switch (elem_type) { const attribute = switch (elem_type) {
inline else => |dt| mlir.DenseIntOrFPElementsAttribute(dt).init(result_type.as(mlir.Type).?, raw_bytes).as(mlir.Attribute).?, inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
}; };
return mlir.Operation.make(ctx, "stablehlo.constant", .{ return mlir.Operation.make(ctx, "stablehlo.constant", .{
.operands = &.{}, .operands = &.{},
.results = &.{result_type.as(mlir.Type).?}, .results = &.{result_type.as(mlir.Type)},
.attributes = &.{.{ "value", attribute }}, .attributes = &.{.{ "value", attribute }},
.location = location, .location = location,
}); });
@ -243,7 +243,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6
.operands = &.{operand}, .operands = &.{operand},
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = &.{
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute).? }, .{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -254,7 +254,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l
.operands = &.{value}, .operands = &.{value},
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = &.{
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute).? }, .{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -265,9 +265,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64,
.operands = &.{operand}, .operands = &.{operand},
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = &.{
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute).? }, .{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) },
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute).? }, .{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) },
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute).? }, .{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -278,7 +278,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
.operands = inputs, .operands = inputs,
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -287,7 +287,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation { pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.reshape", .{ return mlir.Operation.make(ctx, "stablehlo.reshape", .{
.operands = &.{value}, .operands = &.{value},
.results = &.{result_type.as(mlir.Type).?}, .results = &.{result_type.as(mlir.Type)},
.location = location, .location = location,
}); });
} }
@ -331,9 +331,9 @@ pub fn gather(
args.start_indices_batching_dims, args.start_indices_batching_dims,
args.start_index_map, args.start_index_map,
args.index_vector_dim, args.index_vector_dim,
).as(mlir.Attribute).? }, ).as(mlir.Attribute) },
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute).? }, .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}, },
@ -393,8 +393,8 @@ pub fn scatter(
.blocks = &.{update_block}, .blocks = &.{update_block},
.attributes = &.{ .attributes = &.{
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? }, .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) },
}, },
.result_type_inference = true, .result_type_inference = true,
.location = location, .location = location,
@ -407,7 +407,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location:
.operands = &.{}, .operands = &.{},
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = &.{
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, .{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -419,7 +419,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
.operands = &.{operand}, .operands = &.{operand},
.results = &.{result_type}, .results = &.{result_type},
.attributes = &.{ .attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? }, .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -430,8 +430,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "comparison_direction", comparison_direction.as(mlir.Attribute).? }, .{ "comparison_direction", comparison_direction.as(mlir.Attribute) },
.{ "compare_type", compare_type.as(mlir.Attribute).? }, .{ "compare_type", compare_type.as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -452,7 +452,7 @@ pub fn reduce(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| { for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
reduce_elem_types[i] = arg_type; reduce_elem_types[i] = arg_type;
reduce_elem_types[inputs.len + i] = arg_type; reduce_elem_types[inputs.len + i] = arg_type;
} }
@ -474,7 +474,7 @@ pub fn reduce(
.result_type_inference = true, .result_type_inference = true,
.block = block, .block = block,
.attributes = &.{ .attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? }, .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -494,7 +494,7 @@ pub fn sort(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2]; const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined; var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| { for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
sort_elem_types[i * 2] = arg_type; sort_elem_types[i * 2] = arg_type;
sort_elem_types[i * 2 + 1] = arg_type; sort_elem_types[i * 2 + 1] = arg_type;
} }
@ -511,8 +511,8 @@ pub fn sort(
.result_type_inference = true, .result_type_inference = true,
.block = block, .block = block,
.attributes = &.{ .attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? }, .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute).? }, .{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -523,7 +523,7 @@ pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i6
.variadic_operands = &.{ &.{operand}, start_indices }, .variadic_operands = &.{ &.{operand}, start_indices },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute).? }, .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -556,9 +556,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
.operands = &.{ value, padding_value }, .operands = &.{ value, padding_value },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? }, .{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) },
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? }, .{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? }, .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -576,10 +576,10 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
.operands = &.{ value, other }, .operands = &.{ value, other },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute).? }, .{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) },
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute).? }, .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) },
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute).? }, .{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute).? }, .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -595,8 +595,8 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
.operands = &.{value}, .operands = &.{value},
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute).? }, .{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute).? }, .{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -607,7 +607,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
.operands = &.{ a, b, shape }, .operands = &.{ a, b, shape },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute).? }, .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -618,7 +618,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
.operands = &.{initial_state}, .operands = &.{initial_state},
.results = &.{ res_state_type, res_type }, .results = &.{ res_state_type, res_type },
.attributes = &.{ .attributes = &.{
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute).? }, .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -629,8 +629,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32
.operands = &.{value}, .operands = &.{value},
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute).? }, .{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) },
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute).? }, .{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -657,7 +657,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64,
.operands = &.{tuple_value}, .operands = &.{tuple_value},
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute).? }, .{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -694,23 +694,23 @@ pub fn convolution(
) mlir.Operation { ) mlir.Operation {
var max_precisions: [2]mlir.Attribute = undefined; var max_precisions: [2]mlir.Attribute = undefined;
for (opts.precision_config, 0..) |p, i| { for (opts.precision_config, 0..) |p, i| {
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?; max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute);
} }
var window_reversal: [3]i32 = undefined; var window_reversal: [3]i32 = undefined;
for (opts.window_reversal, 0..) |w, i| { for (opts.window_reversal, 0..) |w, i| {
window_reversal[i] = @intCast(@intFromBool(w)); window_reversal[i] = @intCast(@intFromBool(w));
} }
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?; const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type);
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type).?; const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type);
return mlir.Operation.make(ctx, "stablehlo.convolution", .{ return mlir.Operation.make(ctx, "stablehlo.convolution", .{
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.results = &.{res_type}, .results = &.{res_type},
.attributes = &.{ .attributes = &.{
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? }, .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.pad_value)).as(mlir.Attribute).? }, .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) },
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute).? }, .{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) },
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute).? }, .{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) },
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute).? }, .{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) },
.{ .{
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{ "dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
.input_batch_dimension = opts.input_batch_dimension, .input_batch_dimension = opts.input_batch_dimension,
@ -722,11 +722,11 @@ pub fn convolution(
.output_batch_dimension = opts.output_batch_dimension, .output_batch_dimension = opts.output_batch_dimension,
.output_feature_dimension = opts.output_feature_dimension, .output_feature_dimension = opts.output_feature_dimension,
.output_spatial_dimensions = opts.output_spatial_dimensions, .output_spatial_dimensions = opts.output_spatial_dimensions,
}).as(mlir.Attribute).?, }).as(mlir.Attribute),
}, },
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute).? }, .{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) },
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute).? }, .{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) },
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute).? }, .{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });
@ -747,18 +747,18 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable; const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
for (opts.output_operand_aliases, 0..) |alias, i| { for (opts.output_operand_aliases, 0..) |alias, i| {
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute).?; output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute);
} }
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs, .operands = inputs,
.results = res_types, .results = res_types,
.attributes = &.{ .attributes = &.{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute).? }, .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute).? }, .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute).? }, .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute).? }, .{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute).? }, .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
}, },
.location = location, .location = location,
}); });

View File

@ -1,5 +1,6 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const log = std.log.scoped(.mlir); const log = std.log.scoped(.mlir);
const c = @import("c"); const c = @import("c");
@ -95,9 +96,10 @@ pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(O
return false; return false;
} }
pub inline fn as(self: OuterT, comptime OtherT: type) ?OtherT { pub inline fn as(self: OuterT, comptime OtherT: type) OtherT {
if (OtherT.Methods.is_a_fn) |is_a_fn| { if (OtherT.Methods.is_a_fn) |is_a_fn| {
return if (is_a_fn(self.inner())) OtherT.wrap(self.inner()) else null; stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT });
return OtherT.wrap(self.inner());
} }
// if the other type doesn't have an is_a_fn, try. // if the other type doesn't have an is_a_fn, try.
return OtherT.wrap(self.inner()); return OtherT.wrap(self.inner());
@ -425,7 +427,7 @@ pub const BoolAttribute = struct {
} }
pub fn asAttr(self: Self) Attribute { pub fn asAttr(self: Self) Attribute {
return self.as(Attribute).?; return self.as(Attribute);
} }
}; };
@ -446,7 +448,7 @@ pub const TypeAttribute = struct {
} }
pub fn asAttr(self: TypeAttribute) Attribute { pub fn asAttr(self: TypeAttribute) Attribute {
return self.as(Attribute).?; return self.as(Attribute);
} }
}; };
@ -591,10 +593,6 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
.i64 => .i64, .i64 => .i64,
else => @compileError("DenseArrayAttribute: unreachable"), else => @compileError("DenseArrayAttribute: unreachable"),
}); });
pub fn toElements(self: Attr) DenseArray {
return DenseArray.wrap(c.mlirDenseArrayToElements(self.inner()));
}
}, },
else => struct {}, else => struct {},
}; };
@ -615,28 +613,32 @@ pub const DenseElementsAttributeTypes = enum {
f16, f16,
f32, f32,
f64, f64,
index,
}; };
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) { const ZigType = switch (dt) {
.bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue }, .bool => bool,
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value }, .i8 => i8,
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value }, .i16 => i16,
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value }, .i32 => i32,
.i64 => .{ i64, i64, c.mlirDenseElementsAttrInt64Get, c.mlirDenseElementsAttrGetInt64Value }, .i64 => i64,
.u8 => .{ u8, u8, c.mlirDenseElementsAttrUInt8Get, c.mlirDenseElementsAttrGetUInt8Value }, .u8 => u8,
.u16 => .{ u16, u16, c.mlirDenseElementsAttrUInt16Get, c.mlirDenseElementsAttrGetUInt16Value }, .u16 => u16,
.u32 => .{ u32, u32, c.mlirDenseElementsAttrUInt32Get, c.mlirDenseElementsAttrGetUInt32Value }, .u32 => u32,
.u64 => .{ u64, u64, c.mlirDenseElementsAttrUInt64Get, c.mlirDenseElementsAttrGetUInt64Value }, .u64 => u64,
.bf16 => .{ u16, f32, c.mlirDenseElementsAttrBFloat16Get, c.mlirDenseElementsAttrGetFloatValue }, .bf16 => u16,
.f16 => .{ f16, f32, c.mlirDenseElementsAttrFloat16Get, c.mlirDenseElementsAttrGetFloatValue }, .f16 => f16,
.f32 => .{ f32, f32, c.mlirDenseElementsAttrFloatGet, c.mlirDenseElementsAttrGetFloatValue }, .f32 => f32,
.f64 => .{ f64, f64, c.mlirDenseElementsAttrDoubleGet, c.mlirDenseElementsAttrGetDoubleValue }, .f64 => f64,
.index => usize,
}; };
return struct { return struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
const Attr = @This(); const Attr = @This();
pub usingnamespace MlirHelpers(Attr, .{ pub usingnamespace MlirHelpers(Attr, .{
.is_a_fn = c.mlirAttributeIsADenseElements, .is_a_fn = c.mlirAttributeIsADenseElements,
.is_null_fn = c.mlirAttributeIsNull, .is_null_fn = c.mlirAttributeIsNull,
@ -644,13 +646,29 @@ pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) t
.equal_fn = c.mlirAttributeEqual, .equal_fn = c.mlirAttributeEqual,
}); });
pub fn init(shaped_type: Type, raw_values: []const u8) Attr { pub fn init(shaped_type: Type, slice: anytype) Attr {
const values = std.mem.bytesAsSlice(ZigInDataType, raw_values); const bytes = std.mem.sliceAsBytes(slice);
return Attr.wrap(initFn(shaped_type.inner(), @intCast(values.len), @ptrCast(@alignCast(values.ptr)))); const v = Attr.wrapOr(
c.mlirDenseElementsAttrRawBufferGet(
shaped_type.inner(),
@intCast(bytes.len),
@ptrCast(bytes.ptr),
),
) orelse unreachable;
return v;
} }
pub fn get(self: Attr, pos: usize) ZigOutDataType { pub fn len(self: Attr) usize {
return getValue(self.inner(), @intCast(pos)); return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
}
pub fn constSlice(self: Attr) []const ZigType {
const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
return ptr[0..self.len()];
}
pub fn data(self: Attr) []const u8 {
return std.mem.sliceAsBytes(self.constSlice());
} }
}; };
} }
@ -1338,12 +1356,9 @@ pub const FloatTypes = enum {
f32, f32,
f64, f64,
unknown, pub fn asType(self: FloatTypes, ctx: Context) Type {
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
return switch (self) { return switch (self) {
.unknown => null, inline else => |ft| FloatType(ft).init(ctx).as(Type),
inline else => |ft| FloatType(ft).init(ctx).asType(),
}; };
} }
}; };
@ -1359,48 +1374,23 @@ pub fn FloatType(comptime ft: FloatTypes) type {
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet }, .f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet }, .f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet }, .f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
.unknown => .{ null, null },
}; };
return struct { return struct {
_inner: c.MlirType, _inner: c.MlirType,
const Float = @This();
pub usingnamespace MlirHelpers(Float, .{ const Self = @This();
.is_a_fn = switch (ft) {
.unknown => typeIsAUnknownFloat, pub usingnamespace MlirHelpers(Self, .{
else => Config[0], .is_a_fn = Config[0],
},
.is_null_fn = c.mlirTypeIsNull, .is_null_fn = c.mlirTypeIsNull,
.dump_fn = c.mlirTypeDump, .dump_fn = c.mlirTypeDump,
.equal_fn = c.mlirTypeEqual, .equal_fn = c.mlirTypeEqual,
}); });
pub usingnamespace if (ft != .unknown) struct { pub fn init(ctx: Context) Self {
pub const FloatTypeType = ft;
pub fn init(ctx: Context) Float {
const type_get = Config[1]; const type_get = Config[1];
return Float.wrap(type_get(ctx.inner())); return Self.wrap(type_get(ctx.inner()));
}
} else struct {};
fn typeIsAUnknownFloat(typ: c.MlirType) callconv(.C) bool {
const is_a_fns = .{
c.mlirTypeIsABF16,
c.mlirTypeIsAF16,
c.mlirTypeIsAF32,
c.mlirTypeIsF64,
};
inline for (is_a_fns) |is_a_fn| {
if (is_a_fn(typ)) {
return true;
}
}
return false;
}
pub fn asType(self: Float) Type {
return self.as(Type).?;
} }
}; };
} }
@ -1545,10 +1535,6 @@ pub const RankedTensorType = struct {
pub fn getDimension(self: RankedTensorType, dim: usize) i64 { pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim)); return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
} }
pub fn asType(self: RankedTensorType) Type {
return self.as(Type).?;
}
}; };
pub const Dialect = struct { pub const Dialect = struct {

View File

@ -1,27 +0,0 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlirx.h"
namespace mlirx {
static mlir::Attribute ArrayToElements(mlir::Attribute attr) {
if (auto array = attr.dyn_cast<mlir::DenseI64ArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
if (auto array = attr.dyn_cast<mlir::DenseBoolArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
return attr;
}
}
MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) {
return wrap(mlirx::ArrayToElements(unwrap(attr)));
}

View File

@ -1,16 +0,0 @@
#ifndef MLIRX_CC_H
#define MLIRX_CC_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr);
#ifdef __cplusplus
}
#endif
#endif // MLIRX_CC_H

View File

@ -15,7 +15,7 @@ pub usingnamespace @import("mlir");
pub const ext = struct { pub const ext = struct {
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type { pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?; return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
} }
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes { pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
@ -38,21 +38,21 @@ pub const ext = struct {
} }
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute { pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
const ranked_type_ = ranked_type.as(mlir.Type).?; const ranked_type_ = ranked_type.as(mlir.Type);
return switch (dt) { return switch (dt) {
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute).?, .bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute).?, .i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute).?, .i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute).?, .i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute).?, .i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute).?, .u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute).?, .u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute).?, .u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute).?, .u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute).?, .bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute).?, .f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute).?, .f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute).?, .f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)), inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
}; };
} }
@ -66,28 +66,28 @@ pub const ext = struct {
pub const Type = struct { pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type { pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) { return switch (dt) {
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type).?, .bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type).?, .f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type).?, .f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type).?, .f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type).?, .f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type).?, .f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type).?, .bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type).?, .f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type).?, .f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type).?, .f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type).?, .i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type).?, .i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type).?, .i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type).?, .i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?, .i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type).?, .u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type).?, .u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type).?, .u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type).?, .u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type).?, .u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type).?, .c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type).?, .c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
}; };
} }
@ -123,7 +123,7 @@ pub const ext = struct {
inline for (mapping) |entry| { inline for (mapping) |entry| {
const dt, const mlirT = entry; const dt, const mlirT = entry;
if (mlir_type.as(mlirT)) |_| { if (mlir_type.is_a(mlirT)) {
return dt; return dt;
} }
} }
@ -136,39 +136,39 @@ pub const ext = struct {
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute { pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
switch (data) { switch (data) {
.bool => |val| { .bool => |val| {
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute).?; return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
}, },
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| { inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
const float_type = @field(mlir.FloatTypes, @tagName(tag)); const float_type = @field(mlir.FloatTypes, @tagName(tag));
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32()); const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
return float_attr.as(mlir.Attribute).?; return float_attr.as(mlir.Attribute);
}, },
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| { inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
const int_type = @field(mlir.IntegerTypes, @tagName(tag)); const int_type = @field(mlir.IntegerTypes, @tagName(tag));
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val)); const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
return int_attr.as(mlir.Attribute).?; return int_attr.as(mlir.Attribute);
}, },
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
} }
} }
}; };
pub const DenseIntOrFPElementsAttribute = struct { pub const DenseElementsAttribute = struct {
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute { pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
return switch (data.dtype()) { return switch (data.dtype()) {
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?, .bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute).?, .i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute).?, .u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute).?, .u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute).?, .u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
}; };
} }

View File

@ -123,7 +123,7 @@ pub const CompilationContext = struct {
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc); const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?); module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute));
var canonicalizer = try mlir.PassManager.init(mlir_ctx); var canonicalizer = try mlir.PassManager.init(mlir_ctx);
{ {
@ -492,7 +492,7 @@ pub const CompilationContext = struct {
attributes[a].appendAssumeCapacity( attributes[a].appendAssumeCapacity(
mlir.NamedAttribute.init( mlir.NamedAttribute.init(
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?, mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute),
), ),
); );
// log.debug("attribute: {}", .{attributes[a].constSlice()}); // log.debug("attribute: {}", .{attributes[a].constSlice()});

View File

@ -132,7 +132,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
}, },
&.{ &.{
mlir.ext.mlirType(mlir_ctx, q.shape()), mlir.ext.mlirType(mlir_ctx, q.shape()),
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(), mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type),
}, },
loc, loc,
); );

View File

@ -148,7 +148,7 @@ pub fn reduce(
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{body_block}, .blocks = &.{body_block},
.attributes = &.{ .attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute).? }, .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) },
}, },
// We can't verify right away, cause the weights captured by the reduce haven't been added yet. // We can't verify right away, cause the weights captured by the reduce haven't been added yet.
.verify = false, .verify = false,
@ -197,7 +197,7 @@ pub fn reduce(
mlir_ctx, mlir_ctx,
val, val,
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced], inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type).?, mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type),
inner_ctx.loc, inner_ctx.loc,
); );
tensor.* = Tensor._result(reduced_shape, broad_val.result(0)); tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
@ -240,17 +240,17 @@ pub fn reduceWindow(
const pad_shape = mlir.RankedTensorType.init( const pad_shape = mlir.RankedTensorType.init(
&.{ @intCast(opts.padding.len), 2 }, &.{ @intCast(opts.padding.len), 2 },
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64), mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
).as(mlir.Type).?; ).as(mlir.Type);
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{ const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
.variadic_operands = &.{ input_values[0..], init_values[0..] }, .variadic_operands = &.{ input_values[0..], init_values[0..] },
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{body_block}, .blocks = &.{body_block},
.attributes = &.{ .attributes = &.{
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute).? }, .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) },
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? }, .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) },
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? }, .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) },
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? }, .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) },
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? }, .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) },
}, },
.location = loc, .location = loc,
}); });
@ -611,8 +611,8 @@ pub fn sort(
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{block}, .blocks = &.{block},
.attributes = &.{ .attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute).? }, .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute).? }, .{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) },
}, },
.location = loc, .location = loc,
}); });

View File

@ -110,7 +110,7 @@ pub const Tensor = struct {
/// ///
/// The shape is derived from the type of the mlir.Value. /// The shape is derived from the type of the mlir.Value.
pub fn fromMlirValue(val: mlir.Value) Tensor { pub fn fromMlirValue(val: mlir.Value) Tensor {
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?; const ranked_tensor = val.getType().as(mlir.RankedTensorType);
const n = ranked_tensor.getRank(); const n = ranked_tensor.getRank();
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
@ -281,7 +281,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.bitcast_convert( const op = dialect.stablehlo.bitcast_convert(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?, mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type),
loc, loc,
); );
@ -830,7 +830,7 @@ pub const Tensor = struct {
self.value(), self.value(),
other.value(), other.value(),
used_opts, used_opts,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type).?, mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type),
loc, loc,
); );
@ -1010,7 +1010,7 @@ pub const Tensor = struct {
return self; return self;
} }
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?; const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type);
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc); const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
@ -1520,7 +1520,7 @@ pub const Tensor = struct {
const mlir_ctx = self.getContext().mlirCtx(); const mlir_ctx = self.getContext().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices}); const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?; const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type);
const slice_op = dialect.stablehlo.slice( const slice_op = dialect.stablehlo.slice(
mlir_ctx, mlir_ctx,
self.value(), self.value(),
@ -1785,7 +1785,12 @@ pub const Tensor = struct {
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a }); const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
const mlir_ctx = ctx.mlirCtx(); const mlir_ctx = ctx.mlirCtx();
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); var op = dialect.stablehlo.iota(
mlir_ctx,
a,
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type),
loc,
);
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));
} }
@ -1857,7 +1862,7 @@ pub const Tensor = struct {
}; };
if (sh.rank() > 0) { if (sh.rank() > 0) {
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc); constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc);
} }
return _result(sh, constant_op.result(0)).convert(val.dtype()); return _result(sh, constant_op.result(0)).convert(val.dtype());
} }
@ -1904,7 +1909,7 @@ pub const Tensor = struct {
return _result(res_shape, self.value()); return _result(res_shape, self.value());
} }
const ctx = self.getContext(); const ctx = self.getContext();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?; const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type);
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);