add mini-DSL for creating MLIR common attributes and types, leveraging Zig 0.14 to simplify mlir.Type and mlir.Attribute creation

This commit is contained in:
Tarry Singh 2024-08-26 14:19:00 +00:00
parent 63ef78efcc
commit ac63c30e12
16 changed files with 463 additions and 325 deletions

View File

@ -9,6 +9,7 @@ cc_library(
"@llvm-project//mlir:CAPIArith",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIMath",
"@llvm-project//mlir:CAPISCF",
"@llvm-project//mlir:CAPITransforms",
],
)

View File

@ -3,6 +3,7 @@
#include <mlir-c/Dialect/Arith.h>
#include <mlir-c/Dialect/Func.h>
#include <mlir-c/Dialect/Math.h>
#include <mlir-c/Dialect/SCF.h>
#include <mlir-c/IR.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Transforms.h>

View File

@ -7,6 +7,7 @@ zig_library(
"arith.zig",
"func.zig",
"math.zig",
"scf.zig",
"tensor.zig",
],
import_name = "mlir/dialects",

View File

@ -1,11 +1,10 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn constant(ctx: mlir.Context, value: mlir.Attribute, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "arith.constant", .{
.attributes = &.{
.{ "value", value },
},
.attributes = &.{.{ "value", value }},
.result_type_inference = true,
.location = location,
});
@ -44,6 +43,8 @@ pub const mulf = binary_fn("arith.mulf");
pub const divsi = binary_fn("arith.divsi");
pub const divui = binary_fn("arith.divui");
pub const divf = binary_fn("arith.divf");
pub const maxnumf = binary_fn("arith.maxnumf");
pub const maxnumi = binary_fn("arith.maxnumi");
pub const extsi = cast_fn("arith.extsi");
pub const extui = cast_fn("arith.extui");
pub const extf = cast_fn("arith.extf");

View File

@ -3,6 +3,7 @@ const std = @import("std");
pub const arith = @import("arith.zig");
pub const func = @import("func.zig");
pub const math = @import("math.zig");
pub const scf = @import("scf.zig");
pub const tensor = @import("tensor.zig");
pub const stablehlo = @import("mlir/dialects/stablehlo");

View File

@ -1,4 +1,5 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn func(
@ -14,14 +15,14 @@ pub fn func(
},
) mlir.Operation {
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) });
if (args.arg_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", .array(ctx, args.arg_attrs) });
}
if (args.res_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", .array(ctx, args.res_attrs) });
}
return mlir.Operation.make(ctx, "func.func", .{
@ -36,7 +37,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r
.variadic_operands = &.{values},
.results = results,
.verify = true,
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }},
.attributes = &.{.{ "callee", .symbol(ctx, name) }},
.location = loc,
});
}

View File

@ -8,7 +8,7 @@ fn unary_fn(comptime op_name: [:0]const u8) type {
pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
.operands = &.{value},
.results = &.{},
.results = &.{value.getType()},
.location = location,
});
}
@ -32,4 +32,5 @@ pub const fpowi = binary_fn("fpowi").call;
pub const tanh = unary_fn("tanh").call;
pub const sqrt = unary_fn("sqrt").call;
pub const exp = unary_fn("exp").call;
pub const exp2 = unary_fn("exp2").call;
pub const log = unary_fn("log").call;

61
mlir/dialects/scf.zig Normal file
View File

@ -0,0 +1,61 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn ForBody(ExtraArgs: type) type {
return fn (mlir.Context, mlir.Block, ExtraArgs) mlir.Operation;
}
pub const ForRange = struct {
start: mlir.Value,
end: mlir.Value,
step: mlir.Value,
};
pub fn @"for"(
ExtraArgs: type,
ctx: mlir.Context,
range: ForRange,
init_values: []const mlir.Value,
body: ForBody(ExtraArgs),
extra_args: ExtraArgs,
loc: mlir.Location,
) mlir.Operation {
const n_args = init_values.len;
var init_types_buf: [32]mlir.Type = undefined;
var locs_buf: [32]mlir.Location = undefined;
// The first block argument is the for loop induction variable,
// followed then by all the loop-carried variables.
const init_types = init_types_buf[0 .. n_args + 1];
const locs = locs_buf[0 .. n_args + 1];
init_types[0] = range.start.getType();
locs[0] = loc;
for (1.., init_values) |i, val| {
init_types[i] = val.getType();
locs[i] = loc;
}
const block = mlir.Block.init(init_types, locs) catch unreachable;
const yield_op = @call(.auto, body, .{ ctx, block, extra_args });
std.debug.assert(std.mem.eql(u8, "scf.yield", yield_op.name().str()));
block.appendOperationRecursive(yield_op, .open);
const for_op = mlir.Operation.make(ctx, "scf.for", .{
.variadic_operands = &.{ &.{ range.start, range.end, range.step }, init_values },
.results = init_types[1..],
.blocks = &.{block},
.location = loc,
.verify = false,
});
return for_op;
}
pub fn yield(ctx: mlir.Context, res: []const mlir.Value, loc: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "scf.yield", .{
.variadic_operands = &.{res},
.results = &.{},
.location = loc,
.verify = false,
});
}

