mlir: rework DenseElementsAttribute to correctly slice inputs and modify .as() to return a concrete value instead of an optional
This commit is contained in:
parent
201f5245c1
commit
aec1d96e6d
@ -1,16 +1,6 @@
|
|||||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
load("//bazel:zig.bzl", "zig_cc_test")
|
load("//bazel:zig.bzl", "zig_cc_test")
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "mlirx",
|
|
||||||
srcs = ["mlirx.cc"],
|
|
||||||
hdrs = ["mlirx.h"],
|
|
||||||
includes = ["."],
|
|
||||||
deps = [
|
|
||||||
"@llvm-project//mlir:CAPIIR",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "c",
|
name = "c",
|
||||||
hdrs = ["c.h"],
|
hdrs = ["c.h"],
|
||||||
@ -30,7 +20,7 @@ zig_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":c",
|
":c",
|
||||||
":mlirx",
|
"//stdx",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -72,7 +72,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
|
|||||||
.operands = &.{ lhs, rhs },
|
.operands = &.{ lhs, rhs },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
|
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -102,7 +102,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
|
|||||||
.operands = &.{ lhs, rhs },
|
.operands = &.{ lhs, rhs },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
|
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
|
|||||||
@ -14,14 +14,14 @@ pub fn func(
|
|||||||
},
|
},
|
||||||
) mlir.Operation {
|
) mlir.Operation {
|
||||||
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
|
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) });
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) });
|
||||||
if (args.arg_attrs.len > 0) {
|
if (args.arg_attrs.len > 0) {
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args.res_attrs.len > 0) {
|
if (args.res_attrs.len > 0) {
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) });
|
||||||
}
|
}
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "func.func", .{
|
return mlir.Operation.make(ctx, "func.func", .{
|
||||||
@ -36,7 +36,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r
|
|||||||
.variadic_operands = &.{values},
|
.variadic_operands = &.{values},
|
||||||
.results = results,
|
.results = results,
|
||||||
.verify = true,
|
.verify = true,
|
||||||
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }},
|
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }},
|
||||||
.location = loc,
|
.location = loc,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -98,7 +98,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli
|
|||||||
.operands = &.{value},
|
.operands = &.{value},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute).? },
|
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -126,7 +126,7 @@ pub const DotPrecision = union(enum) {
|
|||||||
// When we specify the dot algorithm, we should not specify the precision.
|
// When we specify the dot algorithm, we should not specify the precision.
|
||||||
.algorithm => .DEFAULT,
|
.algorithm => .DEFAULT,
|
||||||
});
|
});
|
||||||
return precision.as(mlir.Attribute).?;
|
return precision.as(mlir.Attribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
|
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
|
||||||
@ -156,14 +156,14 @@ pub const DotAlgorithm = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
|
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
|
||||||
const tensor_type = operand_type.as(mlir.RankedTensorType) orelse @panic("dot_general expects RankedTensor as input");
|
const tensor_type = operand_type.as(mlir.RankedTensorType);
|
||||||
const elem_type = tensor_type.getElementType();
|
const elem_type = tensor_type.getElementType();
|
||||||
|
|
||||||
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
|
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
|
||||||
ctx.inner(),
|
ctx.inner(),
|
||||||
elem_type.inner(),
|
elem_type.inner(),
|
||||||
elem_type.inner(),
|
elem_type.inner(),
|
||||||
self.accumulation.asType(ctx).?.inner(),
|
self.accumulation.asType(ctx).inner(),
|
||||||
self.component_count,
|
self.component_count,
|
||||||
self.component_count,
|
self.component_count,
|
||||||
self.num_primitive_operations,
|
self.num_primitive_operations,
|
||||||
@ -196,7 +196,7 @@ pub fn dot_general(
|
|||||||
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
|
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
|
||||||
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
|
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
|
||||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||||
}).as(mlir.Attribute).?,
|
}).as(mlir.Attribute),
|
||||||
},
|
},
|
||||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
|
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
|
||||||
// keep algorithm as the last attribute so we can omit it when it's not set.
|
// keep algorithm as the last attribute so we can omit it when it's not set.
|
||||||
@ -219,12 +219,12 @@ pub fn constant(
|
|||||||
location: mlir.Location,
|
location: mlir.Location,
|
||||||
) mlir.Operation {
|
) mlir.Operation {
|
||||||
const attribute = switch (elem_type) {
|
const attribute = switch (elem_type) {
|
||||||
inline else => |dt| mlir.DenseIntOrFPElementsAttribute(dt).init(result_type.as(mlir.Type).?, raw_bytes).as(mlir.Attribute).?,
|
inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
|
||||||
};
|
};
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.constant", .{
|
return mlir.Operation.make(ctx, "stablehlo.constant", .{
|
||||||
.operands = &.{},
|
.operands = &.{},
|
||||||
.results = &.{result_type.as(mlir.Type).?},
|
.results = &.{result_type.as(mlir.Type)},
|
||||||
.attributes = &.{.{ "value", attribute }},
|
.attributes = &.{.{ "value", attribute }},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -243,7 +243,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6
|
|||||||
.operands = &.{operand},
|
.operands = &.{operand},
|
||||||
.results = &.{result_type},
|
.results = &.{result_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute).? },
|
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -254,7 +254,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l
|
|||||||
.operands = &.{value},
|
.operands = &.{value},
|
||||||
.results = &.{result_type},
|
.results = &.{result_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute).? },
|
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -265,9 +265,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64,
|
|||||||
.operands = &.{operand},
|
.operands = &.{operand},
|
||||||
.results = &.{result_type},
|
.results = &.{result_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute).? },
|
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) },
|
||||||
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute).? },
|
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) },
|
||||||
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute).? },
|
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -278,7 +278,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
|
|||||||
.operands = inputs,
|
.operands = inputs,
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
|
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -287,7 +287,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
|
|||||||
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
|
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
|
||||||
return mlir.Operation.make(ctx, "stablehlo.reshape", .{
|
return mlir.Operation.make(ctx, "stablehlo.reshape", .{
|
||||||
.operands = &.{value},
|
.operands = &.{value},
|
||||||
.results = &.{result_type.as(mlir.Type).?},
|
.results = &.{result_type.as(mlir.Type)},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -331,9 +331,9 @@ pub fn gather(
|
|||||||
args.start_indices_batching_dims,
|
args.start_indices_batching_dims,
|
||||||
args.start_index_map,
|
args.start_index_map,
|
||||||
args.index_vector_dim,
|
args.index_vector_dim,
|
||||||
).as(mlir.Attribute).? },
|
).as(mlir.Attribute) },
|
||||||
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute).? },
|
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) },
|
||||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
|
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
},
|
},
|
||||||
@ -393,8 +393,8 @@ pub fn scatter(
|
|||||||
.blocks = &.{update_block},
|
.blocks = &.{update_block},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
|
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
|
||||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
|
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
|
||||||
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
|
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.location = location,
|
.location = location,
|
||||||
@ -407,7 +407,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location:
|
|||||||
.operands = &.{},
|
.operands = &.{},
|
||||||
.results = &.{result_type},
|
.results = &.{result_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
|
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -419,7 +419,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
|
|||||||
.operands = &.{operand},
|
.operands = &.{operand},
|
||||||
.results = &.{result_type},
|
.results = &.{result_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
|
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -430,8 +430,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
|
|||||||
.operands = &.{ lhs, rhs },
|
.operands = &.{ lhs, rhs },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "comparison_direction", comparison_direction.as(mlir.Attribute).? },
|
.{ "comparison_direction", comparison_direction.as(mlir.Attribute) },
|
||||||
.{ "compare_type", compare_type.as(mlir.Attribute).? },
|
.{ "compare_type", compare_type.as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -452,7 +452,7 @@ pub fn reduce(
|
|||||||
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
|
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
|
||||||
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
||||||
for (inputs, 0..) |input, i| {
|
for (inputs, 0..) |input, i| {
|
||||||
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
|
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
|
||||||
reduce_elem_types[i] = arg_type;
|
reduce_elem_types[i] = arg_type;
|
||||||
reduce_elem_types[inputs.len + i] = arg_type;
|
reduce_elem_types[inputs.len + i] = arg_type;
|
||||||
}
|
}
|
||||||
@ -474,7 +474,7 @@ pub fn reduce(
|
|||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.block = block,
|
.block = block,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
|
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -494,7 +494,7 @@ pub fn sort(
|
|||||||
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
|
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
|
||||||
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
||||||
for (inputs, 0..) |input, i| {
|
for (inputs, 0..) |input, i| {
|
||||||
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
|
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
|
||||||
sort_elem_types[i * 2] = arg_type;
|
sort_elem_types[i * 2] = arg_type;
|
||||||
sort_elem_types[i * 2 + 1] = arg_type;
|
sort_elem_types[i * 2 + 1] = arg_type;
|
||||||
}
|
}
|
||||||
@ -511,8 +511,8 @@ pub fn sort(
|
|||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.block = block,
|
.block = block,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute).? },
|
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||||
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute).? },
|
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -523,7 +523,7 @@ pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i6
|
|||||||
.variadic_operands = &.{ &.{operand}, start_indices },
|
.variadic_operands = &.{ &.{operand}, start_indices },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute).? },
|
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -556,9 +556,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
|
|||||||
.operands = &.{ value, padding_value },
|
.operands = &.{ value, padding_value },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute).? },
|
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) },
|
||||||
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute).? },
|
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) },
|
||||||
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute).? },
|
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -576,10 +576,10 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
|
|||||||
.operands = &.{ value, other },
|
.operands = &.{ value, other },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute).? },
|
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) },
|
||||||
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute).? },
|
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) },
|
||||||
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute).? },
|
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) },
|
||||||
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute).? },
|
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -595,8 +595,8 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
|
|||||||
.operands = &.{value},
|
.operands = &.{value},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute).? },
|
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
|
||||||
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute).? },
|
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -607,7 +607,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
|
|||||||
.operands = &.{ a, b, shape },
|
.operands = &.{ a, b, shape },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute).? },
|
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -618,7 +618,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
|
|||||||
.operands = &.{initial_state},
|
.operands = &.{initial_state},
|
||||||
.results = &.{ res_state_type, res_type },
|
.results = &.{ res_state_type, res_type },
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute).? },
|
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -629,8 +629,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32
|
|||||||
.operands = &.{value},
|
.operands = &.{value},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute).? },
|
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) },
|
||||||
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute).? },
|
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -657,7 +657,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64,
|
|||||||
.operands = &.{tuple_value},
|
.operands = &.{tuple_value},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute).? },
|
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -694,23 +694,23 @@ pub fn convolution(
|
|||||||
) mlir.Operation {
|
) mlir.Operation {
|
||||||
var max_precisions: [2]mlir.Attribute = undefined;
|
var max_precisions: [2]mlir.Attribute = undefined;
|
||||||
for (opts.precision_config, 0..) |p, i| {
|
for (opts.precision_config, 0..) |p, i| {
|
||||||
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute).?;
|
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute);
|
||||||
}
|
}
|
||||||
var window_reversal: [3]i32 = undefined;
|
var window_reversal: [3]i32 = undefined;
|
||||||
for (opts.window_reversal, 0..) |w, i| {
|
for (opts.window_reversal, 0..) |w, i| {
|
||||||
window_reversal[i] = @intCast(@intFromBool(w));
|
window_reversal[i] = @intCast(@intFromBool(w));
|
||||||
}
|
}
|
||||||
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?;
|
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type);
|
||||||
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type).?;
|
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type);
|
||||||
return mlir.Operation.make(ctx, "stablehlo.convolution", .{
|
return mlir.Operation.make(ctx, "stablehlo.convolution", .{
|
||||||
.operands = &.{ lhs, rhs },
|
.operands = &.{ lhs, rhs },
|
||||||
.results = &.{res_type},
|
.results = &.{res_type},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? },
|
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) },
|
||||||
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.pad_value)).as(mlir.Attribute).? },
|
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) },
|
||||||
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute).? },
|
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) },
|
||||||
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute).? },
|
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) },
|
||||||
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute).? },
|
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) },
|
||||||
.{
|
.{
|
||||||
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
|
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
|
||||||
.input_batch_dimension = opts.input_batch_dimension,
|
.input_batch_dimension = opts.input_batch_dimension,
|
||||||
@ -722,11 +722,11 @@ pub fn convolution(
|
|||||||
.output_batch_dimension = opts.output_batch_dimension,
|
.output_batch_dimension = opts.output_batch_dimension,
|
||||||
.output_feature_dimension = opts.output_feature_dimension,
|
.output_feature_dimension = opts.output_feature_dimension,
|
||||||
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
||||||
}).as(mlir.Attribute).?,
|
}).as(mlir.Attribute),
|
||||||
},
|
},
|
||||||
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute).? },
|
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) },
|
||||||
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute).? },
|
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) },
|
||||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute).? },
|
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
@ -747,18 +747,18 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
|
|
||||||
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
|
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
|
||||||
for (opts.output_operand_aliases, 0..) |alias, i| {
|
for (opts.output_operand_aliases, 0..) |alias, i| {
|
||||||
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute).?;
|
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||||
.operands = inputs,
|
.operands = inputs,
|
||||||
.results = res_types,
|
.results = res_types,
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute).? },
|
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
|
||||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute).? },
|
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute).? },
|
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||||
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute).? },
|
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
|
||||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute).? },
|
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
|
|||||||
124
mlir/mlir.zig
124
mlir/mlir.zig
@ -1,5 +1,6 @@
|
|||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const stdx = @import("stdx");
|
||||||
const log = std.log.scoped(.mlir);
|
const log = std.log.scoped(.mlir);
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
@ -95,9 +96,10 @@ pub fn MlirHelpers(comptime OuterT: type, comptime methods: MlirHelpersMethods(O
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn as(self: OuterT, comptime OtherT: type) ?OtherT {
|
pub inline fn as(self: OuterT, comptime OtherT: type) OtherT {
|
||||||
if (OtherT.Methods.is_a_fn) |is_a_fn| {
|
if (OtherT.Methods.is_a_fn) |is_a_fn| {
|
||||||
return if (is_a_fn(self.inner())) OtherT.wrap(self.inner()) else null;
|
stdx.debug.assert(is_a_fn(self.inner()), "Wrongly tried to cast {} into {}", .{ OuterT, OtherT });
|
||||||
|
return OtherT.wrap(self.inner());
|
||||||
}
|
}
|
||||||
// if the other type doesn't have an is_a_fn, try.
|
// if the other type doesn't have an is_a_fn, try.
|
||||||
return OtherT.wrap(self.inner());
|
return OtherT.wrap(self.inner());
|
||||||
@ -425,7 +427,7 @@ pub const BoolAttribute = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn asAttr(self: Self) Attribute {
|
pub fn asAttr(self: Self) Attribute {
|
||||||
return self.as(Attribute).?;
|
return self.as(Attribute);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -446,7 +448,7 @@ pub const TypeAttribute = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn asAttr(self: TypeAttribute) Attribute {
|
pub fn asAttr(self: TypeAttribute) Attribute {
|
||||||
return self.as(Attribute).?;
|
return self.as(Attribute);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -591,10 +593,6 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
|
|||||||
.i64 => .i64,
|
.i64 => .i64,
|
||||||
else => @compileError("DenseArrayAttribute: unreachable"),
|
else => @compileError("DenseArrayAttribute: unreachable"),
|
||||||
});
|
});
|
||||||
|
|
||||||
pub fn toElements(self: Attr) DenseArray {
|
|
||||||
return DenseArray.wrap(c.mlirDenseArrayToElements(self.inner()));
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
else => struct {},
|
else => struct {},
|
||||||
};
|
};
|
||||||
@ -615,28 +613,32 @@ pub const DenseElementsAttributeTypes = enum {
|
|||||||
f16,
|
f16,
|
||||||
f32,
|
f32,
|
||||||
f64,
|
f64,
|
||||||
|
index,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
||||||
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
|
const ZigType = switch (dt) {
|
||||||
.bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
|
.bool => bool,
|
||||||
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
|
.i8 => i8,
|
||||||
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
|
.i16 => i16,
|
||||||
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },
|
.i32 => i32,
|
||||||
.i64 => .{ i64, i64, c.mlirDenseElementsAttrInt64Get, c.mlirDenseElementsAttrGetInt64Value },
|
.i64 => i64,
|
||||||
.u8 => .{ u8, u8, c.mlirDenseElementsAttrUInt8Get, c.mlirDenseElementsAttrGetUInt8Value },
|
.u8 => u8,
|
||||||
.u16 => .{ u16, u16, c.mlirDenseElementsAttrUInt16Get, c.mlirDenseElementsAttrGetUInt16Value },
|
.u16 => u16,
|
||||||
.u32 => .{ u32, u32, c.mlirDenseElementsAttrUInt32Get, c.mlirDenseElementsAttrGetUInt32Value },
|
.u32 => u32,
|
||||||
.u64 => .{ u64, u64, c.mlirDenseElementsAttrUInt64Get, c.mlirDenseElementsAttrGetUInt64Value },
|
.u64 => u64,
|
||||||
.bf16 => .{ u16, f32, c.mlirDenseElementsAttrBFloat16Get, c.mlirDenseElementsAttrGetFloatValue },
|
.bf16 => u16,
|
||||||
.f16 => .{ f16, f32, c.mlirDenseElementsAttrFloat16Get, c.mlirDenseElementsAttrGetFloatValue },
|
.f16 => f16,
|
||||||
.f32 => .{ f32, f32, c.mlirDenseElementsAttrFloatGet, c.mlirDenseElementsAttrGetFloatValue },
|
.f32 => f32,
|
||||||
.f64 => .{ f64, f64, c.mlirDenseElementsAttrDoubleGet, c.mlirDenseElementsAttrGetDoubleValue },
|
.f64 => f64,
|
||||||
|
.index => usize,
|
||||||
};
|
};
|
||||||
|
|
||||||
return struct {
|
return struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
|
|
||||||
const Attr = @This();
|
const Attr = @This();
|
||||||
|
|
||||||
pub usingnamespace MlirHelpers(Attr, .{
|
pub usingnamespace MlirHelpers(Attr, .{
|
||||||
.is_a_fn = c.mlirAttributeIsADenseElements,
|
.is_a_fn = c.mlirAttributeIsADenseElements,
|
||||||
.is_null_fn = c.mlirAttributeIsNull,
|
.is_null_fn = c.mlirAttributeIsNull,
|
||||||
@ -644,13 +646,29 @@ pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) t
|
|||||||
.equal_fn = c.mlirAttributeEqual,
|
.equal_fn = c.mlirAttributeEqual,
|
||||||
});
|
});
|
||||||
|
|
||||||
pub fn init(shaped_type: Type, raw_values: []const u8) Attr {
|
pub fn init(shaped_type: Type, slice: anytype) Attr {
|
||||||
const values = std.mem.bytesAsSlice(ZigInDataType, raw_values);
|
const bytes = std.mem.sliceAsBytes(slice);
|
||||||
return Attr.wrap(initFn(shaped_type.inner(), @intCast(values.len), @ptrCast(@alignCast(values.ptr))));
|
const v = Attr.wrapOr(
|
||||||
|
c.mlirDenseElementsAttrRawBufferGet(
|
||||||
|
shaped_type.inner(),
|
||||||
|
@intCast(bytes.len),
|
||||||
|
@ptrCast(bytes.ptr),
|
||||||
|
),
|
||||||
|
) orelse unreachable;
|
||||||
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(self: Attr, pos: usize) ZigOutDataType {
|
pub fn len(self: Attr) usize {
|
||||||
return getValue(self.inner(), @intCast(pos));
|
return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn constSlice(self: Attr) []const ZigType {
|
||||||
|
const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
|
||||||
|
return ptr[0..self.len()];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn data(self: Attr) []const u8 {
|
||||||
|
return std.mem.sliceAsBytes(self.constSlice());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -1338,12 +1356,9 @@ pub const FloatTypes = enum {
|
|||||||
f32,
|
f32,
|
||||||
f64,
|
f64,
|
||||||
|
|
||||||
unknown,
|
pub fn asType(self: FloatTypes, ctx: Context) Type {
|
||||||
|
|
||||||
pub fn asType(self: FloatTypes, ctx: Context) ?Type {
|
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
.unknown => null,
|
inline else => |ft| FloatType(ft).init(ctx).as(Type),
|
||||||
inline else => |ft| FloatType(ft).init(ctx).asType(),
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1359,48 +1374,23 @@ pub fn FloatType(comptime ft: FloatTypes) type {
|
|||||||
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
|
.f16 => .{ c.mlirTypeIsAF16, c.mlirF16TypeGet },
|
||||||
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
|
.f32 => .{ c.mlirTypeIsAF32, c.mlirF32TypeGet },
|
||||||
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
|
.f64 => .{ c.mlirTypeIsAF64, c.mlirF64TypeGet },
|
||||||
.unknown => .{ null, null },
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return struct {
|
return struct {
|
||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
const Float = @This();
|
|
||||||
pub usingnamespace MlirHelpers(Float, .{
|
const Self = @This();
|
||||||
.is_a_fn = switch (ft) {
|
|
||||||
.unknown => typeIsAUnknownFloat,
|
pub usingnamespace MlirHelpers(Self, .{
|
||||||
else => Config[0],
|
.is_a_fn = Config[0],
|
||||||
},
|
|
||||||
.is_null_fn = c.mlirTypeIsNull,
|
.is_null_fn = c.mlirTypeIsNull,
|
||||||
.dump_fn = c.mlirTypeDump,
|
.dump_fn = c.mlirTypeDump,
|
||||||
.equal_fn = c.mlirTypeEqual,
|
.equal_fn = c.mlirTypeEqual,
|
||||||
});
|
});
|
||||||
|
|
||||||
pub usingnamespace if (ft != .unknown) struct {
|
pub fn init(ctx: Context) Self {
|
||||||
pub const FloatTypeType = ft;
|
|
||||||
|
|
||||||
pub fn init(ctx: Context) Float {
|
|
||||||
const type_get = Config[1];
|
const type_get = Config[1];
|
||||||
return Float.wrap(type_get(ctx.inner()));
|
return Self.wrap(type_get(ctx.inner()));
|
||||||
}
|
|
||||||
} else struct {};
|
|
||||||
|
|
||||||
fn typeIsAUnknownFloat(typ: c.MlirType) callconv(.C) bool {
|
|
||||||
const is_a_fns = .{
|
|
||||||
c.mlirTypeIsABF16,
|
|
||||||
c.mlirTypeIsAF16,
|
|
||||||
c.mlirTypeIsAF32,
|
|
||||||
c.mlirTypeIsF64,
|
|
||||||
};
|
|
||||||
inline for (is_a_fns) |is_a_fn| {
|
|
||||||
if (is_a_fn(typ)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn asType(self: Float) Type {
|
|
||||||
return self.as(Type).?;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -1545,10 +1535,6 @@ pub const RankedTensorType = struct {
|
|||||||
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
||||||
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
|
return c.mlirShapedTypeGetDimSize(self.inner(), @intCast(dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asType(self: RankedTensorType) Type {
|
|
||||||
return self.as(Type).?;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Dialect = struct {
|
pub const Dialect = struct {
|
||||||
|
|||||||
@ -1,27 +0,0 @@
|
|||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/CAPI/IR.h"
|
|
||||||
#include "mlir/CAPI/Support.h"
|
|
||||||
#include "mlirx.h"
|
|
||||||
|
|
||||||
namespace mlirx {
|
|
||||||
|
|
||||||
static mlir::Attribute ArrayToElements(mlir::Attribute attr) {
|
|
||||||
if (auto array = attr.dyn_cast<mlir::DenseI64ArrayAttr>()) {
|
|
||||||
return mlir::DenseIntElementsAttr::get(
|
|
||||||
mlir::RankedTensorType::get(array.size(), array.getElementType()),
|
|
||||||
array.asArrayRef());
|
|
||||||
}
|
|
||||||
if (auto array = attr.dyn_cast<mlir::DenseBoolArrayAttr>()) {
|
|
||||||
return mlir::DenseIntElementsAttr::get(
|
|
||||||
mlir::RankedTensorType::get(array.size(), array.getElementType()),
|
|
||||||
array.asArrayRef());
|
|
||||||
}
|
|
||||||
return attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) {
|
|
||||||
return wrap(mlirx::ArrayToElements(unwrap(attr)));
|
|
||||||
}
|
|
||||||
16
mlir/mlirx.h
16
mlir/mlirx.h
@ -1,16 +0,0 @@
|
|||||||
#ifndef MLIRX_CC_H
|
|
||||||
#define MLIRX_CC_H
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // MLIRX_CC_H
|
|
||||||
110
zml/mlir.zig
110
zml/mlir.zig
@ -15,7 +15,7 @@ pub usingnamespace @import("mlir");
|
|||||||
|
|
||||||
pub const ext = struct {
|
pub const ext = struct {
|
||||||
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
|
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
|
||||||
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?;
|
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
|
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
|
||||||
@ -38,21 +38,21 @@ pub const ext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
|
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
|
||||||
const ranked_type_ = ranked_type.as(mlir.Type).?;
|
const ranked_type_ = ranked_type.as(mlir.Type);
|
||||||
return switch (dt) {
|
return switch (dt) {
|
||||||
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute).?,
|
.f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
|
||||||
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
|
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -66,28 +66,28 @@ pub const ext = struct {
|
|||||||
pub const Type = struct {
|
pub const Type = struct {
|
||||||
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
|
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
|
||||||
return switch (dt) {
|
return switch (dt) {
|
||||||
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type).?,
|
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
|
||||||
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type).?,
|
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
|
||||||
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type).?,
|
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
|
||||||
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type).?,
|
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
|
||||||
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type).?,
|
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
|
||||||
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type).?,
|
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
|
||||||
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type).?,
|
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
|
||||||
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type).?,
|
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
|
||||||
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type).?,
|
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
|
||||||
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type).?,
|
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
|
||||||
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type).?,
|
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
|
||||||
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type).?,
|
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
|
||||||
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type).?,
|
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
|
||||||
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type).?,
|
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
|
||||||
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type).?,
|
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
|
||||||
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type).?,
|
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
|
||||||
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type).?,
|
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
|
||||||
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type).?,
|
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
|
||||||
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type).?,
|
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
|
||||||
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type).?,
|
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
|
||||||
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type).?,
|
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
|
||||||
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type).?,
|
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ pub const ext = struct {
|
|||||||
|
|
||||||
inline for (mapping) |entry| {
|
inline for (mapping) |entry| {
|
||||||
const dt, const mlirT = entry;
|
const dt, const mlirT = entry;
|
||||||
if (mlir_type.as(mlirT)) |_| {
|
if (mlir_type.is_a(mlirT)) {
|
||||||
return dt;
|
return dt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,39 +136,39 @@ pub const ext = struct {
|
|||||||
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
|
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
|
||||||
switch (data) {
|
switch (data) {
|
||||||
.bool => |val| {
|
.bool => |val| {
|
||||||
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute).?;
|
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
|
||||||
},
|
},
|
||||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
|
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
|
||||||
const float_type = @field(mlir.FloatTypes, @tagName(tag));
|
const float_type = @field(mlir.FloatTypes, @tagName(tag));
|
||||||
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
|
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
|
||||||
return float_attr.as(mlir.Attribute).?;
|
return float_attr.as(mlir.Attribute);
|
||||||
},
|
},
|
||||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
|
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
|
||||||
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
|
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
|
||||||
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
|
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
|
||||||
return int_attr.as(mlir.Attribute).?;
|
return int_attr.as(mlir.Attribute);
|
||||||
},
|
},
|
||||||
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const DenseIntOrFPElementsAttribute = struct {
|
pub const DenseElementsAttribute = struct {
|
||||||
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
|
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
|
||||||
return switch (data.dtype()) {
|
return switch (data.dtype()) {
|
||||||
.bool => mlir.DenseIntOrFPElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.i8 => mlir.DenseIntOrFPElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.i16 => mlir.DenseIntOrFPElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.i32 => mlir.DenseIntOrFPElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.i64 => mlir.DenseIntOrFPElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.u8 => mlir.DenseIntOrFPElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.u16 => mlir.DenseIntOrFPElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.u32 => mlir.DenseIntOrFPElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.u64 => mlir.DenseIntOrFPElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.bf16 => mlir.DenseIntOrFPElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
.f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?,
|
.f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
|
||||||
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -123,7 +123,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
|
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
|
||||||
const module = mlir.Module.init(loc);
|
const module = mlir.Module.init(loc);
|
||||||
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?);
|
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute));
|
||||||
|
|
||||||
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
|
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
|
||||||
{
|
{
|
||||||
@ -492,7 +492,7 @@ pub const CompilationContext = struct {
|
|||||||
attributes[a].appendAssumeCapacity(
|
attributes[a].appendAssumeCapacity(
|
||||||
mlir.NamedAttribute.init(
|
mlir.NamedAttribute.init(
|
||||||
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
||||||
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute).?,
|
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
||||||
|
|||||||
@ -132,7 +132,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
},
|
},
|
||||||
&.{
|
&.{
|
||||||
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||||
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(),
|
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type),
|
||||||
},
|
},
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|||||||
20
zml/ops.zig
20
zml/ops.zig
@ -148,7 +148,7 @@ pub fn reduce(
|
|||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.blocks = &.{body_block},
|
.blocks = &.{body_block},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute).? },
|
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
|
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
|
||||||
.verify = false,
|
.verify = false,
|
||||||
@ -197,7 +197,7 @@ pub fn reduce(
|
|||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
val,
|
val,
|
||||||
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
|
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
|
||||||
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type).?,
|
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type),
|
||||||
inner_ctx.loc,
|
inner_ctx.loc,
|
||||||
);
|
);
|
||||||
tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
|
tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
|
||||||
@ -240,17 +240,17 @@ pub fn reduceWindow(
|
|||||||
const pad_shape = mlir.RankedTensorType.init(
|
const pad_shape = mlir.RankedTensorType.init(
|
||||||
&.{ @intCast(opts.padding.len), 2 },
|
&.{ @intCast(opts.padding.len), 2 },
|
||||||
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
|
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
|
||||||
).as(mlir.Type).?;
|
).as(mlir.Type);
|
||||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
|
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
|
||||||
.variadic_operands = &.{ input_values[0..], init_values[0..] },
|
.variadic_operands = &.{ input_values[0..], init_values[0..] },
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.blocks = &.{body_block},
|
.blocks = &.{body_block},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute).? },
|
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) },
|
||||||
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute).? },
|
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) },
|
||||||
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute).? },
|
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) },
|
||||||
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute).? },
|
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) },
|
||||||
.{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding)).as(mlir.Attribute).? },
|
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = loc,
|
.location = loc,
|
||||||
});
|
});
|
||||||
@ -611,8 +611,8 @@ pub fn sort(
|
|||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.blocks = &.{block},
|
.blocks = &.{block},
|
||||||
.attributes = &.{
|
.attributes = &.{
|
||||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute).? },
|
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) },
|
||||||
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute).? },
|
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) },
|
||||||
},
|
},
|
||||||
.location = loc,
|
.location = loc,
|
||||||
});
|
});
|
||||||
|
|||||||
@ -110,7 +110,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// The shape is derived from the type of the mlir.Value.
|
/// The shape is derived from the type of the mlir.Value.
|
||||||
pub fn fromMlirValue(val: mlir.Value) Tensor {
|
pub fn fromMlirValue(val: mlir.Value) Tensor {
|
||||||
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
|
const ranked_tensor = val.getType().as(mlir.RankedTensorType);
|
||||||
const n = ranked_tensor.getRank();
|
const n = ranked_tensor.getRank();
|
||||||
|
|
||||||
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
|
||||||
@ -281,7 +281,7 @@ pub const Tensor = struct {
|
|||||||
const op = dialect.stablehlo.bitcast_convert(
|
const op = dialect.stablehlo.bitcast_convert(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?,
|
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type),
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -830,7 +830,7 @@ pub const Tensor = struct {
|
|||||||
self.value(),
|
self.value(),
|
||||||
other.value(),
|
other.value(),
|
||||||
used_opts,
|
used_opts,
|
||||||
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type).?,
|
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type),
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -1010,7 +1010,7 @@ pub const Tensor = struct {
|
|||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?;
|
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type);
|
||||||
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
||||||
|
|
||||||
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
||||||
@ -1520,7 +1520,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const mlir_ctx = self.getContext().mlirCtx();
|
const mlir_ctx = self.getContext().mlirCtx();
|
||||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type).?;
|
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type);
|
||||||
const slice_op = dialect.stablehlo.slice(
|
const slice_op = dialect.stablehlo.slice(
|
||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1785,7 +1785,12 @@ pub const Tensor = struct {
|
|||||||
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
|
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
|
||||||
|
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
var op = dialect.stablehlo.iota(
|
||||||
|
mlir_ctx,
|
||||||
|
a,
|
||||||
|
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type),
|
||||||
|
loc,
|
||||||
|
);
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1857,7 +1862,7 @@ pub const Tensor = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (sh.rank() > 0) {
|
if (sh.rank() > 0) {
|
||||||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
|
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc);
|
||||||
}
|
}
|
||||||
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
||||||
}
|
}
|
||||||
@ -1904,7 +1909,7 @@ pub const Tensor = struct {
|
|||||||
return _result(res_shape, self.value());
|
return _result(res_shape, self.value());
|
||||||
}
|
}
|
||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?;
|
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type);
|
||||||
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
||||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user