diff --git a/MODULE.bazel b/MODULE.bazel index e433c54..105eeca 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -19,7 +19,7 @@ bazel_dep(name = "rules_proto", version = "7.1.0") bazel_dep(name = "rules_python", version = "1.5.3") bazel_dep(name = "rules_rust", version = "0.63.0") bazel_dep(name = "rules_uv", version = "0.87.0") -bazel_dep(name = "rules_zig", version = "20250714.0-b14a4f1") +bazel_dep(name = "rules_zig", version = "20250821.0-be53625") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4") bazel_dep(name = "with_cfg.bzl", version = "0.11.0") diff --git a/async/BUILD.bazel b/async/BUILD.bazel index 65e44d0..3f3c7dc 100644 --- a/async/BUILD.bazel +++ b/async/BUILD.bazel @@ -1,6 +1,4 @@ load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") -load("@zml//bazel:zig_srcs.bzl", "zig_srcs") - zig_library( name = "async", @@ -23,11 +21,6 @@ zig_library( zig_test( name = "test", - deps = [":async"], testonly = False, -) - -zig_srcs( - name = "sources", - zig_bin = ":test", + deps = [":async"], ) diff --git a/bazel/zig_srcs.bzl b/bazel/zig_srcs.bzl index 2aa7956..ee4adb3 100644 --- a/bazel/zig_srcs.bzl +++ b/bazel/zig_srcs.bzl @@ -1,7 +1,7 @@ load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") -load("@rules_zig//zig:defs.bzl", "zig_binary", "BINARY_KIND") +load("@rules_zig//zig:defs.bzl", "zig_static_library") -def zig_srcs(name, zig_bin="", zig_lib=""): +def zig_srcs(name, zig_bin = "", zig_lib = ""): """For a given zig_library, recursively extract all zig sources into a tarball. This also includes the files translated from C headers. @@ -10,21 +10,22 @@ def zig_srcs(name, zig_bin="", zig_lib=""): """ if zig_bin == "": zig_bin = "{}_bin".format(name) - zig_binary( + zig_static_library( name = zig_bin, - kind = BINARY_KIND.bc, - tags = ["manual", "@rules_zig//zig/lib:libc"], + tags = ["manual"], deps = [zig_lib], ) native.filegroup( name = "{}_files".format(name), srcs = [zig_bin], + tags = ["manual"], output_group = "srcs", ) mtree_spec( name = "{}_mtree".format(name), srcs = [":{}_files".format(name)], + tags = ["manual"], ) tar( name = name, diff --git a/mlir/BUILD.bazel b/mlir/BUILD.bazel index ae32943..87982a7 100644 --- a/mlir/BUILD.bazel +++ b/mlir/BUILD.bazel @@ -1,11 +1,9 @@ load("@rules_cc//cc:defs.bzl", "cc_library") -load("@rules_zig//zig:defs.bzl", "zig_library") -load("//bazel:zig.bzl", "zig_cc_test") -load("//bazel:zig_srcs.bzl", "zig_srcs") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") cc_library( name = "c", - hdrs = ["c.h"], + hdrs = ["mlir.h"], visibility = ["//mlir:__subpackages__"], deps = [ "@llvm-project//mlir:CAPIArith", @@ -18,7 +16,6 @@ cc_library( zig_library( name = "mlir", - copts = ["-lc"], main = "mlir.zig", visibility = ["//visibility:public"], deps = [ @@ -27,12 +24,7 @@ zig_library( ], ) -zig_cc_test( +zig_test( name = "test", deps = [":mlir"], ) - -zig_srcs( - name = "sources", - zig_bin = ":test_test_lib", -) diff --git a/mlir/dialects/BUILD.bazel b/mlir/dialects/BUILD.bazel index c18f12d..845775e 100644 --- a/mlir/dialects/BUILD.bazel +++ b/mlir/dialects/BUILD.bazel @@ -1,6 +1,4 @@ -load("@rules_zig//zig:defs.bzl", "zig_library") -load("//bazel:zig.bzl", "zig_cc_test") -load("//bazel:zig_srcs.bzl", "zig_srcs") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") zig_library( name = "dialects", @@ -15,35 +13,14 @@ zig_library( main = "dialects.zig", visibility = ["//visibility:public"], deps = [ - ":stablehlo", "//mlir", + "//mlir/dialects/stablehlo", ], ) -zig_cc_test( +zig_test( name = "test", - deps = [":dialects"], -) - -zig_srcs( - name = "sources", - zig_bin = ":test_test_lib", -) - -zig_library( - name = "stablehlo", - import_name = "mlir/dialects/stablehlo", - main = "stablehlo.zig", - visibility = ["//mlir/dialects:__subpackages__"], deps = [ - "//mlir", - "//mlir:c", - "//stdx", - "@stablehlo//:stablehlo_dialect_capi", + ":dialects", ], ) - -zig_cc_test( - name = "stablehlo_test", - deps = [":stablehlo"], -) diff --git a/mlir/dialects/stablehlo/BUILD.bazel b/mlir/dialects/stablehlo/BUILD.bazel new file mode 100644 index 0000000..45e6923 --- /dev/null +++ b/mlir/dialects/stablehlo/BUILD.bazel @@ -0,0 +1,19 @@ +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") + +zig_library( + name = "stablehlo", + import_name = "mlir/dialects/stablehlo", + main = "stablehlo.zig", + visibility = ["//mlir/dialects:__subpackages__"], + deps = [ + "//mlir", + "//mlir:c", + "//stdx", + "@stablehlo//:stablehlo_dialect_capi", + ], +) + +zig_test( + name = "test", + deps = [":stablehlo"], +) diff --git a/mlir/dialects/stablehlo/stablehlo.zig b/mlir/dialects/stablehlo/stablehlo.zig new file mode 100644 index 0000000..8010985 --- /dev/null +++ b/mlir/dialects/stablehlo/stablehlo.zig @@ -0,0 +1,1366 @@ +const std = @import("std"); + +const c = @import("c"); +const mlir = @import("mlir"); +const stdx = @import("stdx"); + +pub const abs = functors.unary_fn("stablehlo.abs").call; +pub const cosine = functors.unary_fn("stablehlo.cosine").call; +pub const sine = functors.unary_fn("stablehlo.sine").call; +pub const exponential = functors.unary_fn("stablehlo.exponential").call; +pub const exponential_minus_one = functors.unary_fn("stablehlo.exponential_minus_one").call; +pub const floor = functors.unary_fn("stablehlo.floor").call; +pub const log = functors.unary_fn("stablehlo.log").call; +pub const log_plus_one = functors.unary_fn("stablehlo.log_plus_one").call; +pub const not = functors.unary_fn("stablehlo.not").call; +pub const negate = functors.unary_fn("stablehlo.negate").call; +pub const sqrt = functors.unary_fn("stablehlo.sqrt").call; +pub const tanh = functors.unary_fn("stablehlo.tanh").call; +pub const cbrt = functors.unary_fn("stablehlo.cbrt").call; +pub const ceil = functors.unary_fn("stablehlo.ceil").call; +pub const rsqrt = functors.unary_fn("stablehlo.rsqrt").call; +pub const count_leading_zeros = functors.unary_fn("stablehlo.count_leading_zeros").call; +pub const is_finite = functors.unary_fn("stablehlo.is_finite").call; +pub const logistic = functors.unary_fn("stablehlo.logistic").call; +pub const popcnt = functors.unary_fn("stablehlo.popcnt").call; +pub const sign = functors.unary_fn("stablehlo.sign").call; +pub const real = functors.unary_fn("stablehlo.real").call; +pub const imag = functors.unary_fn("stablehlo.imag").call; + +pub const add = functors.binary_fn("stablehlo.add").call; +pub const multiply = functors.binary_fn("stablehlo.multiply").call; +pub const divide = functors.binary_fn("stablehlo.divide").call; +pub const subtract = functors.binary_fn("stablehlo.subtract").call; +pub const or_ = functors.binary_fn("stablehlo.or").call; +pub const xor = functors.binary_fn("stablehlo.xor").call; +pub const and_ = functors.binary_fn("stablehlo.and").call; +pub const atan2 = functors.binary_fn("stablehlo.atan2").call; +pub const maximum = functors.binary_fn("stablehlo.maximum").call; +pub const minimum = functors.binary_fn("stablehlo.minimum").call; +pub const power = functors.binary_fn("stablehlo.power").call; +pub const remainder = functors.binary_fn("stablehlo.remainder").call; +pub const shift_left = functors.binary_fn("stablehlo.shift_left").call; +pub const shift_right_arithmetic = functors.binary_fn("stablehlo.shift_right_arithmetic").call; +pub const shift_right_logical = functors.binary_fn("stablehlo.shift_right_logical").call; +pub const complex = functors.binary_fn("stablehlo.complex").call; + +const functors = struct { + fn unary_fn(comptime op_name: [:0]const u8) type { + return struct { + pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, op_name, .{ + .operands = &.{value}, + .result_type_inference = true, + .location = location, + }); + } + }; + } + + pub fn binary_fn(comptime op_name: [:0]const u8) type { + return struct { + pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, op_name, .{ + .operands = &.{ lhs, rhs }, + .result_type_inference = true, + .location = location, + }); + } + }; + } +}; + +pub fn return_(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.return", .{ + .variadic_operands = &.{&.{value}}, + .verify = false, + .location = location, + }); +} + +pub fn returns_(ctx: mlir.Context, values: []const mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.return", .{ + .variadic_operands = &.{values}, + .verify = false, + .location = location, + }); +} + +pub fn bitcast_convert(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.bitcast_convert", .{ + .operands = &.{value}, + .results = &.{result_type}, + .location = location, + }); +} + +pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.cholesky", .{ + .operands = &.{value}, + .result_type_inference = true, + .attributes = &.{ + .{ "lower", .i1FromBool(ctx, lower) }, + }, + .location = location, + }); +} + +pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.clamp", .{ + .operands = &.{ min, value, max }, + .result_type_inference = true, + .location = location, + }); +} + +pub const DotPrecision = union(enum) { + fast, + high, + highest, + algorithm: DotAlgorithm, + + pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute { + const precision = PrecisionAttribute.init(ctx, switch (self) { + .fast => .DEFAULT, + .high => .HIGH, + .highest => .HIGHEST, + // When we specify the dot algorithm, we should not specify the precision. + .algorithm => .DEFAULT, + }); + return precision.asAttr(); + } + + pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute { + return switch (self) { + .algorithm => |algo| algo.asAttr(ctx, operand_type), + else => null, + }; + } +}; + +pub const DotAlgorithm = struct { + accumulation: mlir.FloatTypes, + // Note stablehlo distinguish between left/right component_count + // but all the supported algorithm have the same component_count on both side. + component_count: u8 = 1, + num_primitive_operations: u8 = 1, + allow_imprecise_accumulation: bool = false, + + // bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32. + // not sure where this is available. + pub const bf16_6x: DotAlgorithm = .{ + .operand = .bf16, + .accumulation = .f32, + .component_count = 1, + .num_primitive_operations = 6, + .allow_imprecise_accumulation = false, + }; + + pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute { + const elem_type = tensor_type.getElementType(); + + return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( + ctx._inner, + elem_type._inner, + elem_type._inner, + self.accumulation.asType(ctx)._inner, + self.component_count, + self.component_count, + self.num_primitive_operations, + self.allow_imprecise_accumulation, + )); + } +}; + +/// General matrix multiplication "a la Einstein sum" +/// Note: stablehlo doesn't do type inference for dot_general +pub fn dot_general( + ctx: mlir.Context, + lhs: mlir.Value, + rhs: mlir.Value, + result_type: mlir.Type, + location: mlir.Location, + opts: struct { + lhs_batching_dimensions: []const i64, + rhs_batching_dimensions: []const i64, + lhs_contracting_dimensions: []const i64, + rhs_contracting_dimensions: []const i64, + precision: DotPrecision, + }, +) mlir.Operation { + const precisions: [2]mlir.Attribute = @splat(opts.precision.precisionAttr(ctx)); + const attributes = [3]mlir.AttrTuple{ + .{ + "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ + .lhs_batching_dimensions = opts.lhs_batching_dimensions, + .rhs_batching_dimensions = opts.rhs_batching_dimensions, + .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, + .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, + }).asAttr(), + }, + .{ "precision_config", .array(ctx, &precisions) }, + // keep algorithm as the last attribute so we can omit it when it's not set. + .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined }, + }; + const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1; + return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ + .operands = &.{ lhs, rhs }, + .results = &.{result_type}, + .attributes = attributes[0..n_attributes], + .location = location, + }); +} + +pub fn constant( + ctx: mlir.Context, + dims: []const i64, + elem_type: mlir.DenseElementsAttributeTypes, + raw_bytes: []const u8, + location: mlir.Location, +) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.constant", .{ + .operands = &.{}, + .results = &.{.tensor(dims, elem_type.mlirType(ctx))}, + .attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }}, + .location = location, + }); +} + +pub fn convert(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.convert", .{ + .operands = &.{value}, + .results = &.{result_type}, + .location = location, + }); +} + +pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.broadcast_in_dim", .{ + .operands = &.{operand}, + .results = &.{result_type}, + .attributes = &.{ + .{ "broadcast_dimensions", .dense(ctx, .i64, dims) }, + }, + .location = location, + }); +} + +pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct { permutation: []const i64 }) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.transpose", .{ + .operands = &.{value}, + .results = &.{result_type}, + .attributes = &.{ + .{ "permutation", .dense(ctx, .i64, opts.permutation) }, + }, + .location = location, + }); +} + +pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64, limit_indices: []const i64, strides: []const i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.slice", .{ + .operands = &.{operand}, + .results = &.{result_type}, + .attributes = &.{ + .{ "start_indices", .dense(ctx, .i64, start_indices) }, + .{ "limit_indices", .dense(ctx, .i64, limit_indices) }, + .{ "strides", .dense(ctx, .i64, strides) }, + }, + .location = location, + }); +} + +pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.concatenate", .{ + .operands = inputs, + .result_type_inference = true, + .attributes = &.{ + .{ "dimension", .int(ctx, .i64, dimension) }, + }, + .location = location, + }); +} + +pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.reshape", .{ + .operands = &.{value}, + .results = &.{result_type}, + .location = location, + }); +} + +pub fn select(ctx: mlir.Context, condition: mlir.Value, then: mlir.Value, else_: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.select", .{ + .operands = &.{ condition, then, else_ }, + .results = &.{then.getType()}, + .location = location, + }); +} + +pub fn gather( + ctx: mlir.Context, + value: mlir.Value, + indices: mlir.Value, + slice_sizes: []const i64, + location: mlir.Location, + args: struct { + offset_dims: []const i64, + collapsed_slice_dims: []const i64, + operand_batching_dims: []const i64, + start_indices_batching_dims: []const i64, + start_index_map: []const i64, + index_vector_dim: i64, + indices_are_sorted: bool = false, + }, +) mlir.Operation { + return mlir.Operation.make( + ctx, + "stablehlo.gather", + .{ + .operands = &.{ value, indices }, + .result_type_inference = true, + .attributes = &.{ + .{ "dimension_numbers", GatherDimensionNumbersAttribute.init( + ctx, + args.offset_dims, + args.collapsed_slice_dims, + args.operand_batching_dims, + args.start_indices_batching_dims, + args.start_index_map, + args.index_vector_dim, + ).asAttr() }, + .{ "slice_sizes", .dense(ctx, .i64, slice_sizes) }, + .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, + }, + .location = location, + }, + ); +} + +fn elementTypeOrSelf(typ: mlir.Type) mlir.Type { + return if (typ.as(mlir.ShapedType)) |shaped| { + return shaped.elementType(); + } else typ; +} + +pub const ScatterArgs = struct { + update_window_dims: []const i64, + inserted_window_dims: []const i64, + input_batching_dims: []const i64, + scatter_indices_batching_dims: []const i64, + scatter_dims_to_operand_dims: []const i64, + index_vector_dim: i64, + indices_are_sorted: bool = false, + unique_indices: bool = false, + + pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute { + return .{ ._inner = c.stablehloScatterDimensionNumbersGet( + ctx._inner, + @intCast(self.update_window_dims.len), + self.update_window_dims.ptr, + @intCast(self.inserted_window_dims.len), + self.inserted_window_dims.ptr, + @intCast(self.input_batching_dims.len), + self.input_batching_dims.ptr, + @intCast(self.scatter_indices_batching_dims.len), + self.scatter_indices_batching_dims.ptr, + @intCast(self.scatter_dims_to_operand_dims.len), + self.scatter_dims_to_operand_dims.ptr, + self.index_vector_dim, + ) }; + } +}; + +pub fn scatter( + ctx: mlir.Context, + inputs: []const mlir.Value, + scatter_indices: []const mlir.Value, + updates: []const mlir.Value, + update_block: mlir.Block, + args: ScatterArgs, + location: mlir.Location, +) mlir.Operation { + return mlir.Operation.make( + ctx, + "stablehlo.scatter", + .{ + .variadic_operands = &.{ inputs, scatter_indices, updates }, + .blocks = &.{update_block}, + .attributes = &.{ + .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, + .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, + .{ "unique_indices", .boolean(ctx, args.unique_indices) }, + }, + .result_type_inference = true, + .location = location, + }, + ); +} + +pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.iota", .{ + .operands = &.{}, + .results = &.{result_type}, + .attributes = &.{ + .{ "iota_dimension", .int(ctx, .i64, dimension) }, + }, + .location = location, + }); +} + +pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, location: mlir.Location) mlir.Operation { + const result_type = operand.getType(); + return mlir.Operation.make(ctx, "stablehlo.reverse", .{ + .operands = &.{operand}, + .results = &.{result_type}, + .attributes = &.{ + .{ "dimensions", .dense(ctx, .i64, dimensions) }, + }, + .location = location, + }); +} + +pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.compare", .{ + .operands = &.{ lhs, rhs }, + .result_type_inference = true, + .attributes = &.{ + .{ "comparison_direction", comparison_direction.asAttr() }, + .{ "compare_type", compare_type.asAttr() }, + }, + .location = location, + }); +} + +pub fn reduce( + ctx: mlir.Context, + inputs: []const mlir.Value, + init_values: []const mlir.Value, + dimensions: []const i64, + blkctx: anytype, + blkfn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation, + location: mlir.Location, +) mlir.Operation { + const MaxBlockArguments = 32; + + const block_n_args = inputs.len + init_values.len; + const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; + var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; + for (inputs, 0..) |input, i| { + const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); + reduce_elem_types[i] = arg_type; + reduce_elem_types[inputs.len + i] = arg_type; + } + var block = mlir.Block.open(reduce_elem_types[0..block_n_args], locations) catch unreachable; + { + defer block.close(); + + var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined; + var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined; + for (0..inputs.len) |i| { + block_inputs[i] = block.argument(i); + block_accs[i] = block.argument(inputs.len + i); + } + _ = blkfn(blkctx, ctx, block_inputs[0..inputs.len], block_accs[0..init_values.len]); + } + + return mlir.Operation.make(ctx, "stablehlo.reduce", .{ + .variadic_operands = &.{ inputs, init_values }, + .result_type_inference = true, + .block = block, + .attributes = &.{ + .{ "dimensions", .dense(ctx, .i64, dimensions) }, + }, + .location = location, + }); +} + +pub fn sort( + ctx: mlir.Context, + inputs: []const mlir.Value, + dimension: i64, + is_stable: bool, + blkctx: anytype, + compfn: fn (anytype, mlir.Context, []const mlir.Value) mlir.Operation, + location: mlir.Location, +) mlir.Operation { + const MaxBlockArguments = 32; + + const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2]; + var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined; + for (inputs, 0..) |input, i| { + const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); + sort_elem_types[i * 2] = arg_type; + sort_elem_types[i * 2 + 1] = arg_type; + } + var block = mlir.Block.init(sort_elem_types[0 .. inputs.len * 2], locations) catch unreachable; + + var block_inputs: [MaxBlockArguments]mlir.Value = undefined; + for (0..inputs.len * 2) |i| { + block_inputs[i] = block.argument(i); + } + _ = compfn(blkctx, ctx, block_inputs[0 .. inputs.len * 2]); + + return mlir.Operation.make(ctx, "stablehlo.sort", .{ + .variadic_operands = &.{inputs}, + .result_type_inference = true, + .block = block, + .attributes = &.{ + .{ "dimension", .int(ctx, .i64, dimension) }, + .{ "is_stable", .boolean(ctx, is_stable) }, + }, + .location = location, + }); +} + +pub fn dynamic_slice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.dynamic_slice", .{ + .variadic_operands = &.{ &.{operand}, start_indices }, + .result_type_inference = true, + .attributes = &.{ + .{ "slice_sizes", .dense(ctx, .i64, new_dims) }, + }, + .location = location, + }); +} + +pub fn round_nearest_afz(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.round_nearest_afz", .{ + .operands = &.{value}, + .result_type_inference = true, + .location = location, + }); +} + +pub fn round_nearest_even(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.round_nearest_even", .{ + .operands = &.{value}, + .result_type_inference = true, + .location = location, + }); +} + +pub const PadOpts = struct { + low: []const i64, + high: []const i64, + interior: []const i64, +}; + +pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.pad", .{ + .operands = &.{ value, padding_value }, + .result_type_inference = true, + .attributes = &.{ + .{ "edge_padding_low", .dense(ctx, .i64, opts.low) }, + .{ "edge_padding_high", .dense(ctx, .i64, opts.high) }, + .{ "interior_padding", .dense(ctx, .i64, opts.interior) }, + }, + .location = location, + }); +} + +pub const TriangularSolveOpts = struct { + left_side: bool, + lower: bool, + unit_diagonal: bool, + transpose_a: Transpose.Type, +}; + +pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value, location: mlir.Location, opts: TriangularSolveOpts) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.triangular_solve", .{ + .operands = &.{ value, other }, + .result_type_inference = true, + .attributes = &.{ + .{ "left_side", .i1FromBool(ctx, opts.left_side) }, + .{ "lower", .i1FromBool(ctx, opts.lower) }, + .{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) }, + .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() }, + }, + .location = location, + }); +} + +pub const FftOpts = struct { + kind: FftType.Type, + length: []const i64, +}; + +pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts: FftOpts) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.fft", .{ + .operands = &.{value}, + .result_type_inference = true, + .attributes = &.{ + .{ "fft_type", FftType.init(ctx, opts.kind).asAttr() }, + .{ "fft_length", .dense(ctx, .i64, opts.length) }, + }, + .location = location, + }); +} + +pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, rng_distribution: RngDistribution.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.rng", .{ + .operands = &.{ a, b, shape }, + .result_type_inference = true, + .attributes = &.{ + .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() }, + }, + .location = location, + }); +} + +pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, initial_state: mlir.Value, res_state_type: mlir.Type, res_type: mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.rng_bit_generator", .{ + .operands = &.{initial_state}, + .results = &.{ res_state_type, res_type }, + .attributes = &.{ + .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() }, + }, + .location = location, + }); +} + +pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32, mantissa_bits: i32, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.reduce_precision", .{ + .operands = &.{value}, + .result_type_inference = true, + .attributes = &.{ + .{ "exponent_bits", .int(ctx, .i32, exponent_bits) }, + .{ "mantissa_bits", .int(ctx, .i32, mantissa_bits) }, + }, + .location = location, + }); +} + +pub fn dynamic_update_slice(ctx: mlir.Context, operand: mlir.Value, update: mlir.Value, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.dynamic_update_slice", .{ + .variadic_operands = &.{ &.{operand}, &.{update}, start_indices }, + .result_type_inference = true, + .location = location, + }); +} + +pub fn tuple(ctx: mlir.Context, values: []const mlir.Value, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.tuple", .{ + .operands = values, + .result_type_inference = true, + .location = location, + }); +} + +pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.get_tuple_element", .{ + .operands = &.{tuple_value}, + .result_type_inference = true, + .attributes = &.{ + .{ "index", .int(ctx, .i32, index) }, + }, + .location = location, + }); +} + +pub const ConvolutionOpts = struct { + window_strides: []const i64, + pad_value: []const i64, + pad_shape: []const i64 = &.{}, + lhs_dilation: []const i64, + rhs_dilation: []const i64, + window_reversal: []const bool, + input_batch_dimension: i64, + input_feature_dimension: i64, + input_spatial_dimensions: []const i64, + kernel_input_feature_dimension: i64, + kernel_output_feature_dimension: i64, + kernel_spatial_dimensions: []const i64, + output_batch_dimension: i64, + output_feature_dimension: i64, + output_spatial_dimensions: []const i64, + feature_group_count: i64, + batch_group_count: i64, + precision_config: []const PrecisionAttribute.Precision = &.{}, +}; + +pub fn convolution( + ctx: mlir.Context, + lhs: mlir.Value, + rhs: mlir.Value, + opts: ConvolutionOpts, + res_type: mlir.Type, + location: mlir.Location, +) mlir.Operation { + var max_precisions: [2]mlir.Attribute = undefined; + for (opts.precision_config, 0..) |p, i| { + max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr(); + } + var window_reversal: [3]i32 = undefined; + for (opts.window_reversal, 0..) |w, i| { + window_reversal[i] = @intCast(@intFromBool(w)); + } + return mlir.Operation.make(ctx, "stablehlo.convolution", .{ + .operands = &.{ lhs, rhs }, + .results = &.{res_type}, + .attributes = &.{ + .{ "window_strides", .dense(ctx, .i64, opts.window_strides) }, + .{ "padding", .denseElements(ctx, opts.pad_shape, .i64, opts.pad_value) }, + .{ "lhs_dilation", .dense(ctx, .i64, opts.lhs_dilation) }, + .{ "rhs_dilation", .dense(ctx, .i64, opts.rhs_dilation) }, + .{ "window_reversal", .dense(ctx, .bool, window_reversal[0..opts.window_reversal.len]) }, + .{ + "dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{ + .input_batch_dimension = opts.input_batch_dimension, + .input_feature_dimension = opts.input_feature_dimension, + .input_spatial_dimensions = opts.input_spatial_dimensions, + .kernel_input_feature_dimension = opts.kernel_input_feature_dimension, + .kernel_output_feature_dimension = opts.kernel_output_feature_dimension, + .kernel_spatial_dimensions = opts.kernel_spatial_dimensions, + .output_batch_dimension = opts.output_batch_dimension, + .output_feature_dimension = opts.output_feature_dimension, + .output_spatial_dimensions = opts.output_spatial_dimensions, + }).asAttr(), + }, + .{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) }, + .{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) }, + .{ "precision_config", .array(ctx, &max_precisions) }, + }, + .location = location, + }); +} + +pub const CustomCallOpts = struct { + pub const ApiVersion = enum(i32) { + original = 1, + status_returning = 2, + status_returning_unified = 3, + typed_ffi = 4, + }; + + call_target_name: [:0]const u8, + has_side_effect: bool, + backend_config: ?mlir.Attribute, + operand_layouts: ?[]const []const usize = null, + result_layouts: ?[]const []const usize = null, + output_operand_aliases: []const i64 = &.{}, + additional_attributes: []const mlir.AttrTuple = &.{}, + api_version: ApiVersion, +}; + +pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { + const MAX_OPERANDS = 64; + const MAX_RESULTS = 16; + + const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, ""); + if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { + stdx.debug.assert( + backend_config.isA(mlir.StringAttribute), + "API version < 4 requires a string as backend_config, got {}", + .{backend_config}, + ); + } else { + stdx.debug.assert( + backend_config.isA(mlir.DictionaryAttribute), + "API version >= 4 requires a dictionary as backend_config, got {}", + .{backend_config}, + ); + } + + var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; + attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ + .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, + .{ "call_target_name", .string(ctx, opts.call_target_name) }, + .{ "has_side_effect", .boolean(ctx, opts.has_side_effect) }, + .{ "backend_config", backend_config }, + }); + + { + var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (opts.output_operand_aliases) |alias| { + output_operand_aliases.appendAssumeCapacity( + OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(), + ); + } + attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); + } + + const MINOR_TO_MAJOR = blk: { + const MAX_RANK = 8; + var ret: [MAX_RANK]usize = undefined; + for (0..MAX_RANK) |i| { + ret[i] = @intCast(MAX_RANK - i - 1); + } + break :blk ret; + }; + + if (opts.operand_layouts) |layouts| { + var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (layouts) |ol| { + operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); + } + attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); + } else { + const operand_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (inputs) |input| { + const ranked_type = input.getType().as(mlir.RankedTensorType).?; + const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); + } + break :blk ret; + }; + attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); + } + + if (opts.result_layouts) |layouts| { + var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (layouts) |rl| { + result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); + } + attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); + } else { + const result_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (res_types) |t| { + const ranked_t = t.as(mlir.RankedTensorType).?; + const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; + ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); + } + break :blk ret; + }; + attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); + } + + attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes"); + + return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ + .operands = inputs, + .results = res_types, + .attributes = attrs.constSlice(), + .location = location, + }); +} + +pub const DotDimensionNumbersAttribute = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers; + const Self = DotDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub fn init(ctx: mlir.Context, args: struct { + lhs_batching_dimensions: []const i64, + rhs_batching_dimensions: []const i64, + lhs_contracting_dimensions: []const i64, + rhs_contracting_dimensions: []const i64, + }) Self { + return .{ + ._inner = c.stablehloDotDimensionNumbersGet( + ctx._inner, + @intCast(args.lhs_batching_dimensions.len), + args.lhs_batching_dimensions.ptr, + @intCast(args.rhs_batching_dimensions.len), + args.rhs_batching_dimensions.ptr, + @intCast(args.lhs_contracting_dimensions.len), + args.lhs_contracting_dimensions.ptr, + @intCast(args.rhs_contracting_dimensions.len), + args.rhs_contracting_dimensions.ptr, + ), + }; + } + + pub fn getLhsBatchingDimensionsSize(self: Self) usize { + return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner)); + } + + pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos)); + } + + pub fn getRhsBatchingDimensionsSize(self: Self) usize { + return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner)); + } + + pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos)); + } + + pub fn getLhsContractingDimensionsSize(self: Self) usize { + return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner)); + } + + pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos)); + } + + pub fn getRhsContractingDimensionsSize(self: Self) usize { + return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner)); + } + + pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos)); + } +}; + +pub const GatherDimensionNumbersAttribute = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers; + const Self = GatherDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub fn init( + ctx: mlir.Context, + offset_dims: []const i64, + collapsed_slice_dims: []const i64, + operand_batching_dims: []const i64, + start_indices_batching_dims: []const i64, + start_index_map: []const i64, + index_vector_dim: i64, + ) Self { + return .{ + ._inner = c.stablehloGatherDimensionNumbersGet( + ctx._inner, + @intCast(offset_dims.len), + offset_dims.ptr, + @intCast(collapsed_slice_dims.len), + collapsed_slice_dims.ptr, + @intCast(operand_batching_dims.len), + operand_batching_dims.ptr, + @intCast(start_indices_batching_dims.len), + start_indices_batching_dims.ptr, + @intCast(start_index_map.len), + start_index_map.ptr, + index_vector_dim, + ), + }; + } + + pub fn getOffsetDimsSize(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner)); + } + + pub fn getOffsetDimsElem(self: Self, pos: usize) i64 { + return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos)); + } + + pub fn getCollapsedSliceDimsSize(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner)); + } + + pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 { + return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos)); + } + + pub fn getStartIndexMapSize(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner)); + } + + pub fn getOperandBatchingDimsSize(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner)); + } + + pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 { + return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos)); + } + + pub fn getStartIndicesBatchingDimsSize(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner)); + } + + pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 { + return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos)); + } + + pub fn getStartIndexMapElem(self: Self, pos: usize) i64 { + return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos)); + } + + pub fn getIndexVectorDim(self: Self) usize { + return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner)); + } +}; + +pub const ConvDimensionNumbersAttribute = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers; + const Self = ConvDimensionNumbersAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub fn init(ctx: mlir.Context, args: struct { + input_batch_dimension: i64, + input_feature_dimension: i64, + input_spatial_dimensions: []const i64, + kernel_input_feature_dimension: i64, + kernel_output_feature_dimension: i64, + kernel_spatial_dimensions: []const i64, + output_batch_dimension: i64, + output_feature_dimension: i64, + output_spatial_dimensions: []const i64, + }) Self { + return .{ + ._inner = c.stablehloConvDimensionNumbersGet( + ctx._inner, + args.input_batch_dimension, + args.input_feature_dimension, + @intCast(args.input_spatial_dimensions.len), + args.input_spatial_dimensions.ptr, + args.kernel_input_feature_dimension, + args.kernel_output_feature_dimension, + @intCast(args.kernel_spatial_dimensions.len), + args.kernel_spatial_dimensions.ptr, + args.output_batch_dimension, + args.output_feature_dimension, + @intCast(args.output_spatial_dimensions.len), + args.output_spatial_dimensions.ptr, + ), + }; + } + + pub fn getInputBatchDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner); + } + + pub fn getInputFeatureDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner); + } + + pub fn getInputSpatialDimensionsSize(self: Self) usize { + return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner)); + } + + pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos)); + } + + pub fn getKernelInputFeatureDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner); + } + + pub fn getKernelOutputFeatureDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner); + } + + pub fn getKernelSpatialDimensionsSize(self: Self) usize { + return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner)); + } + + pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos)); + } + + pub fn getOutputBatchDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner); + } + + pub fn getOutputFeatureDimension(self: Self) i64 { + return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner); + } + + pub fn getOutputSpatialDimensionsSize(self: Self) usize { + return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner)); + } + + pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 { + return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos)); + } +}; + +pub const OutputOperandAliasAttribute = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias; + pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute); + pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute); + + pub fn init( + ctx: mlir.Context, + output_tuple_indices: []const i64, + operand_index: i64, + operand_tuple_indices: []const i64, + ) OutputOperandAliasAttribute { + return .{ ._inner = c.stablehloOutputOperandAliasGet( + ctx._inner, + @intCast(output_tuple_indices.len), + output_tuple_indices.ptr, + @intCast(operand_index), + @intCast(operand_tuple_indices.len), + operand_tuple_indices.ptr, + ) }; + } +}; + +pub const PrecisionAttribute = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr; + const Self = PrecisionAttribute; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Precision = enum { + DEFAULT, + HIGH, + HIGHEST, + }; + + pub fn init(ctx: mlir.Context, value: Precision) Self { + return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Precision { + const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner)); + return std.meta.stringToEnum(Precision, value) orelse unreachable; + } +}; + +pub const ComparisonDirection = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr; + const Self = ComparisonDirection; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Direction = enum { + EQ, + NE, + GE, + GT, + LE, + LT, + }; + + pub fn init(ctx: mlir.Context, value: Direction) Self { + return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Direction { + const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner)); + return std.meta.stringToEnum(Direction, value) orelse unreachable; + } +}; + +pub const CompareType = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr; + const Self = CompareType; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Type = enum { + SIGNED, + UNSIGNED, + FLOAT, + TOTALORDER, + }; + + pub fn init(ctx: mlir.Context, value: Type) Self { + return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Type { + const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner)); + return std.meta.stringToEnum(Type, value) orelse unreachable; + } +}; + +pub const Transpose = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsATransposeAttr; + const Self = Transpose; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Type = enum { + NO_TRANSPOSE, + TRANSPOSE, + ADJOINT, + }; + + pub fn init(ctx: mlir.Context, value: Type) Self { + return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Type { + const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner)); + return std.meta.stringToEnum(Type, value) orelse unreachable; + } +}; + +pub const FftType = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr; + const Self = FftType; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Type = enum { + FFT, + IFFT, + RFFT, + IRFFT, + }; + + pub fn init(ctx: mlir.Context, value: Type) Self { + return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Type { + const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner)); + return std.meta.stringToEnum(Type, value) orelse unreachable; + } +}; + +pub const RngDistribution = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr; + const Self = RngDistribution; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Type = enum { + UNIFORM, + NORMAL, + }; + + pub fn init(ctx: mlir.Context, value: Type) Self { + return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Type { + const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner)); + return std.meta.stringToEnum(Type, value) orelse unreachable; + } +}; + +pub const RngAlgorithm = struct { + _inner: c.MlirAttribute, + + pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr; + const Self = RngAlgorithm; + pub const asAttr = mlir.Attribute.fromAny(Self); + pub const eql = mlir.Attribute.eqlAny(Self); + + pub const Type = enum { + DEFAULT, + THREE_FRY, + PHILOX, + }; + + pub fn init(ctx: mlir.Context, value: Type) Self { + return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; + } + + pub fn getValue(self: Self) Type { + const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner)); + return std.meta.stringToEnum(Type, value) orelse unreachable; + } +}; + +pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 { + const state = struct { + var buf: [32]u8 = undefined; + + fn call(req: c.MlirStablehloCompatibilityRequirement) []u8 { + var stream = std.io.fixedBufferStream(&buf); + var context = .{ .writer = stream.writer() }; + const WriterContext = @TypeOf(context); + + c.stablehloVersionFromCompatibilityRequirement(req, (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); + + return buf[0..stream.pos]; + } + }; + + return state.call(requirement); +} + +pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 { + var buf: [32]u8 = undefined; + + var stream = std.io.fixedBufferStream(&buf); + var context = .{ .writer = stream.writer() }; + const WriterContext = @TypeOf(context); + + _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); + + return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2; +} + +pub fn getCurrentVersion() []const u8 { + const state = struct { + var buf: [32]u8 = undefined; + var str: []const u8 = undefined; + var once = std.once(call); + + fn call() void { + var stream = std.io.fixedBufferStream(&buf); + var writer_ = stream.writer(); + const ContextWriter = @TypeOf(writer_); + + c.stablehloGetCurrentVersion((struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const writer: *ContextWriter = @ptrCast(@alignCast(userdata)); + _ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &writer_); + + str = buf[0..stream.pos]; + } + }; + + state.once.call(); + return state.str; +} + +pub fn getMinimumVersion() []const u8 { + const state = struct { + var buf: [32]u8 = undefined; + var str: []const u8 = undefined; + var once = std.once(call); + + fn call() void { + var stream = std.io.fixedBufferStream(&buf); + var context = .{ .writer = stream.writer() }; + const WriterContext = @TypeOf(context); + + c.stablehloGetMinimumVersion((struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); + + str = buf[0..stream.pos]; + } + }; + + state.once.call(); + return state.str; +} + +pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void { + var context = .{ .writer = writer }; + const WriterContext = @TypeOf(context); + + try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context), error.InvalidMlirBytecodeVersion); +} diff --git a/mlir/mlir.h b/mlir/mlir.h new file mode 100644 index 0000000..4ad1523 --- /dev/null +++ b/mlir/mlir.h @@ -0,0 +1,9 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 470a415..173bea7 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -1,5 +1,4 @@ load("@rules_zig//zig:defs.bzl", "zig_library") -load("@zml//bazel:zig_srcs.bzl", "zig_srcs") zig_library( name = "pjrt", @@ -15,8 +14,3 @@ zig_library( "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", ], ) - -zig_srcs( - name = "sources", - zig_lib = ":pjrt", -) diff --git a/runtimes/neuron/BUILD.bazel b/runtimes/neuron/BUILD.bazel index deab87a..0fbe667 100644 --- a/runtimes/neuron/BUILD.bazel +++ b/runtimes/neuron/BUILD.bazel @@ -2,7 +2,7 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:pip.bzl", "compile_pip_requirements") load("@rules_python//python/entry_points:py_console_script_binary.bzl", "py_console_script_binary") load("@rules_uv//uv:pip.bzl", "pip_compile") -load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library") load(":neuron.bzl", "py_binary_with_script") load(":pyenv.bzl", "pyenv_zig") @@ -21,12 +21,9 @@ zig_library( # # Additionally, it provides a way to load implicit transitive dependencies # of neuronx-cc (see add_needed of the patchelf target below). -# -# TODO(cerisier): Use a zig_cc_shared_library instead. -zig_binary( +zig_shared_library( name = "libpjrt_neuron_proxy", - copts = ["-lc"], - kind = BINARY_KIND.shared_lib, + copts = ["-fno-stack-check"], main = "libpjrt_neuron.zig", visibility = ["@libpjrt_neuron//:__subpackages__"], deps = [ diff --git a/runtimes/rocm/BUILD.bazel b/runtimes/rocm/BUILD.bazel index bab048b..536ccb9 100644 --- a/runtimes/rocm/BUILD.bazel +++ b/runtimes/rocm/BUILD.bazel @@ -1,5 +1,5 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_zig//zig:defs.bzl", "zig_library") -load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") cc_library( name = "zmlxrocm_lib", diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 171d136..76389c5 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -1,5 +1,4 @@ load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") -load("@zml//bazel:zig_srcs.bzl", "zig_srcs") zig_library( name = "stdx", @@ -22,11 +21,6 @@ zig_library( zig_test( name = "test", - deps = [":stdx"], testonly = False, -) - -zig_srcs( - name = "sources", - zig_bin = ":test", + deps = [":stdx"], ) diff --git a/stdx/fs.zig b/stdx/fs.zig index 3f278e3..ea8a5fc 100644 --- a/stdx/fs.zig +++ b/stdx/fs.zig @@ -1,5 +1,28 @@ const std = @import("std"); +extern fn dladdr(addr: *anyopaque, info: *Dl_info) c_int; + +const Dl_info = extern struct { + dli_fname: [*c]const u8, + dli_fbase: *anyopaque, + dli_sname: [*c]const u8, + dli_saddr: *anyopaque, +}; + +fn selfSharedObjectPathImpl(addr: usize) []const u8 { + var info: Dl_info = undefined; + _ = dladdr(@ptrFromInt(addr), &info); + return std.mem.span(info.dli_fname); +} + +pub fn selfSharedObjectPath() []const u8 { + return selfSharedObjectPathImpl(@returnAddress()); +} + +pub fn selfSharedObjectDirPath() []const u8 { + return std.fs.path.dirname(selfSharedObjectPathImpl(@returnAddress())).?; +} + pub const path = struct { pub fn bufJoin(buf: []u8, paths: []const []const u8) ![]u8 { var fa: std.heap.FixedBufferAllocator = .init(buf); diff --git a/third_party/modules/rules_zig/20250821.0-be53625/MODULE.bazel b/third_party/modules/rules_zig/20250821.0-be53625/MODULE.bazel new file mode 100644 index 0000000..3c64eb3 --- /dev/null +++ b/third_party/modules/rules_zig/20250821.0-be53625/MODULE.bazel @@ -0,0 +1,75 @@ +module( + name = "rules_zig", + version = "20250821.0-be53625", + compatibility_level = 1, +) + +bazel_dep(name = "aspect_bazel_lib", version = "2.8.1") +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") + +zig = use_extension("//zig:extensions.bzl", "zig") +zig.index(file = "//zig/private:versions.json") +use_repo(zig, "zig_toolchains") + +register_toolchains("@rules_zig//zig/target:all") + +register_toolchains("@zig_toolchains//:all") + +zig_dev = use_extension( + "//zig:extensions.bzl", + "zig", + dev_dependency = True, +) +zig_dev.toolchain(zig_version = "0.13.0") +zig_dev.toolchain(zig_version = "0.12.1") +zig_dev.toolchain(zig_version = "0.12.0") +zig_dev.toolchain(zig_version = "0.11.0") + +bazel_dep(name = "rules_cc", version = "0.0.9") + +bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc") +bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle") +bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True) +bazel_dep( + name = "buildifier_prebuilt", + version = "7.3.1", + dev_dependency = True, +) +bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True) +bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True) +bazel_dep( + name = "rules_bazel_integration_test", + version = "0.25.0", + dev_dependency = True, +) + +bazel_binaries = use_extension( + "@rules_bazel_integration_test//:extensions.bzl", + "bazel_binaries", + dev_dependency = True, +) + +# NOTE: Keep in sync with WORKSPACE. +bazel_binaries.download(version_file = "//:.bazelversion") +bazel_binaries.download(version = "7.0.0") +use_repo( + bazel_binaries, + "bazel_binaries", + "bazel_binaries_bazelisk", + "build_bazel_bazel_.bazelversion", + "build_bazel_bazel_7_0_0", +) + +# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test. +# However, if we do not include it explicitly, then the runfiles resolution for +# cgrindel_bazel_starlib/shlib/lib/message.sh fails in +# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked +# through the rules_multirun target //util:update. +bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True) + +# Hack to get around a cc_common.link limitation. +# See https://github.com/bazelbuild/bazel/pull/23838 +cc_common_link = use_repo_rule("//zig:extensions.bzl", "cc_common_link") + +cc_common_link(name = "build_bazel_rules_android") diff --git a/third_party/modules/rules_zig/20250821.0-be53625/source.json b/third_party/modules/rules_zig/20250821.0-be53625/source.json new file mode 100644 index 0000000..959b2db --- /dev/null +++ b/third_party/modules/rules_zig/20250821.0-be53625/source.json @@ -0,0 +1,5 @@ +{ + "strip_prefix": "rules_zig-be53625afb13e73856ee4cab38d5aad9f86f63ef", + "url": "https://github.com/zml/rules_zig/archive/be53625afb13e73856ee4cab38d5aad9f86f63ef.tar.gz", + "integrity": "sha256-IsqEl3BHW/vttyHpVBx5WzO3an/K3eZZx8RJk7nVx8s=" +} diff --git a/third_party/modules/rules_zig/metadata.json b/third_party/modules/rules_zig/metadata.json index 00e9f2f..00757b8 100644 --- a/third_party/modules/rules_zig/metadata.json +++ b/third_party/modules/rules_zig/metadata.json @@ -7,18 +7,17 @@ "name": "ZML Engineering Team" } ], - "repository": [ - "github:zml/rules_zig" - ], + "repository": ["github:zml/rules_zig"], "versions": [ - "20240904.0-010da15", - "20240909.0-37f17ff", - "20240912.0-41bfe84", - "20240913.0-1957d05", - "20250314.0-b9739c6", - "20250519.0-233b207", - "20250613.0-567662a", - "20250714.0-b14a4f1" + "20240904.0-010da15", + "20240909.0-37f17ff", + "20240912.0-41bfe84", + "20240913.0-1957d05", + "20250314.0-b9739c6", + "20250519.0-233b207", + "20250613.0-567662a", + "20250714.0-b14a4f1", + "20250821.0-be53625" ], "yanked_versions": {} } diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index 556ae03..7901965 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -1,8 +1,6 @@ load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library") load("@rules_cc//cc:defs.bzl", "cc_library") -load("@rules_zig//zig:defs.bzl", "zig_library") -load("//bazel:zig.bzl", "zig_cc_test") -load("//bazel:zig_srcs.bzl", "zig_srcs") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") upb_c_proto_library( name = "xla_data_upb", @@ -52,7 +50,7 @@ zig_library( # All ZML Tests -zig_cc_test( +zig_test( name = "test", data = [ "aio/torch/simple.pt", @@ -67,8 +65,3 @@ filegroup( srcs = ["test_runner.zig"], visibility = ["//visibility:public"], ) - -zig_srcs( - name = "sources", - zig_bin = ":test_test_lib", -) diff --git a/zml/tokenizer/BUILD.bazel b/zml/tokenizer/BUILD.bazel index d1bb0de..4489d77 100644 --- a/zml/tokenizer/BUILD.bazel +++ b/zml/tokenizer/BUILD.bazel @@ -1,5 +1,4 @@ -load("@rules_zig//zig:defs.bzl", "zig_library") -load("@zml//bazel:zig.bzl", "zig_cc_binary") +load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library") zig_library( name = "tokenizer", @@ -15,7 +14,7 @@ zig_library( ], ) -zig_cc_binary( +zig_binary( name = "main", main = "main.zig", visibility = ["//visibility:public"], diff --git a/zml/tokenizer/sentencepiece/BUILD.bazel b/zml/tokenizer/sentencepiece/BUILD.bazel index 457d811..354b289 100644 --- a/zml/tokenizer/sentencepiece/BUILD.bazel +++ b/zml/tokenizer/sentencepiece/BUILD.bazel @@ -1,6 +1,5 @@ load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:swig.bzl", "swig_cc_library") -load("//bazel:zig_srcs.bzl", "zig_srcs") swig_cc_library( name = "sentencepiece_swig", @@ -22,8 +21,3 @@ zig_library( "//ffi:zig", ], ) - -zig_srcs( - name = "sources", - zig_lib = ":sentencepiece", -)