View File

@ -99,7 +99,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) },
.{ "lower", .i1FromBool(ctx, lower) },
},
.location = location,
});
@ -189,7 +189,7 @@ pub fn dot_general(
precision: DotPrecision,
},
) mlir.Operation {
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
const precisions: [2]mlir.Attribute = @splat(opts.precision.precisionAttr(ctx));
const attributes = [3]mlir.AttrTuple{
.{
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
@ -199,7 +199,7 @@ pub fn dot_general(
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute),
},
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
.{ "precision_config", .array(ctx, &precisions) },
// keep algorithm as the last attribute so we can omit it when it's not set.
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
};
@ -244,7 +244,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) },
.{ "broadcast_dimensions", .dense(ctx, .i64, dims) },
},
.location = location,
});
@ -255,7 +255,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l
.operands = &.{value},
.results = &.{result_type},
.attributes = &.{
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) },
.{ "permutation", .dense(ctx, .i64, opts.permutation) },
},
.location = location,
});
@ -266,9 +266,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64,
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) },
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) },
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) },
.{ "start_indices", .dense(ctx, .i64, start_indices) },
.{ "limit_indices", .dense(ctx, .i64, limit_indices) },
.{ "strides", .dense(ctx, .i64, strides) },
},
.location = location,
});
@ -279,7 +279,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
.operands = inputs,
.result_type_inference = true,
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
.{ "dimension", .int(ctx, .i64, dimension) },
},
.location = location,
});
@ -333,8 +333,8 @@ pub fn gather(
args.start_index_map,
args.index_vector_dim,
).as(mlir.Attribute) },
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
.{ "slice_sizes", .dense(ctx, .i64, slice_sizes) },
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
},
.location = location,
},
@ -394,8 +394,8 @@ pub fn scatter(
.blocks = &.{update_block},
.attributes = &.{
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) },
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
.{ "unique_indices", .boolean(ctx, args.unique_indices) },
},
.result_type_inference = true,
.location = location,
@ -408,7 +408,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location:
.operands = &.{},
.results = &.{result_type},
.attributes = &.{
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
.{ "iota_dimension", .int(ctx, .i64, dimension) },
},
.location = location,
});
@ -420,7 +420,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
.{ "dimensions", .dense(ctx, .i64, dimensions) },
},
.location = location,
});
@ -453,7 +453,7 @@ pub fn reduce(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType()));
reduce_elem_types[i] = arg_type;
reduce_elem_types[inputs.len + i] = arg_type;
}
@ -475,7 +475,7 @@ pub fn reduce(
.result_type_inference = true,
.block = block,
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
.{ "dimensions", .dense(ctx, .i64, dimensions) },
},
.location = location,
});
@ -495,7 +495,7 @@ pub fn sort(
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType()));
sort_elem_types[i * 2] = arg_type;
sort_elem_types[i * 2 + 1] = arg_type;
}
@ -512,19 +512,19 @@ pub fn sort(
.result_type_inference = true,
.block = block,
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) },
.{ "dimension", .int(ctx, .i64, dimension) },
.{ "is_stable", .boolean(ctx, is_stable) },
},
.location = location,
});
}
pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation {
pub fn dynamic_slice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.dynamic_slice", .{
.variadic_operands = &.{ &.{operand}, start_indices },
.result_type_inference = true,
.attributes = &.{
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) },
.{ "slice_sizes", .dense(ctx, .i64, new_dims) },
},
.location = location,
});
@ -557,9 +557,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
.operands = &.{ value, padding_value },
.result_type_inference = true,
.attributes = &.{
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) },
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) },
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) },
.{ "edge_padding_low", .dense(ctx, .i64, opts.low) },
.{ "edge_padding_high", .dense(ctx, .i64, opts.high) },
.{ "interior_padding", .dense(ctx, .i64, opts.interior) },
},
.location = location,
});
@ -577,9 +577,9 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
.operands = &.{ value, other },
.result_type_inference = true,
.attributes = &.{
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) },
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) },
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) },
.{ "left_side", .i1FromBool(ctx, opts.left_side) },
.{ "lower", .i1FromBool(ctx, opts.lower) },
.{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
},
.location = location,
@ -597,7 +597,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
.result_type_inference = true,
.attributes = &.{
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) },
.{ "fft_length", .dense(ctx, .i64, opts.length) },
},
.location = location,
});
@ -630,8 +630,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) },
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) },
.{ "exponent_bits", .int(ctx, .i32, exponent_bits) },
.{ "mantissa_bits", .int(ctx, .i32, mantissa_bits) },
},
.location = location,
});
@ -658,7 +658,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64,
.operands = &.{tuple_value},
.result_type_inference = true,
.attributes = &.{
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) },
.{ "index", .int(ctx, .i32, index) },
},
.location = location,
});
@ -701,17 +701,15 @@ pub fn convolution(
for (opts.window_reversal, 0..) |w, i| {
window_reversal[i] = @intCast(@intFromBool(w));
}
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type);
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type);
return mlir.Operation.make(ctx, "stablehlo.convolution", .{
.operands = &.{ lhs, rhs },
.results = &.{res_type},
.attributes = &.{
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) },
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) },
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) },
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) },
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) },
.{ "window_strides", .dense(ctx, .i64, opts.window_strides) },
.{ "padding", .denseElements(ctx, opts.pad_shape, .i64, opts.pad_value) },
.{ "lhs_dilation", .dense(ctx, .i64, opts.lhs_dilation) },
.{ "rhs_dilation", .dense(ctx, .i64, opts.rhs_dilation) },
.{ "window_reversal", .dense(ctx, .bool, window_reversal[0..opts.window_reversal.len]) },
.{
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
.input_batch_dimension = opts.input_batch_dimension,
@ -725,9 +723,9 @@ pub fn convolution(
.output_spatial_dimensions = opts.output_spatial_dimensions,
}).as(mlir.Attribute),
},
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) },
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) },
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) },
.{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) },
.{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) },
.{ "precision_config", .array(ctx, &max_precisions) },
},
.location = location,
});
@ -764,25 +762,20 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| {
ret.appendAssumeCapacity(
OutputOperandAliasAttribute.init(
ctx,
&.{},
alias,
&.{},
).as(mlir.Attribute),
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
);
}
break :blk ret;
};
const backend_config = switch (opts.backend_config) {
const backend_config: mlir.Attribute = switch (opts.backend_config) {
.string => blk: {
stdx.debug.assert(
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
"Only API version of less than 4 is supported for backend_config as string",
.{},
);
break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute);
break :blk .string(ctx, opts.backend_config.string);
},
.dict => blk: {
stdx.debug.assert(
@ -797,45 +790,33 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
.{ "call_target_name", .string(ctx, opts.call_target_name) },
.{ "has_side_effect", .boolean(ctx, opts.has_side_effect) },
.{ "backend_config", backend_config },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) },
});
if (opts.operand_layouts) |layouts| {
const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (layouts) |ol| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(ol.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
}
break :blk ret;
};
const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) };
attrs.appendAssumeCapacity(attr);
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
}
if (opts.result_layouts) |layouts| {
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (layouts) |rl| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(rl.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
}
break :blk ret;
};
const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) };
attrs.appendAssumeCapacity(attr);
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
}
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
@ -852,9 +833,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
const frontend_attributes = mlir.DictionaryAttribute.init(
ctx,
&.{
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
},
&.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())},
).asAttr();
return custom_call(ctx, inputs, .{

View File

@ -1,9 +1,10 @@
const builtin = @import("builtin");
const std = @import("std");
const stdx = @import("stdx");
const log = std.log.scoped(.mlir);
const builtin = @import("builtin");
const c = @import("c");
const stdx = @import("stdx");
const log = std.log.scoped(.mlir);
test {
std.testing.refAllDecls(@This());
@ -254,6 +255,10 @@ pub const Module = struct {
pub fn op(self: Module) Operation {
return Operation.wrap(c.mlirModuleGetOperation(self.inner()));
}
pub fn hash(self: Module, hasher: *std.hash.XxHash64) void {
return self.op().hash(hasher);
}
};
pub const PassManager = struct {
@ -352,21 +357,81 @@ pub const Attribute = struct {
) orelse Error.InvalidMlir;
}
pub fn getNull() Self {
return Self.wrap(c.mlirAttributeGetNull());
// utilities function to built common attributes.
// All attributes are upcasted to the Attribute type, making it easier to chain construct,
// but losing type information.
pub fn null_() Attribute {
return .wrap(c.mlirAttributeGetNull());
}
pub fn string(ctx: Context, str: []const u8) Attribute {
return StringAttribute.init(ctx, str).asAttr();
}
pub fn type_(t: Type) Attribute {
return TypeAttribute.init(t).asAttr();
}
pub fn unit(ctx: Context) Attribute {
return .wrap(c.mlirUnitAttrGet(ctx.inner()));
}
pub fn boolean(ctx: Context, value: bool) Attribute {
return BoolAttribute.init(ctx, value).asAttr();
}
pub fn i1FromBool(ctx: Context, value: bool) Attribute {
return IntegerAttribute(.i1).init(ctx, @intFromBool(value)).asAttr();
}
pub fn int(ctx: Context, comptime int_type: IntegerTypes, value: i64) Attribute {
return IntegerAttribute(int_type).init(ctx, value).asAttr();
}
pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute {
return .wrap(FloatAttribute(float_type).init(ctx, value)._inner);
}
pub fn array(ctx: Context, attrs: []const Attribute) Attribute {
return ArrayAttribute.init(ctx, attrs).asAttr();
}
pub fn dense(ctx: Context, comptime dt: DenseArrayTypes, values: []const dt.ZigType()) Attribute {
return DenseArrayAttribute(dt).init(ctx, values).asAttr();
}
/// Use a tensor as an attribute.
/// The tensor is specified by dims, dtype and a flat slice of values.
pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: anytype) Attribute {
return .wrap(DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).inner());
}
pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute {
return .wrap(FlatSymbolRefAttribute.init(ctx, flat_name).inner());
}
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
return NamedAttribute.init(Identifier.get(ctx, name), attr);
}
};
pub const NamedAttribute = struct {
_inner: c.MlirNamedAttribute,
pub usingnamespace MlirHelpers(NamedAttribute, .{});
const Self = NamedAttribute;
pub const NamedAttribute = extern struct {
name: c.MlirIdentifier,
attribute: c.MlirAttribute,
pub fn init(name: Identifier, attr: Attribute) Self {
return Self.wrap(.{
pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute {
return .{
.name = c.mlirIdentifierGet(ctx._inner, stringRef(name)),
.attribute = attr.inner(),
};
}
pub fn init(name: Identifier, attr: Attribute) NamedAttribute {
return .{
.name = name.inner(),
.attribute = attr.inner(),
});
};
}
};
@ -393,21 +458,6 @@ pub const StringAttribute = struct {
}
};
pub const UnitAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace MlirHelpers(UnitAttribute, .{
.is_a_fn = c.mlirAttributeIsAUnit,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = UnitAttribute;
pub fn init(ctx: Context) Self {
return Self.wrap(c.mlirUnitAttrGet(ctx.inner()));
}
};
pub const BoolAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace MlirHelpers(BoolAttribute, .{
@ -481,8 +531,8 @@ pub const ArrayAttribute = struct {
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
const ZigType, const getter = comptime switch (it) {
.i1, .i4, .i8, .i16, .i32, .i64 => .{ u64, c.mlirIntegerAttrGetValueInt },
.si4, .si8, .si16, .si32, .si64 => .{ u64, c.mlirIntegerAttrGetValueSInt },
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
.unknown => @compileError("IntegerAttribute(unknown)"),
};
@ -547,38 +597,48 @@ pub const DenseArrayTypes = enum {
i64,
f32,
f64,
pub fn ZigType(comptime dt: DenseArrayTypes) type {
return switch (dt) {
.bool => i32,
.i8 => i8,
.i16 => i16,
.i32 => i32,
.i64 => i64,
.f32 => f32,
.f64 => f64,
};
}
};
pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
const Config = switch (dt) {
.bool => .{ i32, c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
.i8 => .{ i8, c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
.i16 => .{ i16, c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
.i32 => .{ i32, c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
.i64 => .{ i64, c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
.f32 => .{ f32, c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
.f64 => .{ f64, c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
const is_a_fn, const get_fn, const get_element_fn = switch (dt) {
.bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
.i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
.i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
.i32 => .{ c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
.i64 => .{ c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
.f32 => .{ c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
.f64 => .{ c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
};
return struct {
_inner: c.MlirAttribute,
pub usingnamespace MlirHelpers(@This(), .{
.is_a_fn = Config[1],
.is_a_fn = is_a_fn,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Attr = @This();
const ElementType = dt;
const ElementTypeZig = Config[0];
const ElementTypeZig = dt.ZigType();
pub fn init(ctx: Context, values: []const ElementTypeZig) Attr {
const get_fn = Config[2];
return Attr.wrap(get_fn(ctx.inner(), @intCast(values.len), @ptrCast(values.ptr)));
}
pub fn get(self: Attr, pos: usize) ElementTypeZig {
const get_element_fn = Config[3];
return get_element_fn(self.inner(), @intCast(pos));
}
@ -586,6 +646,10 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
return @intCast(c.mlirDenseArrayGetNumElements(self.inner()));
}
pub fn asAttr(self: Attr) Attribute {
return Attribute.wrap(self._inner);
}
pub usingnamespace switch (dt) {
.bool, .i64 => struct {
const DenseArray = DenseArrayAttribute(switch (dt) {
@ -614,26 +678,47 @@ pub const DenseElementsAttributeTypes = enum {
f32,
f64,
index,
pub fn ZigType(comptime dt: DenseElementsAttributeTypes) type {
return switch (dt) {
.bool => bool,
.i8 => i8,
.i16 => i16,
.i32 => i32,
.i64 => i64,
.u8 => u8,
.u16 => u16,
.u32 => u32,
.u64 => u64,
.bf16 => u16,
.f16 => f16,
.f32 => f32,
.f64 => f64,
.index => usize,
};
}
pub fn mlirType(dt: DenseElementsAttributeTypes, ctx: Context) Type {
return switch (dt) {
.bool => .int(ctx, .i1),
.i8 => .int(ctx, .i8),
.i16 => .int(ctx, .i16),
.i32 => .int(ctx, .i32),
.i64 => .int(ctx, .i64),
.u8 => .int(ctx, .u8),
.u16 => .int(ctx, .u16),
.u32 => .int(ctx, .u32),
.u64 => .int(ctx, .u64),
.bf16 => .float(ctx, .bf16),
.f16 => .float(ctx, .f16),
.f32 => .float(ctx, .f32),
.f64 => .float(ctx, .f64),
.index => .index(ctx),
};
}
};
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
const ZigType = switch (dt) {
.bool => bool,
.i8 => i8,
.i16 => i16,
.i32 => i32,
.i64 => i64,
.u8 => u8,
.u16 => u16,
.u32 => u32,
.u64 => u64,
.bf16 => u16,
.f16 => f16,
.f32 => f32,
.f64 => f64,
.index => usize,
};
return struct {
_inner: c.MlirAttribute,
@ -662,8 +747,8 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
}
pub fn constSlice(self: Attr) []const ZigType {
const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
pub fn constSlice(self: Attr) []const dt.ZigType() {
const ptr: [*]const dt.ZigType() = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
return ptr[0..self.len()];
}
@ -814,6 +899,7 @@ pub const Operation = struct {
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
operands: ?[]const Value = null,
variadic_operands: ?[]const []const Value = null,
tt_variadic_operands: ?[]const []const Value = null,
results: ?[]const Type = null,
variadic_results: ?[]const []const Type = null,
result_type_inference: ?bool = null,
@ -828,9 +914,24 @@ pub const Operation = struct {
if (args.operands) |operands| {
state.addOperands(operands);
} else if (args.variadic_operands) |operands_segments| {
const MAX_SEGMENTS = 32;
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
for (operands_segments) |operands| {
state.addOperands(operands);
segments.appendAssumeCapacity(@intCast(operands.len));
}
state.addAttribute(ctx, "operandSegmentSizes", .denseElements(ctx, &.{@intCast(segments.len)}, .i32, segments.constSlice()));
} else if (args.tt_variadic_operands) |operands_segments| {
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
const MAX_SEGMENTS = 32;
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
for (operands_segments) |operands| {
state.addOperands(operands);
segments.appendAssumeCapacity(@intCast(operands.len));
}
state.addAttribute(ctx, "operandSegmentSizes", .dense(ctx, .i32, segments.constSlice()));
}
if (args.result_type_inference) |enable| {
state.resultTypeInference(enable);
@ -1076,6 +1177,26 @@ pub const Operation = struct {
pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool {
return c.mlirOperationRemoveAttributeByName(self.inner(), stringRef(name_));
}
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
const NoError = error{};
const write = struct {
fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize {
hasher_.update(bytes);
return bytes.len;
}
}.write;
const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write);
const writer: HashWriter = .{ .context = hasher };
// Hash the canonicalized IR, without debug information that can change across builds.
// Note: before we where using op.writeBytecode(writer),
// but it crashes on some inputs, notably for unused variables.
// So we use the text representation of the mlir.
// See https://github.com/zml/zml/issues/97.
// Writes can't fail because we are writing to a hasher.
op.print(writer, .{ .debug_info = false });
}
};
pub const OpPrintingFlags = struct {
@ -1255,6 +1376,41 @@ pub const Type = struct {
c.mlirTypeParseGet(ctx.inner(), stringRef(str)),
) orelse Error.InvalidMlir;
}
pub fn index(ctx: Context) Type {
return IndexType.init(ctx).as(Type);
}
pub fn int(ctx: Context, int_type: IntegerTypes) Type {
return switch (int_type) {
.unknown => @panic("Unknown integer type"),
inline else => |t| IntegerType(t).init(ctx).as(Type),
};
}
pub fn float(ctx: Context, float_type: FloatTypes) Type {
return switch (float_type) {
inline else => |t| FloatType(t).init(ctx).as(Type),
};
}
pub fn complex(ctx: Context, complex_type: ComplexTypes) Type {
return switch (complex_type) {
inline else => |t| ComplexType(t).init(ctx).as(Type),
};
}
pub fn tuple(ctx: Context, types: []const Type) Type {
return (TupleType.init(ctx, types) catch unreachable).as(Type);
}
pub fn function(ctx: Context, args: []const Type, results: []const Type) Type {
return (FunctionType.init(ctx, args, results) catch unreachable).as(Type);
}
pub fn tensor(dimensions: []const i64, elem_type: Type) Type {
return RankedTensorType.init(dimensions, elem_type).as(Type);
}
};
pub const IndexType = struct {
@ -1680,7 +1836,9 @@ pub const Block = struct {
}
pub fn argument(self: Block, index: usize) Value {
return Value.wrap(c.mlirBlockGetArgument(self.inner(), @intCast(index)));
const arg = c.mlirBlockGetArgument(self.inner(), @intCast(index));
stdx.debug.assert(!Value.Methods.is_null_fn.?(arg), "Block doesn't have argument {}, only got {}", .{ index, self.numArguments() });
return Value.wrap(arg);
}
pub fn numArguments(self: Block) usize {
@ -1708,4 +1866,29 @@ pub const Block = struct {
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
}
}
pub const RecursiveOpts = enum { open, hermetic };
pub fn appendValueRecursive(self: Block, value: Value, opt: RecursiveOpts) void {
switch (value.kind()) {
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
.block_argument => |arg| {
// Hermetic blocks are not allowed to use arguments from other blocks.
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
},
.null => @panic("InvalidMlir"),
}
}
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
if (op.block()) |prev_block| {
// Hermetic blocks are not allowed to reference values from other blocks.
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
return;
}
for (0..op.numOperands()) |i| {
self.appendValueRecursive(op.operand(i), opt);
}
self.appendOperation(op);
}
};

View File

@ -6,15 +6,14 @@ const runfiles = @import("runfiles");
const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto");
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig");
const BaseExe = @import("exe.zig").BaseExe;
const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const Location = mlir.Location;
const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
@ -28,46 +27,6 @@ test {
std.testing.refAllDecls(@This());
}
pub const BlockKind = enum { open, hermetic };
const Block = union(BlockKind) {
open: mlir.Block,
hermetic: mlir.Block,
pub fn block(self: Block) mlir.Block {
return switch (self) {
inline .open, .hermetic => |t| t,
};
}
fn appendTensorRecursive(self: Block, x: *const Tensor) void {
self.appendValueRecursive(x.value());
}
fn appendValueRecursive(self: Block, value: mlir.Value) void {
switch (value.kind()) {
.op_result => |parent_op| self.appendOperationRecursive(parent_op),
.block_argument => |arg| {
// Hermetic blocks are not allowed to use arguments from other blocks.
stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr });
},
.null => @panic("InvalidMlir"),
}
}
fn appendOperationRecursive(self: Block, op: mlir.Operation) void {
if (op.block()) |prev_block| {
// Hermetic blocks are not allowed to reference values from other blocks.
std.debug.assert(self == .open or prev_block.equals(self.block()));
return;
}
for (0..op.numOperands()) |i| {
self.appendValueRecursive(op.operand(i));
}
self.block().appendOperation(op);
}
};
pub const MlirFn = struct {
name: []const u8,
num_args: u32,
@ -94,7 +53,7 @@ pub const CompilationContext = struct {
_module: mlir.Module,
_blocks: std.BoundedArray(Block, 64) = .{},
_blocks: std.BoundedArray(TaggedBlock, 64) = .{},
_fn_cache: FnCache = .{},
_block_args: TensorToBlockArg = .{},
@ -104,6 +63,7 @@ pub const CompilationContext = struct {
_previous: ?*CompilationContext = null,
threadlocal var _current: ?*CompilationContext = null;
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
@ -123,7 +83,7 @@ pub const CompilationContext = struct {
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute));
module.op().setAttributeByName("sym_name", .string(mlir_ctx, "zml"));
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
{
@ -280,26 +240,22 @@ pub const CompilationContext = struct {
);
}
fn currentBlock(self: *const CompilationContext) ?Block {
fn currentBlock(self: *const CompilationContext) ?TaggedBlock {
return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null;
}
pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block {
const mlir_block = try mlir.Block.init(args, locs);
const block: Block = switch (kind) {
.open => .{ .open = mlir_block },
.hermetic => .{ .hermetic = mlir_block },
};
pub fn openBlock(self: *CompilationContext, kind: mlir.Block.RecursiveOpts, args: []const mlir.Type, locs: []const mlir.Location) !TaggedBlock {
const block: TaggedBlock = .{ try mlir.Block.init(args, locs), kind };
self.pushBlock(block);
return block;
}
pub fn closeBlock(self: *CompilationContext, block: Block) void {
pub fn closeBlock(self: *CompilationContext, block: TaggedBlock) void {
const popped = self._blocks.pop();
std.debug.assert(block.block().eql(popped.?.block()));
std.debug.assert(block[0].eql(popped.?[0]));
}
fn pushBlock(self: *CompilationContext, block: Block) void {
fn pushBlock(self: *CompilationContext, block: TaggedBlock) void {
self._blocks.appendAssumeCapacity(block);
}
@ -311,7 +267,7 @@ pub const CompilationContext = struct {
/// But their shapes/tags can be safely propagated further.
pub fn makeBlock(
self: *CompilationContext,
kind: BlockKind,
kind: mlir.Block.RecursiveOpts,
comptime S: ops.BlockSignature,
func: *const S.Fn,
blkctx: S.BlkCtx,
@ -326,7 +282,7 @@ pub const CompilationContext = struct {
// Before creating a new block, assign all received values to previous block,
// otherwise they will be assign to this block
if (self.currentBlock()) |prev_block| {
meta.visit(Block.appendTensorRecursive, prev_block, &blkctx);
meta.visit(_appendTensorRecursive, prev_block, &blkctx);
}
const block = self.openBlock(kind, &input_types, &locations) catch unreachable;
@ -337,15 +293,20 @@ pub const CompilationContext = struct {
// So we create a copy of the arguments, and replace values
// by the block arguments.
var blk_args = args;
std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N);
std.debug.assert(assignBlockArguments(&blk_args, block[0], 0) == N);
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
var block_res_values: [S.nOut]mlir.Value = undefined;
self.extractValues(&block_res, &block_res_values);
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
block.appendOperationRecursive(block_ret);
block[0].appendOperationRecursive(block_ret, block[1]);
return .{ block.block(), block_res };
return .{ block[0], block_res };
}
fn _appendTensorRecursive(tagged_block: TaggedBlock, x: *const Tensor) void {
const block, const tag = tagged_block;
block.appendValueRecursive(x.value(), tag);
}
pub const EmitMlirOpts = struct {
@ -411,7 +372,7 @@ pub const CompilationContext = struct {
defer self.closeBlock(fn_body);
try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count));
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
const assigned_args_count = self.mapBlockArguments(args, fn_body[0], 0);
std.debug.assert(assigned_args_count == tensor_count);
fn_res.* = forward: {
@ -424,7 +385,7 @@ pub const CompilationContext = struct {
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body.appendOperationRecursive(fn_ret);
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
}
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
@ -446,7 +407,7 @@ pub const CompilationContext = struct {
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types,
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
.block = fn_body.block(),
.block = fn_body[0],
.location = loc,
});
@ -475,6 +436,7 @@ pub const CompilationContext = struct {
/// Given a list of donations mapping output buffers to input buffers,
/// generate donation attribute for each `n_args` input argument.
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
const ctx = self.mlirCtx();
var n_donations: usize = 0;
for (donations, 0..) |donation, index| {
switch (donation) {
@ -489,12 +451,7 @@ pub const CompilationContext = struct {
// When the time come, do a more fancy lookup here to check if an argument
// is donated twice.
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
attributes[a].appendAssumeCapacity(
mlir.NamedAttribute.init(
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute),
),
);
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
// log.debug("attribute: {}", .{attributes[a].constSlice()});
},
}
@ -552,46 +509,29 @@ pub const CompilationContext = struct {
}
}
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
const mlir_ctx = self.mlirCtx();
const num_partitions = self._platform.sharding().num_partitions;
var sharding_str: std.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice());
}
fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void {
const mlir_ctx = self.mlirCtx();
const ctx = self.mlirCtx();
if (!self._platform.compilation_options.sharding_enabled) return;
const mhlo_default_layout = mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
);
const default_layout = mlir.NamedAttribute.named(ctx, "mhlo.layout_mode", .string(ctx, "default"));
for (arg_attrs, input_shapes) |*attr, shape| {
attr.appendAssumeCapacity(mhlo_default_layout);
const sharding_attr = self.getShardingAttr(shape);
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
sharding_attr.asAttr(),
));
attr.appendAssumeCapacity(default_layout);
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
}
for (res_attrs, output_shapes) |*attr, shape| {
attr.appendAssumeCapacity(mhlo_default_layout);
const sharding_attr = self.getShardingAttr(shape);
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
sharding_attr.asAttr(),
));
attr.appendAssumeCapacity(default_layout);
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
}
}
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
const ctx = self.mlirCtx();
const num_partitions = self._platform.sharding().num_partitions;
var sharding_str: std.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.Attribute.string(ctx, sharding_str.constSlice());
}
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
if (n_sharded == 0 or num_partitions == 1) {
@ -819,20 +759,11 @@ pub const CompilationContext = struct {
fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
var hasher = std.hash.XxHash64.init(0);
var hasher_writer = xxHash64Writer(&hasher);
const writer = hasher_writer.writer();
module.hash(&hasher);
// Hash the canonicalized IR, without debug information that can change across builds.
module.op().print(writer, .{ .debug_info = false });
// Note: before we where using module.op().writeBytecode(writer),
// but it crashes on some inputs, notably for unused variables.
// So we use the text representation of the mlir.
// See https://github.com/zml/zml/issues/97.
// Writes can't fail because we are writing to a hasher.
writer.writeAll(platform.pjrt_client.getPlatformName(platform.pjrt_api)) catch unreachable;
hasher.update(platform.pjrt_client.getPlatformName(platform.pjrt_api));
const api_version = platform.pjrt_api.version();
writer.writeInt(i64, api_version.major, .little) catch unreachable;
writer.writeInt(i64, api_version.minor, .little) catch unreachable;
hasher.update(std.mem.sliceAsBytes(&[_]i64{ api_version.major, api_version.minor }));
return hasher.final();
}
@ -1002,26 +933,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
return context.index;
}
pub const XxHash64Writer = struct {
hasher: *std.hash.XxHash64,
pub const Error = error{};
pub const Writer = std.io.Writer(*XxHash64Writer, Error, write);
pub fn writer(self: *XxHash64Writer) Writer {
return .{ .context = self };
}
pub fn write(self: *XxHash64Writer, bytes: []const u8) Error!usize {
self.hasher.update(bytes);
return bytes.len;
}
};
pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer {
return .{ .hasher = hasher };
}
pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn);
pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 };
@ -1165,12 +1076,11 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void {
}
}
const HashStrategy = std.hash.Strategy;
const tensorAwareHash = hash; // alias for when "hash" is ambiguous
/// Provides generic hashing for any eligible type.
/// Strategy is provided to determine if pointers should be followed or not.
pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy) void {
pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Strategy) void {
const Key = @TypeOf(key);
if (Key == Tensor) return hashShape(hasher, key.shape());
if (Key == Shape) return hashShape(hasher, key);
@ -1287,7 +1197,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
}
}
fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void {
fn hashArray(hasher: anytype, key: anytype, comptime strat: std.hash.Strategy) void {
for (key) |element| {
hash(hasher, element, strat);
}

View File

@ -131,7 +131,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
},
&.{
mlir.ext.mlirType(mlir_ctx, q.shape()),
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type),
.tensor(&.{0}, .int(mlir_ctx, .u8)),
},
loc,
);

View File

@ -147,9 +147,7 @@ pub fn reduce(
.variadic_operands = &.{ &input_values, &init_values },
.result_type_inference = true,
.blocks = &.{body_block},
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) },
},
.attributes = &.{.{ "dimensions", .dense(ctx.mlirCtx(), .i64, axes) }},
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
.verify = false,
.location = loc,
@ -236,21 +234,16 @@ pub fn reduceWindow(
ctx.extractValues(&inits, &init_values);
const loc = ctx.mlirCtx().location(@src());
const pad_shape = mlir.RankedTensorType.init(
&.{ @intCast(opts.padding.len), 2 },
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
).as(mlir.Type);
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
.variadic_operands = &.{ input_values[0..], init_values[0..] },
.result_type_inference = true,
.blocks = &.{body_block},
.attributes = &.{
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) },
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) },
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) },
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) },
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) },
.{ "window_dimensions", .dense(ctx.mlirCtx(), .i64, opts.window_dimensions) },
.{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) },
.{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) },
.{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) },
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) },
},
.location = loc,
});
@ -841,13 +834,13 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
}
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)),
.named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)),
.named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)),
.named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])),
.named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])),
.named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])),
.named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)),
.named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)),
});
const MINOR_TO_MAJOR = blk: {

View File

@ -1,10 +1,11 @@
const builtin = @import("builtin");
const std = @import("std");
const testing = std.testing;
const builtin = @import("builtin");
const stdx = @import("stdx");
const testing = std.testing;
const DataType = @import("dtype.zig").DataType;
const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.shape);
@ -131,6 +132,10 @@ pub const Shape = struct {
return res;
}
pub fn scalar(dt: DataType) Shape {
return .{ ._dtype = dt };
}
/// Creates a Shape with dims set to `.{0, 1, 2, ..., rank-1}`.
pub fn range(rank_: usize, dt: DataType) Shape {
var res: Shape = .{ ._dtype = dt };

View File

@ -1,28 +1,29 @@
const builtin = @import("builtin");
const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
const builtin = @import("builtin");
const stdx = @import("stdx");
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const ops = @import("ops.zig");
const module = @import("module.zig");
const Location = mlir.Location;
const CompilationContext = module.CompilationContext;
const Shape = @import("shape.zig").Shape;
const Buffer = @import("buffer.zig").Buffer;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const Location = mlir.Location;
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const ops = @import("ops.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const EnumLiteral = @TypeOf(.enum_literal);
const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo;
};
const assert = std.debug.assert;
const testing = std.testing;
const scoped_log = std.log.scoped(.@"zml/tensor");
test {
@ -173,15 +174,13 @@ pub const Tensor = struct {
var res = self;
res._shape = self._shape.withSharding(axes_);
const sharding = ctx.getShardingAttr(res._shape);
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
&.{self.value()},
.{
.call_target_name = "Sharding",
.has_side_effect = false,
.addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }},
.addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }},
.api_version = .original,
},
&.{self.value().getType()},
@ -1014,11 +1013,11 @@ pub const Tensor = struct {
if (to == self.dtype()) {
return self;
}
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type);
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
const mlir_ctx = self.getContext().mlirCtx();
const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to));
const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc);
return _result(self._shape.withDtype(to), op.result(0));
}
@ -1455,7 +1454,7 @@ pub const Tensor = struct {
/// .{ .a, .b, .c }.mergeTranspose(.{ .a, .c }, .ac) -> .{ .b, .ac }
pub fn mergeTranspose(self: Tensor, axes_: anytype, merged: EnumLiteral) Tensor {
const cont = self.contiguous(axes_);
return cont.reshape(cont._shape.mergeAxis(axes_, merged));
return cont.reshape(cont._shape.mergeAxis(merged, axes_));
}
/// Transposes the input Tensor, such has the given axes end up in contiguous position.
@ -3107,7 +3106,7 @@ pub const Tensor = struct {
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
start_indices[a] = slice_.start.value();
const op = dialect.stablehlo.dynamicSlice(
const op = dialect.stablehlo.dynamic_slice(
self.getContext().mlirCtx(),
self.value(),
new_shape.dims(),
@ -3168,7 +3167,7 @@ pub const Tensor = struct {
res_shape._dims.set(a, len);
}
}
const op = dialect.stablehlo.dynamicSlice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc);
const op = dialect.stablehlo.dynamic_slice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc);
return _result(res_shape, op.result(0));
}

View File

@ -25,6 +25,7 @@ pub const nn = @import("nn.zig");
pub const module = @import("module.zig");
pub const meta = @import("meta.zig");
pub const platform = @import("platform.zig");
pub const mlir = @import("mlir.zig");
pub const pjrt = @import("pjrtx.zig");
pub const testing = @import("testing.zig");
pub const torch = @import("torch.zig");