Add XLA 20250718.0‑6319f0d with ROCm 6.4.1 support, update Bazel module files and runtime configs, and apply migration, FFI‑handler and header‑cleanup patches.

This commit is contained in:
Tarry Singh 2025-05-12 12:10:27 +00:00
parent cba9ce9615
commit 55c5b540f8
18 changed files with 1405 additions and 938 deletions

View File

@ -22,7 +22,7 @@ bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
bazel_dep(name = "toolchains_protoc", version = "0.4.1") bazel_dep(name = "toolchains_protoc", version = "0.4.1")
bazel_dep(name = "with_cfg.bzl", version = "0.9.1") bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
bazel_dep(name = "xla", version = "20250710.0-22ea002") bazel_dep(name = "xla", version = "20250718.0-6319f0d")
bazel_dep(name = "zig-protobuf", version = "20250716.0-97f1e31") bazel_dep(name = "zig-protobuf", version = "20250716.0-97f1e31")
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf") bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")

View File

@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx):
http_archive( http_archive(
name = "libpjrt_cpu_linux_amd64", name = "libpjrt_cpu_linux_amd64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
sha256 = "3369fa7a1a1bb5998b818e1fb5f2c28966a59f6096eab500ef2d8419548a1c91", sha256 = "cf5ea44b14a6ddc320c5b1d7bb88328dda099571e106b17a214b6ec586e321b8",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_linux-amd64.tar.gz",
) )
http_archive( http_archive(
name = "libpjrt_cpu_darwin_amd64", name = "libpjrt_cpu_darwin_amd64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
sha256 = "9947382613d30eb757dfb1bfcad0536ec9dad1e11b1189d1172abbce434b69bb", sha256 = "1b48c26b0d2709730df6921a86d9ae97ab71d3f2cfa17fd1a370e85235370914",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_darwin-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-amd64.tar.gz",
) )
http_archive( http_archive(
name = "libpjrt_cpu_darwin_arm64", name = "libpjrt_cpu_darwin_arm64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
sha256 = "fe3818455b034c9ffbd65dec559c04c2211a200a9b4d7feec8a00d6a3ffd0acd", sha256 = "66dc6c65933a6d7985b1c69837d8abe13750cd1f06704427f6989c3f952d3511",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cpu_darwin-arm64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-arm64.tar.gz",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -1,7 +1,6 @@
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") load("@aspect_bazel_lib//lib:tar.bzl", "tar")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
load("@zml//bazel:cc_import.bzl", "cc_import")
load("@zml//bazel:patchelf.bzl", "patchelf") load("@zml//bazel:patchelf.bzl", "patchelf")
string_list_flag( string_list_flag(
@ -53,7 +52,7 @@ copy_to_directory(
":zmlxrocm_so", ":zmlxrocm_so",
":libpjrt_rocm.patchelf", ":libpjrt_rocm.patchelf",
"@comgr//:amd_comgr", "@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip", "@hip-runtime-amd//:amdhip_patched",
"@hip-runtime-amd//:hiprtc", "@hip-runtime-amd//:hiprtc",
"@hipblaslt//:hipblaslt", "@hipblaslt//:hipblaslt",
"@hipfft", "@hipfft",
@ -92,6 +91,7 @@ copy_to_directory(
"rocblas": "lib", "rocblas": "lib",
"opt/amdgpu/lib/x86_64-linux-gnu": "lib", "opt/amdgpu/lib/x86_64-linux-gnu": "lib",
"libdrm-amdgpu-amdgpu1": "lib", "libdrm-amdgpu-amdgpu1": "lib",
"amdhip_patched": "lib",
}, },
add_directory_to_runfiles = True, add_directory_to_runfiles = True,
include_external_repositories = ["**"], include_external_repositories = ["**"],

File diff suppressed because it is too large Load Diff

View File

@ -5,9 +5,9 @@ version: 1
sources: sources:
- channel: jammy main - channel: jammy main
url: https://repo.radeon.com/amdgpu/6.3.4/ubuntu url: https://repo.radeon.com/amdgpu/6.4.1/ubuntu
- channel: jammy main - channel: jammy main
url: https://repo.radeon.com/rocm/apt/6.3.4 url: https://repo.radeon.com/rocm/apt/6.4.1
- channel: jammy main - channel: jammy main
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
- channel: jammy-security main - channel: jammy-security main

View File

@ -6,7 +6,7 @@ _BUILD_FILE_DEFAULT_VISIBILITY = """\
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
""" """
_ROCM_STRIP_PREFIX = "opt/rocm-6.3.4" _ROCM_STRIP_PREFIX = "opt/rocm-6.4.1"
_UBUNTU_PACKAGES = { _UBUNTU_PACKAGES = {
"libdrm2-amdgpu": packages.filegroup(name = "libdrm2-amdgpu", srcs = ["opt/amdgpu/lib/x86_64-linux-gnu/libdrm.so.2"]), "libdrm2-amdgpu": packages.filegroup(name = "libdrm2-amdgpu", srcs = ["opt/amdgpu/lib/x86_64-linux-gnu/libdrm.so.2"]),
@ -15,7 +15,7 @@ _UBUNTU_PACKAGES = {
packages.patchelf( packages.patchelf(
name = "libelf1", name = "libelf1",
shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1", shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1",
set_rpath = '$ORIGIN', set_rpath = "$ORIGIN",
), ),
]), ]),
"libdrm-amdgpu-common": packages.filegroup(name = "amdgpu_ids", srcs = ["opt/amdgpu/share/libdrm/amdgpu.ids"]), "libdrm-amdgpu-common": packages.filegroup(name = "amdgpu_ids", srcs = ["opt/amdgpu/share/libdrm/amdgpu.ids"]),
@ -26,7 +26,7 @@ _UBUNTU_PACKAGES = {
packages.patchelf( packages.patchelf(
name = "libdrm-amdgpu-amdgpu1", name = "libdrm-amdgpu-amdgpu1",
shared_library = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1", shared_library = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1",
set_rpath = '$ORIGIN', set_rpath = "$ORIGIN",
), ),
]), ]),
"libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]), "libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]),
@ -38,7 +38,14 @@ _ROCM_PACKAGES = {
"rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]), "rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]),
"hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]), "hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]),
"hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]), "hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]),
"comgr": packages.filegroup(name = "amd_comgr", srcs = ["lib/libamd_comgr.so.2"]), "comgr": "\n".join([
packages.filegroup(
name = "amd_comgr",
srcs = [
"lib/libamd_comgr.so.3",
],
),
]),
"rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]), "rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]),
"miopen-hip": "\n".join([ "miopen-hip": "\n".join([
packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]), packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]),
@ -102,14 +109,22 @@ _ROCM_PACKAGES = {
name = "runfiles", name = "runfiles",
srcs = [ srcs = [
"lib/hipblaslt/library/hipblasltExtOpLibrary.dat", "lib/hipblaslt/library/hipblasltExtOpLibrary.dat",
"lib/hipblaslt/library/TensileManifest.txt",
":bytecodes", ":bytecodes",
], ],
), ),
]), ]),
"hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]), "hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]),
"hip-runtime-amd": "\n".join([ "hip-runtime-amd": "\n".join([
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]), packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]),
packages.patchelf(
name = "amdhip_patched",
shared_library = ":amdhip",
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
),
packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]), packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]),
]), ]),
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]), "hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
@ -142,8 +157,8 @@ def _rocm_impl(mctx):
http_archive( http_archive(
name = "libpjrt_rocm", name = "libpjrt_rocm",
build_file = "libpjrt_rocm.BUILD.bazel", build_file = "libpjrt_rocm.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-rocm_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-rocm_linux-amd64.tar.gz",
sha256 = "a6d8ef38ae4deda244856549271a1b1a6f46499e9efb64fb71a12fd6ae792d3b", sha256 = "709982b959595750545a01d125adf4893c42f05c60ec290425276bba8aa49f64",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -17,7 +17,7 @@ void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility
"libhsa-amd-aqlprofile64.so", "libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1", "libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so", "libamd_comgr.so",
"libamd_comgr.so.2", "libamd_comgr.so.3",
"librocprofiler-register.so", "librocprofiler-register.so",
"librocprofiler-register.so.0", "librocprofiler-register.so.0",
"libMIOpen.so", "libMIOpen.so",

View File

@ -0,0 +1,58 @@
module(
name = "xla",
version = "20250718.0-6319f0d",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.17")
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.39.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "rules_shell", version = "0.4.1")
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
toolchains_private = use_extension("//:toolchains_private.bzl", "toolchains_private")
use_repo(
toolchains_private,
"rules_ml_toolchain",
)
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
use_repo(
workspace_private,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
"python_version_repo",
"tsl",
)
workspace_public = use_extension("//:xla.bzl", "xla")
use_repo(
workspace_public,
"llvm-raw",
"stablehlo",
"triton",
)
llvm = use_extension("//:llvm.bzl", "llvm")
llvm.configure(
targets = [
"AArch64",
"AMDGPU",
"NVPTX",
"X86",
],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,58 @@
module(
name = "xla",
version = "20250718.0-6319f0d",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.17")
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.39.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "rules_shell", version = "0.4.1")
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
toolchains_private = use_extension("//:toolchains_private.bzl", "toolchains_private")
use_repo(
toolchains_private,
"rules_ml_toolchain",
)
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
use_repo(
workspace_private,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
"python_version_repo",
"tsl",
)
workspace_public = use_extension("//:xla.bzl", "xla")
use_repo(
workspace_public,
"llvm-raw",
"stablehlo",
"triton",
)
llvm = use_extension("//:llvm.bzl", "llvm")
llvm.configure(
targets = [
"AArch64",
"AMDGPU",
"NVPTX",
"X86",
],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,30 @@
load("@llvm-raw//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
def _llvm_impl(mctx):
_targets = {}
for mod in mctx.modules:
for conf in mod.tags.configure:
for target in conf.targets:
_targets[target] = True
_llvm_configure(
name = "llvm-project",
targets = _targets.keys(),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
llvm = module_extension(
implementation = _llvm_impl,
tag_classes = {
"configure": tag_class(
attrs = {
"targets": attr.string_list(
default = [],
),
},
),
},
)

View File

@ -0,0 +1,21 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def _toolchains_private_impl(mctx):
http_archive(
name = "rules_ml_toolchain",
sha256 = "fb78d09234528aef2be856820b69b76486829f65e4eb3c7ffaa5803b667fa441",
strip_prefix = "rules_ml_toolchain-f4ad89fa906be2c1374785a79335c8a7dcd49df7",
urls = [
"https://github.com/zml/rules_ml_toolchain/archive/f4ad89fa906be2c1374785a79335c8a7dcd49df7.tar.gz",
],
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
toolchains_private = module_extension(
implementation = _toolchains_private_impl,
)

View File

@ -0,0 +1,73 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls", "tf_vendored")
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("//third_party/py:python_repo.bzl", "python_repository")
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
def _workspace_private_impl(mctx):
http_archive(
name = "rules_ml_toolchain",
sha256 = "fb78d09234528aef2be856820b69b76486829f65e4eb3c7ffaa5803b667fa441",
strip_prefix = "rules_ml_toolchain-f4ad89fa906be2c1374785a79335c8a7dcd49df7",
urls = [
"https://github.com/zml/rules_ml_toolchain/archive/f4ad89fa906be2c1374785a79335c8a7dcd49df7.tar.gz",
],
)
# Use cuda_configure from XLA to make it work with bzlmod.
# A pure bzlmod solution for rules_ml_toolchain is impossible because of the legacy design.
# It relies on a "generate-then-load" pattern that creates a deadlock in Bazel's architecture:
# - Generate: First, it runs a rule to generate a .bzl file containing configuration data.
# - Load: Then, it requires a load() statement to load that same file to continue the setup.
# This fails in bzlmod because Bazel's Loading Phase (when load() statements are processed) happens before
# the Analysis Phase (when repository rules are run).
# This creates a fundamental chicken-and-egg problem: the build tries to load a file that has not been generated yet.
# Without using the official WORKSPACE.bzlmod escape hatch,
# this incompatibility cannot be resolved without modifying the upstream rules.
cuda_configure(name = "local_config_cuda")
remote_execution_configure(name = "local_config_remote_execution")
rocm_configure(name = "local_config_rocm")
tensorrt_configure(name = "local_config_tensorrt")
tf_vendored(name = "tsl", relpath = "third_party/tsl")
pybind11_bazel()
tf_http_archive(
name = "com_github_grpc_grpc",
sha256 = "afbc5d78d6ba6d509cc6e264de0d49dcd7304db435cbf2d630385bacf49e066c",
strip_prefix = "grpc-1.68.2",
patch_file = [
"//third_party/grpc:grpc.patch",
],
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.68.2.tar.gz"),
)
tf_http_archive(
name = "com_google_protobuf",
patch_file = ["//third_party/protobuf:protobuf.patch"],
sha256 = "f645e6e42745ce922ca5388b1883ca583bafe4366cc74cf35c3c9299005136e2",
strip_prefix = "protobuf-5.28.3",
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/refs/tags/v5.28.3.zip"),
)
python_repository(
name = "python_version_repo",
requirements_versions = ["3.11"],
requirements_locks = ["//:requirements_lock_3_11.txt"],
local_wheel_workspaces = [],
local_wheel_dist_folder = None,
default_python_version = None,
local_wheel_inclusion_list = ["*"],
local_wheel_exclusion_list = [],
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
workspace_private = module_extension(
implementation = _workspace_private_impl,
)

View File

@ -0,0 +1,17 @@
load("//third_party/llvm:workspace.bzl", llvm = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
load("//third_party/triton:workspace.bzl", triton = "repo")
def _xla_impl(mctx):
triton()
llvm("llvm-raw")
stablehlo()
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
xla = module_extension(
implementation = _xla_impl,
)

View File

@ -0,0 +1,41 @@
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Dalido <jb@zml.ai>
Date: Mon, 6 Jan 2025 13:33:13 +0100
Subject: [PATCH] bazel: migration to bazel 8.0.1
---
.bazelversion | 2 +-
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.bazelversion b/.bazelversion
index f22d756da3..fa5fce04b3 100644
--- a/.bazelversion
+++ b/.bazelversion
@@ -1 +1 @@
-7.4.1
+8.1.1
\ No newline at end of file
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index d62531152d..71d80a5a99 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
load(
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
"escape_string",
- "get_env_var",
)
load(
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
- "find_msvc_tool",
"find_vc_path",
"setup_vc_env_vars",
)
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
load(
"//third_party/remote_config:common.bzl",
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,135 @@
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Tue, 27 May 2025 11:48:17 +0200
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
PR: https://github.com/openxla/xla/pull/13420
---
xla/pjrt/c/BUILD | 5 +++++
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
3 files changed, 57 insertions(+), 1 deletion(-)
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
index 79f18fa0bc..0f33dd8a6e 100644
--- a/xla/pjrt/c/BUILD
+++ b/xla/pjrt/c/BUILD
@@ -69,8 +69,13 @@ cc_library(
":pjrt_c_api_wrapper_impl",
"//xla/ffi:execution_context",
"//xla/ffi:type_id_registry",
+ "//xla/ffi:ffi_api",
+ "//xla/ffi/api:c_api",
+ "//xla/ffi/api:ffi",
+ "//xla/service:custom_call_target_registry",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
index 995a2c7e50..b8f10bc2f7 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
// Adds a user data to the execute context.
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
+typedef enum PJRT_FFI_Handler_TraitsBits {
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
+} PJRT_FFI_Handler_TraitsBits;
+
+struct PJRT_FFI_Register_Handler_Args {
+ size_t struct_size;
+ const char* target_name;
+ size_t target_name_size;
+ int api_version; // 0 for an untyped call, 1 -- for typed
+ void* handler;
+ const char* platform_name;
+ size_t platform_name_size;
+ PJRT_FFI_Handler_TraitsBits traits;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
+
+// Registers an FFI call handler for a specific platform.
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args);
+
typedef struct PJRT_FFI_Extension {
PJRT_Extension_Base base;
PJRT_FFI_TypeID_Register* type_id_register;
PJRT_FFI_UserData_Add* user_data_add;
+ PJRT_FFI_Register_Handler* register_handler;
} PJRT_FFI;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
index 5fa88eab33..763270331b 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
+#include <string>
#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
+#include "xla/ffi/api/c_api.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/type_id_registry.h"
+#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
+#include "xla/service/custom_call_target_registry.h"
namespace pjrt {
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
return nullptr;
}
+static PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args) {
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+ "PJRT_FFI_Register_Handler_Args",
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
+ std::string target_name(args->target_name, args->target_name_size);
+ std::string platform_name(args->platform_name, args->platform_name_size);
+ switch (args->api_version) {
+ case 0:
+ xla::CustomCallTargetRegistry::Global()->Register(
+ target_name, args->handler, platform_name);
+ return nullptr;
+ case 1:
+ xla::ffi::Ffi::RegisterStaticHandler(
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
+ return nullptr;
+ default:
+ return new PJRT_Error{absl::UnimplementedError(
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
+ "Supported versions are 0 and 1.",
+ args->api_version))};
+ }
+}
+
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
return {
PJRT_Extension_Base{
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
},
/*type_id_register=*/PJRT_FFI_TypeID_Register,
/*user_data_add=*/PJRT_FFI_UserData_Add,
+ /*register_handler=*/PJRT_FFI_Register_Handler,
};
}
--
2.39.5 (Apple Git-154)

View File

@ -0,0 +1,124 @@
From 6078da86a46b6f0d983dccb9ae4f36fc90640247 Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Fri, 11 Jul 2025 14:05:16 +0200
Subject: [PATCH] zml patch
---
third_party/stablehlo/workspace.bzl | 1 +
third_party/stablehlo/zml.patch | 93 +++++++++++++++++++++++++++++
2 files changed, 94 insertions(+)
create mode 100644 third_party/stablehlo/zml.patch
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index d9d5063744..44980948d0 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -15,5 +15,6 @@ def repo():
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
patch_file = [
"//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove.
+ "//third_party/stablehlo:zml.patch", # Autogenerated, don't remove.
],
)
diff --git a/third_party/stablehlo/zml.patch b/third_party/stablehlo/zml.patch
new file mode 100644
index 0000000000..2a09384582
--- /dev/null
+++ b/third_party/stablehlo/zml.patch
@@ -0,0 +1,93 @@
+From e38ab68376dd8a17ebf4469d2c8350f521310182 Mon Sep 17 00:00:00 2001
+From: Hugo Mano <hugo@zml.ai>
+Date: Fri, 11 Jul 2025 12:08:35 +0200
+Subject: [PATCH] zml patch
+
+---
+ stablehlo/dialect/Serialization.cpp | 5 ++---
+ stablehlo/dialect/Serialization.h | 3 +--
+ stablehlo/integrations/c/StablehloDialectApi.cpp | 3 +--
+ stablehlo/integrations/c/StablehloDialectApi.h | 2 +-
+ stablehlo/tools/StablehloTranslateMain.cpp | 2 +-
+ 5 files changed, 6 insertions(+), 9 deletions(-)
+
+diff --git a/stablehlo/dialect/Serialization.cpp b/stablehlo/dialect/Serialization.cpp
+index cb89d673..4370d588 100644
+--- a/stablehlo/dialect/Serialization.cpp
++++ b/stablehlo/dialect/Serialization.cpp
+@@ -39,8 +39,7 @@ namespace stablehlo {
+
+ LogicalResult serializePortableArtifact(ModuleOp module,
+ StringRef targetVersion,
+- raw_ostream& os,
+- bool allowOtherDialects) {
++ raw_ostream& os) {
+ MLIRContext* context = module.getContext();
+
+ // Convert StableHLO --> VHLO.
+@@ -49,7 +48,7 @@ LogicalResult serializePortableArtifact(ModuleOp module,
+ {
+ PassManager pm(context);
+ StablehloLegalizeToVhloPassOptions options;
+- options.allowOtherDialects = allowOtherDialects;
++ options.allowOtherDialects = false;
+ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options));
+ if (!succeeded(pm.run(module))) {
+ return failure();
+diff --git a/stablehlo/dialect/Serialization.h b/stablehlo/dialect/Serialization.h
+index 811ca97b..abe95e63 100644
+--- a/stablehlo/dialect/Serialization.h
++++ b/stablehlo/dialect/Serialization.h
+@@ -34,8 +34,7 @@ namespace stablehlo {
+ // unsupported dialects.
+ LogicalResult serializePortableArtifact(ModuleOp module,
+ StringRef targetVersion,
+- raw_ostream& os,
+- bool allowOtherDialects = false);
++ raw_ostream& os);
+
+ // Read StableHLO portable artifact
+ //
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/integrations/c/StablehloDialectApi.cpp
+index 343f8d0b..8f52e4d5 100644
+--- a/stablehlo/integrations/c/StablehloDialectApi.cpp
++++ b/stablehlo/integrations/c/StablehloDialectApi.cpp
+@@ -81,8 +81,7 @@ MlirLogicalResult stablehloSerializePortableArtifactFromModule(
+ MlirStringCallback callback, void *userData, bool allowOtherDialects) {
+ mlir::detail::CallbackOstream stream(callback, userData);
+ if (failed(mlir::stablehlo::serializePortableArtifact(
+- unwrap(moduleStr), unwrap(targetVersion), stream,
+- allowOtherDialects)))
++ unwrap(moduleStr), unwrap(targetVersion), stream)))
+ return mlirLogicalResultFailure();
+ return mlirLogicalResultSuccess();
+ }
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h
+index 385156bf..24d11c1d 100644
+--- a/stablehlo/integrations/c/StablehloDialectApi.h
++++ b/stablehlo/integrations/c/StablehloDialectApi.h
+@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
+ MlirStringRef targetVersion,
+ MlirStringCallback callback,
+ void* userData,
+- bool allowOtherDialects = false);
++ bool allowOtherDialects);
+
+ // Read a StableHLO program from a portable artifact, returning the module as
+ // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
+diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp
+index fdf0d6a9..8d5c8752 100644
+--- a/stablehlo/tools/StablehloTranslateMain.cpp
++++ b/stablehlo/tools/StablehloTranslateMain.cpp
+@@ -323,7 +323,7 @@ TranslateFromMLIRRegistration serializeRegistration(
+ }
+
+ return stablehlo::serializePortableArtifact(
+- module, targetVersion, os, allowOtherDialectsOption.getValue());
++ module, targetVersion, os);
+ },
+ [](DialectRegistry &registry) {
+ mlir::registerAllDialects(registry);
+--
+2.39.5 (Apple Git-154)
+
--
2.39.5 (Apple Git-154)

View File

@ -0,0 +1,18 @@
{
"strip_prefix": "xla-6319f0d3bdfd3078e04bb984a759c890b7116484",
"url": "https://github.com/openxla/xla/archive/6319f0d3bdfd3078e04bb984a759c890b7116484.tar.gz",
"integrity": "sha256-0XAMbJ4tn5yFllPGVCSh8st7lgTOUoyeypoudCB8tsY=",
"overlay": {
"toolchains_private.bzl": "",
"llvm.bzl": "",
"MODULE.bazel": "",
"workspace_private.bzl": "",
"xla.bzl": ""
},
"patch_strip": 1,
"patches": {
"0001-bazel-migration-to-bazel-8.1.1.patch": "",
"0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "",
"0003-Remove-unconventional-C-code-in-headers.patch": ""
}
}

View File

@ -23,7 +23,8 @@
"20250317.2-71c67e2", "20250317.2-71c67e2",
"20250527.0-cb67f2f", "20250527.0-cb67f2f",
"20250612.0-6e48cbb", "20250612.0-6e48cbb",
"20250710.0-22ea002" "20250710.0-22ea002",
"20250718.0-6319f0d"
], ],
"yanked_versions": {} "yanked_versions": {}
} }