From c488b634fc90f47aa8cd28ca37ba7da92ae1342d Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 7 Jul 2025 16:48:07 +0000 Subject: [PATCH] runtimes/rocm: implement zmlxrocm in Zig Also, sandbox `amdgpu.ids` and restore safetensors json parsing. --- bazel/cc_import.bzl | 2 +- bazel/patchelf.bzl | 8 +- bazel/zig.bzl | 63 -- mlir/c.h | 9 - mlir/dialects/stablehlo.zig | 1366 ------------------------ runtimes/cuda/libpjrt_cuda.BUILD.bazel | 2 +- runtimes/neuron/neuron.bzl | 4 +- runtimes/rocm/BUILD.bazel | 19 +- runtimes/rocm/libpjrt_rocm.BUILD.bazel | 57 +- runtimes/rocm/packages.lock.json | 51 +- runtimes/rocm/packages.yaml | 4 +- runtimes/rocm/rocm.bzl | 26 +- runtimes/rocm/rocm.zig | 38 +- runtimes/rocm/zmlxrocm.c | 52 - runtimes/rocm/zmlxrocm.zig | 50 + zml/aio/json.zig | 107 ++ zml/aio/safetensors.zig | 19 +- 17 files changed, 280 insertions(+), 1597 deletions(-) delete mode 100644 bazel/zig.bzl delete mode 100644 mlir/c.h delete mode 100644 mlir/dialects/stablehlo.zig delete mode 100644 runtimes/rocm/zmlxrocm.c create mode 100644 runtimes/rocm/zmlxrocm.zig create mode 100644 zml/aio/json.zig diff --git a/bazel/cc_import.bzl b/bazel/cc_import.bzl index d65d90d..7c0e8de 100644 --- a/bazel/cc_import.bzl +++ b/bazel/cc_import.bzl @@ -54,7 +54,7 @@ def cc_import( patched_name = "{}.patchelf".format(name) patchelf( name = patched_name, - shared_library = shared_library, + src = shared_library, soname = soname, add_needed = add_needed, remove_needed = remove_needed, diff --git a/bazel/patchelf.bzl b/bazel/patchelf.bzl index 265c397..ac79033 100644 --- a/bazel/patchelf.bzl +++ b/bazel/patchelf.bzl @@ -1,5 +1,5 @@ def _patchelf_impl(ctx): - output_name = ctx.file.shared_library.basename + output_name = ctx.file.src.basename if ctx.attr.soname: output_name = ctx.attr.soname output = ctx.actions.declare_file("{}/{}".format(ctx.attr.name, output_name)) @@ -43,9 +43,9 @@ def _patchelf_impl(ctx): ctx.actions.write(renamed_syms, "") ctx.actions.run_shell( - inputs = [ctx.file.shared_library, renamed_syms], + inputs = [ctx.file.src, renamed_syms], outputs = [output], - arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path], + arguments = [ctx.executable._patchelf.path, ctx.file.src.path, output.path], command = "\n".join(commands), tools = [ctx.executable._patchelf], ) @@ -59,7 +59,7 @@ def _patchelf_impl(ctx): patchelf = rule( implementation = _patchelf_impl, attrs = { - "shared_library": attr.label(allow_single_file = True, mandatory = True), + "src": attr.label(allow_single_file = True, mandatory = True), "soname": attr.string(), "add_needed": attr.string_list(), "remove_needed": attr.string_list(), diff --git a/bazel/zig.bzl b/bazel/zig.bzl deleted file mode 100644 index 26773a5..0000000 --- a/bazel/zig.bzl +++ /dev/null @@ -1,63 +0,0 @@ -load("@rules_cc//cc:cc_binary.bzl", "cc_binary") -load("@rules_cc//cc:cc_test.bzl", "cc_test") -load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary") - -def zig_cc_binary( - name, - copts = [], - args = None, - env = None, - data = [], - deps = [], - tags = [], - visibility = None, - **kwargs): - zig_binary( - name = "{}_lib".format(name), - kind = BINARY_KIND.static_lib, - copts = copts + ["-lc"], - deps = deps, - visibility = visibility, - **kwargs - ) - cc_binary( - name = name, - args = args, - env = env, - data = data, - deps = [":{}_lib".format(name)], - tags = tags, - visibility = visibility, - ) - -def zig_cc_test( - name, - copts = [], - env = None, - data = [], - deps = [], - test_runner = None, - tags = [], - visibility = None, - **kwargs): - zig_binary( - name = "{}_test_lib".format(name), - kind = BINARY_KIND.test_lib, - test_runner = test_runner, - tags = tags, - copts = copts + ["-lc"], - deps = deps + [ - "@rules_zig//zig/lib:libc", - ], - visibility = visibility, - **kwargs - ) - cc_test( - name = name, - env = env, - data = data, - deps = [":{}_test_lib".format(name)], - tags = tags, - visibility = visibility, - linkstatic = True, - ) diff --git a/mlir/c.h b/mlir/c.h deleted file mode 100644 index 4ad1523..0000000 --- a/mlir/c.h +++ /dev/null @@ -1,9 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig deleted file mode 100644 index 8010985..0000000 --- a/mlir/dialects/stablehlo.zig +++ /dev/null @@ -1,1366 +0,0 @@ -const std = @import("std"); - -const c = @import("c"); -const mlir = @import("mlir"); -const stdx = @import("stdx"); - -pub const abs = functors.unary_fn("stablehlo.abs").call; -pub const cosine = functors.unary_fn("stablehlo.cosine").call; -pub const sine = functors.unary_fn("stablehlo.sine").call; -pub const exponential = functors.unary_fn("stablehlo.exponential").call; -pub const exponential_minus_one = functors.unary_fn("stablehlo.exponential_minus_one").call; -pub const floor = functors.unary_fn("stablehlo.floor").call; -pub const log = functors.unary_fn("stablehlo.log").call; -pub const log_plus_one = functors.unary_fn("stablehlo.log_plus_one").call; -pub const not = functors.unary_fn("stablehlo.not").call; -pub const negate = functors.unary_fn("stablehlo.negate").call; -pub const sqrt = functors.unary_fn("stablehlo.sqrt").call; -pub const tanh = functors.unary_fn("stablehlo.tanh").call; -pub const cbrt = functors.unary_fn("stablehlo.cbrt").call; -pub const ceil = functors.unary_fn("stablehlo.ceil").call; -pub const rsqrt = functors.unary_fn("stablehlo.rsqrt").call; -pub const count_leading_zeros = functors.unary_fn("stablehlo.count_leading_zeros").call; -pub const is_finite = functors.unary_fn("stablehlo.is_finite").call; -pub const logistic = functors.unary_fn("stablehlo.logistic").call; -pub const popcnt = functors.unary_fn("stablehlo.popcnt").call; -pub const sign = functors.unary_fn("stablehlo.sign").call; -pub const real = functors.unary_fn("stablehlo.real").call; -pub const imag = functors.unary_fn("stablehlo.imag").call; - -pub const add = functors.binary_fn("stablehlo.add").call; -pub const multiply = functors.binary_fn("stablehlo.multiply").call; -pub const divide = functors.binary_fn("stablehlo.divide").call; -pub const subtract = functors.binary_fn("stablehlo.subtract").call; -pub const or_ = functors.binary_fn("stablehlo.or").call; -pub const xor = functors.binary_fn("stablehlo.xor").call; -pub const and_ = functors.binary_fn("stablehlo.and").call; -pub const atan2 = functors.binary_fn("stablehlo.atan2").call; -pub const maximum = functors.binary_fn("stablehlo.maximum").call; -pub const minimum = functors.binary_fn("stablehlo.minimum").call; -pub const power = functors.binary_fn("stablehlo.power").call; -pub const remainder = functors.binary_fn("stablehlo.remainder").call; -pub const shift_left = functors.binary_fn("stablehlo.shift_left").call; -pub const shift_right_arithmetic = functors.binary_fn("stablehlo.shift_right_arithmetic").call; -pub const shift_right_logical = functors.binary_fn("stablehlo.shift_right_logical").call; -pub const complex = functors.binary_fn("stablehlo.complex").call; - -const functors = struct { - fn unary_fn(comptime op_name: [:0]const u8) type { - return struct { - pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, op_name, .{ - .operands = &.{value}, - .result_type_inference = true, - .location = location, - }); - } - }; - } - - pub fn binary_fn(comptime op_name: [:0]const u8) type { - return struct { - pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, op_name, .{ - .operands = &.{ lhs, rhs }, - .result_type_inference = true, - .location = location, - }); - } - }; - } -}; - -pub fn return_(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.return", .{ - .variadic_operands = &.{&.{value}}, - .verify = false, - .location = location, - }); -} - -pub fn returns_(ctx: mlir.Context, values: []const mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.return", .{ - .variadic_operands = &.{values}, - .verify = false, - .location = location, - }); -} - -pub fn bitcast_convert(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.bitcast_convert", .{ - .operands = &.{value}, - .results = &.{result_type}, - .location = location, - }); -} - -pub fn cholesky(ctx: mlir.Context, value: mlir.Value, lower: bool, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.cholesky", .{ - .operands = &.{value}, - .result_type_inference = true, - .attributes = &.{ - .{ "lower", .i1FromBool(ctx, lower) }, - }, - .location = location, - }); -} - -pub fn clamp(ctx: mlir.Context, min: mlir.Value, value: mlir.Value, max: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.clamp", .{ - .operands = &.{ min, value, max }, - .result_type_inference = true, - .location = location, - }); -} - -pub const DotPrecision = union(enum) { - fast, - high, - highest, - algorithm: DotAlgorithm, - - pub fn precisionAttr(self: DotPrecision, ctx: mlir.Context) mlir.Attribute { - const precision = PrecisionAttribute.init(ctx, switch (self) { - .fast => .DEFAULT, - .high => .HIGH, - .highest => .HIGHEST, - // When we specify the dot algorithm, we should not specify the precision. - .algorithm => .DEFAULT, - }); - return precision.asAttr(); - } - - pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute { - return switch (self) { - .algorithm => |algo| algo.asAttr(ctx, operand_type), - else => null, - }; - } -}; - -pub const DotAlgorithm = struct { - accumulation: mlir.FloatTypes, - // Note stablehlo distinguish between left/right component_count - // but all the supported algorithm have the same component_count on both side. - component_count: u8 = 1, - num_primitive_operations: u8 = 1, - allow_imprecise_accumulation: bool = false, - - // bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32. - // not sure where this is available. - pub const bf16_6x: DotAlgorithm = .{ - .operand = .bf16, - .accumulation = .f32, - .component_count = 1, - .num_primitive_operations = 6, - .allow_imprecise_accumulation = false, - }; - - pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute { - const elem_type = tensor_type.getElementType(); - - return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( - ctx._inner, - elem_type._inner, - elem_type._inner, - self.accumulation.asType(ctx)._inner, - self.component_count, - self.component_count, - self.num_primitive_operations, - self.allow_imprecise_accumulation, - )); - } -}; - -/// General matrix multiplication "a la Einstein sum" -/// Note: stablehlo doesn't do type inference for dot_general -pub fn dot_general( - ctx: mlir.Context, - lhs: mlir.Value, - rhs: mlir.Value, - result_type: mlir.Type, - location: mlir.Location, - opts: struct { - lhs_batching_dimensions: []const i64, - rhs_batching_dimensions: []const i64, - lhs_contracting_dimensions: []const i64, - rhs_contracting_dimensions: []const i64, - precision: DotPrecision, - }, -) mlir.Operation { - const precisions: [2]mlir.Attribute = @splat(opts.precision.precisionAttr(ctx)); - const attributes = [3]mlir.AttrTuple{ - .{ - "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ - .lhs_batching_dimensions = opts.lhs_batching_dimensions, - .rhs_batching_dimensions = opts.rhs_batching_dimensions, - .lhs_contracting_dimensions = opts.lhs_contracting_dimensions, - .rhs_contracting_dimensions = opts.rhs_contracting_dimensions, - }).asAttr(), - }, - .{ "precision_config", .array(ctx, &precisions) }, - // keep algorithm as the last attribute so we can omit it when it's not set. - .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined }, - }; - const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1; - return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ - .operands = &.{ lhs, rhs }, - .results = &.{result_type}, - .attributes = attributes[0..n_attributes], - .location = location, - }); -} - -pub fn constant( - ctx: mlir.Context, - dims: []const i64, - elem_type: mlir.DenseElementsAttributeTypes, - raw_bytes: []const u8, - location: mlir.Location, -) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.constant", .{ - .operands = &.{}, - .results = &.{.tensor(dims, elem_type.mlirType(ctx))}, - .attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }}, - .location = location, - }); -} - -pub fn convert(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.convert", .{ - .operands = &.{value}, - .results = &.{result_type}, - .location = location, - }); -} - -pub fn broadcast_in_dim(ctx: mlir.Context, operand: mlir.Value, dims: []const i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.broadcast_in_dim", .{ - .operands = &.{operand}, - .results = &.{result_type}, - .attributes = &.{ - .{ "broadcast_dimensions", .dense(ctx, .i64, dims) }, - }, - .location = location, - }); -} - -pub fn transpose(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location, opts: struct { permutation: []const i64 }) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.transpose", .{ - .operands = &.{value}, - .results = &.{result_type}, - .attributes = &.{ - .{ "permutation", .dense(ctx, .i64, opts.permutation) }, - }, - .location = location, - }); -} - -pub fn slice(ctx: mlir.Context, operand: mlir.Value, start_indices: []const i64, limit_indices: []const i64, strides: []const i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.slice", .{ - .operands = &.{operand}, - .results = &.{result_type}, - .attributes = &.{ - .{ "start_indices", .dense(ctx, .i64, start_indices) }, - .{ "limit_indices", .dense(ctx, .i64, limit_indices) }, - .{ "strides", .dense(ctx, .i64, strides) }, - }, - .location = location, - }); -} - -pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.concatenate", .{ - .operands = inputs, - .result_type_inference = true, - .attributes = &.{ - .{ "dimension", .int(ctx, .i64, dimension) }, - }, - .location = location, - }); -} - -pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.reshape", .{ - .operands = &.{value}, - .results = &.{result_type}, - .location = location, - }); -} - -pub fn select(ctx: mlir.Context, condition: mlir.Value, then: mlir.Value, else_: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.select", .{ - .operands = &.{ condition, then, else_ }, - .results = &.{then.getType()}, - .location = location, - }); -} - -pub fn gather( - ctx: mlir.Context, - value: mlir.Value, - indices: mlir.Value, - slice_sizes: []const i64, - location: mlir.Location, - args: struct { - offset_dims: []const i64, - collapsed_slice_dims: []const i64, - operand_batching_dims: []const i64, - start_indices_batching_dims: []const i64, - start_index_map: []const i64, - index_vector_dim: i64, - indices_are_sorted: bool = false, - }, -) mlir.Operation { - return mlir.Operation.make( - ctx, - "stablehlo.gather", - .{ - .operands = &.{ value, indices }, - .result_type_inference = true, - .attributes = &.{ - .{ "dimension_numbers", GatherDimensionNumbersAttribute.init( - ctx, - args.offset_dims, - args.collapsed_slice_dims, - args.operand_batching_dims, - args.start_indices_batching_dims, - args.start_index_map, - args.index_vector_dim, - ).asAttr() }, - .{ "slice_sizes", .dense(ctx, .i64, slice_sizes) }, - .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, - }, - .location = location, - }, - ); -} - -fn elementTypeOrSelf(typ: mlir.Type) mlir.Type { - return if (typ.as(mlir.ShapedType)) |shaped| { - return shaped.elementType(); - } else typ; -} - -pub const ScatterArgs = struct { - update_window_dims: []const i64, - inserted_window_dims: []const i64, - input_batching_dims: []const i64, - scatter_indices_batching_dims: []const i64, - scatter_dims_to_operand_dims: []const i64, - index_vector_dim: i64, - indices_are_sorted: bool = false, - unique_indices: bool = false, - - pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute { - return .{ ._inner = c.stablehloScatterDimensionNumbersGet( - ctx._inner, - @intCast(self.update_window_dims.len), - self.update_window_dims.ptr, - @intCast(self.inserted_window_dims.len), - self.inserted_window_dims.ptr, - @intCast(self.input_batching_dims.len), - self.input_batching_dims.ptr, - @intCast(self.scatter_indices_batching_dims.len), - self.scatter_indices_batching_dims.ptr, - @intCast(self.scatter_dims_to_operand_dims.len), - self.scatter_dims_to_operand_dims.ptr, - self.index_vector_dim, - ) }; - } -}; - -pub fn scatter( - ctx: mlir.Context, - inputs: []const mlir.Value, - scatter_indices: []const mlir.Value, - updates: []const mlir.Value, - update_block: mlir.Block, - args: ScatterArgs, - location: mlir.Location, -) mlir.Operation { - return mlir.Operation.make( - ctx, - "stablehlo.scatter", - .{ - .variadic_operands = &.{ inputs, scatter_indices, updates }, - .blocks = &.{update_block}, - .attributes = &.{ - .{ "scatter_dimension_numbers", args.getScatterDimensionNumbers(ctx) }, - .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, - .{ "unique_indices", .boolean(ctx, args.unique_indices) }, - }, - .result_type_inference = true, - .location = location, - }, - ); -} - -pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.iota", .{ - .operands = &.{}, - .results = &.{result_type}, - .attributes = &.{ - .{ "iota_dimension", .int(ctx, .i64, dimension) }, - }, - .location = location, - }); -} - -pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, location: mlir.Location) mlir.Operation { - const result_type = operand.getType(); - return mlir.Operation.make(ctx, "stablehlo.reverse", .{ - .operands = &.{operand}, - .results = &.{result_type}, - .attributes = &.{ - .{ "dimensions", .dense(ctx, .i64, dimensions) }, - }, - .location = location, - }); -} - -pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.compare", .{ - .operands = &.{ lhs, rhs }, - .result_type_inference = true, - .attributes = &.{ - .{ "comparison_direction", comparison_direction.asAttr() }, - .{ "compare_type", compare_type.asAttr() }, - }, - .location = location, - }); -} - -pub fn reduce( - ctx: mlir.Context, - inputs: []const mlir.Value, - init_values: []const mlir.Value, - dimensions: []const i64, - blkctx: anytype, - blkfn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation, - location: mlir.Location, -) mlir.Operation { - const MaxBlockArguments = 32; - - const block_n_args = inputs.len + init_values.len; - const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; - var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; - for (inputs, 0..) |input, i| { - const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); - reduce_elem_types[i] = arg_type; - reduce_elem_types[inputs.len + i] = arg_type; - } - var block = mlir.Block.open(reduce_elem_types[0..block_n_args], locations) catch unreachable; - { - defer block.close(); - - var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined; - var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined; - for (0..inputs.len) |i| { - block_inputs[i] = block.argument(i); - block_accs[i] = block.argument(inputs.len + i); - } - _ = blkfn(blkctx, ctx, block_inputs[0..inputs.len], block_accs[0..init_values.len]); - } - - return mlir.Operation.make(ctx, "stablehlo.reduce", .{ - .variadic_operands = &.{ inputs, init_values }, - .result_type_inference = true, - .block = block, - .attributes = &.{ - .{ "dimensions", .dense(ctx, .i64, dimensions) }, - }, - .location = location, - }); -} - -pub fn sort( - ctx: mlir.Context, - inputs: []const mlir.Value, - dimension: i64, - is_stable: bool, - blkctx: anytype, - compfn: fn (anytype, mlir.Context, []const mlir.Value) mlir.Operation, - location: mlir.Location, -) mlir.Operation { - const MaxBlockArguments = 32; - - const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0 .. inputs.len * 2]; - var sort_elem_types: [MaxBlockArguments]mlir.Type = undefined; - for (inputs, 0..) |input, i| { - const arg_type: mlir.Type = .tensor(&.{}, elementTypeOrSelf(input.getType())); - sort_elem_types[i * 2] = arg_type; - sort_elem_types[i * 2 + 1] = arg_type; - } - var block = mlir.Block.init(sort_elem_types[0 .. inputs.len * 2], locations) catch unreachable; - - var block_inputs: [MaxBlockArguments]mlir.Value = undefined; - for (0..inputs.len * 2) |i| { - block_inputs[i] = block.argument(i); - } - _ = compfn(blkctx, ctx, block_inputs[0 .. inputs.len * 2]); - - return mlir.Operation.make(ctx, "stablehlo.sort", .{ - .variadic_operands = &.{inputs}, - .result_type_inference = true, - .block = block, - .attributes = &.{ - .{ "dimension", .int(ctx, .i64, dimension) }, - .{ "is_stable", .boolean(ctx, is_stable) }, - }, - .location = location, - }); -} - -pub fn dynamic_slice(ctx: mlir.Context, operand: mlir.Value, new_dims: []const i64, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.dynamic_slice", .{ - .variadic_operands = &.{ &.{operand}, start_indices }, - .result_type_inference = true, - .attributes = &.{ - .{ "slice_sizes", .dense(ctx, .i64, new_dims) }, - }, - .location = location, - }); -} - -pub fn round_nearest_afz(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.round_nearest_afz", .{ - .operands = &.{value}, - .result_type_inference = true, - .location = location, - }); -} - -pub fn round_nearest_even(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.round_nearest_even", .{ - .operands = &.{value}, - .result_type_inference = true, - .location = location, - }); -} - -pub const PadOpts = struct { - low: []const i64, - high: []const i64, - interior: []const i64, -}; - -pub fn pad(ctx: mlir.Context, value: mlir.Value, padding_value: mlir.Value, opts: PadOpts, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.pad", .{ - .operands = &.{ value, padding_value }, - .result_type_inference = true, - .attributes = &.{ - .{ "edge_padding_low", .dense(ctx, .i64, opts.low) }, - .{ "edge_padding_high", .dense(ctx, .i64, opts.high) }, - .{ "interior_padding", .dense(ctx, .i64, opts.interior) }, - }, - .location = location, - }); -} - -pub const TriangularSolveOpts = struct { - left_side: bool, - lower: bool, - unit_diagonal: bool, - transpose_a: Transpose.Type, -}; - -pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value, location: mlir.Location, opts: TriangularSolveOpts) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.triangular_solve", .{ - .operands = &.{ value, other }, - .result_type_inference = true, - .attributes = &.{ - .{ "left_side", .i1FromBool(ctx, opts.left_side) }, - .{ "lower", .i1FromBool(ctx, opts.lower) }, - .{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) }, - .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() }, - }, - .location = location, - }); -} - -pub const FftOpts = struct { - kind: FftType.Type, - length: []const i64, -}; - -pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts: FftOpts) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.fft", .{ - .operands = &.{value}, - .result_type_inference = true, - .attributes = &.{ - .{ "fft_type", FftType.init(ctx, opts.kind).asAttr() }, - .{ "fft_length", .dense(ctx, .i64, opts.length) }, - }, - .location = location, - }); -} - -pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, rng_distribution: RngDistribution.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.rng", .{ - .operands = &.{ a, b, shape }, - .result_type_inference = true, - .attributes = &.{ - .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() }, - }, - .location = location, - }); -} - -pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, initial_state: mlir.Value, res_state_type: mlir.Type, res_type: mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.rng_bit_generator", .{ - .operands = &.{initial_state}, - .results = &.{ res_state_type, res_type }, - .attributes = &.{ - .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() }, - }, - .location = location, - }); -} - -pub fn reduce_precision(ctx: mlir.Context, value: mlir.Value, exponent_bits: i32, mantissa_bits: i32, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.reduce_precision", .{ - .operands = &.{value}, - .result_type_inference = true, - .attributes = &.{ - .{ "exponent_bits", .int(ctx, .i32, exponent_bits) }, - .{ "mantissa_bits", .int(ctx, .i32, mantissa_bits) }, - }, - .location = location, - }); -} - -pub fn dynamic_update_slice(ctx: mlir.Context, operand: mlir.Value, update: mlir.Value, start_indices: []const mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.dynamic_update_slice", .{ - .variadic_operands = &.{ &.{operand}, &.{update}, start_indices }, - .result_type_inference = true, - .location = location, - }); -} - -pub fn tuple(ctx: mlir.Context, values: []const mlir.Value, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.tuple", .{ - .operands = values, - .result_type_inference = true, - .location = location, - }); -} - -pub fn get_tuple_element(ctx: mlir.Context, tuple_value: mlir.Value, index: i64, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.get_tuple_element", .{ - .operands = &.{tuple_value}, - .result_type_inference = true, - .attributes = &.{ - .{ "index", .int(ctx, .i32, index) }, - }, - .location = location, - }); -} - -pub const ConvolutionOpts = struct { - window_strides: []const i64, - pad_value: []const i64, - pad_shape: []const i64 = &.{}, - lhs_dilation: []const i64, - rhs_dilation: []const i64, - window_reversal: []const bool, - input_batch_dimension: i64, - input_feature_dimension: i64, - input_spatial_dimensions: []const i64, - kernel_input_feature_dimension: i64, - kernel_output_feature_dimension: i64, - kernel_spatial_dimensions: []const i64, - output_batch_dimension: i64, - output_feature_dimension: i64, - output_spatial_dimensions: []const i64, - feature_group_count: i64, - batch_group_count: i64, - precision_config: []const PrecisionAttribute.Precision = &.{}, -}; - -pub fn convolution( - ctx: mlir.Context, - lhs: mlir.Value, - rhs: mlir.Value, - opts: ConvolutionOpts, - res_type: mlir.Type, - location: mlir.Location, -) mlir.Operation { - var max_precisions: [2]mlir.Attribute = undefined; - for (opts.precision_config, 0..) |p, i| { - max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr(); - } - var window_reversal: [3]i32 = undefined; - for (opts.window_reversal, 0..) |w, i| { - window_reversal[i] = @intCast(@intFromBool(w)); - } - return mlir.Operation.make(ctx, "stablehlo.convolution", .{ - .operands = &.{ lhs, rhs }, - .results = &.{res_type}, - .attributes = &.{ - .{ "window_strides", .dense(ctx, .i64, opts.window_strides) }, - .{ "padding", .denseElements(ctx, opts.pad_shape, .i64, opts.pad_value) }, - .{ "lhs_dilation", .dense(ctx, .i64, opts.lhs_dilation) }, - .{ "rhs_dilation", .dense(ctx, .i64, opts.rhs_dilation) }, - .{ "window_reversal", .dense(ctx, .bool, window_reversal[0..opts.window_reversal.len]) }, - .{ - "dimension_numbers", ConvDimensionNumbersAttribute.init(ctx, .{ - .input_batch_dimension = opts.input_batch_dimension, - .input_feature_dimension = opts.input_feature_dimension, - .input_spatial_dimensions = opts.input_spatial_dimensions, - .kernel_input_feature_dimension = opts.kernel_input_feature_dimension, - .kernel_output_feature_dimension = opts.kernel_output_feature_dimension, - .kernel_spatial_dimensions = opts.kernel_spatial_dimensions, - .output_batch_dimension = opts.output_batch_dimension, - .output_feature_dimension = opts.output_feature_dimension, - .output_spatial_dimensions = opts.output_spatial_dimensions, - }).asAttr(), - }, - .{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) }, - .{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) }, - .{ "precision_config", .array(ctx, &max_precisions) }, - }, - .location = location, - }); -} - -pub const CustomCallOpts = struct { - pub const ApiVersion = enum(i32) { - original = 1, - status_returning = 2, - status_returning_unified = 3, - typed_ffi = 4, - }; - - call_target_name: [:0]const u8, - has_side_effect: bool, - backend_config: ?mlir.Attribute, - operand_layouts: ?[]const []const usize = null, - result_layouts: ?[]const []const usize = null, - output_operand_aliases: []const i64 = &.{}, - additional_attributes: []const mlir.AttrTuple = &.{}, - api_version: ApiVersion, -}; - -pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { - const MAX_OPERANDS = 64; - const MAX_RESULTS = 16; - - const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, ""); - if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { - stdx.debug.assert( - backend_config.isA(mlir.StringAttribute), - "API version < 4 requires a string as backend_config, got {}", - .{backend_config}, - ); - } else { - stdx.debug.assert( - backend_config.isA(mlir.DictionaryAttribute), - "API version >= 4 requires a dictionary as backend_config, got {}", - .{backend_config}, - ); - } - - var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; - attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ - .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, - .{ "call_target_name", .string(ctx, opts.call_target_name) }, - .{ "has_side_effect", .boolean(ctx, opts.has_side_effect) }, - .{ "backend_config", backend_config }, - }); - - { - var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (opts.output_operand_aliases) |alias| { - output_operand_aliases.appendAssumeCapacity( - OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(), - ); - } - attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); - } - - const MINOR_TO_MAJOR = blk: { - const MAX_RANK = 8; - var ret: [MAX_RANK]usize = undefined; - for (0..MAX_RANK) |i| { - ret[i] = @intCast(MAX_RANK - i - 1); - } - break :blk ret; - }; - - if (opts.operand_layouts) |layouts| { - var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; - for (layouts) |ol| { - operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); - } - attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); - } else { - const operand_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; - for (inputs) |input| { - const ranked_type = input.getType().as(mlir.RankedTensorType).?; - const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; - ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); - } - break :blk ret; - }; - attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); - } - - if (opts.result_layouts) |layouts| { - var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (layouts) |rl| { - result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); - } - attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); - } else { - const result_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (res_types) |t| { - const ranked_t = t.as(mlir.RankedTensorType).?; - const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; - ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); - } - break :blk ret; - }; - attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); - } - - attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes"); - - return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ - .operands = inputs, - .results = res_types, - .attributes = attrs.constSlice(), - .location = location, - }); -} - -pub const DotDimensionNumbersAttribute = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers; - const Self = DotDimensionNumbersAttribute; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub fn init(ctx: mlir.Context, args: struct { - lhs_batching_dimensions: []const i64, - rhs_batching_dimensions: []const i64, - lhs_contracting_dimensions: []const i64, - rhs_contracting_dimensions: []const i64, - }) Self { - return .{ - ._inner = c.stablehloDotDimensionNumbersGet( - ctx._inner, - @intCast(args.lhs_batching_dimensions.len), - args.lhs_batching_dimensions.ptr, - @intCast(args.rhs_batching_dimensions.len), - args.rhs_batching_dimensions.ptr, - @intCast(args.lhs_contracting_dimensions.len), - args.lhs_contracting_dimensions.ptr, - @intCast(args.rhs_contracting_dimensions.len), - args.rhs_contracting_dimensions.ptr, - ), - }; - } - - pub fn getLhsBatchingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner)); - } - - pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos)); - } - - pub fn getRhsBatchingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner)); - } - - pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos)); - } - - pub fn getLhsContractingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner)); - } - - pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos)); - } - - pub fn getRhsContractingDimensionsSize(self: Self) usize { - return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner)); - } - - pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos)); - } -}; - -pub const GatherDimensionNumbersAttribute = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers; - const Self = GatherDimensionNumbersAttribute; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub fn init( - ctx: mlir.Context, - offset_dims: []const i64, - collapsed_slice_dims: []const i64, - operand_batching_dims: []const i64, - start_indices_batching_dims: []const i64, - start_index_map: []const i64, - index_vector_dim: i64, - ) Self { - return .{ - ._inner = c.stablehloGatherDimensionNumbersGet( - ctx._inner, - @intCast(offset_dims.len), - offset_dims.ptr, - @intCast(collapsed_slice_dims.len), - collapsed_slice_dims.ptr, - @intCast(operand_batching_dims.len), - operand_batching_dims.ptr, - @intCast(start_indices_batching_dims.len), - start_indices_batching_dims.ptr, - @intCast(start_index_map.len), - start_index_map.ptr, - index_vector_dim, - ), - }; - } - - pub fn getOffsetDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner)); - } - - pub fn getOffsetDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos)); - } - - pub fn getCollapsedSliceDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner)); - } - - pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos)); - } - - pub fn getStartIndexMapSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner)); - } - - pub fn getOperandBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner)); - } - - pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos)); - } - - pub fn getStartIndicesBatchingDimsSize(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner)); - } - - pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos)); - } - - pub fn getStartIndexMapElem(self: Self, pos: usize) i64 { - return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos)); - } - - pub fn getIndexVectorDim(self: Self) usize { - return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner)); - } -}; - -pub const ConvDimensionNumbersAttribute = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers; - const Self = ConvDimensionNumbersAttribute; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub fn init(ctx: mlir.Context, args: struct { - input_batch_dimension: i64, - input_feature_dimension: i64, - input_spatial_dimensions: []const i64, - kernel_input_feature_dimension: i64, - kernel_output_feature_dimension: i64, - kernel_spatial_dimensions: []const i64, - output_batch_dimension: i64, - output_feature_dimension: i64, - output_spatial_dimensions: []const i64, - }) Self { - return .{ - ._inner = c.stablehloConvDimensionNumbersGet( - ctx._inner, - args.input_batch_dimension, - args.input_feature_dimension, - @intCast(args.input_spatial_dimensions.len), - args.input_spatial_dimensions.ptr, - args.kernel_input_feature_dimension, - args.kernel_output_feature_dimension, - @intCast(args.kernel_spatial_dimensions.len), - args.kernel_spatial_dimensions.ptr, - args.output_batch_dimension, - args.output_feature_dimension, - @intCast(args.output_spatial_dimensions.len), - args.output_spatial_dimensions.ptr, - ), - }; - } - - pub fn getInputBatchDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner); - } - - pub fn getInputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner); - } - - pub fn getInputSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner)); - } - - pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos)); - } - - pub fn getKernelInputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner); - } - - pub fn getKernelOutputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner); - } - - pub fn getKernelSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner)); - } - - pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos)); - } - - pub fn getOutputBatchDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner); - } - - pub fn getOutputFeatureDimension(self: Self) i64 { - return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner); - } - - pub fn getOutputSpatialDimensionsSize(self: Self) usize { - return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner)); - } - - pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 { - return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos)); - } -}; - -pub const OutputOperandAliasAttribute = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias; - pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute); - pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute); - - pub fn init( - ctx: mlir.Context, - output_tuple_indices: []const i64, - operand_index: i64, - operand_tuple_indices: []const i64, - ) OutputOperandAliasAttribute { - return .{ ._inner = c.stablehloOutputOperandAliasGet( - ctx._inner, - @intCast(output_tuple_indices.len), - output_tuple_indices.ptr, - @intCast(operand_index), - @intCast(operand_tuple_indices.len), - operand_tuple_indices.ptr, - ) }; - } -}; - -pub const PrecisionAttribute = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr; - const Self = PrecisionAttribute; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Precision = enum { - DEFAULT, - HIGH, - HIGHEST, - }; - - pub fn init(ctx: mlir.Context, value: Precision) Self { - return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Precision { - const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner)); - return std.meta.stringToEnum(Precision, value) orelse unreachable; - } -}; - -pub const ComparisonDirection = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr; - const Self = ComparisonDirection; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Direction = enum { - EQ, - NE, - GE, - GT, - LE, - LT, - }; - - pub fn init(ctx: mlir.Context, value: Direction) Self { - return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Direction { - const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner)); - return std.meta.stringToEnum(Direction, value) orelse unreachable; - } -}; - -pub const CompareType = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr; - const Self = CompareType; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Type = enum { - SIGNED, - UNSIGNED, - FLOAT, - TOTALORDER, - }; - - pub fn init(ctx: mlir.Context, value: Type) Self { - return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner)); - return std.meta.stringToEnum(Type, value) orelse unreachable; - } -}; - -pub const Transpose = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsATransposeAttr; - const Self = Transpose; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Type = enum { - NO_TRANSPOSE, - TRANSPOSE, - ADJOINT, - }; - - pub fn init(ctx: mlir.Context, value: Type) Self { - return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner)); - return std.meta.stringToEnum(Type, value) orelse unreachable; - } -}; - -pub const FftType = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr; - const Self = FftType; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Type = enum { - FFT, - IFFT, - RFFT, - IRFFT, - }; - - pub fn init(ctx: mlir.Context, value: Type) Self { - return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner)); - return std.meta.stringToEnum(Type, value) orelse unreachable; - } -}; - -pub const RngDistribution = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr; - const Self = RngDistribution; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Type = enum { - UNIFORM, - NORMAL, - }; - - pub fn init(ctx: mlir.Context, value: Type) Self { - return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner)); - return std.meta.stringToEnum(Type, value) orelse unreachable; - } -}; - -pub const RngAlgorithm = struct { - _inner: c.MlirAttribute, - - pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr; - const Self = RngAlgorithm; - pub const asAttr = mlir.Attribute.fromAny(Self); - pub const eql = mlir.Attribute.eqlAny(Self); - - pub const Type = enum { - DEFAULT, - THREE_FRY, - PHILOX, - }; - - pub fn init(ctx: mlir.Context, value: Type) Self { - return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) }; - } - - pub fn getValue(self: Self) Type { - const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner)); - return std.meta.stringToEnum(Type, value) orelse unreachable; - } -}; - -pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 { - const state = struct { - var buf: [32]u8 = undefined; - - fn call(req: c.MlirStablehloCompatibilityRequirement) []u8 { - var stream = std.io.fixedBufferStream(&buf); - var context = .{ .writer = stream.writer() }; - const WriterContext = @TypeOf(context); - - c.stablehloVersionFromCompatibilityRequirement(req, (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context); - - return buf[0..stream.pos]; - } - }; - - return state.call(requirement); -} - -pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 { - var buf: [32]u8 = undefined; - - var stream = std.io.fixedBufferStream(&buf); - var context = .{ .writer = stream.writer() }; - const WriterContext = @TypeOf(context); - - _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context); - - return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2; -} - -pub fn getCurrentVersion() []const u8 { - const state = struct { - var buf: [32]u8 = undefined; - var str: []const u8 = undefined; - var once = std.once(call); - - fn call() void { - var stream = std.io.fixedBufferStream(&buf); - var writer_ = stream.writer(); - const ContextWriter = @TypeOf(writer_); - - c.stablehloGetCurrentVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const writer: *ContextWriter = @ptrCast(@alignCast(userdata)); - _ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &writer_); - - str = buf[0..stream.pos]; - } - }; - - state.once.call(); - return state.str; -} - -pub fn getMinimumVersion() []const u8 { - const state = struct { - var buf: [32]u8 = undefined; - var str: []const u8 = undefined; - var once = std.once(call); - - fn call() void { - var stream = std.io.fixedBufferStream(&buf); - var context = .{ .writer = stream.writer() }; - const WriterContext = @TypeOf(context); - - c.stablehloGetMinimumVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context); - - str = buf[0..stream.pos]; - } - }; - - state.once.call(); - return state.str; -} - -pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void { - var context = .{ .writer = writer }; - const WriterContext = @TypeOf(context); - - try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context), error.InvalidMlirBytecodeVersion); -} diff --git a/runtimes/cuda/libpjrt_cuda.BUILD.bazel b/runtimes/cuda/libpjrt_cuda.BUILD.bazel index d49d135..6cb5d74 100644 --- a/runtimes/cuda/libpjrt_cuda.BUILD.bazel +++ b/runtimes/cuda/libpjrt_cuda.BUILD.bazel @@ -10,7 +10,7 @@ cc_shared_library( patchelf( name = "libpjrt_cuda.patchelf", - shared_library = "libpjrt_cuda.so", + src = "libpjrt_cuda.so", add_needed = [ "libzmlxcuda.so.0", ], diff --git a/runtimes/neuron/neuron.bzl b/runtimes/neuron/neuron.bzl index 89996fa..a277b35 100644 --- a/runtimes/neuron/neuron.bzl +++ b/runtimes/neuron/neuron.bzl @@ -31,7 +31,7 @@ _NEURON_PACKAGES = { ), packages.patchelf( name = "libnrt.patchelf", - shared_library = "lib/libnrt.so.1", + src = "lib/libnrt.so.1", set_rpath = '$ORIGIN', add_needed = [ # readelf -d ./opt/aws/neuron/libl/libncfw.so @@ -43,7 +43,7 @@ _NEURON_PACKAGES = { ), packages.patchelf( name = "libncfw.patchelf", - shared_library = "lib/libncfw.so", + src = "lib/libncfw.so", soname = "libncfw.so.2", ), ]), diff --git a/runtimes/rocm/BUILD.bazel b/runtimes/rocm/BUILD.bazel index 536ccb9..f94e806 100644 --- a/runtimes/rocm/BUILD.bazel +++ b/runtimes/rocm/BUILD.bazel @@ -1,13 +1,14 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") -load("@rules_zig//zig:defs.bzl", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library") -cc_library( - name = "zmlxrocm_lib", - srcs = ["zmlxrocm.c"], - linkopts = [ - "-lc", - "-ldl", - ], +zig_shared_library( + name = "zmlxrocm", + main = "zmlxrocm.zig", + # Use Clang's compiler-rt, but disable stack checking + # to avoid requiring on the _zig_probe_stack symbol. + copts = ["-fno-stack-check"], + shared_lib_name = "libzmlxrocm.so.0", + deps = ["//stdx"], visibility = ["@libpjrt_rocm//:__subpackages__"], ) @@ -51,6 +52,6 @@ zig_library( filegroup( name = "layers", - srcs = ["@libpjrt_rocm//:amdgpu_ids_layer"], + srcs = [], visibility = ["//visibility:public"], ) diff --git a/runtimes/rocm/libpjrt_rocm.BUILD.bazel b/runtimes/rocm/libpjrt_rocm.BUILD.bazel index 9b1a143..7a59252 100644 --- a/runtimes/rocm/libpjrt_rocm.BUILD.bazel +++ b/runtimes/rocm/libpjrt_rocm.BUILD.bazel @@ -20,26 +20,12 @@ config_setting( flag_values = {":hipblaslt": "True"}, ) -cc_shared_library( - name = "zmlxrocm_so", - shared_lib_name = "lib/libzmlxrocm.so.0", - deps = ["@zml//runtimes/rocm:zmlxrocm_lib"], -) - patchelf( - name = "libpjrt_rocm.patchelf", - shared_library = "libpjrt_rocm.so", + name = "libpjrt_rocm_so", + src = "libpjrt_rocm.so", add_needed = [ "libzmlxrocm.so.0", - # So that RPATH is taken into account. - "librocblas.so.4", - "libMIOpen.so.1", - ] + select({ - "_hipblaslt": [ - "libhipblaslt.so.0", - ], - "//conditions:default": [], - }), + ], rename_dynamic_symbols = { "dlopen": "zmlxrocm_dlopen", }, @@ -49,49 +35,52 @@ patchelf( copy_to_directory( name = "sandbox", srcs = [ - ":zmlxrocm_so", - ":libpjrt_rocm.patchelf", + ":libpjrt_rocm_so", "@comgr//:amd_comgr", - "@hip-runtime-amd//:amdhip_patched", + "@hip-runtime-amd//:amdhip", "@hip-runtime-amd//:hiprtc", "@hipblaslt//:hipblaslt", "@hipfft", "@hipsolver", "@hsa-amd-aqlprofile//:hsa-amd-aqlprofile", "@hsa-rocr//:hsa-runtime", + "@libdrm-amdgpu-amdgpu1", + "@libdrm-amdgpu-common//:amdgpu_ids", + "@libdrm2-amdgpu", + "@libelf1", + "@libnuma1", + "@libtinfo6", + "@libzstd1", "@miopen-hip//:MIOpen", "@rccl", "@rocblas//:rocblas", "@rocblas//:runfiles", + "@rocfft", "@rocm-core", "@rocm-device-libs//:runfiles", "@rocm-smi-lib//:rocm_smi", "@rocprofiler-register", - "@rocfft", "@rocsolver", "@roctracer", "@roctracer//:roctx", - "@libelf1", - "@libdrm2-amdgpu", - "@libnuma1", - "@libzstd1", - "@libdrm-amdgpu-amdgpu1", - "@libtinfo6", "@zlib1g", + "@zml//runtimes/rocm:zmlxrocm", ] + select({ ":_hipblaslt": ["@hipblaslt//:runfiles"], "//conditions:default": [], }), replace_prefixes = { - "libpjrt_rocm.patchelf": "lib", - "lib/x86_64-linux-gnu": "lib", - "usr/lib/x86_64-linux-gnu": "lib", - "libelf1": "lib", + "amdhip": "lib", "hipblaslt": "lib", - "rocblas": "lib", - "opt/amdgpu/lib/x86_64-linux-gnu": "lib", + "lib/x86_64-linux-gnu": "lib", "libdrm-amdgpu-amdgpu1": "lib", - "amdhip_patched": "lib", + "libelf1": "lib", + "libpjrt_rocm_so": "lib", + "opt/amdgpu/lib/x86_64-linux-gnu": "lib", + "opt/amdgpu/share": "share", + "rocblas": "lib", + "runtimes/rocm": "lib", + "usr/lib/x86_64-linux-gnu": "lib", }, add_directory_to_runfiles = True, include_external_repositories = ["**"], diff --git a/runtimes/rocm/packages.lock.json b/runtimes/rocm/packages.lock.json index 42a3aaf..213b57c 100755 --- a/runtimes/rocm/packages.lock.json +++ b/runtimes/rocm/packages.lock.json @@ -1277,7 +1277,13 @@ }, { "arch": "amd64", - "dependencies": [], + "dependencies": [ + { + "key": "rocm-core_6.4.1.60401-83_22.04_amd64", + "name": "rocm-core", + "version": "6.4.1.60401-83~22.04" + } + ], "key": "roctracer_4.1.60401.60401-83_22.04_amd64", "name": "roctracer", "sha256": "58cead537cf07c8a8770bfe28346c3b3c92cc4297b51e307c9032b04434b187c", @@ -4153,6 +4159,49 @@ "https://repo.radeon.com/rocm/apt/6.4.1/pool/main/r/rocfft/rocfft_1.0.32.60401-83~22.04_amd64.deb" ], "version": "1.0.32.60401-83~22.04" + }, + { + "arch": "amd64", + "dependencies": [ + { + "key": "hipblaslt_0.12.1.60401-83_22.04_amd64", + "name": "hipblaslt", + "version": "0.12.1.60401-83~22.04" + }, + { + "key": "rocm-core_6.4.1.60401-83_22.04_amd64", + "name": "rocm-core", + "version": "6.4.1.60401-83~22.04" + }, + { + "key": "roctracer_4.1.60401.60401-83_22.04_amd64", + "name": "roctracer", + "version": "4.1.60401.60401-83~22.04" + }, + { + "key": "hipblas-common-dev_1.0.0.60401-83_22.04_amd64", + "name": "hipblas-common-dev", + "version": "1.0.0.60401-83~22.04" + } + ], + "key": "hipblaslt-dev_0.12.1.60401-83_22.04_amd64", + "name": "hipblaslt-dev", + "sha256": "46eb2285c76d246b162eb54cc7f9e5cb7bcdd0aa83d57ecaea440e57260f2f4a", + "urls": [ + "https://repo.radeon.com/rocm/apt/6.4.1/pool/main/h/hipblaslt-dev/hipblaslt-dev_0.12.1.60401-83~22.04_amd64.deb" + ], + "version": "0.12.1.60401-83~22.04" + }, + { + "arch": "amd64", + "dependencies": [], + "key": "hipblas-common-dev_1.0.0.60401-83_22.04_amd64", + "name": "hipblas-common-dev", + "sha256": "5df3e4a8a1959cbf94106f7bf87d7fb71bf06e726cde00c6092ef29bbd8156f0", + "urls": [ + "https://repo.radeon.com/rocm/apt/6.4.1/pool/main/h/hipblas-common-dev/hipblas-common-dev_1.0.0.60401-83~22.04_amd64.deb" + ], + "version": "1.0.0.60401-83~22.04" } ], "version": 1 diff --git a/runtimes/rocm/packages.yaml b/runtimes/rocm/packages.yaml index fd425f4..fcb6583 100644 --- a/runtimes/rocm/packages.yaml +++ b/runtimes/rocm/packages.yaml @@ -35,7 +35,7 @@ packages: - "rocsolver" - "hipsolver" - "hipfft" - # - "roctracer" + - "roctracer" - "hipblaslt" - # - "hipblaslt-dev" + - "hipblaslt-dev" - "hip-runtime-amd" diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index b38f69b..1d11da3 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -14,7 +14,7 @@ _UBUNTU_PACKAGES = { packages.load_("@zml//bazel:patchelf.bzl", "patchelf"), packages.patchelf( name = "libelf1", - shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1", + src = "usr/lib/x86_64-linux-gnu/libelf.so.1", set_rpath = "$ORIGIN", ), ]), @@ -25,8 +25,12 @@ _UBUNTU_PACKAGES = { packages.load_("@zml//bazel:patchelf.bzl", "patchelf"), packages.patchelf( name = "libdrm-amdgpu-amdgpu1", - shared_library = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1", + src = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1", + add_needed = ["libzmlxrocm.so.0"], set_rpath = "$ORIGIN", + rename_dynamic_symbols = { + "fopen64": "zmlxrocm_fopen64", + }, ), ]), "libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]), @@ -38,14 +42,7 @@ _ROCM_PACKAGES = { "rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]), "hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]), "hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]), - "comgr": "\n".join([ - packages.filegroup( - name = "amd_comgr", - srcs = [ - "lib/libamd_comgr.so.3", - ], - ), - ]), + "comgr": packages.filegroup(name = "amd_comgr", srcs = ["lib/libamd_comgr.so.3"]), "rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]), "miopen-hip": "\n".join([ packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]), @@ -59,7 +56,7 @@ _ROCM_PACKAGES = { packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"), packages.patchelf( name = "rocblas", - shared_library = "lib/librocblas.so.4", + src = "lib/librocblas.so.4", add_needed = ["libzmlxrocm.so.0"], rename_dynamic_symbols = { "dlopen": "zmlxrocm_dlopen", @@ -90,7 +87,7 @@ _ROCM_PACKAGES = { packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"), packages.patchelf( name = "hipblaslt", - shared_library = "lib/libhipblaslt.so.0", + src = "lib/libhipblaslt.so.0", add_needed = ["libzmlxrocm.so.0"], rename_dynamic_symbols = { "dlopen": "zmlxrocm_dlopen", @@ -116,10 +113,9 @@ _ROCM_PACKAGES = { "hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]), "hip-runtime-amd": "\n".join([ packages.load_("@zml//bazel:patchelf.bzl", "patchelf"), - packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]), packages.patchelf( - name = "amdhip_patched", - shared_library = ":amdhip", + name = "amdhip", + src = "lib/libamdhip64.so.6", add_needed = ["libzmlxrocm.so.0"], rename_dynamic_symbols = { "dlopen": "zmlxrocm_dlopen", diff --git a/runtimes/rocm/rocm.zig b/runtimes/rocm/rocm.zig index 98be603..0861f92 100644 --- a/runtimes/rocm/rocm.zig +++ b/runtimes/rocm/rocm.zig @@ -1,5 +1,5 @@ -const builtin = @import("builtin"); const std = @import("std"); +const builtin = @import("builtin"); const asynk = @import("async"); const bazel_builtin = @import("bazel_builtin"); @@ -10,20 +10,6 @@ const stdx = @import("stdx"); const log = std.log.scoped(.@"zml/runtime/rocm"); -const ROCmEnvEntry = struct { - name: [:0]const u8, - rpath: []const u8, - dirname: bool, - mandatory: bool, -}; - -const rocm_env_entries: []const ROCmEnvEntry = &.{ - .{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false }, - .{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false }, - .{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true }, - .{ .name = "ROCM_PATH", .rpath = "/", .dirname = false, .mandatory = true }, -}; - pub fn isEnabled() bool { return @hasDecl(c, "ZML_RUNTIME_ROCM"); } @@ -35,23 +21,9 @@ fn hasRocmDevices() bool { return true; } -fn setupRocmEnv(allocator: std.mem.Allocator, rocm_data_dir: []const u8) !void { - for (rocm_env_entries) |entry| { - var real_path: []const u8 = std.fmt.allocPrintZ(allocator, "{s}/{s}", .{ rocm_data_dir, entry.rpath }) catch null orelse { - if (entry.mandatory) { - stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository }); - } - continue; - }; - - if (entry.dirname) { - real_path = std.fs.path.dirname(real_path) orelse { - stdx.debug.panic("Unable to dirname on {s}", .{real_path}); - }; - } - - _ = c.setenv(entry.name, try allocator.dupeZ(u8, real_path), 1); - } +fn setupRocmEnv(rocm_data_dir: []const u8) !void { + var buf: [std.fs.max_path_bytes]u8 = undefined; + _ = c.setenv("ROCM_PATH", try stdx.fs.path.bufJoinZ(&buf, &.{rocm_data_dir}), 1); // must be zero terminated } pub fn load() !*const pjrt.Api { @@ -81,7 +53,7 @@ pub fn load() !*const pjrt.Api { return error.FileNotFound; }; - try setupRocmEnv(arena.allocator(), sandbox_path); + try setupRocmEnv(sandbox_path); var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined; const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" }); diff --git a/runtimes/rocm/zmlxrocm.c b/runtimes/rocm/zmlxrocm.c deleted file mode 100644 index 4c81c30..0000000 --- a/runtimes/rocm/zmlxrocm.c +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include -#include -#include - -void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default"))) -{ - if (filename != NULL) - { - char *replacements[] = { - "librocm-core.so", - "librocm-core.so.1", - "librocm_smi64.so", - "librocm_smi64.so.7", - "libhsa-runtime64.so", - "libhsa-runtime64.so.1", - "libhsa-amd-aqlprofile64.so", - "libhsa-amd-aqlprofile64.so.1", - "libamd_comgr.so", - "libamd_comgr.so.3", - "librocprofiler-register.so", - "librocprofiler-register.so.0", - "libMIOpen.so", - "libMIOpen.so.1", - "librccl.so", - "librccl.so.1", - "librocblas.so", - "librocblas.so.4", - "libroctracer64.so", - "libroctracer64.so.4", - "libroctx64.so", - "libroctx64.so.4", - "libhipblaslt.so", - "libhipblaslt.so.0", - "libamdhip64.so", - "libamdhip64.so.6", - "libhiprtc.so", - "libhiprtc.so.6", - NULL, - NULL, - }; - for (int i = 0; replacements[i] != NULL; i += 2) - { - if (strcmp(filename, replacements[i]) == 0) - { - filename = replacements[i + 1]; - break; - } - } - } - return dlopen(filename, flags); -} diff --git a/runtimes/rocm/zmlxrocm.zig b/runtimes/rocm/zmlxrocm.zig new file mode 100644 index 0000000..19f840c --- /dev/null +++ b/runtimes/rocm/zmlxrocm.zig @@ -0,0 +1,50 @@ +const std = @import("std"); + +const stdx = @import("stdx"); + +pub export fn zmlxrocm_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque { + const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{ + .{ "librocm-core.so", "librocm-core.so.1" }, + .{ "librocm_smi64.so", "librocm_smi64.so.7" }, + .{ "libhsa-runtime64.so", "libhsa-runtime64.so.1" }, + .{ "libhsa-amd-aqlprofile64.so", "libhsa-amd-aqlprofile64.so.1" }, + .{ "libamd_comgr.so", "libamd_comgr.so.3" }, + .{ "librocprofiler-register.so", "librocprofiler-register.so.0" }, + .{ "libMIOpen.so", "libMIOpen.so.1" }, + .{ "librccl.so", "librccl.so.1" }, + .{ "librocblas.so", "librocblas.so.4" }, + .{ "libroctracer64.so", "libroctracer64.so.4" }, + .{ "libroctx64.so", "libroctx64.so.4" }, + .{ "libhipblaslt.so", "libhipblaslt.so.0" }, + .{ "libamdhip64.so", "libamdhip64.so.6" }, + .{ "libhiprtc.so", "libhiprtc.so.6" }, + }); + + var buf: [std.fs.max_path_bytes]u8 = undefined; + const new_filename: [*c]const u8 = if (filename) |f| blk: { + const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f; + break :blk stdx.fs.path.bufJoinZ(&buf, &.{ + stdx.fs.selfSharedObjectDirPath(), + replacement, + }) catch unreachable; + } else null; + + return std.c.dlopen(new_filename, @bitCast(flags)); +} + +pub export fn zmlxrocm_fopen64(pathname: [*c]const u8, mode: [*c]const u8) ?*std.c.FILE { + const replacements: std.StaticStringMap([]const u8) = .initComptime(.{ + .{ "/opt/amdgpu/share/libdrm/amdgpu.ids", "../share/libdrm/amdgpu.ids" }, + }); + + var buf: [std.fs.max_path_bytes]u8 = undefined; + const new_pathname: [*c]const u8 = blk: { + const replacement = replacements.get(std.mem.span(pathname)) orelse break :blk pathname; + break :blk stdx.fs.path.bufJoinZ(&buf, &.{ + stdx.fs.selfSharedObjectDirPath(), + replacement, + }) catch unreachable; + }; + + return std.c.fopen64(new_pathname, mode); +} diff --git a/zml/aio/json.zig b/zml/aio/json.zig new file mode 100644 index 0000000..9c02a5d --- /dev/null +++ b/zml/aio/json.zig @@ -0,0 +1,107 @@ +const asynk = @import("async"); +const std = @import("std"); +const zml = @import("../zml.zig"); + +const StringBuilder = std.ArrayListUnmanaged(u8); +const Allocator = std.mem.Allocator; + +pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { + const file = try std.fs.cwd().openFile(path, .{}); + defer file.close(); + var res: zml.aio.BufferStore = .{ + .arena = std.heap.ArenaAllocator.init(allocator), + }; + errdefer res.arena.deinit(); + const arena = res.arena.allocator(); + + const json_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size()); + const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); + + var it = metadata.object.iterator(); + while (it.next()) |entry| { + var prefix_buf: [1024]u8 = undefined; + try parseMetadata(allocator, &res, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*); + } + + return res; +} + +pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void { + const metadata = &store._metadata; + const key = prefix.items; + return switch (val) { + .null => try metadata.put(allocator, try allocator.dupe(u8, key), .null), + .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }), + .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }), + .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }), + .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }), + .array => |v| { + if (v.items.len == 0) return; + return if (validSlice(v)) |item_type| { + const data: zml.aio.Metadata = switch (item_type) { + .bool => blk: { + const values = try allocator.alloc(bool, v.items.len); + for (v.items, 0..) |item, i| values[i] = item.bool; + break :blk .{ .array_bool = values }; + }, + .integer => blk: { + const values = try allocator.alloc(i64, v.items.len); + for (v.items, 0..) |item, i| values[i] = item.integer; + break :blk .{ .array_int = values }; + }, + .float => blk: { + const values = try allocator.alloc(f64, v.items.len); + for (v.items, 0..) |item, i| values[i] = item.float; + break :blk .{ .array_float = values }; + }, + inline .string, .number_string => |tag| blk: { + const values = try allocator.alloc([]const u8, v.items.len); + for (v.items, 0..) |item, i| { + values[i] = @field(item, @tagName(tag)); + } + break :blk .{ .array_string = values }; + }, + .null, .array, .object => unreachable, + }; + try metadata.put(allocator, try allocator.dupe(u8, key), data); + } else { + for (v.items, 0..) |item, i| { + var new_prefix = prefix; + if (prefix.items.len > 0) + new_prefix.appendAssumeCapacity('.'); + new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + try parseMetadata(allocator, store, new_prefix, item); + } + }; + }, + .object => |v| { + var obj_iter = v.iterator(); + while (obj_iter.next()) |entry| { + var new_prefix = prefix; + if (prefix.items.len > 0) + new_prefix.appendAssumeCapacity('.'); + new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*); + try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*); + } + }, + }; +} + +/// We can only create a Zig slice out of json array, if all values +/// in the array have the same type. +fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) { + if (v.items.len == 0) return null; + + const item_type: std.meta.Tag(std.json.Value) = v.items[0]; + switch (item_type) { + .null, .array, .object => return null, + else => {}, + } + + for (v.items[1..]) |item| { + if (item != item_type) + return null; + } + + return item_type; +} diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index 8d67480..581ab77 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -1,13 +1,12 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; - const asynk = @import("async"); - -const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; +const std = @import("std"); const zml = @import("../zml.zig"); +const json = @import("json.zig"); const HostBuffer = zml.HostBuffer; +const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; const StringBuilder = std.ArrayListUnmanaged(u8); +const Allocator = std.mem.Allocator; const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { @@ -56,6 +55,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std. const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename }); try loadFile(allocator, store, files, full_filename); } + + if (index.object.get("__metadata__")) |metadata| { + var prefix_buf: [1024]u8 = undefined; + try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata); + } } fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { @@ -85,6 +89,11 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array var it = metadata.object.iterator(); while (it.next()) |entry| { const key = entry.key_ptr.*; + if (std.mem.eql(u8, key, "__metadata__")) { + var prefix_buf: [1024]u8 = undefined; + try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*); + continue; + } const val = entry.value_ptr.*; const shape_field = val.object.get("shape").?.array; if (shape_field.items.len > zml.Shape.MAX_RANK) {