add mini-DSL for creating MLIR common attributes and types, leveraging Zig 0.14 to simplify mlir.Type and mlir.Attribute creation
This commit is contained in:
parent
63ef78efcc
commit
ac63c30e12
@ -9,6 +9,7 @@ cc_library(
|
||||
"@llvm-project//mlir:CAPIArith",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:CAPIMath",
|
||||
"@llvm-project//mlir:CAPISCF",
|
||||
"@llvm-project//mlir:CAPITransforms",
|
||||
],
|
||||
)
|
||||
|
||||
1
mlir/c.h
1
mlir/c.h
@ -3,6 +3,7 @@
|
||||
#include <mlir-c/Dialect/Arith.h>
|
||||
#include <mlir-c/Dialect/Func.h>
|
||||
#include <mlir-c/Dialect/Math.h>
|
||||
#include <mlir-c/Dialect/SCF.h>
|
||||
#include <mlir-c/IR.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
|
||||
@ -7,6 +7,7 @@ zig_library(
|
||||
"arith.zig",
|
||||
"func.zig",
|
||||
"math.zig",
|
||||
"scf.zig",
|
||||
"tensor.zig",
|
||||
],
|
||||
import_name = "mlir/dialects",
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
const std = @import("std");
|
||||
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn constant(ctx: mlir.Context, value: mlir.Attribute, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "arith.constant", .{
|
||||
.attributes = &.{
|
||||
.{ "value", value },
|
||||
},
|
||||
.attributes = &.{.{ "value", value }},
|
||||
.result_type_inference = true,
|
||||
.location = location,
|
||||
});
|
||||
@ -44,6 +43,8 @@ pub const mulf = binary_fn("arith.mulf");
|
||||
pub const divsi = binary_fn("arith.divsi");
|
||||
pub const divui = binary_fn("arith.divui");
|
||||
pub const divf = binary_fn("arith.divf");
|
||||
pub const maxnumf = binary_fn("arith.maxnumf");
|
||||
pub const maxnumi = binary_fn("arith.maxnumi");
|
||||
pub const extsi = cast_fn("arith.extsi");
|
||||
pub const extui = cast_fn("arith.extui");
|
||||
pub const extf = cast_fn("arith.extf");
|
||||
|
||||
@ -3,6 +3,7 @@ const std = @import("std");
|
||||
pub const arith = @import("arith.zig");
|
||||
pub const func = @import("func.zig");
|
||||
pub const math = @import("math.zig");
|
||||
pub const scf = @import("scf.zig");
|
||||
pub const tensor = @import("tensor.zig");
|
||||
pub const stablehlo = @import("mlir/dialects/stablehlo");
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
const std = @import("std");
|
||||
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn func(
|
||||
@ -14,14 +15,14 @@ pub fn func(
|
||||
},
|
||||
) mlir.Operation {
|
||||
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) });
|
||||
if (args.arg_attrs.len > 0) {
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", .array(ctx, args.arg_attrs) });
|
||||
}
|
||||
|
||||
if (args.res_attrs.len > 0) {
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", .array(ctx, args.res_attrs) });
|
||||
}
|
||||
|
||||
return mlir.Operation.make(ctx, "func.func", .{
|
||||
@ -36,7 +37,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r
|
||||
.variadic_operands = &.{values},
|
||||
.results = results,
|
||||
.verify = true,
|
||||
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }},
|
||||
.attributes = &.{.{ "callee", .symbol(ctx, name) }},
|
||||
.location = loc,
|
||||
});
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ fn unary_fn(comptime op_name: [:0]const u8) type {
|
||||
pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
|
||||
.operands = &.{value},
|
||||
.results = &.{},
|
||||
.results = &.{value.getType()},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
@ -32,4 +32,5 @@ pub const fpowi = binary_fn("fpowi").call;
|
||||
pub const tanh = unary_fn("tanh").call;
|
||||
pub const sqrt = unary_fn("sqrt").call;
|
||||
pub const exp = unary_fn("exp").call;
|
||||
pub const exp2 = unary_fn("exp2").call;
|
||||
pub const log = unary_fn("log").call;
|
||||
|
||||
61
mlir/dialects/scf.zig
Normal file
61
mlir/dialects/scf.zig
Normal file
@ -0,0 +1,61 @@
|
||||
const std = @import("std");
|
||||
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn ForBody(ExtraArgs: type) type {
|
||||
return fn (mlir.Context, mlir.Block, ExtraArgs) mlir.Operation;
|
||||
}
|
||||
|
||||
pub const ForRange = struct {
|
||||
start: mlir.Value,
|
||||
end: mlir.Value,
|
||||
step: mlir.Value,
|
||||
};
|
||||
|
||||
pub fn @"for"(
|
||||
ExtraArgs: type,
|
||||
ctx: mlir.Context,
|
||||
range: ForRange,
|
||||
init_values: []const mlir.Value,
|
||||
body: ForBody(ExtraArgs),
|
||||
extra_args: ExtraArgs,
|
||||
loc: mlir.Location,
|
||||
) mlir.Operation {
|
||||
const n_args = init_values.len;
|
||||
var init_types_buf: [32]mlir.Type = undefined;
|
||||
var locs_buf: [32]mlir.Location = undefined;
|
||||
|
||||
// The first block argument is the for loop induction variable,
|
||||
// followed then by all the loop-carried variables.
|
||||
const init_types = init_types_buf[0 .. n_args + 1];
|
||||
const locs = locs_buf[0 .. n_args + 1];
|
||||
init_types[0] = range.start.getType();
|
||||
locs[0] = loc;
|
||||
for (1.., init_values) |i, val| {
|
||||
init_types[i] = val.getType();
|
||||
locs[i] = loc;
|
||||
}
|
||||
|
||||
const block = mlir.Block.init(init_types, locs) catch unreachable;
|
||||
const yield_op = @call(.auto, body, .{ ctx, block, extra_args });
|
||||
std.debug.assert(std.mem.eql(u8, "scf.yield", yield_op.name().str()));
|
||||
block.appendOperationRecursive(yield_op, .open);
|
||||
|
||||
const for_op = mlir.Operation.make(ctx, "scf.for", .{
|
||||
.variadic_operands = &.{ &.{ range.start, range.end, range.step }, init_values },
|
||||
.results = init_types[1..],
|
||||
.blocks = &.{block},
|
||||
.location = loc,
|
||||
.verify = false,
|
||||
});
|
||||
return for_op;
|
||||
}
|
||||
|
||||
pub fn yield(ctx: mlir.Context, res: []const mlir.Value, loc: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "scf.yield", .{
|
||||
.variadic_operands = &.{res},
|
||||
.results = &.{},
|
||||
.location = loc,
|
||||
.verify = false,
|
||||
});
|
||||
}
|
||||
@ -99,7 +99,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli
|
||||
.operands = &.{value},
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) },
|
||||
.{ "lower", .i1FromBool(ctx, lower) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -189,7 +189,7 @@ pub fn dot_general(
|
||||
precision: DotPrecision,
|
||||
},
|
||||
) mlir.Operation {
|
||||
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
|
||||
const precisions: [2]mlir.Attribute = @splat(opts.precision.precisionAttr(ctx));
|
||||
const attributes = [3]mlir.AttrTuple{
|
||||
.{
|
||||
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
||||
@ -199,7 +199,7 @@ pub fn dot_general(
|
||||
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
|
||||
}).as(mlir.Attribute),
|
||||
},
|
||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() },
|
||||
.{ "precision_config", .array(ctx, &precisions) },
|
||||
// keep algorithm as the last attribute so we can omit it when it's not set.
|
||||
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
|
||||
};
|
||||
@ -244,7 +244,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6
|
||||
.operands = &.{operand},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) },
|
||||
.{ "broadcast_dimensions", .dense(ctx, .i64, dims) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -255,7 +255,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l
|
||||
.operands = &.{value},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) },
|
||||
.{ "permutation", .dense(ctx, .i64, opts.permutation) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -266,9 +266,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64,
|
||||
.operands = &.{operand},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) },
|
||||
.{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) },
|
||||
.{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) },
|
||||
.{ "start_indices", .dense(ctx, .i64, start_indices) },
|
||||
.{ "limit_indices", .dense(ctx, .i64, limit_indices) },
|
||||
.{ "strides", .dense(ctx, .i64, strides) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -279,7 +279,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
|
||||
.operands = inputs,
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||
.{ "dimension", .int(ctx, .i64, dimension) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -333,8 +333,8 @@ pub fn gather(
|
||||
args.start_index_map,
|
||||
args.index_vector_dim,
|
||||
).as(mlir.Attribute) },
|
||||
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) },
|
||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
|
||||
.{ "slice_sizes", .dense(ctx, .i64, slice_sizes) },
|
||||
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
|
||||
},
|
||||
.location = location,
|
||||
},
|
||||
@ -394,8 +394,8 @@ pub fn scatter(
|
||||
.blocks = &.{update_block},
|
||||
.attributes = &.{
|
||||
.{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) },
|
||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) },
|
||||
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) },
|
||||
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
|
||||
.{ "unique_indices", .boolean(ctx, args.unique_indices) },
|
||||
},
|
||||
.result_type_inference = true,
|
||||
.location = location,
|
||||
@ -408,7 +408,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location:
|
||||
.operands = &.{},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||
.{ "iota_dimension", .int(ctx, .i64, dimension) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -420,7 +420,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
|
||||
.operands = &.{operand},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
|
||||
.{ "dimensions", .dense(ctx, .i64, dimensions) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -453,7 +453,7 @@ pub fn reduce(
|
||||
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
|
||||
var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
||||
for (inputs, 0..) |input, i| {
|
||||
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
|
||||
const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType()));
|
||||
reduce_elem_types[i] = arg_type;
|
||||
reduce_elem_types[inputs.len + i] = arg_type;
|
||||
}
|
||||
@ -475,7 +475,7 @@ pub fn reduce(
|
||||
.result_type_inference = true,
|
||||
.block = block,
|
||||
.attributes = &.{
|
||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) },
|
||||
.{ "dimensions", .dense(ctx, .i64, dimensions) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -495,7 +495,7 @@ pub fn sort(
|
||||
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2];
|
||||
var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
||||
for (inputs, 0..) |input, i| {
|
||||
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type);
|
||||
const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType()));
|
||||
sort_elem_types[i * 2] = arg_type;
|
||||
sort_elem_types[i * 2 + 1] = arg_type;
|
||||
}
|
||||
@ -512,19 +512,19 @@ pub fn sort(
|
||||
.result_type_inference = true,
|
||||
.block = block,
|
||||
.attributes = &.{
|
||||
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) },
|
||||
.{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) },
|
||||
.{ "dimension", .int(ctx, .i64, dimension) },
|
||||
.{ "is_stable", .boolean(ctx, is_stable) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
pub fn dynamic_slice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "stablehlo.dynamic_slice", .{
|
||||
.variadic_operands = &.{ &.{operand}, start_indices },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) },
|
||||
.{ "slice_sizes", .dense(ctx, .i64, new_dims) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -557,9 +557,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts
|
||||
.operands = &.{ value, padding_value },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) },
|
||||
.{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) },
|
||||
.{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) },
|
||||
.{ "edge_padding_low", .dense(ctx, .i64, opts.low) },
|
||||
.{ "edge_padding_high", .dense(ctx, .i64, opts.high) },
|
||||
.{ "interior_padding", .dense(ctx, .i64, opts.interior) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -577,9 +577,9 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
|
||||
.operands = &.{ value, other },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) },
|
||||
.{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) },
|
||||
.{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) },
|
||||
.{ "left_side", .i1FromBool(ctx, opts.left_side) },
|
||||
.{ "lower", .i1FromBool(ctx, opts.lower) },
|
||||
.{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) },
|
||||
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
|
||||
},
|
||||
.location = location,
|
||||
@ -597,7 +597,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
|
||||
.{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) },
|
||||
.{ "fft_length", .dense(ctx, .i64, opts.length) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -630,8 +630,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32
|
||||
.operands = &.{value},
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) },
|
||||
.{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) },
|
||||
.{ "exponent_bits", .int(ctx, .i32, exponent_bits) },
|
||||
.{ "mantissa_bits", .int(ctx, .i32, mantissa_bits) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -658,7 +658,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64,
|
||||
.operands = &.{tuple_value},
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) },
|
||||
.{ "index", .int(ctx, .i32, index) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -701,17 +701,15 @@ pub fn convolution(
|
||||
for (opts.window_reversal, 0..) |w, i| {
|
||||
window_reversal[i] = @intCast(@intFromBool(w));
|
||||
}
|
||||
const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type);
|
||||
const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type);
|
||||
return mlir.Operation.make(ctx, "stablehlo.convolution", .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.results = &.{res_type},
|
||||
.attributes = &.{
|
||||
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) },
|
||||
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) },
|
||||
.{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) },
|
||||
.{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) },
|
||||
.{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) },
|
||||
.{ "window_strides", .dense(ctx, .i64, opts.window_strides) },
|
||||
.{ "padding", .denseElements(ctx, opts.pad_shape, .i64, opts.pad_value) },
|
||||
.{ "lhs_dilation", .dense(ctx, .i64, opts.lhs_dilation) },
|
||||
.{ "rhs_dilation", .dense(ctx, .i64, opts.rhs_dilation) },
|
||||
.{ "window_reversal", .dense(ctx, .bool, window_reversal[0..opts.window_reversal.len]) },
|
||||
.{
|
||||
"dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{
|
||||
.input_batch_dimension = opts.input_batch_dimension,
|
||||
@ -725,9 +723,9 @@ pub fn convolution(
|
||||
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
||||
}).as(mlir.Attribute),
|
||||
},
|
||||
.{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) },
|
||||
.{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) },
|
||||
.{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) },
|
||||
.{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) },
|
||||
.{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) },
|
||||
.{ "precision_config", .array(ctx, &max_precisions) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
@ -764,25 +762,20 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.output_operand_aliases) |alias| {
|
||||
ret.appendAssumeCapacity(
|
||||
OutputOperandAliasAttribute.init(
|
||||
ctx,
|
||||
&.{},
|
||||
alias,
|
||||
&.{},
|
||||
).as(mlir.Attribute),
|
||||
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
|
||||
);
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const backend_config = switch (opts.backend_config) {
|
||||
const backend_config: mlir.Attribute = switch (opts.backend_config) {
|
||||
.string => blk: {
|
||||
stdx.debug.assert(
|
||||
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
|
||||
"Only API version of less than 4 is supported for backend_config as string",
|
||||
.{},
|
||||
);
|
||||
break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute);
|
||||
break :blk .string(ctx, opts.backend_config.string);
|
||||
},
|
||||
.dict => blk: {
|
||||
stdx.debug.assert(
|
||||
@ -797,45 +790,33 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||
|
||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
|
||||
.{ "call_target_name", .string(ctx, opts.call_target_name) },
|
||||
.{ "has_side_effect", .boolean(ctx, opts.has_side_effect) },
|
||||
.{ "backend_config", backend_config },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
||||
.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) },
|
||||
});
|
||||
|
||||
if (opts.operand_layouts) |layouts| {
|
||||
const operand_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||
for (layouts) |ol| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(ol.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) };
|
||||
attrs.appendAssumeCapacity(attr);
|
||||
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
||||
}
|
||||
|
||||
if (opts.result_layouts) |layouts| {
|
||||
const result_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (layouts) |rl| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(rl.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) };
|
||||
attrs.appendAssumeCapacity(attr);
|
||||
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
||||
}
|
||||
|
||||
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
||||
@ -852,9 +833,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
||||
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
const frontend_attributes = mlir.DictionaryAttribute.init(
|
||||
ctx,
|
||||
&.{
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
|
||||
},
|
||||
&.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())},
|
||||
).asAttr();
|
||||
|
||||
return custom_call(ctx, inputs, .{
|
||||
|
||||
305
mlir/mlir.zig
305
mlir/mlir.zig
@ -1,9 +1,10 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const log = std.log.scoped(.mlir);
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const c = @import("c");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const log = std.log.scoped(.mlir);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
@ -254,6 +255,10 @@ pub const Module = struct {
|
||||
pub fn op(self: Module) Operation {
|
||||
return Operation.wrap(c.mlirModuleGetOperation(self.inner()));
|
||||
}
|
||||
|
||||
pub fn hash(self: Module, hasher: *std.hash.XxHash64) void {
|
||||
return self.op().hash(hasher);
|
||||
}
|
||||
};
|
||||
|
||||
pub const PassManager = struct {
|
||||
@ -352,21 +357,81 @@ pub const Attribute = struct {
|
||||
) orelse Error.InvalidMlir;
|
||||
}
|
||||
|
||||
pub fn getNull() Self {
|
||||
return Self.wrap(c.mlirAttributeGetNull());
|
||||
// utilities function to built common attributes.
|
||||
// All attributes are upcasted to the Attribute type, making it easier to chain construct,
|
||||
// but losing type information.
|
||||
|
||||
pub fn null_() Attribute {
|
||||
return .wrap(c.mlirAttributeGetNull());
|
||||
}
|
||||
|
||||
pub fn string(ctx: Context, str: []const u8) Attribute {
|
||||
return StringAttribute.init(ctx, str).asAttr();
|
||||
}
|
||||
|
||||
pub fn type_(t: Type) Attribute {
|
||||
return TypeAttribute.init(t).asAttr();
|
||||
}
|
||||
|
||||
pub fn unit(ctx: Context) Attribute {
|
||||
return .wrap(c.mlirUnitAttrGet(ctx.inner()));
|
||||
}
|
||||
|
||||
pub fn boolean(ctx: Context, value: bool) Attribute {
|
||||
return BoolAttribute.init(ctx, value).asAttr();
|
||||
}
|
||||
|
||||
pub fn i1FromBool(ctx: Context, value: bool) Attribute {
|
||||
return IntegerAttribute(.i1).init(ctx, @intFromBool(value)).asAttr();
|
||||
}
|
||||
|
||||
pub fn int(ctx: Context, comptime int_type: IntegerTypes, value: i64) Attribute {
|
||||
return IntegerAttribute(int_type).init(ctx, value).asAttr();
|
||||
}
|
||||
|
||||
pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute {
|
||||
return .wrap(FloatAttribute(float_type).init(ctx, value)._inner);
|
||||
}
|
||||
|
||||
pub fn array(ctx: Context, attrs: []const Attribute) Attribute {
|
||||
return ArrayAttribute.init(ctx, attrs).asAttr();
|
||||
}
|
||||
|
||||
pub fn dense(ctx: Context, comptime dt: DenseArrayTypes, values: []const dt.ZigType()) Attribute {
|
||||
return DenseArrayAttribute(dt).init(ctx, values).asAttr();
|
||||
}
|
||||
|
||||
/// Use a tensor as an attribute.
|
||||
/// The tensor is specified by dims, dtype and a flat slice of values.
|
||||
pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: anytype) Attribute {
|
||||
return .wrap(DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).inner());
|
||||
}
|
||||
|
||||
pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute {
|
||||
return .wrap(FlatSymbolRefAttribute.init(ctx, flat_name).inner());
|
||||
}
|
||||
|
||||
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
|
||||
return NamedAttribute.init(Identifier.get(ctx, name), attr);
|
||||
}
|
||||
};
|
||||
|
||||
pub const NamedAttribute = struct {
|
||||
_inner: c.MlirNamedAttribute,
|
||||
pub usingnamespace MlirHelpers(NamedAttribute, .{});
|
||||
const Self = NamedAttribute;
|
||||
pub const NamedAttribute = extern struct {
|
||||
name: c.MlirIdentifier,
|
||||
attribute: c.MlirAttribute,
|
||||
|
||||
pub fn init(name: Identifier, attr: Attribute) Self {
|
||||
return Self.wrap(.{
|
||||
pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute {
|
||||
return .{
|
||||
.name = c.mlirIdentifierGet(ctx._inner, stringRef(name)),
|
||||
.attribute = attr.inner(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn init(name: Identifier, attr: Attribute) NamedAttribute {
|
||||
return .{
|
||||
.name = name.inner(),
|
||||
.attribute = attr.inner(),
|
||||
});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@ -393,21 +458,6 @@ pub const StringAttribute = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const UnitAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
pub usingnamespace MlirHelpers(UnitAttribute, .{
|
||||
.is_a_fn = c.mlirAttributeIsAUnit,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
const Self = UnitAttribute;
|
||||
|
||||
pub fn init(ctx: Context) Self {
|
||||
return Self.wrap(c.mlirUnitAttrGet(ctx.inner()));
|
||||
}
|
||||
};
|
||||
|
||||
pub const BoolAttribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
pub usingnamespace MlirHelpers(BoolAttribute, .{
|
||||
@ -481,8 +531,8 @@ pub const ArrayAttribute = struct {
|
||||
|
||||
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
||||
const ZigType, const getter = comptime switch (it) {
|
||||
.i1, .i4, .i8, .i16, .i32, .i64 => .{ u64, c.mlirIntegerAttrGetValueInt },
|
||||
.si4, .si8, .si16, .si32, .si64 => .{ u64, c.mlirIntegerAttrGetValueSInt },
|
||||
.i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt },
|
||||
.si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt },
|
||||
.u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt },
|
||||
.unknown => @compileError("IntegerAttribute(unknown)"),
|
||||
};
|
||||
@ -547,38 +597,48 @@ pub const DenseArrayTypes = enum {
|
||||
i64,
|
||||
f32,
|
||||
f64,
|
||||
|
||||
pub fn ZigType(comptime dt: DenseArrayTypes) type {
|
||||
return switch (dt) {
|
||||
.bool => i32,
|
||||
.i8 => i8,
|
||||
.i16 => i16,
|
||||
.i32 => i32,
|
||||
.i64 => i64,
|
||||
.f32 => f32,
|
||||
.f64 => f64,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
|
||||
const Config = switch (dt) {
|
||||
.bool => .{ i32, c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
|
||||
.i8 => .{ i8, c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
|
||||
.i16 => .{ i16, c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
|
||||
.i32 => .{ i32, c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
|
||||
.i64 => .{ i64, c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
|
||||
.f32 => .{ f32, c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
|
||||
.f64 => .{ f64, c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
|
||||
const is_a_fn, const get_fn, const get_element_fn = switch (dt) {
|
||||
.bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement },
|
||||
.i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement },
|
||||
.i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement },
|
||||
.i32 => .{ c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement },
|
||||
.i64 => .{ c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement },
|
||||
.f32 => .{ c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement },
|
||||
.f64 => .{ c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement },
|
||||
};
|
||||
|
||||
return struct {
|
||||
_inner: c.MlirAttribute,
|
||||
pub usingnamespace MlirHelpers(@This(), .{
|
||||
.is_a_fn = Config[1],
|
||||
.is_a_fn = is_a_fn,
|
||||
.is_null_fn = c.mlirAttributeIsNull,
|
||||
.dump_fn = c.mlirAttributeDump,
|
||||
.equal_fn = c.mlirAttributeEqual,
|
||||
});
|
||||
const Attr = @This();
|
||||
const ElementType = dt;
|
||||
const ElementTypeZig = Config[0];
|
||||
const ElementTypeZig = dt.ZigType();
|
||||
|
||||
pub fn init(ctx: Context, values: []const ElementTypeZig) Attr {
|
||||
const get_fn = Config[2];
|
||||
return Attr.wrap(get_fn(ctx.inner(), @intCast(values.len), @ptrCast(values.ptr)));
|
||||
}
|
||||
|
||||
pub fn get(self: Attr, pos: usize) ElementTypeZig {
|
||||
const get_element_fn = Config[3];
|
||||
return get_element_fn(self.inner(), @intCast(pos));
|
||||
}
|
||||
|
||||
@ -586,6 +646,10 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type {
|
||||
return @intCast(c.mlirDenseArrayGetNumElements(self.inner()));
|
||||
}
|
||||
|
||||
pub fn asAttr(self: Attr) Attribute {
|
||||
return Attribute.wrap(self._inner);
|
||||
}
|
||||
|
||||
pub usingnamespace switch (dt) {
|
||||
.bool, .i64 => struct {
|
||||
const DenseArray = DenseArrayAttribute(switch (dt) {
|
||||
@ -614,26 +678,47 @@ pub const DenseElementsAttributeTypes = enum {
|
||||
f32,
|
||||
f64,
|
||||
index,
|
||||
|
||||
pub fn ZigType(comptime dt: DenseElementsAttributeTypes) type {
|
||||
return switch (dt) {
|
||||
.bool => bool,
|
||||
.i8 => i8,
|
||||
.i16 => i16,
|
||||
.i32 => i32,
|
||||
.i64 => i64,
|
||||
.u8 => u8,
|
||||
.u16 => u16,
|
||||
.u32 => u32,
|
||||
.u64 => u64,
|
||||
.bf16 => u16,
|
||||
.f16 => f16,
|
||||
.f32 => f32,
|
||||
.f64 => f64,
|
||||
.index => usize,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn mlirType(dt: DenseElementsAttributeTypes, ctx: Context) Type {
|
||||
return switch (dt) {
|
||||
.bool => .int(ctx, .i1),
|
||||
.i8 => .int(ctx, .i8),
|
||||
.i16 => .int(ctx, .i16),
|
||||
.i32 => .int(ctx, .i32),
|
||||
.i64 => .int(ctx, .i64),
|
||||
.u8 => .int(ctx, .u8),
|
||||
.u16 => .int(ctx, .u16),
|
||||
.u32 => .int(ctx, .u32),
|
||||
.u64 => .int(ctx, .u64),
|
||||
.bf16 => .float(ctx, .bf16),
|
||||
.f16 => .float(ctx, .f16),
|
||||
.f32 => .float(ctx, .f32),
|
||||
.f64 => .float(ctx, .f64),
|
||||
.index => .index(ctx),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
||||
const ZigType = switch (dt) {
|
||||
.bool => bool,
|
||||
.i8 => i8,
|
||||
.i16 => i16,
|
||||
.i32 => i32,
|
||||
.i64 => i64,
|
||||
.u8 => u8,
|
||||
.u16 => u16,
|
||||
.u32 => u32,
|
||||
.u64 => u64,
|
||||
.bf16 => u16,
|
||||
.f16 => f16,
|
||||
.f32 => f32,
|
||||
.f64 => f64,
|
||||
.index => usize,
|
||||
};
|
||||
|
||||
return struct {
|
||||
_inner: c.MlirAttribute,
|
||||
|
||||
@ -662,8 +747,8 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
||||
return @intCast(c.mlirElementsAttrGetNumElements(self.inner()));
|
||||
}
|
||||
|
||||
pub fn constSlice(self: Attr) []const ZigType {
|
||||
const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
|
||||
pub fn constSlice(self: Attr) []const dt.ZigType() {
|
||||
const ptr: [*]const dt.ZigType() = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable)));
|
||||
return ptr[0..self.len()];
|
||||
}
|
||||
|
||||
@ -814,6 +899,7 @@ pub const Operation = struct {
|
||||
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
||||
operands: ?[]const Value = null,
|
||||
variadic_operands: ?[]const []const Value = null,
|
||||
tt_variadic_operands: ?[]const []const Value = null,
|
||||
results: ?[]const Type = null,
|
||||
variadic_results: ?[]const []const Type = null,
|
||||
result_type_inference: ?bool = null,
|
||||
@ -828,9 +914,24 @@ pub const Operation = struct {
|
||||
if (args.operands) |operands| {
|
||||
state.addOperands(operands);
|
||||
} else if (args.variadic_operands) |operands_segments| {
|
||||
const MAX_SEGMENTS = 32;
|
||||
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
||||
|
||||
for (operands_segments) |operands| {
|
||||
state.addOperands(operands);
|
||||
segments.appendAssumeCapacity(@intCast(operands.len));
|
||||
}
|
||||
state.addAttribute(ctx, "operandSegmentSizes", .denseElements(ctx, &.{@intCast(segments.len)}, .i32, segments.constSlice()));
|
||||
} else if (args.tt_variadic_operands) |operands_segments| {
|
||||
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
|
||||
const MAX_SEGMENTS = 32;
|
||||
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
||||
|
||||
for (operands_segments) |operands| {
|
||||
state.addOperands(operands);
|
||||
segments.appendAssumeCapacity(@intCast(operands.len));
|
||||
}
|
||||
state.addAttribute(ctx, "operandSegmentSizes", .dense(ctx, .i32, segments.constSlice()));
|
||||
}
|
||||
if (args.result_type_inference) |enable| {
|
||||
state.resultTypeInference(enable);
|
||||
@ -1076,6 +1177,26 @@ pub const Operation = struct {
|
||||
pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool {
|
||||
return c.mlirOperationRemoveAttributeByName(self.inner(), stringRef(name_));
|
||||
}
|
||||
|
||||
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
|
||||
const NoError = error{};
|
||||
const write = struct {
|
||||
fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize {
|
||||
hasher_.update(bytes);
|
||||
return bytes.len;
|
||||
}
|
||||
}.write;
|
||||
const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write);
|
||||
const writer: HashWriter = .{ .context = hasher };
|
||||
|
||||
// Hash the canonicalized IR, without debug information that can change across builds.
|
||||
// Note: before we where using op.writeBytecode(writer),
|
||||
// but it crashes on some inputs, notably for unused variables.
|
||||
// So we use the text representation of the mlir.
|
||||
// See https://github.com/zml/zml/issues/97.
|
||||
// Writes can't fail because we are writing to a hasher.
|
||||
op.print(writer, .{ .debug_info = false });
|
||||
}
|
||||
};
|
||||
|
||||
pub const OpPrintingFlags = struct {
|
||||
@ -1255,6 +1376,41 @@ pub const Type = struct {
|
||||
c.mlirTypeParseGet(ctx.inner(), stringRef(str)),
|
||||
) orelse Error.InvalidMlir;
|
||||
}
|
||||
|
||||
pub fn index(ctx: Context) Type {
|
||||
return IndexType.init(ctx).as(Type);
|
||||
}
|
||||
|
||||
pub fn int(ctx: Context, int_type: IntegerTypes) Type {
|
||||
return switch (int_type) {
|
||||
.unknown => @panic("Unknown integer type"),
|
||||
inline else => |t| IntegerType(t).init(ctx).as(Type),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn float(ctx: Context, float_type: FloatTypes) Type {
|
||||
return switch (float_type) {
|
||||
inline else => |t| FloatType(t).init(ctx).as(Type),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn complex(ctx: Context, complex_type: ComplexTypes) Type {
|
||||
return switch (complex_type) {
|
||||
inline else => |t| ComplexType(t).init(ctx).as(Type),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn tuple(ctx: Context, types: []const Type) Type {
|
||||
return (TupleType.init(ctx, types) catch unreachable).as(Type);
|
||||
}
|
||||
|
||||
pub fn function(ctx: Context, args: []const Type, results: []const Type) Type {
|
||||
return (FunctionType.init(ctx, args, results) catch unreachable).as(Type);
|
||||
}
|
||||
|
||||
pub fn tensor(dimensions: []const i64, elem_type: Type) Type {
|
||||
return RankedTensorType.init(dimensions, elem_type).as(Type);
|
||||
}
|
||||
};
|
||||
|
||||
pub const IndexType = struct {
|
||||
@ -1680,7 +1836,9 @@ pub const Block = struct {
|
||||
}
|
||||
|
||||
pub fn argument(self: Block, index: usize) Value {
|
||||
return Value.wrap(c.mlirBlockGetArgument(self.inner(), @intCast(index)));
|
||||
const arg = c.mlirBlockGetArgument(self.inner(), @intCast(index));
|
||||
stdx.debug.assert(!Value.Methods.is_null_fn.?(arg), "Block doesn't have argument {}, only got {}", .{ index, self.numArguments() });
|
||||
return Value.wrap(arg);
|
||||
}
|
||||
|
||||
pub fn numArguments(self: Block) usize {
|
||||
@ -1708,4 +1866,29 @@ pub const Block = struct {
|
||||
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
|
||||
}
|
||||
}
|
||||
|
||||
pub const RecursiveOpts = enum { open, hermetic };
|
||||
|
||||
pub fn appendValueRecursive(self: Block, value: Value, opt: RecursiveOpts) void {
|
||||
switch (value.kind()) {
|
||||
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
|
||||
.block_argument => |arg| {
|
||||
// Hermetic blocks are not allowed to use arguments from other blocks.
|
||||
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
|
||||
},
|
||||
.null => @panic("InvalidMlir"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
|
||||
if (op.block()) |prev_block| {
|
||||
// Hermetic blocks are not allowed to reference values from other blocks.
|
||||
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
|
||||
return;
|
||||
}
|
||||
for (0..op.numOperands()) |i| {
|
||||
self.appendValueRecursive(op.operand(i), opt);
|
||||
}
|
||||
self.appendOperation(op);
|
||||
}
|
||||
};
|
||||
|
||||
184
zml/module.zig
184
zml/module.zig
@ -6,15 +6,14 @@ const runfiles = @import("runfiles");
|
||||
const stdx = @import("stdx");
|
||||
const xla_pb = @import("//xla:xla_proto");
|
||||
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
|
||||
const BaseExe = @import("exe.zig").BaseExe;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const Location = mlir.Location;
|
||||
const ops = @import("ops.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
@ -28,46 +27,6 @@ test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
pub const BlockKind = enum { open, hermetic };
|
||||
|
||||
const Block = union(BlockKind) {
|
||||
open: mlir.Block,
|
||||
hermetic: mlir.Block,
|
||||
|
||||
pub fn block(self: Block) mlir.Block {
|
||||
return switch (self) {
|
||||
inline .open, .hermetic => |t| t,
|
||||
};
|
||||
}
|
||||
|
||||
fn appendTensorRecursive(self: Block, x: *const Tensor) void {
|
||||
self.appendValueRecursive(x.value());
|
||||
}
|
||||
|
||||
fn appendValueRecursive(self: Block, value: mlir.Value) void {
|
||||
switch (value.kind()) {
|
||||
.op_result => |parent_op| self.appendOperationRecursive(parent_op),
|
||||
.block_argument => |arg| {
|
||||
// Hermetic blocks are not allowed to use arguments from other blocks.
|
||||
stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr });
|
||||
},
|
||||
.null => @panic("InvalidMlir"),
|
||||
}
|
||||
}
|
||||
|
||||
fn appendOperationRecursive(self: Block, op: mlir.Operation) void {
|
||||
if (op.block()) |prev_block| {
|
||||
// Hermetic blocks are not allowed to reference values from other blocks.
|
||||
std.debug.assert(self == .open or prev_block.equals(self.block()));
|
||||
return;
|
||||
}
|
||||
for (0..op.numOperands()) |i| {
|
||||
self.appendValueRecursive(op.operand(i));
|
||||
}
|
||||
self.block().appendOperation(op);
|
||||
}
|
||||
};
|
||||
|
||||
pub const MlirFn = struct {
|
||||
name: []const u8,
|
||||
num_args: u32,
|
||||
@ -94,7 +53,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
_module: mlir.Module,
|
||||
|
||||
_blocks: std.BoundedArray(Block, 64) = .{},
|
||||
_blocks: std.BoundedArray(TaggedBlock, 64) = .{},
|
||||
_fn_cache: FnCache = .{},
|
||||
|
||||
_block_args: TensorToBlockArg = .{},
|
||||
@ -104,6 +63,7 @@ pub const CompilationContext = struct {
|
||||
_previous: ?*CompilationContext = null,
|
||||
threadlocal var _current: ?*CompilationContext = null;
|
||||
|
||||
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
|
||||
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
||||
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
||||
|
||||
@ -123,7 +83,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
|
||||
const module = mlir.Module.init(loc);
|
||||
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute));
|
||||
module.op().setAttributeByName("sym_name", .string(mlir_ctx, "zml"));
|
||||
|
||||
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
|
||||
{
|
||||
@ -280,26 +240,22 @@ pub const CompilationContext = struct {
|
||||
);
|
||||
}
|
||||
|
||||
fn currentBlock(self: *const CompilationContext) ?Block {
|
||||
fn currentBlock(self: *const CompilationContext) ?TaggedBlock {
|
||||
return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null;
|
||||
}
|
||||
|
||||
pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block {
|
||||
const mlir_block = try mlir.Block.init(args, locs);
|
||||
const block: Block = switch (kind) {
|
||||
.open => .{ .open = mlir_block },
|
||||
.hermetic => .{ .hermetic = mlir_block },
|
||||
};
|
||||
pub fn openBlock(self: *CompilationContext, kind: mlir.Block.RecursiveOpts, args: []const mlir.Type, locs: []const mlir.Location) !TaggedBlock {
|
||||
const block: TaggedBlock = .{ try mlir.Block.init(args, locs), kind };
|
||||
self.pushBlock(block);
|
||||
return block;
|
||||
}
|
||||
|
||||
pub fn closeBlock(self: *CompilationContext, block: Block) void {
|
||||
pub fn closeBlock(self: *CompilationContext, block: TaggedBlock) void {
|
||||
const popped = self._blocks.pop();
|
||||
std.debug.assert(block.block().eql(popped.?.block()));
|
||||
std.debug.assert(block[0].eql(popped.?[0]));
|
||||
}
|
||||
|
||||
fn pushBlock(self: *CompilationContext, block: Block) void {
|
||||
fn pushBlock(self: *CompilationContext, block: TaggedBlock) void {
|
||||
self._blocks.appendAssumeCapacity(block);
|
||||
}
|
||||
|
||||
@ -311,7 +267,7 @@ pub const CompilationContext = struct {
|
||||
/// But their shapes/tags can be safely propagated further.
|
||||
pub fn makeBlock(
|
||||
self: *CompilationContext,
|
||||
kind: BlockKind,
|
||||
kind: mlir.Block.RecursiveOpts,
|
||||
comptime S: ops.BlockSignature,
|
||||
func: *const S.Fn,
|
||||
blkctx: S.BlkCtx,
|
||||
@ -326,7 +282,7 @@ pub const CompilationContext = struct {
|
||||
// Before creating a new block, assign all received values to previous block,
|
||||
// otherwise they will be assign to this block
|
||||
if (self.currentBlock()) |prev_block| {
|
||||
meta.visit(Block.appendTensorRecursive, prev_block, &blkctx);
|
||||
meta.visit(_appendTensorRecursive, prev_block, &blkctx);
|
||||
}
|
||||
|
||||
const block = self.openBlock(kind, &input_types, &locations) catch unreachable;
|
||||
@ -337,15 +293,20 @@ pub const CompilationContext = struct {
|
||||
// So we create a copy of the arguments, and replace values
|
||||
// by the block arguments.
|
||||
var blk_args = args;
|
||||
std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N);
|
||||
std.debug.assert(assignBlockArguments(&blk_args, block[0], 0) == N);
|
||||
|
||||
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
|
||||
var block_res_values: [S.nOut]mlir.Value = undefined;
|
||||
self.extractValues(&block_res, &block_res_values);
|
||||
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
|
||||
block.appendOperationRecursive(block_ret);
|
||||
block[0].appendOperationRecursive(block_ret, block[1]);
|
||||
|
||||
return .{ block.block(), block_res };
|
||||
return .{ block[0], block_res };
|
||||
}
|
||||
|
||||
fn _appendTensorRecursive(tagged_block: TaggedBlock, x: *const Tensor) void {
|
||||
const block, const tag = tagged_block;
|
||||
block.appendValueRecursive(x.value(), tag);
|
||||
}
|
||||
|
||||
pub const EmitMlirOpts = struct {
|
||||
@ -411,7 +372,7 @@ pub const CompilationContext = struct {
|
||||
defer self.closeBlock(fn_body);
|
||||
|
||||
try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count));
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body[0], 0);
|
||||
std.debug.assert(assigned_args_count == tensor_count);
|
||||
|
||||
fn_res.* = forward: {
|
||||
@ -424,7 +385,7 @@ pub const CompilationContext = struct {
|
||||
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||
|
||||
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||
fn_body.appendOperationRecursive(fn_ret);
|
||||
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
|
||||
}
|
||||
|
||||
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
|
||||
@ -446,7 +407,7 @@ pub const CompilationContext = struct {
|
||||
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
||||
.results = fn_res_types,
|
||||
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
|
||||
.block = fn_body.block(),
|
||||
.block = fn_body[0],
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
@ -475,6 +436,7 @@ pub const CompilationContext = struct {
|
||||
/// Given a list of donations mapping output buffers to input buffers,
|
||||
/// generate donation attribute for each `n_args` input argument.
|
||||
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
|
||||
const ctx = self.mlirCtx();
|
||||
var n_donations: usize = 0;
|
||||
for (donations, 0..) |donation, index| {
|
||||
switch (donation) {
|
||||
@ -489,12 +451,7 @@ pub const CompilationContext = struct {
|
||||
// When the time come, do a more fancy lookup here to check if an argument
|
||||
// is donated twice.
|
||||
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
|
||||
attributes[a].appendAssumeCapacity(
|
||||
mlir.NamedAttribute.init(
|
||||
mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"),
|
||||
mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute),
|
||||
),
|
||||
);
|
||||
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
|
||||
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
||||
},
|
||||
}
|
||||
@ -552,46 +509,29 @@ pub const CompilationContext = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
|
||||
const mlir_ctx = self.mlirCtx();
|
||||
|
||||
const num_partitions = self._platform.sharding().num_partitions;
|
||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||
|
||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||
return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice());
|
||||
}
|
||||
|
||||
fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void {
|
||||
const mlir_ctx = self.mlirCtx();
|
||||
const ctx = self.mlirCtx();
|
||||
if (!self._platform.compilation_options.sharding_enabled) return;
|
||||
|
||||
const mhlo_default_layout = mlir.NamedAttribute.init(
|
||||
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
|
||||
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
|
||||
);
|
||||
const default_layout = mlir.NamedAttribute.named(ctx, "mhlo.layout_mode", .string(ctx, "default"));
|
||||
for (arg_attrs, input_shapes) |*attr, shape| {
|
||||
attr.appendAssumeCapacity(mhlo_default_layout);
|
||||
|
||||
const sharding_attr = self.getShardingAttr(shape);
|
||||
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
||||
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
||||
sharding_attr.asAttr(),
|
||||
));
|
||||
attr.appendAssumeCapacity(default_layout);
|
||||
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
|
||||
}
|
||||
|
||||
for (res_attrs, output_shapes) |*attr, shape| {
|
||||
attr.appendAssumeCapacity(mhlo_default_layout);
|
||||
|
||||
const sharding_attr = self.getShardingAttr(shape);
|
||||
|
||||
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
||||
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
||||
sharding_attr.asAttr(),
|
||||
));
|
||||
attr.appendAssumeCapacity(default_layout);
|
||||
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
||||
const ctx = self.mlirCtx();
|
||||
const num_partitions = self._platform.sharding().num_partitions;
|
||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
||||
}
|
||||
|
||||
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
|
||||
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
|
||||
if (n_sharded == 0 or num_partitions == 1) {
|
||||
@ -819,20 +759,11 @@ pub const CompilationContext = struct {
|
||||
|
||||
fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
|
||||
var hasher = std.hash.XxHash64.init(0);
|
||||
var hasher_writer = xxHash64Writer(&hasher);
|
||||
const writer = hasher_writer.writer();
|
||||
module.hash(&hasher);
|
||||
|
||||
// Hash the canonicalized IR, without debug information that can change across builds.
|
||||
module.op().print(writer, .{ .debug_info = false });
|
||||
// Note: before we where using module.op().writeBytecode(writer),
|
||||
// but it crashes on some inputs, notably for unused variables.
|
||||
// So we use the text representation of the mlir.
|
||||
// See https://github.com/zml/zml/issues/97.
|
||||
// Writes can't fail because we are writing to a hasher.
|
||||
writer.writeAll(platform.pjrt_client.getPlatformName(platform.pjrt_api)) catch unreachable;
|
||||
hasher.update(platform.pjrt_client.getPlatformName(platform.pjrt_api));
|
||||
const api_version = platform.pjrt_api.version();
|
||||
writer.writeInt(i64, api_version.major, .little) catch unreachable;
|
||||
writer.writeInt(i64, api_version.minor, .little) catch unreachable;
|
||||
hasher.update(std.mem.sliceAsBytes(&[_]i64{ api_version.major, api_version.minor }));
|
||||
|
||||
return hasher.final();
|
||||
}
|
||||
@ -1002,26 +933,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
|
||||
return context.index;
|
||||
}
|
||||
|
||||
pub const XxHash64Writer = struct {
|
||||
hasher: *std.hash.XxHash64,
|
||||
|
||||
pub const Error = error{};
|
||||
pub const Writer = std.io.Writer(*XxHash64Writer, Error, write);
|
||||
|
||||
pub fn writer(self: *XxHash64Writer) Writer {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn write(self: *XxHash64Writer, bytes: []const u8) Error!usize {
|
||||
self.hasher.update(bytes);
|
||||
return bytes.len;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer {
|
||||
return .{ .hasher = hasher };
|
||||
}
|
||||
|
||||
pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn);
|
||||
pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 };
|
||||
|
||||
@ -1165,12 +1076,11 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void {
|
||||
}
|
||||
}
|
||||
|
||||
const HashStrategy = std.hash.Strategy;
|
||||
const tensorAwareHash = hash; // alias for when "hash" is ambiguous
|
||||
|
||||
/// Provides generic hashing for any eligible type.
|
||||
/// Strategy is provided to determine if pointers should be followed or not.
|
||||
pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy) void {
|
||||
pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Strategy) void {
|
||||
const Key = @TypeOf(key);
|
||||
if (Key == Tensor) return hashShape(hasher, key.shape());
|
||||
if (Key == Shape) return hashShape(hasher, key);
|
||||
@ -1287,7 +1197,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
||||
}
|
||||
}
|
||||
|
||||
fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void {
|
||||
fn hashArray(hasher: anytype, key: anytype, comptime strat: std.hash.Strategy) void {
|
||||
for (key) |element| {
|
||||
hash(hasher, element, strat);
|
||||
}
|
||||
|
||||
@ -131,7 +131,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
},
|
||||
&.{
|
||||
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type),
|
||||
.tensor(&.{0}, .int(mlir_ctx, .u8)),
|
||||
},
|
||||
loc,
|
||||
);
|
||||
|
||||
33
zml/ops.zig
33
zml/ops.zig
@ -147,9 +147,7 @@ pub fn reduce(
|
||||
.variadic_operands = &.{ &input_values, &init_values },
|
||||
.result_type_inference = true,
|
||||
.blocks = &.{body_block},
|
||||
.attributes = &.{
|
||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) },
|
||||
},
|
||||
.attributes = &.{.{ "dimensions", .dense(ctx.mlirCtx(), .i64, axes) }},
|
||||
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
|
||||
.verify = false,
|
||||
.location = loc,
|
||||
@ -236,21 +234,16 @@ pub fn reduceWindow(
|
||||
ctx.extractValues(&inits, &init_values);
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
|
||||
const pad_shape = mlir.RankedTensorType.init(
|
||||
&.{ @intCast(opts.padding.len), 2 },
|
||||
mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64),
|
||||
).as(mlir.Type);
|
||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{
|
||||
.variadic_operands = &.{ input_values[0..], init_values[0..] },
|
||||
.result_type_inference = true,
|
||||
.blocks = &.{body_block},
|
||||
.attributes = &.{
|
||||
.{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) },
|
||||
.{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) },
|
||||
.{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) },
|
||||
.{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) },
|
||||
.{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) },
|
||||
.{ "window_dimensions", .dense(ctx.mlirCtx(), .i64, opts.window_dimensions) },
|
||||
.{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) },
|
||||
.{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) },
|
||||
.{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) },
|
||||
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) },
|
||||
},
|
||||
.location = loc,
|
||||
});
|
||||
@ -841,13 +834,13 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
|
||||
}
|
||||
|
||||
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)),
|
||||
.named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)),
|
||||
.named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)),
|
||||
.named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])),
|
||||
.named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])),
|
||||
.named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])),
|
||||
.named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)),
|
||||
.named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)),
|
||||
});
|
||||
|
||||
const MINOR_TO_MAJOR = blk: {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const testing = std.testing;
|
||||
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
|
||||
const EnumLiteral = @TypeOf(.enum_literal);
|
||||
|
||||
const log = std.log.scoped(.shape);
|
||||
@ -131,6 +132,10 @@ pub const Shape = struct {
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn scalar(dt: DataType) Shape {
|
||||
return .{ ._dtype = dt };
|
||||
}
|
||||
|
||||
/// Creates a Shape with dims set to `.{0, 1, 2, ..., rank-1}`.
|
||||
pub fn range(rank_: usize, dt: DataType) Shape {
|
||||
var res: Shape = .{ ._dtype = dt };
|
||||
|
||||
@ -1,28 +1,29 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const module = @import("module.zig");
|
||||
|
||||
const Location = mlir.Location;
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const meta = @import("meta.zig");
|
||||
const mlir = @import("mlir.zig");
|
||||
const Location = mlir.Location;
|
||||
const module = @import("module.zig");
|
||||
const CompilationContext = module.CompilationContext;
|
||||
const ops = @import("ops.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
|
||||
const EnumLiteral = @TypeOf(.enum_literal);
|
||||
|
||||
const dialect = struct {
|
||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||
};
|
||||
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
const scoped_log = std.log.scoped(.@"zml/tensor");
|
||||
|
||||
test {
|
||||
@ -173,15 +174,13 @@ pub const Tensor = struct {
|
||||
var res = self;
|
||||
res._shape = self._shape.withSharding(axes_);
|
||||
|
||||
const sharding = ctx.getShardingAttr(res._shape);
|
||||
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
&.{self.value()},
|
||||
.{
|
||||
.call_target_name = "Sharding",
|
||||
.has_side_effect = false,
|
||||
.addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }},
|
||||
.addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }},
|
||||
.api_version = .original,
|
||||
},
|
||||
&.{self.value().getType()},
|
||||
@ -1014,11 +1013,11 @@ pub const Tensor = struct {
|
||||
if (to == self.dtype()) {
|
||||
return self;
|
||||
}
|
||||
|
||||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type);
|
||||
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
||||
|
||||
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
||||
const mlir_ctx = self.getContext().mlirCtx();
|
||||
const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to));
|
||||
const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc);
|
||||
return _result(self._shape.withDtype(to), op.result(0));
|
||||
}
|
||||
|
||||
@ -1455,7 +1454,7 @@ pub const Tensor = struct {
|
||||
/// .{ .a, .b, .c }.mergeTranspose(.{ .a, .c }, .ac) -> .{ .b, .ac }
|
||||
pub fn mergeTranspose(self: Tensor, axes_: anytype, merged: EnumLiteral) Tensor {
|
||||
const cont = self.contiguous(axes_);
|
||||
return cont.reshape(cont._shape.mergeAxis(axes_, merged));
|
||||
return cont.reshape(cont._shape.mergeAxis(merged, axes_));
|
||||
}
|
||||
|
||||
/// Transposes the input Tensor, such has the given axes end up in contiguous position.
|
||||
@ -3107,7 +3106,7 @@ pub const Tensor = struct {
|
||||
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
||||
start_indices[a] = slice_.start.value();
|
||||
|
||||
const op = dialect.stablehlo.dynamicSlice(
|
||||
const op = dialect.stablehlo.dynamic_slice(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
new_shape.dims(),
|
||||
@ -3168,7 +3167,7 @@ pub const Tensor = struct {
|
||||
res_shape._dims.set(a, len);
|
||||
}
|
||||
}
|
||||
const op = dialect.stablehlo.dynamicSlice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc);
|
||||
const op = dialect.stablehlo.dynamic_slice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc);
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ pub const nn = @import("nn.zig");
|
||||
pub const module = @import("module.zig");
|
||||
pub const meta = @import("meta.zig");
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const mlir = @import("mlir.zig");
|
||||
pub const pjrt = @import("pjrtx.zig");
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user