diff --git a/mlir/BUILD.bazel b/mlir/BUILD.bazel index 0bebfa7..b4d6d29 100644 --- a/mlir/BUILD.bazel +++ b/mlir/BUILD.bazel @@ -9,6 +9,7 @@ cc_library( "@llvm-project//mlir:CAPIArith", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPISCF", "@llvm-project//mlir:CAPITransforms", ], ) diff --git a/mlir/c.h b/mlir/c.h index 014666a..4ad1523 100644 --- a/mlir/c.h +++ b/mlir/c.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/dialects/BUILD.bazel b/mlir/dialects/BUILD.bazel index 9570cb8..c4c8220 100644 --- a/mlir/dialects/BUILD.bazel +++ b/mlir/dialects/BUILD.bazel @@ -7,6 +7,7 @@ zig_library( "arith.zig", "func.zig", "math.zig", + "scf.zig", "tensor.zig", ], import_name = "mlir/dialects", diff --git a/mlir/dialects/arith.zig b/mlir/dialects/arith.zig index c005f6a..f6d86d5 100644 --- a/mlir/dialects/arith.zig +++ b/mlir/dialects/arith.zig @@ -1,11 +1,10 @@ const std = @import("std"); + const mlir = @import("mlir"); pub fn constant(ctx: mlir.Context, value: mlir.Attribute, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "arith.constant", .{ - .attributes = &.{ - .{ "value", value }, - }, + .attributes = &.{.{ "value", value }}, .result_type_inference = true, .location = location, }); @@ -44,6 +43,8 @@ pub const mulf = binary_fn("arith.mulf"); pub const divsi = binary_fn("arith.divsi"); pub const divui = binary_fn("arith.divui"); pub const divf = binary_fn("arith.divf"); +pub const maxnumf = binary_fn("arith.maxnumf"); +pub const maxnumi = binary_fn("arith.maxnumi"); pub const extsi = cast_fn("arith.extsi"); pub const extui = cast_fn("arith.extui"); pub const extf = cast_fn("arith.extf"); diff --git a/mlir/dialects/dialects.zig b/mlir/dialects/dialects.zig index 1da34b7..47f74d8 100644 --- a/mlir/dialects/dialects.zig +++ b/mlir/dialects/dialects.zig @@ -3,6 +3,7 @@ const std = @import("std"); pub const arith = @import("arith.zig"); pub const func = @import("func.zig"); pub const math = @import("math.zig"); +pub const scf = @import("scf.zig"); pub const tensor = @import("tensor.zig"); pub const stablehlo = @import("mlir/dialects/stablehlo"); diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index ec31e55..9db3612 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -1,4 +1,5 @@ const std = @import("std"); + const mlir = @import("mlir"); pub fn func( @@ -14,14 +15,14 @@ pub fn func( }, ) mlir.Operation { var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; - attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute) }); - attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type)).as(mlir.Attribute) }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) }); if (args.arg_attrs.len > 0) { - attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute) }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", .array(ctx, args.arg_attrs) }); } if (args.res_attrs.len > 0) { - attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute) }); + attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", .array(ctx, args.res_attrs) }); } return mlir.Operation.make(ctx, "func.func", .{ @@ -36,7 +37,7 @@ pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, r .variadic_operands = &.{values}, .results = results, .verify = true, - .attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute) }}, + .attributes = &.{.{ "callee", .symbol(ctx, name) }}, .location = loc, }); } diff --git a/mlir/dialects/math.zig b/mlir/dialects/math.zig index 766f40c..62ef622 100644 --- a/mlir/dialects/math.zig +++ b/mlir/dialects/math.zig @@ -8,7 +8,7 @@ fn unary_fn(comptime op_name: [:0]const u8) type { pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{ .operands = &.{value}, - .results = &.{}, + .results = &.{value.getType()}, .location = location, }); } @@ -32,4 +32,5 @@ pub const fpowi = binary_fn("fpowi").call; pub const tanh = unary_fn("tanh").call; pub const sqrt = unary_fn("sqrt").call; pub const exp = unary_fn("exp").call; +pub const exp2 = unary_fn("exp2").call; pub const log = unary_fn("log").call; diff --git a/mlir/dialects/scf.zig b/mlir/dialects/scf.zig new file mode 100644 index 0000000..61fcb18 --- /dev/null +++ b/mlir/dialects/scf.zig @@ -0,0 +1,61 @@ +const std = @import("std"); + +const mlir = @import("mlir"); + +pub fn ForBody(ExtraArgs: type) type { + return fn (mlir.Context, mlir.Block, ExtraArgs) mlir.Operation; +} + +pub const ForRange = struct { + start: mlir.Value, + end: mlir.Value, + step: mlir.Value, +}; + +pub fn @"for"( + ExtraArgs: type, + ctx: mlir.Context, + range: ForRange, + init_values: []const mlir.Value, + body: ForBody(ExtraArgs), + extra_args: ExtraArgs, + loc: mlir.Location, +) mlir.Operation { + const n_args = init_values.len; + var init_types_buf: [32]mlir.Type = undefined; + var locs_buf: [32]mlir.Location = undefined; + + // The first block argument is the for loop induction variable, + // followed then by all the loop-carried variables. + const init_types = init_types_buf[0 .. n_args + 1]; + const locs = locs_buf[0 .. n_args + 1]; + init_types[0] = range.start.getType(); + locs[0] = loc; + for (1.., init_values) |i, val| { + init_types[i] = val.getType(); + locs[i] = loc; + } + + const block = mlir.Block.init(init_types, locs) catch unreachable; + const yield_op = @call(.auto, body, .{ ctx, block, extra_args }); + std.debug.assert(std.mem.eql(u8, "scf.yield", yield_op.name().str())); + block.appendOperationRecursive(yield_op, .open); + + const for_op = mlir.Operation.make(ctx, "scf.for", .{ + .variadic_operands = &.{ &.{ range.start, range.end, range.step }, init_values }, + .results = init_types[1..], + .blocks = &.{block}, + .location = loc, + .verify = false, + }); + return for_op; +} + +pub fn yield(ctx: mlir.Context, res: []const mlir.Value, loc: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "scf.yield", .{ + .variadic_operands = &.{res}, + .results = &.{}, + .location = loc, + .verify = false, + }); +} diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 117f809..c5870a2 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -99,7 +99,7 @@ pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mli .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(lower))).as(mlir.Attribute) }, + .{ "lower", .i1FromBool(ctx, lower) }, }, .location = location, }); @@ -189,7 +189,7 @@ pub fn dot_general( precision: DotPrecision, }, ) mlir.Operation { - const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2; + const precisions: [2]mlir.Attribute = @splat(opts.precision.precisionAttr(ctx)); const attributes = [3]mlir.AttrTuple{ .{ "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ @@ -199,7 +199,7 @@ pub fn dot_general( .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, }).as(mlir.Attribute), }, - .{ "precision_config", mlir.ArrayAttribute.init(ctx, &precisions).asAttr() }, + .{ "precision_config", .array(ctx, &precisions) }, // keep algorithm as the last attribute so we can omit it when it's not set. .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined }, }; @@ -244,7 +244,7 @@ pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i6 .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "broadcast_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dims).as(mlir.Attribute) }, + .{ "broadcast_dimensions", .dense(ctx, .i64, dims) }, }, .location = location, }); @@ -255,7 +255,7 @@ pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, l .operands = &.{value}, .results = &.{result_type}, .attributes = &.{ - .{ "permutation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.permutation).as(mlir.Attribute) }, + .{ "permutation", .dense(ctx, .i64, opts.permutation) }, }, .location = location, }); @@ -266,9 +266,9 @@ pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64, .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "start_indices", mlir.DenseArrayAttribute(.i64).init(ctx, start_indices).as(mlir.Attribute) }, - .{ "limit_indices", mlir.DenseArrayAttribute(.i64).init(ctx, limit_indices).as(mlir.Attribute) }, - .{ "strides", mlir.DenseArrayAttribute(.i64).init(ctx, strides).as(mlir.Attribute) }, + .{ "start_indices", .dense(ctx, .i64, start_indices) }, + .{ "limit_indices", .dense(ctx, .i64, limit_indices) }, + .{ "strides", .dense(ctx, .i64, strides) }, }, .location = location, }); @@ -279,7 +279,7 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64 .operands = inputs, .result_type_inference = true, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, + .{ "dimension", .int(ctx, .i64, dimension) }, }, .location = location, }); @@ -333,8 +333,8 @@ pub fn gather( args.start_index_map, args.index_vector_dim, ).as(mlir.Attribute) }, - .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, slice_sizes).as(mlir.Attribute) }, - .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) }, + .{ "slice_sizes", .dense(ctx, .i64, slice_sizes) }, + .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, }, .location = location, }, @@ -394,8 +394,8 @@ pub fn scatter( .blocks = &.{update_block}, .attributes = &.{ .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, - .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute) }, - .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute) }, + .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, + .{ "unique_indices", .boolean(ctx, args.unique_indices) }, }, .result_type_inference = true, .location = location, @@ -408,7 +408,7 @@ pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: .operands = &.{}, .results = &.{result_type}, .attributes = &.{ - .{ "iota_dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, + .{ "iota_dimension", .int(ctx, .i64, dimension) }, }, .location = location, }); @@ -420,7 +420,7 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, .operands = &.{operand}, .results = &.{result_type}, .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) }, + .{ "dimensions", .dense(ctx, .i64, dimensions) }, }, .location = location, }); @@ -453,7 +453,7 @@ pub fn reduce( const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; for (inputs, 0..) |input, i| { - const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type); + const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); reduce_elem_types[i] = arg_type; reduce_elem_types[inputs.len + i] = arg_type; } @@ -475,7 +475,7 @@ pub fn reduce( .result_type_inference = true, .block = block, .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute) }, + .{ "dimensions", .dense(ctx, .i64, dimensions) }, }, .location = location, }); @@ -495,7 +495,7 @@ pub fn sort( const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2]; var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined; for (inputs, 0..) |input, i| { - const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type); + const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); sort_elem_types[i * 2] = arg_type; sort_elem_types[i * 2 + 1] = arg_type; } @@ -512,19 +512,19 @@ pub fn sort( .result_type_inference = true, .block = block, .attributes = &.{ - .{ "dimension", mlir.IntegerAttribute(.i64).init(ctx, dimension).as(mlir.Attribute) }, - .{ "is_stable", mlir.BoolAttribute.init(ctx, is_stable).as(mlir.Attribute) }, + .{ "dimension", .int(ctx, .i64, dimension) }, + .{ "is_stable", .boolean(ctx, is_stable) }, }, .location = location, }); } -pub fn dynamicSlice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { +pub fn dynamic_slice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.dynamic_slice", .{ .variadic_operands = &.{ &.{operand}, start_indices }, .result_type_inference = true, .attributes = &.{ - .{ "slice_sizes", mlir.DenseArrayAttribute(.i64).init(ctx, new_dims).as(mlir.Attribute) }, + .{ "slice_sizes", .dense(ctx, .i64, new_dims) }, }, .location = location, }); @@ -557,9 +557,9 @@ pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts .operands = &.{ value, padding_value }, .result_type_inference = true, .attributes = &.{ - .{ "edge_padding_low", mlir.DenseArrayAttribute(.i64).init(ctx, opts.low).as(mlir.Attribute) }, - .{ "edge_padding_high", mlir.DenseArrayAttribute(.i64).init(ctx, opts.high).as(mlir.Attribute) }, - .{ "interior_padding", mlir.DenseArrayAttribute(.i64).init(ctx, opts.interior).as(mlir.Attribute) }, + .{ "edge_padding_low", .dense(ctx, .i64, opts.low) }, + .{ "edge_padding_high", .dense(ctx, .i64, opts.high) }, + .{ "interior_padding", .dense(ctx, .i64, opts.interior) }, }, .location = location, }); @@ -577,9 +577,9 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value, .operands = &.{ value, other }, .result_type_inference = true, .attributes = &.{ - .{ "left_side", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.left_side))).as(mlir.Attribute) }, - .{ "lower", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.lower))).as(mlir.Attribute) }, - .{ "unit_diagonal", mlir.IntegerAttribute(.i1).init(ctx, @intCast(@intFromBool(opts.unit_diagonal))).as(mlir.Attribute) }, + .{ "left_side", .i1FromBool(ctx, opts.left_side) }, + .{ "lower", .i1FromBool(ctx, opts.lower) }, + .{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) }, .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) }, }, .location = location, @@ -597,7 +597,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts: .result_type_inference = true, .attributes = &.{ .{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) }, - .{ "fft_length", mlir.DenseArrayAttribute(.i64).init(ctx, opts.length).as(mlir.Attribute) }, + .{ "fft_length", .dense(ctx, .i64, opts.length) }, }, .location = location, }); @@ -630,8 +630,8 @@ pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32 .operands = &.{value}, .result_type_inference = true, .attributes = &.{ - .{ "exponent_bits", mlir.IntegerAttribute(.i32).init(ctx, exponent_bits).as(mlir.Attribute) }, - .{ "mantissa_bits", mlir.IntegerAttribute(.i32).init(ctx, mantissa_bits).as(mlir.Attribute) }, + .{ "exponent_bits", .int(ctx, .i32, exponent_bits) }, + .{ "mantissa_bits", .int(ctx, .i32, mantissa_bits) }, }, .location = location, }); @@ -658,7 +658,7 @@ pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64, .operands = &.{tuple_value}, .result_type_inference = true, .attributes = &.{ - .{ "index", mlir.IntegerAttribute(.i32).init(ctx, index).as(mlir.Attribute) }, + .{ "index", .int(ctx, .i32, index) }, }, .location = location, }); @@ -701,17 +701,15 @@ pub fn convolution( for (opts.window_reversal, 0..) |w, i| { window_reversal[i] = @intCast(@intFromBool(w)); } - const pad_type = mlir.IntegerType(.i64).init(ctx).as(mlir.Type); - const pad_shape = mlir.RankedTensorType.init(opts.pad_shape, pad_type).as(mlir.Type); return mlir.Operation.make(ctx, "stablehlo.convolution", .{ .operands = &.{ lhs, rhs }, .results = &.{res_type}, .attributes = &.{ - .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute) }, - .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.pad_value).as(mlir.Attribute) }, - .{ "lhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.lhs_dilation).as(mlir.Attribute) }, - .{ "rhs_dilation", mlir.DenseArrayAttribute(.i64).init(ctx, opts.rhs_dilation).as(mlir.Attribute) }, - .{ "window_reversal", mlir.DenseArrayAttribute(.bool).init(ctx, window_reversal[0..opts.window_reversal.len]).as(mlir.Attribute) }, + .{ "window_strides", .dense(ctx, .i64, opts.window_strides) }, + .{ "padding", .denseElements(ctx, opts.pad_shape, .i64, opts.pad_value) }, + .{ "lhs_dilation", .dense(ctx, .i64, opts.lhs_dilation) }, + .{ "rhs_dilation", .dense(ctx, .i64, opts.rhs_dilation) }, + .{ "window_reversal", .dense(ctx, .bool, window_reversal[0..opts.window_reversal.len]) }, .{ "dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{ .input_batch_dimension = opts.input_batch_dimension, @@ -725,9 +723,9 @@ pub fn convolution( .output_spatial_dimensions = opts.output_spatial_dimensions, }).as(mlir.Attribute), }, - .{ "feature_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.feature_group_count).as(mlir.Attribute) }, - .{ "batch_group_count", mlir.IntegerAttribute(.i64).init(ctx, opts.batch_group_count).as(mlir.Attribute) }, - .{ "precision_config", mlir.ArrayAttribute.init(ctx, &max_precisions).as(mlir.Attribute) }, + .{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) }, + .{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) }, + .{ "precision_config", .array(ctx, &max_precisions) }, }, .location = location, }); @@ -764,25 +762,20 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (opts.output_operand_aliases) |alias| { ret.appendAssumeCapacity( - OutputOperandAliasAttribute.init( - ctx, - &.{}, - alias, - &.{}, - ).as(mlir.Attribute), + OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), ); } break :blk ret; }; - const backend_config = switch (opts.backend_config) { + const backend_config: mlir.Attribute = switch (opts.backend_config) { .string => blk: { stdx.debug.assert( @intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi), "Only API version of less than 4 is supported for backend_config as string", .{}, ); - break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute); + break :blk .string(ctx, opts.backend_config.string); }, .dict => blk: { stdx.debug.assert( @@ -797,45 +790,33 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ - .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) }, - .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) }, - .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) }, + .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, + .{ "call_target_name", .string(ctx, opts.call_target_name) }, + .{ "has_side_effect", .boolean(ctx, opts.has_side_effect) }, .{ "backend_config", backend_config }, - .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) }, + .{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }, }); if (opts.operand_layouts) |layouts| { const operand_layouts = blk: { var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (layouts) |ol| { - const tensor_type = mlir.RankedTensorType.init( - &.{@intCast(ol.len)}, - mlir.IndexType.init(ctx).as(mlir.Type), - ).as(mlir.Type); - const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol); - ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); } break :blk ret; }; - const attr: mlir.AttrTuple = .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) }; - attrs.appendAssumeCapacity(attr); + attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); } if (opts.result_layouts) |layouts| { const result_layouts = blk: { var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (layouts) |rl| { - const tensor_type = mlir.RankedTensorType.init( - &.{@intCast(rl.len)}, - mlir.IndexType.init(ctx).as(mlir.Type), - ).as(mlir.Type); - const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl); - ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); } break :blk ret; }; - const attr: mlir.AttrTuple = .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) }; - attrs.appendAssumeCapacity(attr); + attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); } attrs.appendSliceAssumeCapacity(opts.addional_attributes); @@ -852,9 +833,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { const frontend_attributes = mlir.DictionaryAttribute.init( ctx, - &.{ - mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()), - }, + &.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())}, ).asAttr(); return custom_call(ctx, inputs, .{ diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 1ccc766..4e2336a 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -1,9 +1,10 @@ -const builtin = @import("builtin"); const std = @import("std"); -const stdx = @import("stdx"); -const log = std.log.scoped(.mlir); +const builtin = @import("builtin"); const c = @import("c"); +const stdx = @import("stdx"); + +const log = std.log.scoped(.mlir); test { std.testing.refAllDecls(@This()); @@ -254,6 +255,10 @@ pub const Module = struct { pub fn op(self: Module) Operation { return Operation.wrap(c.mlirModuleGetOperation(self.inner())); } + + pub fn hash(self: Module, hasher: *std.hash.XxHash64) void { + return self.op().hash(hasher); + } }; pub const PassManager = struct { @@ -352,21 +357,81 @@ pub const Attribute = struct { ) orelse Error.InvalidMlir; } - pub fn getNull() Self { - return Self.wrap(c.mlirAttributeGetNull()); + // utilities function to built common attributes. + // All attributes are upcasted to the Attribute type, making it easier to chain construct, + // but losing type information. + + pub fn null_() Attribute { + return .wrap(c.mlirAttributeGetNull()); + } + + pub fn string(ctx: Context, str: []const u8) Attribute { + return StringAttribute.init(ctx, str).asAttr(); + } + + pub fn type_(t: Type) Attribute { + return TypeAttribute.init(t).asAttr(); + } + + pub fn unit(ctx: Context) Attribute { + return .wrap(c.mlirUnitAttrGet(ctx.inner())); + } + + pub fn boolean(ctx: Context, value: bool) Attribute { + return BoolAttribute.init(ctx, value).asAttr(); + } + + pub fn i1FromBool(ctx: Context, value: bool) Attribute { + return IntegerAttribute(.i1).init(ctx, @intFromBool(value)).asAttr(); + } + + pub fn int(ctx: Context, comptime int_type: IntegerTypes, value: i64) Attribute { + return IntegerAttribute(int_type).init(ctx, value).asAttr(); + } + + pub fn float(ctx: Context, comptime float_type: FloatTypes, value: f64) Attribute { + return .wrap(FloatAttribute(float_type).init(ctx, value)._inner); + } + + pub fn array(ctx: Context, attrs: []const Attribute) Attribute { + return ArrayAttribute.init(ctx, attrs).asAttr(); + } + + pub fn dense(ctx: Context, comptime dt: DenseArrayTypes, values: []const dt.ZigType()) Attribute { + return DenseArrayAttribute(dt).init(ctx, values).asAttr(); + } + + /// Use a tensor as an attribute. + /// The tensor is specified by dims, dtype and a flat slice of values. + pub fn denseElements(ctx: Context, dims: []const i64, comptime dt: DenseElementsAttributeTypes, values: anytype) Attribute { + return .wrap(DenseElementsAttribute(dt).init(.tensor(dims, dt.mlirType(ctx)), values).inner()); + } + + pub fn symbol(ctx: Context, flat_name: [:0]const u8) Attribute { + return .wrap(FlatSymbolRefAttribute.init(ctx, flat_name).inner()); + } + + pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute { + return NamedAttribute.init(Identifier.get(ctx, name), attr); } }; -pub const NamedAttribute = struct { - _inner: c.MlirNamedAttribute, - pub usingnamespace MlirHelpers(NamedAttribute, .{}); - const Self = NamedAttribute; +pub const NamedAttribute = extern struct { + name: c.MlirIdentifier, + attribute: c.MlirAttribute, - pub fn init(name: Identifier, attr: Attribute) Self { - return Self.wrap(.{ + pub fn named(ctx: Context, name: [:0]const u8, attr: Attribute) NamedAttribute { + return .{ + .name = c.mlirIdentifierGet(ctx._inner, stringRef(name)), + .attribute = attr.inner(), + }; + } + + pub fn init(name: Identifier, attr: Attribute) NamedAttribute { + return .{ .name = name.inner(), .attribute = attr.inner(), - }); + }; } }; @@ -393,21 +458,6 @@ pub const StringAttribute = struct { } }; -pub const UnitAttribute = struct { - _inner: c.MlirAttribute, - pub usingnamespace MlirHelpers(UnitAttribute, .{ - .is_a_fn = c.mlirAttributeIsAUnit, - .is_null_fn = c.mlirAttributeIsNull, - .dump_fn = c.mlirAttributeDump, - .equal_fn = c.mlirAttributeEqual, - }); - const Self = UnitAttribute; - - pub fn init(ctx: Context) Self { - return Self.wrap(c.mlirUnitAttrGet(ctx.inner())); - } -}; - pub const BoolAttribute = struct { _inner: c.MlirAttribute, pub usingnamespace MlirHelpers(BoolAttribute, .{ @@ -481,8 +531,8 @@ pub const ArrayAttribute = struct { pub fn IntegerAttribute(comptime it: IntegerTypes) type { const ZigType, const getter = comptime switch (it) { - .i1, .i4, .i8, .i16, .i32, .i64 => .{ u64, c.mlirIntegerAttrGetValueInt }, - .si4, .si8, .si16, .si32, .si64 => .{ u64, c.mlirIntegerAttrGetValueSInt }, + .i1, .i4, .i8, .i16, .i32, .i64 => .{ i64, c.mlirIntegerAttrGetValueInt }, + .si4, .si8, .si16, .si32, .si64 => .{ i64, c.mlirIntegerAttrGetValueSInt }, .u4, .u8, .u16, .u32, .u64 => .{ u64, c.mlirIntegerAttrGetValueUInt }, .unknown => @compileError("IntegerAttribute(unknown)"), }; @@ -547,38 +597,48 @@ pub const DenseArrayTypes = enum { i64, f32, f64, + + pub fn ZigType(comptime dt: DenseArrayTypes) type { + return switch (dt) { + .bool => i32, + .i8 => i8, + .i16 => i16, + .i32 => i32, + .i64 => i64, + .f32 => f32, + .f64 => f64, + }; + } }; pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type { - const Config = switch (dt) { - .bool => .{ i32, c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement }, - .i8 => .{ i8, c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement }, - .i16 => .{ i16, c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement }, - .i32 => .{ i32, c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement }, - .i64 => .{ i64, c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement }, - .f32 => .{ f32, c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement }, - .f64 => .{ f64, c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement }, + const is_a_fn, const get_fn, const get_element_fn = switch (dt) { + .bool => .{ c.mlirAttributeIsADenseBoolArray, c.mlirDenseBoolArrayGet, c.mlirDenseBoolArrayGetElement }, + .i8 => .{ c.mlirAttributeIsADenseI8Array, c.mlirDenseI8ArrayGet, c.mlirDenseI8ArrayGetElement }, + .i16 => .{ c.mlirAttributeIsADenseI16Array, c.mlirDenseI16ArrayGet, c.mlirDenseI16ArrayGetElement }, + .i32 => .{ c.mlirAttributeIsADenseI32Array, c.mlirDenseI32ArrayGet, c.mlirDenseI32ArrayGetElement }, + .i64 => .{ c.mlirAttributeIsADenseI64Array, c.mlirDenseI64ArrayGet, c.mlirDenseI64ArrayGetElement }, + .f32 => .{ c.mlirAttributeIsADenseF32Array, c.mlirDenseF32ArrayGet, c.mlirDenseF32ArrayGetElement }, + .f64 => .{ c.mlirAttributeIsADenseF64Array, c.mlirDenseF64ArrayGet, c.mlirDenseF64ArrayGetElement }, }; return struct { _inner: c.MlirAttribute, pub usingnamespace MlirHelpers(@This(), .{ - .is_a_fn = Config[1], + .is_a_fn = is_a_fn, .is_null_fn = c.mlirAttributeIsNull, .dump_fn = c.mlirAttributeDump, .equal_fn = c.mlirAttributeEqual, }); const Attr = @This(); const ElementType = dt; - const ElementTypeZig = Config[0]; + const ElementTypeZig = dt.ZigType(); pub fn init(ctx: Context, values: []const ElementTypeZig) Attr { - const get_fn = Config[2]; return Attr.wrap(get_fn(ctx.inner(), @intCast(values.len), @ptrCast(values.ptr))); } pub fn get(self: Attr, pos: usize) ElementTypeZig { - const get_element_fn = Config[3]; return get_element_fn(self.inner(), @intCast(pos)); } @@ -586,6 +646,10 @@ pub fn DenseArrayAttribute(comptime dt: DenseArrayTypes) type { return @intCast(c.mlirDenseArrayGetNumElements(self.inner())); } + pub fn asAttr(self: Attr) Attribute { + return Attribute.wrap(self._inner); + } + pub usingnamespace switch (dt) { .bool, .i64 => struct { const DenseArray = DenseArrayAttribute(switch (dt) { @@ -614,26 +678,47 @@ pub const DenseElementsAttributeTypes = enum { f32, f64, index, + + pub fn ZigType(comptime dt: DenseElementsAttributeTypes) type { + return switch (dt) { + .bool => bool, + .i8 => i8, + .i16 => i16, + .i32 => i32, + .i64 => i64, + .u8 => u8, + .u16 => u16, + .u32 => u32, + .u64 => u64, + .bf16 => u16, + .f16 => f16, + .f32 => f32, + .f64 => f64, + .index => usize, + }; + } + + pub fn mlirType(dt: DenseElementsAttributeTypes, ctx: Context) Type { + return switch (dt) { + .bool => .int(ctx, .i1), + .i8 => .int(ctx, .i8), + .i16 => .int(ctx, .i16), + .i32 => .int(ctx, .i32), + .i64 => .int(ctx, .i64), + .u8 => .int(ctx, .u8), + .u16 => .int(ctx, .u16), + .u32 => .int(ctx, .u32), + .u64 => .int(ctx, .u64), + .bf16 => .float(ctx, .bf16), + .f16 => .float(ctx, .f16), + .f32 => .float(ctx, .f32), + .f64 => .float(ctx, .f64), + .index => .index(ctx), + }; + } }; pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { - const ZigType = switch (dt) { - .bool => bool, - .i8 => i8, - .i16 => i16, - .i32 => i32, - .i64 => i64, - .u8 => u8, - .u16 => u16, - .u32 => u32, - .u64 => u64, - .bf16 => u16, - .f16 => f16, - .f32 => f32, - .f64 => f64, - .index => usize, - }; - return struct { _inner: c.MlirAttribute, @@ -662,8 +747,8 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { return @intCast(c.mlirElementsAttrGetNumElements(self.inner())); } - pub fn constSlice(self: Attr) []const ZigType { - const ptr: [*]const ZigType = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable))); + pub fn constSlice(self: Attr) []const dt.ZigType() { + const ptr: [*]const dt.ZigType() = @constCast(@ptrCast(@alignCast(c.mlirDenseElementsAttrGetRawData(self.inner()) orelse unreachable))); return ptr[0..self.len()]; } @@ -814,6 +899,7 @@ pub const Operation = struct { pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { operands: ?[]const Value = null, variadic_operands: ?[]const []const Value = null, + tt_variadic_operands: ?[]const []const Value = null, results: ?[]const Type = null, variadic_results: ?[]const []const Type = null, result_type_inference: ?bool = null, @@ -828,9 +914,24 @@ pub const Operation = struct { if (args.operands) |operands| { state.addOperands(operands); } else if (args.variadic_operands) |operands_segments| { + const MAX_SEGMENTS = 32; + var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; + for (operands_segments) |operands| { state.addOperands(operands); + segments.appendAssumeCapacity(@intCast(operands.len)); } + state.addAttribute(ctx, "operandSegmentSizes", .denseElements(ctx, &.{@intCast(segments.len)}, .i32, segments.constSlice())); + } else if (args.tt_variadic_operands) |operands_segments| { + // stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that. + const MAX_SEGMENTS = 32; + var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; + + for (operands_segments) |operands| { + state.addOperands(operands); + segments.appendAssumeCapacity(@intCast(operands.len)); + } + state.addAttribute(ctx, "operandSegmentSizes", .dense(ctx, .i32, segments.constSlice())); } if (args.result_type_inference) |enable| { state.resultTypeInference(enable); @@ -1076,6 +1177,26 @@ pub const Operation = struct { pub fn removeAttributeByName(self: Self, name_: [:0]const u8) bool { return c.mlirOperationRemoveAttributeByName(self.inner(), stringRef(name_)); } + + pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void { + const NoError = error{}; + const write = struct { + fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize { + hasher_.update(bytes); + return bytes.len; + } + }.write; + const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write); + const writer: HashWriter = .{ .context = hasher }; + + // Hash the canonicalized IR, without debug information that can change across builds. + // Note: before we where using op.writeBytecode(writer), + // but it crashes on some inputs, notably for unused variables. + // So we use the text representation of the mlir. + // See https://github.com/zml/zml/issues/97. + // Writes can't fail because we are writing to a hasher. + op.print(writer, .{ .debug_info = false }); + } }; pub const OpPrintingFlags = struct { @@ -1255,6 +1376,41 @@ pub const Type = struct { c.mlirTypeParseGet(ctx.inner(), stringRef(str)), ) orelse Error.InvalidMlir; } + + pub fn index(ctx: Context) Type { + return IndexType.init(ctx).as(Type); + } + + pub fn int(ctx: Context, int_type: IntegerTypes) Type { + return switch (int_type) { + .unknown => @panic("Unknown integer type"), + inline else => |t| IntegerType(t).init(ctx).as(Type), + }; + } + + pub fn float(ctx: Context, float_type: FloatTypes) Type { + return switch (float_type) { + inline else => |t| FloatType(t).init(ctx).as(Type), + }; + } + + pub fn complex(ctx: Context, complex_type: ComplexTypes) Type { + return switch (complex_type) { + inline else => |t| ComplexType(t).init(ctx).as(Type), + }; + } + + pub fn tuple(ctx: Context, types: []const Type) Type { + return (TupleType.init(ctx, types) catch unreachable).as(Type); + } + + pub fn function(ctx: Context, args: []const Type, results: []const Type) Type { + return (FunctionType.init(ctx, args, results) catch unreachable).as(Type); + } + + pub fn tensor(dimensions: []const i64, elem_type: Type) Type { + return RankedTensorType.init(dimensions, elem_type).as(Type); + } }; pub const IndexType = struct { @@ -1680,7 +1836,9 @@ pub const Block = struct { } pub fn argument(self: Block, index: usize) Value { - return Value.wrap(c.mlirBlockGetArgument(self.inner(), @intCast(index))); + const arg = c.mlirBlockGetArgument(self.inner(), @intCast(index)); + stdx.debug.assert(!Value.Methods.is_null_fn.?(arg), "Block doesn't have argument {}, only got {}", .{ index, self.numArguments() }); + return Value.wrap(arg); } pub fn numArguments(self: Block) usize { @@ -1708,4 +1866,29 @@ pub const Block = struct { c.mlirBlockAppendOwnedOperation(self.inner(), op.inner()); } } + + pub const RecursiveOpts = enum { open, hermetic }; + + pub fn appendValueRecursive(self: Block, value: Value, opt: RecursiveOpts) void { + switch (value.kind()) { + .op_result => |parent_op| self.appendOperationRecursive(parent_op, opt), + .block_argument => |arg| { + // Hermetic blocks are not allowed to use arguments from other blocks. + stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr }); + }, + .null => @panic("InvalidMlir"), + } + } + + pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void { + if (op.block()) |prev_block| { + // Hermetic blocks are not allowed to reference values from other blocks. + stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr }); + return; + } + for (0..op.numOperands()) |i| { + self.appendValueRecursive(op.operand(i), opt); + } + self.appendOperation(op); + } }; diff --git a/zml/module.zig b/zml/module.zig index 0b284c7..beb7a01 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -6,15 +6,14 @@ const runfiles = @import("runfiles"); const stdx = @import("stdx"); const xla_pb = @import("//xla:xla_proto"); -const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const ops = @import("ops.zig"); -const pjrt = @import("pjrtx.zig"); - const BaseExe = @import("exe.zig").BaseExe; const Buffer = @import("buffer.zig").Buffer; const Bufferized = @import("tensor.zig").Bufferized; +const meta = @import("meta.zig"); +const mlir = @import("mlir.zig"); const Location = mlir.Location; +const ops = @import("ops.zig"); +const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; const ShapeOf = @import("tensor.zig").ShapeOf; @@ -28,46 +27,6 @@ test { std.testing.refAllDecls(@This()); } -pub const BlockKind = enum { open, hermetic }; - -const Block = union(BlockKind) { - open: mlir.Block, - hermetic: mlir.Block, - - pub fn block(self: Block) mlir.Block { - return switch (self) { - inline .open, .hermetic => |t| t, - }; - } - - fn appendTensorRecursive(self: Block, x: *const Tensor) void { - self.appendValueRecursive(x.value()); - } - - fn appendValueRecursive(self: Block, value: mlir.Value) void { - switch (value.kind()) { - .op_result => |parent_op| self.appendOperationRecursive(parent_op), - .block_argument => |arg| { - // Hermetic blocks are not allowed to use arguments from other blocks. - stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr }); - }, - .null => @panic("InvalidMlir"), - } - } - - fn appendOperationRecursive(self: Block, op: mlir.Operation) void { - if (op.block()) |prev_block| { - // Hermetic blocks are not allowed to reference values from other blocks. - std.debug.assert(self == .open or prev_block.equals(self.block())); - return; - } - for (0..op.numOperands()) |i| { - self.appendValueRecursive(op.operand(i)); - } - self.block().appendOperation(op); - } -}; - pub const MlirFn = struct { name: []const u8, num_args: u32, @@ -94,7 +53,7 @@ pub const CompilationContext = struct { _module: mlir.Module, - _blocks: std.BoundedArray(Block, 64) = .{}, + _blocks: std.BoundedArray(TaggedBlock, 64) = .{}, _fn_cache: FnCache = .{}, _block_args: TensorToBlockArg = .{}, @@ -104,6 +63,7 @@ pub const CompilationContext = struct { _previous: ?*CompilationContext = null, threadlocal var _current: ?*CompilationContext = null; + const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts }; const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); @@ -123,7 +83,7 @@ pub const CompilationContext = struct { const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const module = mlir.Module.init(loc); - module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute)); + module.op().setAttributeByName("sym_name", .string(mlir_ctx, "zml")); var canonicalizer = try mlir.PassManager.init(mlir_ctx); { @@ -280,26 +240,22 @@ pub const CompilationContext = struct { ); } - fn currentBlock(self: *const CompilationContext) ?Block { + fn currentBlock(self: *const CompilationContext) ?TaggedBlock { return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null; } - pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block { - const mlir_block = try mlir.Block.init(args, locs); - const block: Block = switch (kind) { - .open => .{ .open = mlir_block }, - .hermetic => .{ .hermetic = mlir_block }, - }; + pub fn openBlock(self: *CompilationContext, kind: mlir.Block.RecursiveOpts, args: []const mlir.Type, locs: []const mlir.Location) !TaggedBlock { + const block: TaggedBlock = .{ try mlir.Block.init(args, locs), kind }; self.pushBlock(block); return block; } - pub fn closeBlock(self: *CompilationContext, block: Block) void { + pub fn closeBlock(self: *CompilationContext, block: TaggedBlock) void { const popped = self._blocks.pop(); - std.debug.assert(block.block().eql(popped.?.block())); + std.debug.assert(block[0].eql(popped.?[0])); } - fn pushBlock(self: *CompilationContext, block: Block) void { + fn pushBlock(self: *CompilationContext, block: TaggedBlock) void { self._blocks.appendAssumeCapacity(block); } @@ -311,7 +267,7 @@ pub const CompilationContext = struct { /// But their shapes/tags can be safely propagated further. pub fn makeBlock( self: *CompilationContext, - kind: BlockKind, + kind: mlir.Block.RecursiveOpts, comptime S: ops.BlockSignature, func: *const S.Fn, blkctx: S.BlkCtx, @@ -326,7 +282,7 @@ pub const CompilationContext = struct { // Before creating a new block, assign all received values to previous block, // otherwise they will be assign to this block if (self.currentBlock()) |prev_block| { - meta.visit(Block.appendTensorRecursive, prev_block, &blkctx); + meta.visit(_appendTensorRecursive, prev_block, &blkctx); } const block = self.openBlock(kind, &input_types, &locations) catch unreachable; @@ -337,15 +293,20 @@ pub const CompilationContext = struct { // So we create a copy of the arguments, and replace values // by the block arguments. var blk_args = args; - std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N); + std.debug.assert(assignBlockArguments(&blk_args, block[0], 0) == N); const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args)); var block_res_values: [S.nOut]mlir.Value = undefined; self.extractValues(&block_res, &block_res_values); const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc); - block.appendOperationRecursive(block_ret); + block[0].appendOperationRecursive(block_ret, block[1]); - return .{ block.block(), block_res }; + return .{ block[0], block_res }; + } + + fn _appendTensorRecursive(tagged_block: TaggedBlock, x: *const Tensor) void { + const block, const tag = tagged_block; + block.appendValueRecursive(x.value(), tag); } pub const EmitMlirOpts = struct { @@ -411,7 +372,7 @@ pub const CompilationContext = struct { defer self.closeBlock(fn_body); try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count)); - const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); + const assigned_args_count = self.mapBlockArguments(args, fn_body[0], 0); std.debug.assert(assigned_args_count == tensor_count); fn_res.* = forward: { @@ -424,7 +385,7 @@ pub const CompilationContext = struct { self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); - fn_body.appendOperationRecursive(fn_ret); + fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]); } const arg_attrs = try arena.alloc(AttributeList, tensor_count); @@ -446,7 +407,7 @@ pub const CompilationContext = struct { .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .results = fn_res_types, .res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs), - .block = fn_body.block(), + .block = fn_body[0], .location = loc, }); @@ -475,6 +436,7 @@ pub const CompilationContext = struct { /// Given a list of donations mapping output buffers to input buffers, /// generate donation attribute for each `n_args` input argument. fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void { + const ctx = self.mlirCtx(); var n_donations: usize = 0; for (donations, 0..) |donation, index| { switch (donation) { @@ -489,12 +451,7 @@ pub const CompilationContext = struct { // When the time come, do a more fancy lookup here to check if an argument // is donated twice. stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); - attributes[a].appendAssumeCapacity( - mlir.NamedAttribute.init( - mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), - mlir.IntegerAttribute(.i32).init(self.mlirCtx(), @intCast(index)).as(mlir.Attribute), - ), - ); + attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index)))); // log.debug("attribute: {}", .{attributes[a].constSlice()}); }, } @@ -552,46 +509,29 @@ pub const CompilationContext = struct { } } - pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute { - const mlir_ctx = self.mlirCtx(); - - const num_partitions = self._platform.sharding().num_partitions; - var sharding_str: std.BoundedArray(u8, 128) = .{}; - - writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; - return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()); - } - fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void { - const mlir_ctx = self.mlirCtx(); + const ctx = self.mlirCtx(); if (!self._platform.compilation_options.sharding_enabled) return; - - const mhlo_default_layout = mlir.NamedAttribute.init( - mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"), - mlir.StringAttribute.init(mlir_ctx, "default").asAttr(), - ); + const default_layout = mlir.NamedAttribute.named(ctx, "mhlo.layout_mode", .string(ctx, "default")); for (arg_attrs, input_shapes) |*attr, shape| { - attr.appendAssumeCapacity(mhlo_default_layout); - - const sharding_attr = self.getShardingAttr(shape); - attr.appendAssumeCapacity(mlir.NamedAttribute.init( - mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), - sharding_attr.asAttr(), - )); + attr.appendAssumeCapacity(default_layout); + attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape))); } for (res_attrs, output_shapes) |*attr, shape| { - attr.appendAssumeCapacity(mhlo_default_layout); - - const sharding_attr = self.getShardingAttr(shape); - - attr.appendAssumeCapacity(mlir.NamedAttribute.init( - mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), - sharding_attr.asAttr(), - )); + attr.appendAssumeCapacity(default_layout); + attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape))); } } + pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { + const ctx = self.mlirCtx(); + const num_partitions = self._platform.sharding().num_partitions; + var sharding_str: std.BoundedArray(u8, 128) = .{}; + writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; + return mlir.Attribute.string(ctx, sharding_str.constSlice()); + } + fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void { const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info))); if (n_sharded == 0 or num_partitions == 1) { @@ -819,20 +759,11 @@ pub const CompilationContext = struct { fn computeModuleHash(platform: Platform, module: mlir.Module) u64 { var hasher = std.hash.XxHash64.init(0); - var hasher_writer = xxHash64Writer(&hasher); - const writer = hasher_writer.writer(); + module.hash(&hasher); - // Hash the canonicalized IR, without debug information that can change across builds. - module.op().print(writer, .{ .debug_info = false }); - // Note: before we where using module.op().writeBytecode(writer), - // but it crashes on some inputs, notably for unused variables. - // So we use the text representation of the mlir. - // See https://github.com/zml/zml/issues/97. - // Writes can't fail because we are writing to a hasher. - writer.writeAll(platform.pjrt_client.getPlatformName(platform.pjrt_api)) catch unreachable; + hasher.update(platform.pjrt_client.getPlatformName(platform.pjrt_api)); const api_version = platform.pjrt_api.version(); - writer.writeInt(i64, api_version.major, .little) catch unreachable; - writer.writeInt(i64, api_version.minor, .little) catch unreachable; + hasher.update(std.mem.sliceAsBytes(&[_]i64{ api_version.major, api_version.minor })); return hasher.final(); } @@ -1002,26 +933,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize { return context.index; } -pub const XxHash64Writer = struct { - hasher: *std.hash.XxHash64, - - pub const Error = error{}; - pub const Writer = std.io.Writer(*XxHash64Writer, Error, write); - - pub fn writer(self: *XxHash64Writer) Writer { - return .{ .context = self }; - } - - pub fn write(self: *XxHash64Writer, bytes: []const u8) Error!usize { - self.hasher.update(bytes); - return bytes.len; - } -}; - -pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer { - return .{ .hasher = hasher }; -} - pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn); pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 }; @@ -1165,12 +1076,11 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void { } } -const HashStrategy = std.hash.Strategy; const tensorAwareHash = hash; // alias for when "hash" is ambiguous /// Provides generic hashing for any eligible type. /// Strategy is provided to determine if pointers should be followed or not. -pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy) void { +pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Strategy) void { const Key = @TypeOf(key); if (Key == Tensor) return hashShape(hasher, key.shape()); if (Key == Shape) return hashShape(hasher, key); @@ -1287,7 +1197,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy } } -fn hashArray(hasher: anytype, key: anytype, comptime strat: HashStrategy) void { +fn hashArray(hasher: anytype, key: anytype, comptime strat: std.hash.Strategy) void { for (key) |element| { hash(hasher, element, strat); } diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 81ff600..4d5f4f0 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -131,7 +131,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { }, &.{ mlir.ext.mlirType(mlir_ctx, q.shape()), - mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type)).as(mlir.Type), + .tensor(&.{0}, .int(mlir_ctx, .u8)), }, loc, ); diff --git a/zml/ops.zig b/zml/ops.zig index c049985..ad85403 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -147,9 +147,7 @@ pub fn reduce( .variadic_operands = &.{ &input_values, &init_values }, .result_type_inference = true, .blocks = &.{body_block}, - .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), axes).as(mlir.Attribute) }, - }, + .attributes = &.{.{ "dimensions", .dense(ctx.mlirCtx(), .i64, axes) }}, // We can't verify right away, cause the weights captured by the reduce haven't been added yet. .verify = false, .location = loc, @@ -236,21 +234,16 @@ pub fn reduceWindow( ctx.extractValues(&inits, &init_values); const loc = ctx.mlirCtx().location(@src()); - - const pad_shape = mlir.RankedTensorType.init( - &.{ @intCast(opts.padding.len), 2 }, - mlir.ext.Type.fromDType(ctx.mlirCtx(), .i64), - ).as(mlir.Type); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.reduce_window", .{ .variadic_operands = &.{ input_values[0..], init_values[0..] }, .result_type_inference = true, .blocks = &.{body_block}, .attributes = &.{ - .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dimensions).as(mlir.Attribute) }, - .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_strides).as(mlir.Attribute) }, - .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.base_dilations).as(mlir.Attribute) }, - .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx.mlirCtx(), opts.window_dilations).as(mlir.Attribute) }, - .{ "padding", mlir.DenseElementsAttribute(.i64).init(pad_shape, opts.padding).as(mlir.Attribute) }, + .{ "window_dimensions", .dense(ctx.mlirCtx(), .i64, opts.window_dimensions) }, + .{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) }, + .{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) }, + .{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) }, + .{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) }, }, .location = loc, }); @@ -841,13 +834,13 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T } const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{ - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)), - mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)), + .named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)), + .named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)), + .named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])), + .named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])), + .named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])), + .named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)), + .named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)), }); const MINOR_TO_MAJOR = blk: { diff --git a/zml/shape.zig b/zml/shape.zig index bd76309..8da8519 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -1,10 +1,11 @@ -const builtin = @import("builtin"); const std = @import("std"); +const testing = std.testing; +const builtin = @import("builtin"); + const stdx = @import("stdx"); -const testing = std.testing; - const DataType = @import("dtype.zig").DataType; + const EnumLiteral = @TypeOf(.enum_literal); const log = std.log.scoped(.shape); @@ -131,6 +132,10 @@ pub const Shape = struct { return res; } + pub fn scalar(dt: DataType) Shape { + return .{ ._dtype = dt }; + } + /// Creates a Shape with dims set to `.{0, 1, 2, ..., rank-1}`. pub fn range(rank_: usize, dt: DataType) Shape { var res: Shape = .{ ._dtype = dt }; diff --git a/zml/tensor.zig b/zml/tensor.zig index 83e6c28..49381e9 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1,28 +1,29 @@ -const builtin = @import("builtin"); const std = @import("std"); +const assert = std.debug.assert; +const testing = std.testing; +const builtin = @import("builtin"); + const stdx = @import("stdx"); -const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const ops = @import("ops.zig"); -const module = @import("module.zig"); - -const Location = mlir.Location; -const CompilationContext = module.CompilationContext; -const Shape = @import("shape.zig").Shape; const Buffer = @import("buffer.zig").Buffer; -const HostBuffer = @import("hostbuffer.zig").HostBuffer; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; +const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const meta = @import("meta.zig"); +const mlir = @import("mlir.zig"); +const Location = mlir.Location; +const module = @import("module.zig"); +const CompilationContext = module.CompilationContext; +const ops = @import("ops.zig"); const Platform = @import("platform.zig").Platform; +const Shape = @import("shape.zig").Shape; + const EnumLiteral = @TypeOf(.enum_literal); const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; -const assert = std.debug.assert; -const testing = std.testing; const scoped_log = std.log.scoped(.@"zml/tensor"); test { @@ -173,15 +174,13 @@ pub const Tensor = struct { var res = self; res._shape = self._shape.withSharding(axes_); - const sharding = ctx.getShardingAttr(res._shape); - const op = dialect.stablehlo.custom_call( ctx.mlirCtx(), &.{self.value()}, .{ .call_target_name = "Sharding", .has_side_effect = false, - .addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }}, + .addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }}, .api_version = .original, }, &.{self.value().getType()}, @@ -1014,11 +1013,11 @@ pub const Tensor = struct { if (to == self.dtype()) { return self; } - - const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type); const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); - const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc); + const mlir_ctx = self.getContext().mlirCtx(); + const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to)); + const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc); return _result(self._shape.withDtype(to), op.result(0)); } @@ -1455,7 +1454,7 @@ pub const Tensor = struct { /// .{ .a, .b, .c }.mergeTranspose(.{ .a, .c }, .ac) -> .{ .b, .ac } pub fn mergeTranspose(self: Tensor, axes_: anytype, merged: EnumLiteral) Tensor { const cont = self.contiguous(axes_); - return cont.reshape(cont._shape.mergeAxis(axes_, merged)); + return cont.reshape(cont._shape.mergeAxis(merged, axes_)); } /// Transposes the input Tensor, such has the given axes end up in contiguous position. @@ -3107,7 +3106,7 @@ pub const Tensor = struct { var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK; start_indices[a] = slice_.start.value(); - const op = dialect.stablehlo.dynamicSlice( + const op = dialect.stablehlo.dynamic_slice( self.getContext().mlirCtx(), self.value(), new_shape.dims(), @@ -3168,7 +3167,7 @@ pub const Tensor = struct { res_shape._dims.set(a, len); } } - const op = dialect.stablehlo.dynamicSlice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc); + const op = dialect.stablehlo.dynamic_slice(self.getContext().mlirCtx(), self.value(), res_shape.dims(), offset_values[0..self.rank()], loc); return _result(res_shape, op.result(0)); } diff --git a/zml/zml.zig b/zml/zml.zig index 088ad68..18282b9 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -25,6 +25,7 @@ pub const nn = @import("nn.zig"); pub const module = @import("module.zig"); pub const meta = @import("meta.zig"); pub const platform = @import("platform.zig"); +pub const mlir = @import("mlir.zig"); pub const pjrt = @import("pjrtx.zig"); pub const testing = @import("testing.zig"); pub const torch = @import("torch.zig");