Update XLA dependency to version 20250527.0‑cb67f2f and refresh related Bazel BUILD, MODULE, overlay and patch files.
This commit is contained in:
parent
fa13287931
commit
95453c7242
37
MODULE.bazel
37
MODULE.bazel
@ -2,8 +2,6 @@ module(
|
|||||||
name = "zml",
|
name = "zml",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_git_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
|
||||||
|
|
||||||
bazel_dep(name = "abseil-cpp", version = "20240722.0.bcr.2")
|
bazel_dep(name = "abseil-cpp", version = "20240722.0.bcr.2")
|
||||||
bazel_dep(name = "aspect_bazel_lib", version = "2.14.0")
|
bazel_dep(name = "aspect_bazel_lib", version = "2.14.0")
|
||||||
bazel_dep(name = "aspect_rules_py", version = "1.3.2")
|
bazel_dep(name = "aspect_rules_py", version = "1.3.2")
|
||||||
@ -22,9 +20,9 @@ bazel_dep(name = "rules_rust", version = "0.60.0")
|
|||||||
bazel_dep(name = "rules_uv", version = "0.65.0")
|
bazel_dep(name = "rules_uv", version = "0.65.0")
|
||||||
bazel_dep(name = "rules_zig", version = "20250519.0-233b207")
|
bazel_dep(name = "rules_zig", version = "20250519.0-233b207")
|
||||||
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
||||||
bazel_dep(name = "toolchains_protoc", version = "0.3.7")
|
bazel_dep(name = "toolchains_protoc", version = "0.4.1")
|
||||||
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
||||||
bazel_dep(name = "xla", version = "20250317.2-71c67e2")
|
bazel_dep(name = "xla", version = "20250527.0-cb67f2f")
|
||||||
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
|
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
|
||||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ register_toolchains("@rules_zig//zig/target:all")
|
|||||||
register_toolchains("@zig_toolchains//:all")
|
register_toolchains("@zig_toolchains//:all")
|
||||||
|
|
||||||
toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolchains")
|
toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolchains")
|
||||||
use_repo(toolchains, "zig_sdk")
|
use_repo(toolchains, "zig_sdk", "zig_sdk-linux-amd64", "zig_sdk-linux-arm64", "zig_sdk-macos-amd64", "zig_sdk-macos-arm64", "zig_sdk-windows-amd64")
|
||||||
|
|
||||||
register_toolchains(
|
register_toolchains(
|
||||||
"@zig_sdk//toolchain:linux_amd64_gnu.2.31",
|
"@zig_sdk//toolchain:linux_amd64_gnu.2.31",
|
||||||
@ -88,14 +86,18 @@ common_apt_packages = use_extension("//runtimes/common:packages.bzl", "common_ap
|
|||||||
use_repo(common_apt_packages, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
use_repo(common_apt_packages, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
||||||
|
|
||||||
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
||||||
use_repo(cpu, "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_linux_amd64")
|
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
||||||
|
|
||||||
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
|
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
|
||||||
|
|
||||||
inject_repo(cuda, "zlib1g")
|
inject_repo(cuda, "zlib1g")
|
||||||
|
|
||||||
use_repo(cuda, "libpjrt_cuda")
|
use_repo(cuda, "libpjrt_cuda")
|
||||||
|
|
||||||
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
||||||
|
|
||||||
inject_repo(rocm, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
inject_repo(rocm, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
||||||
|
|
||||||
use_repo(rocm, "hipblaslt", "libpjrt_rocm", "rocblas")
|
use_repo(rocm, "hipblaslt", "libpjrt_rocm", "rocblas")
|
||||||
|
|
||||||
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
|
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
|
||||||
@ -105,27 +107,21 @@ neuron = use_extension("//runtimes/neuron:neuron.bzl", "neuron_packages")
|
|||||||
use_repo(neuron, "aws-neuronx-collectives", "aws-neuronx-runtime-lib")
|
use_repo(neuron, "aws-neuronx-collectives", "aws-neuronx-runtime-lib")
|
||||||
|
|
||||||
zls = use_extension("//third_party/zls:zls.bzl", "repo")
|
zls = use_extension("//third_party/zls:zls.bzl", "repo")
|
||||||
use_repo(zls, "zls_aarch64-macos", "zls_x86_64-macos", "zls_x86_64-linux")
|
use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux", "zls_x86_64-macos")
|
||||||
|
|
||||||
register_toolchains("//third_party/zls:all")
|
register_toolchains("//third_party/zls:all")
|
||||||
|
|
||||||
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
xla = use_extension("@xla//:xla.bzl", "xla")
|
||||||
use_repo(tsl, "tsl", "python_version_repo")
|
|
||||||
|
|
||||||
xla = use_extension("@xla//:workspace.bzl", "xla_workspace")
|
|
||||||
use_repo(
|
use_repo(
|
||||||
xla,
|
xla,
|
||||||
"com_github_grpc_grpc",
|
"llvm-raw",
|
||||||
"local_config_cuda",
|
|
||||||
"local_config_remote_execution",
|
|
||||||
"local_config_rocm",
|
|
||||||
"local_config_tensorrt",
|
|
||||||
"pybind11_bazel",
|
|
||||||
"stablehlo",
|
"stablehlo",
|
||||||
"triton",
|
"triton",
|
||||||
)
|
)
|
||||||
llvm_configure = use_extension("@xla//:llvm.bzl", "llvm_configure")
|
|
||||||
use_repo(llvm_configure, "llvm-project")
|
llvm = use_extension("@xla//:llvm.bzl", "llvm")
|
||||||
|
llvm.configure()
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
|
|
||||||
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
|
||||||
rust.toolchain(
|
rust.toolchain(
|
||||||
@ -157,7 +153,6 @@ crate.from_cargo(
|
|||||||
use_repo(crate, "crates")
|
use_repo(crate, "crates")
|
||||||
|
|
||||||
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
||||||
inject_repo(non_module_deps, "python_version_repo", "xla", "tsl")
|
|
||||||
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
|
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
|
||||||
|
|
||||||
apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt")
|
apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt")
|
||||||
@ -167,14 +162,12 @@ apt.install(
|
|||||||
manifest = "//runtimes/common:packages.yaml",
|
manifest = "//runtimes/common:packages.yaml",
|
||||||
)
|
)
|
||||||
use_repo(apt, "apt_common")
|
use_repo(apt, "apt_common")
|
||||||
|
|
||||||
apt.install(
|
apt.install(
|
||||||
name = "apt_rocm",
|
name = "apt_rocm",
|
||||||
lock = "//runtimes/rocm:packages.lock.json",
|
lock = "//runtimes/rocm:packages.lock.json",
|
||||||
manifest = "//runtimes/rocm:packages.yaml",
|
manifest = "//runtimes/rocm:packages.yaml",
|
||||||
)
|
)
|
||||||
use_repo(apt, "apt_rocm")
|
use_repo(apt, "apt_rocm")
|
||||||
|
|
||||||
apt.install(
|
apt.install(
|
||||||
name = "apt_neuron",
|
name = "apt_neuron",
|
||||||
lock = "//runtimes/neuron:packages.lock.json",
|
lock = "//runtimes/neuron:packages.lock.json",
|
||||||
|
|||||||
4052
MODULE.bazel.lock
4052
MODULE.bazel.lock
File diff suppressed because one or more lines are too long
@ -14,10 +14,10 @@ cc_library(
|
|||||||
zig_library(
|
zig_library(
|
||||||
name = "pjrt",
|
name = "pjrt",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"convert/trace_container.zig",
|
||||||
|
"convert/xplane_schema.zig",
|
||||||
"ffi.zig",
|
"ffi.zig",
|
||||||
"profiler.zig",
|
"profiler.zig",
|
||||||
"convert/trace_container.zig",
|
|
||||||
"convert/xplane_schema.zig"
|
|
||||||
],
|
],
|
||||||
main = "pjrt.zig",
|
main = "pjrt.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -41,24 +41,27 @@ zig_library(
|
|||||||
zig_proto_library(
|
zig_proto_library(
|
||||||
name = "profiler_options_proto",
|
name = "profiler_options_proto",
|
||||||
import_name = "//tsl:profiler_options_proto",
|
import_name = "//tsl:profiler_options_proto",
|
||||||
deps = ["@tsl//tsl/profiler/protobuf:profiler_options_proto"],
|
deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:profiler_options_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
zig_proto_library(
|
zig_proto_library(
|
||||||
name = "xplane_proto",
|
name = "xplane_proto",
|
||||||
import_name = "//tsl:xplane_proto",
|
import_name = "//tsl:xplane_proto",
|
||||||
deps = ["@tsl//tsl/profiler/protobuf:xplane_proto"],
|
deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:xplane_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
zig_proto_library(
|
zig_proto_library(
|
||||||
name = "trace_events_proto",
|
name = "trace_events_proto",
|
||||||
import_name = "//tsl:trace_events_proto",
|
import_name = "//tsl:trace_events_proto",
|
||||||
deps = ["@tsl//tsl/profiler/protobuf:trace_events_proto"],
|
deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:trace_events_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
zig_cc_binary(
|
zig_cc_binary(
|
||||||
name = "xspace_to_json",
|
name = "xspace_to_json",
|
||||||
srcs = ["convert/trace_container.zig", "convert/xplane_schema.zig"],
|
srcs = [
|
||||||
|
"convert/trace_container.zig",
|
||||||
|
"convert/xplane_schema.zig",
|
||||||
|
],
|
||||||
main = "xspace_to_json.zig",
|
main = "xspace_to_json.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -25,22 +25,22 @@ def _cpu_pjrt_plugin_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_linux_amd64",
|
name = "libpjrt_cpu_linux_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
||||||
sha256 = "1cda1325095c12bd0019838d28ee92d811ac478d22ed3c08020d5a0cd2d9f34a",
|
sha256 = "ca92bccefa168881f98d01354971d6f598381cc4c5f07b161a0908d327610b66",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v7.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_linux-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_amd64",
|
name = "libpjrt_cpu_darwin_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "35af82d9e5c70d16ac15f4c18024a2dd5ed2faefc89940eafe3d5350d2cbd9e7",
|
sha256 = "b6d05b5cd0382a7bd8943b8df98dc229853e402488127895e47786395afb73a7",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v7.0.0/pjrt-cpu_darwin-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_arm64",
|
name = "libpjrt_cpu_darwin_arm64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "da4deaf850d715997614768b2fc0283595ee8181133ab3243d65635e3439de69",
|
sha256 = "e1ac13cf80b0975eec1dc0643a6ec08001d6e07a6a0d500a38e1c4477f49a78c",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v7.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-arm64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -9,12 +9,12 @@ package(default_visibility = ["//visibility:public"])
|
|||||||
ARCH = "linux-x86_64"
|
ARCH = "linux-x86_64"
|
||||||
|
|
||||||
CUDA_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/"
|
CUDA_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/"
|
||||||
CUDA_VERSION = "12.8.1"
|
CUDA_VERSION = "12.9.0"
|
||||||
CUDA_REDIST_JSON_SHA256 = "249e28a83008d711d5f72880541c8be6253f6d61608461de4fcb715554a6cf17"
|
CUDA_REDIST_JSON_SHA256 = "4e4e17a12adcf8cac40b990e1618406cd7ad52da1817819166af28a9dfe21d4a"
|
||||||
|
|
||||||
CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/"
|
CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/"
|
||||||
CUDNN_VERSION = "9.8.0"
|
CUDNN_VERSION = "9.10.1"
|
||||||
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
CUDNN_REDIST_JSON_SHA256 = "2ac8d48d3ab4de1acdce65fa3e8ecfb14750d4e101b05fe3307d2f95f2740563"
|
||||||
|
|
||||||
CUDA_PACKAGES = {
|
CUDA_PACKAGES = {
|
||||||
"cuda_cudart": "\n".join([
|
"cuda_cudart": "\n".join([
|
||||||
@ -40,7 +40,7 @@ CUDA_PACKAGES = {
|
|||||||
"cuda_nvtx": packages.cc_import_glob_hdrs(
|
"cuda_nvtx": packages.cc_import_glob_hdrs(
|
||||||
name = "nvtx",
|
name = "nvtx",
|
||||||
hdrs_glob = ["include/nvtx3/**/*.h"],
|
hdrs_glob = ["include/nvtx3/**/*.h"],
|
||||||
shared_library = "lib/libnvToolsExt.so.1",
|
shared_library = "lib/libnvtx3interop.so.1",
|
||||||
),
|
),
|
||||||
"libcufft": packages.cc_import(
|
"libcufft": packages.cc_import(
|
||||||
name = "cufft",
|
name = "cufft",
|
||||||
@ -88,7 +88,7 @@ CUDA_PACKAGES = {
|
|||||||
),
|
),
|
||||||
packages.cc_import(
|
packages.cc_import(
|
||||||
name = "nvrtc_builtins",
|
name = "nvrtc_builtins",
|
||||||
shared_library = "lib/libnvrtc-builtins.so.12.8",
|
shared_library = "lib/libnvrtc-builtins.so.12.9",
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
"libcublas": "\n".join([
|
"libcublas": "\n".join([
|
||||||
@ -202,9 +202,9 @@ def _cuda_impl(mctx):
|
|||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "nccl",
|
name = "nccl",
|
||||||
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
urls = ["https://files.pythonhosted.org/packages/48/fb/ec4ac065d9b0d56f72eaf1d9b0df601e33da28197b32ca351dc05b342611/nvidia_nccl_cu12-2.26.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
||||||
type = "zip",
|
type = "zip",
|
||||||
sha256 = "362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a",
|
sha256 = "ea5ed3e053c735f16809bee7111deac62ac35b10128a8c102960a0462ce16cbe",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import(
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + packages.cc_import(
|
||||||
name = "nccl",
|
name = "nccl",
|
||||||
shared_library = "nvidia/nccl/lib/libnccl.so.2",
|
shared_library = "nvidia/nccl/lib/libnccl.so.2",
|
||||||
@ -214,8 +214,8 @@ def _cuda_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v7.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cuda_linux-amd64.tar.gz",
|
||||||
sha256 = "64029cd3d68118b166198e3246877ed706ed35eb732b1770a9bf530b5b0a8ab4",
|
sha256 = "2ae18dacd9762e0ae89f223764b1793f8a4d7bd7238bfcd84d2342d7fb37a106",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -33,6 +33,9 @@ pip_compile(
|
|||||||
python_platform = "x86_64-unknown-linux-gnu",
|
python_platform = "x86_64-unknown-linux-gnu",
|
||||||
requirements_in = "requirements.in",
|
requirements_in = "requirements.in",
|
||||||
requirements_txt = "requirements.lock.txt",
|
requirements_txt = "requirements.lock.txt",
|
||||||
|
tags = [
|
||||||
|
"no_ci",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_console_script_binary(
|
py_console_script_binary(
|
||||||
@ -67,6 +70,9 @@ compile_pip_requirements(
|
|||||||
src = "requirements.in",
|
src = "requirements.in",
|
||||||
py_binary = py_binary_with_script,
|
py_binary = py_binary_with_script,
|
||||||
requirements_txt = "requirements.lock.txt",
|
requirements_txt = "requirements.lock.txt",
|
||||||
|
tags = [
|
||||||
|
"no_ci",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
runfiles_to_default(
|
runfiles_to_default(
|
||||||
|
|||||||
@ -127,8 +127,8 @@ def _rocm_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_rocm",
|
name = "libpjrt_rocm",
|
||||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v7.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-rocm_linux-amd64.tar.gz",
|
||||||
sha256 = "13a2ced965a44a0e8ad0d752c8ac5aa99107b17b41bd850967d40b82e102ec50",
|
sha256 = "31223c61645e6a3966841be6ebbc8c56609835a792c75ad1e1442fd5afed759b",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
52
third_party/modules/xla/20250527.0-cb67f2f/MODULE.bazel
vendored
Normal file
52
third_party/modules/xla/20250527.0-cb67f2f/MODULE.bazel
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250527.0-cb67f2f",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.39.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
bazel_dep(name = "rules_shell", version = "0.4.1")
|
||||||
|
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
|
||||||
|
|
||||||
|
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
|
||||||
|
use_repo(
|
||||||
|
workspace_private,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
"python_version_repo",
|
||||||
|
"tsl",
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace_public = use_extension("//:xla.bzl", "xla")
|
||||||
|
use_repo(
|
||||||
|
workspace_public,
|
||||||
|
"llvm-raw",
|
||||||
|
"stablehlo",
|
||||||
|
"triton",
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm = use_extension("//:llvm.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = [
|
||||||
|
"AArch64",
|
||||||
|
"AMDGPU",
|
||||||
|
"NVPTX",
|
||||||
|
"X86",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
52
third_party/modules/xla/20250527.0-cb67f2f/overlay/MODULE.bazel
vendored
Normal file
52
third_party/modules/xla/20250527.0-cb67f2f/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250527.0-cb67f2f",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.39.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
bazel_dep(name = "rules_shell", version = "0.4.1")
|
||||||
|
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
|
||||||
|
|
||||||
|
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
|
||||||
|
use_repo(
|
||||||
|
workspace_private,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
"python_version_repo",
|
||||||
|
"tsl",
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace_public = use_extension("//:xla.bzl", "xla")
|
||||||
|
use_repo(
|
||||||
|
workspace_public,
|
||||||
|
"llvm-raw",
|
||||||
|
"stablehlo",
|
||||||
|
"triton",
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm = use_extension("//:llvm.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = [
|
||||||
|
"AArch64",
|
||||||
|
"AMDGPU",
|
||||||
|
"NVPTX",
|
||||||
|
"X86",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
30
third_party/modules/xla/20250527.0-cb67f2f/overlay/llvm.bzl
vendored
Normal file
30
third_party/modules/xla/20250527.0-cb67f2f/overlay/llvm.bzl
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
load("@llvm-raw//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||||
|
|
||||||
|
def _llvm_impl(mctx):
|
||||||
|
_targets = {}
|
||||||
|
for mod in mctx.modules:
|
||||||
|
for conf in mod.tags.configure:
|
||||||
|
for target in conf.targets:
|
||||||
|
_targets[target] = True
|
||||||
|
_llvm_configure(
|
||||||
|
name = "llvm-project",
|
||||||
|
targets = _targets.keys(),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm = module_extension(
|
||||||
|
implementation = _llvm_impl,
|
||||||
|
tag_classes = {
|
||||||
|
"configure": tag_class(
|
||||||
|
attrs = {
|
||||||
|
"targets": attr.string_list(
|
||||||
|
default = [],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
71
third_party/modules/xla/20250527.0-cb67f2f/overlay/workspace_private.bzl
vendored
Normal file
71
third_party/modules/xla/20250527.0-cb67f2f/overlay/workspace_private.bzl
vendored
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||||
|
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||||
|
load("//third_party/llvm:workspace.bzl", llvm = "repo")
|
||||||
|
load("//third_party/py:python_repo.bzl", "python_repository")
|
||||||
|
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
|
||||||
|
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
|
||||||
|
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||||
|
load("//third_party/triton:workspace.bzl", triton = "repo")
|
||||||
|
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||||
|
load("//third_party:repo.bzl", "tf_vendored")
|
||||||
|
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||||
|
|
||||||
|
def _workspace_private_impl(mctx):
|
||||||
|
cuda_configure(name = "local_config_cuda")
|
||||||
|
remote_execution_configure(name = "local_config_remote_execution")
|
||||||
|
rocm_configure(name = "local_config_rocm")
|
||||||
|
tensorrt_configure(name = "local_config_tensorrt")
|
||||||
|
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||||
|
pybind11_bazel()
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_github_grpc_grpc",
|
||||||
|
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||||
|
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||||
|
system_build_file = "//third_party/systemlibs:grpc.BUILD",
|
||||||
|
patch_file = [
|
||||||
|
"//third_party/grpc:generate_cc_env_fix.patch",
|
||||||
|
"//third_party/grpc:register_go_toolchain.patch",
|
||||||
|
],
|
||||||
|
system_link_files = {
|
||||||
|
"//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel",
|
||||||
|
"//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||||
|
)
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_google_protobuf",
|
||||||
|
patch_file = ["//third_party/protobuf:protobuf.patch"],
|
||||||
|
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||||
|
strip_prefix = "protobuf-3.21.9",
|
||||||
|
system_build_file = "//third_party/systemlibs:protobuf.BUILD",
|
||||||
|
system_link_files = {
|
||||||
|
"//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||||
|
"//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||||
|
)
|
||||||
|
python_repository(
|
||||||
|
name = "python_version_repo",
|
||||||
|
requirements_versions = ["3.11"],
|
||||||
|
requirements_locks = ["//:requirements_lock_3_11.txt"],
|
||||||
|
local_wheel_workspaces = [],
|
||||||
|
local_wheel_dist_folder = None,
|
||||||
|
default_python_version = None,
|
||||||
|
local_wheel_inclusion_list = ["*"],
|
||||||
|
local_wheel_exclusion_list = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace_private = module_extension(
|
||||||
|
implementation = _workspace_private_impl,
|
||||||
|
)
|
||||||
17
third_party/modules/xla/20250527.0-cb67f2f/overlay/xla.bzl
vendored
Normal file
17
third_party/modules/xla/20250527.0-cb67f2f/overlay/xla.bzl
vendored
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
load("//third_party/llvm:workspace.bzl", llvm = "repo")
|
||||||
|
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
|
||||||
|
load("//third_party/triton:workspace.bzl", triton = "repo")
|
||||||
|
|
||||||
|
def _xla_impl(mctx):
|
||||||
|
triton()
|
||||||
|
llvm("llvm-raw")
|
||||||
|
stablehlo()
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla = module_extension(
|
||||||
|
implementation = _xla_impl,
|
||||||
|
)
|
||||||
41
third_party/modules/xla/20250527.0-cb67f2f/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
41
third_party/modules/xla/20250527.0-cb67f2f/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jean-Baptiste Dalido <jb@zml.ai>
|
||||||
|
Date: Mon, 6 Jan 2025 13:33:13 +0100
|
||||||
|
Subject: [PATCH] bazel: migration to bazel 8.0.1
|
||||||
|
|
||||||
|
---
|
||||||
|
.bazelversion | 2 +-
|
||||||
|
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
|
||||||
|
2 files changed, 3 insertions(+), 3 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/.bazelversion b/.bazelversion
|
||||||
|
index f22d756da3..fa5fce04b3 100644
|
||||||
|
--- a/.bazelversion
|
||||||
|
+++ b/.bazelversion
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-7.4.1
|
||||||
|
+8.1.1
|
||||||
|
\ No newline at end of file
|
||||||
|
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
|
||||||
|
index d62531152d..71d80a5a99 100644
|
||||||
|
--- a/third_party/gpus/cuda_configure.bzl
|
||||||
|
+++ b/third_party/gpus/cuda_configure.bzl
|
||||||
|
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
|
||||||
|
load(
|
||||||
|
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
|
||||||
|
"escape_string",
|
||||||
|
- "get_env_var",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
|
||||||
|
- "find_msvc_tool",
|
||||||
|
"find_vc_path",
|
||||||
|
"setup_vc_env_vars",
|
||||||
|
)
|
||||||
|
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
|
||||||
|
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
|
||||||
|
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
|
||||||
|
load(
|
||||||
|
"//third_party/remote_config:common.bzl",
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-146)
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
|
||||||
|
From: Hugo Mano <hugo@zml.ai>
|
||||||
|
Date: Tue, 27 May 2025 11:48:17 +0200
|
||||||
|
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
||||||
|
|
||||||
|
PR: https://github.com/openxla/xla/pull/13420
|
||||||
|
---
|
||||||
|
xla/pjrt/c/BUILD | 5 +++++
|
||||||
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
|
||||||
|
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
|
||||||
|
3 files changed, 57 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
||||||
|
index 79f18fa0bc..0f33dd8a6e 100644
|
||||||
|
--- a/xla/pjrt/c/BUILD
|
||||||
|
+++ b/xla/pjrt/c/BUILD
|
||||||
|
@@ -69,8 +69,13 @@ cc_library(
|
||||||
|
":pjrt_c_api_wrapper_impl",
|
||||||
|
"//xla/ffi:execution_context",
|
||||||
|
"//xla/ffi:type_id_registry",
|
||||||
|
+ "//xla/ffi:ffi_api",
|
||||||
|
+ "//xla/ffi/api:c_api",
|
||||||
|
+ "//xla/ffi/api:ffi",
|
||||||
|
+ "//xla/service:custom_call_target_registry",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:string_view",
|
||||||
|
+ "@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
index 995a2c7e50..b8f10bc2f7 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
||||||
|
// Adds a user data to the execute context.
|
||||||
|
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
||||||
|
|
||||||
|
+typedef enum PJRT_FFI_Handler_TraitsBits {
|
||||||
|
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
||||||
|
+} PJRT_FFI_Handler_TraitsBits;
|
||||||
|
+
|
||||||
|
+struct PJRT_FFI_Register_Handler_Args {
|
||||||
|
+ size_t struct_size;
|
||||||
|
+ const char* target_name;
|
||||||
|
+ size_t target_name_size;
|
||||||
|
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
||||||
|
+ void* handler;
|
||||||
|
+ const char* platform_name;
|
||||||
|
+ size_t platform_name_size;
|
||||||
|
+ PJRT_FFI_Handler_TraitsBits traits;
|
||||||
|
+};
|
||||||
|
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
|
||||||
|
+
|
||||||
|
+// Registers an FFI call handler for a specific platform.
|
||||||
|
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
||||||
|
+ PJRT_FFI_Register_Handler_Args* args);
|
||||||
|
+
|
||||||
|
typedef struct PJRT_FFI_Extension {
|
||||||
|
PJRT_Extension_Base base;
|
||||||
|
PJRT_FFI_TypeID_Register* type_id_register;
|
||||||
|
PJRT_FFI_UserData_Add* user_data_add;
|
||||||
|
+ PJRT_FFI_Register_Handler* register_handler;
|
||||||
|
} PJRT_FFI;
|
||||||
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
index 5fa88eab33..763270331b 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
||||||
|
+#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
+#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
+#include "xla/ffi/api/c_api.h"
|
||||||
|
#include "xla/ffi/execution_context.h"
|
||||||
|
#include "xla/ffi/type_id_registry.h"
|
||||||
|
+#include "xla/ffi/ffi_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
||||||
|
+#include "xla/service/custom_call_target_registry.h"
|
||||||
|
|
||||||
|
namespace pjrt {
|
||||||
|
|
||||||
|
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
||||||
|
+ PJRT_FFI_Register_Handler_Args* args) {
|
||||||
|
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||||
|
+ "PJRT_FFI_Register_Handler_Args",
|
||||||
|
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
||||||
|
+ std::string target_name(args->target_name, args->target_name_size);
|
||||||
|
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
||||||
|
+ switch (args->api_version) {
|
||||||
|
+ case 0:
|
||||||
|
+ xla::CustomCallTargetRegistry::Global()->Register(
|
||||||
|
+ target_name, args->handler, platform_name);
|
||||||
|
+ return nullptr;
|
||||||
|
+ case 1:
|
||||||
|
+ xla::ffi::Ffi::RegisterStaticHandler(
|
||||||
|
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
||||||
|
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
||||||
|
+ return nullptr;
|
||||||
|
+ default:
|
||||||
|
+ return new PJRT_Error{absl::UnimplementedError(
|
||||||
|
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
||||||
|
+ "Supported versions are 0 and 1.",
|
||||||
|
+ args->api_version))};
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||||
|
return {
|
||||||
|
PJRT_Extension_Base{
|
||||||
|
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||||
|
},
|
||||||
|
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
||||||
|
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
||||||
|
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
--
|
||||||
|
2.39.5 (Apple Git-154)
|
||||||
|
|
||||||
@ -0,0 +1,186 @@
|
|||||||
|
From 2a62bc5df9774810313142eb0c9390aab3cd18f8 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Hugo Mano <hugo@zml.ai>
|
||||||
|
Date: Thu, 29 May 2025 08:02:32 +0200
|
||||||
|
Subject: [PATCH] Revert "Add optional allowOtherDialects field to
|
||||||
|
stablehlo.serialize_portable_artifact."
|
||||||
|
|
||||||
|
Commit: https://github.com/openxla/xla/commit/e7137a383809a24875a95237b1d1f6485acdf710
|
||||||
|
|
||||||
|
Issue: C does not support default arguments with ZigTranslateC
|
||||||
|
---
|
||||||
|
third_party/stablehlo/temporary.patch | 151 ++++----------------------
|
||||||
|
1 file changed, 21 insertions(+), 130 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
|
||||||
|
index 6e1fd159f9..d17c141b18 100755
|
||||||
|
--- a/third_party/stablehlo/temporary.patch
|
||||||
|
+++ b/third_party/stablehlo/temporary.patch
|
||||||
|
@@ -1,3 +1,23 @@
|
||||||
|
+diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||||||
|
+--- stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||||||
|
++++ stablehlo/stablehlo/dialect/StablehloOps.cpp
|
||||||
|
+@@ -511,12 +511,10 @@
|
||||||
|
+ void CustomCallOp::getEffects(
|
||||||
|
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
|
||||||
|
+ effects) {
|
||||||
|
+- // Note: `has_side_effect` "defaults" to `false` but isn't required to exist.
|
||||||
|
+- // This semantic contradiction means, in practical terms, that the attribute
|
||||||
|
+- // won't exist by default but should be *treated* as `false` if missing.
|
||||||
|
+- // `getHasSideEffect()` abstracts this nuance away and returns `false` by
|
||||||
|
+- // default, whereas `getHasSideEffectAttr()` may return a null attribute.
|
||||||
|
+- if (!getHasSideEffect()) return;
|
||||||
|
++ // CustomCall has "all possible effects" unless the has_side_effect is present
|
||||||
|
++ // and set to false.
|
||||||
|
++ auto hasSideEffect = getHasSideEffectAttr();
|
||||||
|
++ if (hasSideEffect && !hasSideEffect.getValue()) return;
|
||||||
|
+ effects.emplace_back(MemoryEffects::Allocate::get());
|
||||||
|
+ effects.emplace_back(MemoryEffects::Free::get());
|
||||||
|
+ effects.emplace_back(MemoryEffects::Write::get());
|
||||||
|
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.h b/stablehlo/stablehlo/dialect/StablehloOps.h
|
||||||
|
--- stablehlo/stablehlo/dialect/StablehloOps.h
|
||||||
|
+++ stablehlo/stablehlo/dialect/StablehloOps.h
|
||||||
|
@@ -82,135 +102,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d
|
||||||
|
]> {
|
||||||
|
let summary = "Recv operation";
|
||||||
|
let description = [{
|
||||||
|
-diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
|
||||||
|
---- stablehlo/stablehlo/dialect/TypeInference.cpp
|
||||||
|
-+++ stablehlo/stablehlo/dialect/TypeInference.cpp
|
||||||
|
-@@ -879,7 +879,8 @@
|
||||||
|
-
|
||||||
|
- auto replicaIds = replicaGroups.getValues<int64_t>();
|
||||||
|
-
|
||||||
|
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
|
||||||
|
-+ // Large programs can have many replicas, use a set with efficient lookup.
|
||||||
|
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
|
||||||
|
- for (int64_t replicaId : replicaIds) {
|
||||||
|
- // Replica groups are stored in a 2D tensor. If the op supports non-uniform
|
||||||
|
- // groups, null replica IDs are stored as -1.
|
||||||
|
-@@ -1841,6 +1842,7 @@
|
||||||
|
- /*allGroupsMustHaveSameSize=*/true,
|
||||||
|
- /*useGlobalDeviceIds=*/false, splitCount)))
|
||||||
|
- return failure();
|
||||||
|
-+
|
||||||
|
- for (const Value& operand : operands) {
|
||||||
|
- auto operandType = cast<RankedTensorType>(operand.getType());
|
||||||
|
-
|
||||||
|
-@@ -3562,6 +3564,19 @@
|
||||||
|
- DenseIntElementsAttr replicaGroups,
|
||||||
|
- int64_t channelId, bool useGlobalDeviceIds,
|
||||||
|
- ValueRange results) {
|
||||||
|
-+ // all_gather_i3, all_gather_c2, all_gather_c4
|
||||||
|
-+ if (failed(verifyReplicaGroups(location, replicaGroups,
|
||||||
|
-+ /*allGroupsMustHaveSameSize=*/true,
|
||||||
|
-+ useGlobalDeviceIds,
|
||||||
|
-+ /*expectedGroupSize=*/std::nullopt)))
|
||||||
|
-+ return failure();
|
||||||
|
-+
|
||||||
|
-+ // all_gather_c5
|
||||||
|
-+ if (useGlobalDeviceIds && channelId < 0)
|
||||||
|
-+ return emitOptionalError(
|
||||||
|
-+ location,
|
||||||
|
-+ "channel_id cannot be negative when useGlobalDeviceIds is set");
|
||||||
|
-+
|
||||||
|
- for (const auto& [operand, result] : llvm::zip(operands, results)) {
|
||||||
|
- auto operandType = cast<RankedTensorType>(operand.getType());
|
||||||
|
- auto resultType = cast<RankedTensorType>(result.getType());
|
||||||
|
-@@ -3576,19 +3591,6 @@
|
||||||
|
- return emitOptionalError(
|
||||||
|
- location,
|
||||||
|
- "dimension size of operand at 'all_gather_dim' cannot be zero");
|
||||||
|
--
|
||||||
|
-- // all_gather_i3, all_gather_c2, all_gather_c4
|
||||||
|
-- if (failed(verifyReplicaGroups(location, replicaGroups,
|
||||||
|
-- /*allGroupsMustHaveSameSize=*/true,
|
||||||
|
-- useGlobalDeviceIds,
|
||||||
|
-- /*expectedGroupSize=*/std::nullopt)))
|
||||||
|
-- return failure();
|
||||||
|
--
|
||||||
|
-- // all_gather_c5
|
||||||
|
-- if (useGlobalDeviceIds && channelId < 0)
|
||||||
|
-- return emitOptionalError(
|
||||||
|
-- location,
|
||||||
|
-- "channel_id cannot be negative when useGlobalDeviceIds is set");
|
||||||
|
-
|
||||||
|
- // all_gather_c6
|
||||||
|
- if (resultType.getRank() != operandType.getRank())
|
||||||
|
-@@ -3788,7 +3790,7 @@
|
||||||
|
- "but instead it is of rank ", replicaGroupType.getRank());
|
||||||
|
-
|
||||||
|
- auto replicaIds = replicaGroups.getValues<int64_t>();
|
||||||
|
-- llvm::SmallSet<int64_t, 8> replicaIdsSeen;
|
||||||
|
-+ llvm::DenseSet<int64_t> replicaIdsSeen;
|
||||||
|
- for (int64_t replicaId : replicaIds) {
|
||||||
|
- // collective_broadcast_c2
|
||||||
|
- // We only check that is is not negative, as it is impossible
|
||||||
|
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||||||
|
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||||||
|
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.cpp
|
||||||
|
-@@ -78,10 +78,11 @@
|
||||||
|
-
|
||||||
|
- MlirLogicalResult stablehloSerializePortableArtifactFromModule(
|
||||||
|
- MlirModule moduleStr, MlirStringRef targetVersion,
|
||||||
|
-- MlirStringCallback callback, void *userData) {
|
||||||
|
-+ MlirStringCallback callback, void *userData, bool allowOtherDialects) {
|
||||||
|
- mlir::detail::CallbackOstream stream(callback, userData);
|
||||||
|
- if (failed(mlir::stablehlo::serializePortableArtifact(
|
||||||
|
-- unwrap(moduleStr), unwrap(targetVersion), stream)))
|
||||||
|
-+ unwrap(moduleStr), unwrap(targetVersion), stream,
|
||||||
|
-+ allowOtherDialects)))
|
||||||
|
- return mlirLogicalResultFailure();
|
||||||
|
- return mlirLogicalResultSuccess();
|
||||||
|
- }
|
||||||
|
-diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||||||
|
---- stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||||||
|
-+++ stablehlo/stablehlo/integrations/c/StablehloDialectApi.h
|
||||||
|
-@@ -92,7 +92,8 @@
|
||||||
|
- stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
|
||||||
|
- MlirStringRef targetVersion,
|
||||||
|
- MlirStringCallback callback,
|
||||||
|
-- void* userData);
|
||||||
|
-+ void* userData,
|
||||||
|
-+ bool allowOtherDialects = false);
|
||||||
|
-
|
||||||
|
- // Read a StableHLO program from a portable artifact, returning the module as
|
||||||
|
- // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
|
||||||
|
-diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||||||
|
---- stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||||||
|
-+++ stablehlo/stablehlo/integrations/python/StablehloApi.cpp
|
||||||
|
-@@ -102,20 +102,22 @@
|
||||||
|
- //
|
||||||
|
- m.def(
|
||||||
|
- "serialize_portable_artifact",
|
||||||
|
-- [](MlirModule module, std::string_view target) -> nb::bytes {
|
||||||
|
-+ [](MlirModule module, std::string_view target,
|
||||||
|
-+ bool allowOtherDialects) -> nb::bytes {
|
||||||
|
- StringWriterHelper accumulator;
|
||||||
|
- if (mlirLogicalResultIsFailure(
|
||||||
|
- stablehloSerializePortableArtifactFromModule(
|
||||||
|
- module, toMlirStringRef(target),
|
||||||
|
- accumulator.getMlirStringCallback(),
|
||||||
|
-- accumulator.getUserData()))) {
|
||||||
|
-+ accumulator.getUserData(), allowOtherDialects))) {
|
||||||
|
- throw nb::value_error("failed to serialize module");
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- std::string serialized = accumulator.toString();
|
||||||
|
- return nb::bytes(serialized.data(), serialized.size());
|
||||||
|
- },
|
||||||
|
-- nb::arg("module"), nb::arg("target"));
|
||||||
|
-+ nb::arg("module"), nb::arg("target"),
|
||||||
|
-+ nb::arg("allow_other_dialects") = false);
|
||||||
|
-
|
||||||
|
- m.def(
|
||||||
|
- "deserialize_portable_artifact",
|
||||||
|
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||||||
|
--- stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||||||
|
+++ stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.mlir
|
||||||
|
@@ -223,4 +114,4 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_convert_to_signless.
|
||||||
|
%7 = builtin.unrealized_conversion_cast %6 : memref<i16> to memref<ui16>
|
||||||
|
func.return %7 : memref<ui16>
|
||||||
|
}
|
||||||
|
-
|
||||||
|
+
|
||||||
|
--
|
||||||
|
2.39.5 (Apple Git-154)
|
||||||
|
|
||||||
17
third_party/modules/xla/20250527.0-cb67f2f/source.json
vendored
Normal file
17
third_party/modules/xla/20250527.0-cb67f2f/source.json
vendored
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "xla-cb67f2f7ce4787f63f5fc80dc5c30cd3dee8f4e3",
|
||||||
|
"url": "https://github.com/openxla/xla/archive/cb67f2f7ce4787f63f5fc80dc5c30cd3dee8f4e3.tar.gz",
|
||||||
|
"integrity": "sha256-SDAAOY6cjcCQ5e1JMob5G4+BYHk7spDb5zZEDrVeA4I=",
|
||||||
|
"overlay": {
|
||||||
|
"llvm.bzl": "",
|
||||||
|
"MODULE.bazel": "",
|
||||||
|
"workspace_private.bzl": "",
|
||||||
|
"xla.bzl": ""
|
||||||
|
},
|
||||||
|
"patch_strip": 1,
|
||||||
|
"patches": {
|
||||||
|
"0001-bazel-migration-to-bazel-8.1.1.patch": "",
|
||||||
|
"0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "",
|
||||||
|
"0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
3
third_party/modules/xla/metadata.json
vendored
3
third_party/modules/xla/metadata.json
vendored
@ -20,7 +20,8 @@
|
|||||||
"20250204-0-6789523",
|
"20250204-0-6789523",
|
||||||
"20250317.0-71c67e2",
|
"20250317.0-71c67e2",
|
||||||
"20250317.1-71c67e2",
|
"20250317.1-71c67e2",
|
||||||
"20250317.2-71c67e2"
|
"20250317.2-71c67e2",
|
||||||
|
"20250527.0-cb67f2f"
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -44,7 +44,7 @@ zig_library(
|
|||||||
zig_proto_library(
|
zig_proto_library(
|
||||||
name = "xla_proto",
|
name = "xla_proto",
|
||||||
import_name = "//xla:xla_proto",
|
import_name = "//xla:xla_proto",
|
||||||
deps = ["@xla//xla/pjrt:compile_options_proto"],
|
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
zig_proto_library(
|
zig_proto_library(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user