Workspace: bump XLA to newer version.
This commit is contained in:
parent
9ef838be25
commit
1cafcc3c60
@ -22,7 +22,7 @@ bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
||||
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.3")
|
||||
bazel_dep(name = "toolchains_protoc", version = "0.4.1")
|
||||
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
||||
bazel_dep(name = "xla", version = "20250527.0-cb67f2f")
|
||||
bazel_dep(name = "xla", version = "20250612.0-6e48cbb")
|
||||
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
|
||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||
|
||||
|
||||
@ -25,22 +25,22 @@ def _cpu_pjrt_plugin_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_linux_amd64",
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
||||
sha256 = "ca92bccefa168881f98d01354971d6f598381cc4c5f07b161a0908d327610b66",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_linux-amd64.tar.gz",
|
||||
sha256 = "4106ca11ab41bc9ec000d536ae084442139b5639ca329bfb62c7e0742acdc47a",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_darwin_amd64",
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||
sha256 = "b6d05b5cd0382a7bd8943b8df98dc229853e402488127895e47786395afb73a7",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-amd64.tar.gz",
|
||||
sha256 = "7be4d98f0737601fba7b29563917054aac3d09365139e6d3f5f96023a8c71c87",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-amd64.tar.gz",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_darwin_arm64",
|
||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||
sha256 = "e1ac13cf80b0975eec1dc0643a6ec08001d6e07a6a0d500a38e1c4477f49a78c",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cpu_darwin-arm64.tar.gz",
|
||||
sha256 = "442cccd98d7adf4afe0f818ebba265baca6b68dea95b10ef2b4d4229b81d5412",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
@ -214,8 +214,8 @@ def _cuda_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_cuda",
|
||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "2ae18dacd9762e0ae89f223764b1793f8a4d7bd7238bfcd84d2342d7fb37a106",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "eddf4db325aaeb1692e9eff1b5021dbeda27c08e527cae87295a61d94e654395",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
@ -127,8 +127,8 @@ def _rocm_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_rocm",
|
||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v9.0.1/pjrt-rocm_linux-amd64.tar.gz",
|
||||
sha256 = "31223c61645e6a3966841be6ebbc8c56609835a792c75ad1e1442fd5afed759b",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v10.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
||||
sha256 = "ce5badf1ba5d1073a7de1e4d1d2a97fd1b66876d1fa255f913ffd410f50e6bc5",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
52
third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel
vendored
Normal file
52
third_party/modules/xla/20250612.0-6e48cbb/MODULE.bazel
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20250612.0-6e48cbb",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.39.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
bazel_dep(name = "rules_shell", version = "0.4.1")
|
||||
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
|
||||
|
||||
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
|
||||
use_repo(
|
||||
workspace_private,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
"python_version_repo",
|
||||
"tsl",
|
||||
)
|
||||
|
||||
workspace_public = use_extension("//:xla.bzl", "xla")
|
||||
use_repo(
|
||||
workspace_public,
|
||||
"llvm-raw",
|
||||
"stablehlo",
|
||||
"triton",
|
||||
)
|
||||
|
||||
llvm = use_extension("//:llvm.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = [
|
||||
"AArch64",
|
||||
"AMDGPU",
|
||||
"NVPTX",
|
||||
"X86",
|
||||
],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
52
third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel
vendored
Normal file
52
third_party/modules/xla/20250612.0-6e48cbb/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20250527.0-cb67f2f",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||
bazel_dep(name = "rules_apple", version = "3.22.0", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.39.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
bazel_dep(name = "rules_shell", version = "0.4.1")
|
||||
bazel_dep(name = "bazel_features", version = "1.25.0", repo_name = "proto_bazel_features")
|
||||
|
||||
workspace_private = use_extension("//:workspace_private.bzl", "workspace_private")
|
||||
use_repo(
|
||||
workspace_private,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
"python_version_repo",
|
||||
"tsl",
|
||||
)
|
||||
|
||||
workspace_public = use_extension("//:xla.bzl", "xla")
|
||||
use_repo(
|
||||
workspace_public,
|
||||
"llvm-raw",
|
||||
"stablehlo",
|
||||
"triton",
|
||||
)
|
||||
|
||||
llvm = use_extension("//:llvm.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = [
|
||||
"AArch64",
|
||||
"AMDGPU",
|
||||
"NVPTX",
|
||||
"X86",
|
||||
],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
30
third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl
vendored
Normal file
30
third_party/modules/xla/20250612.0-6e48cbb/overlay/llvm.bzl
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
load("@llvm-raw//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||
|
||||
def _llvm_impl(mctx):
|
||||
_targets = {}
|
||||
for mod in mctx.modules:
|
||||
for conf in mod.tags.configure:
|
||||
for target in conf.targets:
|
||||
_targets[target] = True
|
||||
_llvm_configure(
|
||||
name = "llvm-project",
|
||||
targets = _targets.keys(),
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
llvm = module_extension(
|
||||
implementation = _llvm_impl,
|
||||
tag_classes = {
|
||||
"configure": tag_class(
|
||||
attrs = {
|
||||
"targets": attr.string_list(
|
||||
default = [],
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
71
third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl
vendored
Normal file
71
third_party/modules/xla/20250612.0-6e48cbb/overlay/workspace_private.bzl
vendored
Normal file
@ -0,0 +1,71 @@
|
||||
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||
load("//third_party/llvm:workspace.bzl", llvm = "repo")
|
||||
load("//third_party/py:python_repo.bzl", "python_repository")
|
||||
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
|
||||
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
|
||||
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||
load("//third_party/triton:workspace.bzl", triton = "repo")
|
||||
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
load("//third_party:repo.bzl", "tf_vendored")
|
||||
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||
|
||||
def _workspace_private_impl(mctx):
|
||||
cuda_configure(name = "local_config_cuda")
|
||||
remote_execution_configure(name = "local_config_remote_execution")
|
||||
rocm_configure(name = "local_config_rocm")
|
||||
tensorrt_configure(name = "local_config_tensorrt")
|
||||
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||
pybind11_bazel()
|
||||
tf_http_archive(
|
||||
name = "com_github_grpc_grpc",
|
||||
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||
system_build_file = "//third_party/systemlibs:grpc.BUILD",
|
||||
patch_file = [
|
||||
"//third_party/grpc:generate_cc_env_fix.patch",
|
||||
"//third_party/grpc:register_go_toolchain.patch",
|
||||
],
|
||||
system_link_files = {
|
||||
"//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel",
|
||||
"//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||
"//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||
"//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||
"//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||
"//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||
"//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||
)
|
||||
tf_http_archive(
|
||||
name = "com_google_protobuf",
|
||||
patch_file = ["//third_party/protobuf:protobuf.patch"],
|
||||
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||
strip_prefix = "protobuf-3.21.9",
|
||||
system_build_file = "//third_party/systemlibs:protobuf.BUILD",
|
||||
system_link_files = {
|
||||
"//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||
"//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||
)
|
||||
python_repository(
|
||||
name = "python_version_repo",
|
||||
requirements_versions = ["3.11"],
|
||||
requirements_locks = ["//:requirements_lock_3_11.txt"],
|
||||
local_wheel_workspaces = [],
|
||||
local_wheel_dist_folder = None,
|
||||
default_python_version = None,
|
||||
local_wheel_inclusion_list = ["*"],
|
||||
local_wheel_exclusion_list = [],
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
workspace_private = module_extension(
|
||||
implementation = _workspace_private_impl,
|
||||
)
|
||||
17
third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl
vendored
Normal file
17
third_party/modules/xla/20250612.0-6e48cbb/overlay/xla.bzl
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
load("//third_party/llvm:workspace.bzl", llvm = "repo")
|
||||
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
|
||||
load("//third_party/triton:workspace.bzl", triton = "repo")
|
||||
|
||||
def _xla_impl(mctx):
|
||||
triton()
|
||||
llvm("llvm-raw")
|
||||
stablehlo()
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
xla = module_extension(
|
||||
implementation = _xla_impl,
|
||||
)
|
||||
41
third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
41
third_party/modules/xla/20250612.0-6e48cbb/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
|
||||
From: Jean-Baptiste Dalido <jb@zml.ai>
|
||||
Date: Mon, 6 Jan 2025 13:33:13 +0100
|
||||
Subject: [PATCH] bazel: migration to bazel 8.0.1
|
||||
|
||||
---
|
||||
.bazelversion | 2 +-
|
||||
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
|
||||
2 files changed, 3 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/.bazelversion b/.bazelversion
|
||||
index f22d756da3..fa5fce04b3 100644
|
||||
--- a/.bazelversion
|
||||
+++ b/.bazelversion
|
||||
@@ -1 +1 @@
|
||||
-7.4.1
|
||||
+8.1.1
|
||||
\ No newline at end of file
|
||||
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
|
||||
index d62531152d..71d80a5a99 100644
|
||||
--- a/third_party/gpus/cuda_configure.bzl
|
||||
+++ b/third_party/gpus/cuda_configure.bzl
|
||||
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
|
||||
load(
|
||||
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
|
||||
"escape_string",
|
||||
- "get_env_var",
|
||||
)
|
||||
load(
|
||||
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
|
||||
- "find_msvc_tool",
|
||||
"find_vc_path",
|
||||
"setup_vc_env_vars",
|
||||
)
|
||||
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
|
||||
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
|
||||
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
|
||||
load(
|
||||
"//third_party/remote_config:common.bzl",
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
@ -0,0 +1,135 @@
|
||||
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
|
||||
From: Hugo Mano <hugo@zml.ai>
|
||||
Date: Tue, 27 May 2025 11:48:17 +0200
|
||||
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
||||
|
||||
PR: https://github.com/openxla/xla/pull/13420
|
||||
---
|
||||
xla/pjrt/c/BUILD | 5 +++++
|
||||
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
|
||||
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
|
||||
3 files changed, 57 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
||||
index 79f18fa0bc..0f33dd8a6e 100644
|
||||
--- a/xla/pjrt/c/BUILD
|
||||
+++ b/xla/pjrt/c/BUILD
|
||||
@@ -69,8 +69,13 @@ cc_library(
|
||||
":pjrt_c_api_wrapper_impl",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:type_id_registry",
|
||||
+ "//xla/ffi:ffi_api",
|
||||
+ "//xla/ffi/api:c_api",
|
||||
+ "//xla/ffi/api:ffi",
|
||||
+ "//xla/service:custom_call_target_registry",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
+ "@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||
index 995a2c7e50..b8f10bc2f7 100644
|
||||
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
||||
// Adds a user data to the execute context.
|
||||
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
||||
|
||||
+typedef enum PJRT_FFI_Handler_TraitsBits {
|
||||
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
||||
+} PJRT_FFI_Handler_TraitsBits;
|
||||
+
|
||||
+struct PJRT_FFI_Register_Handler_Args {
|
||||
+ size_t struct_size;
|
||||
+ const char* target_name;
|
||||
+ size_t target_name_size;
|
||||
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
||||
+ void* handler;
|
||||
+ const char* platform_name;
|
||||
+ size_t platform_name_size;
|
||||
+ PJRT_FFI_Handler_TraitsBits traits;
|
||||
+};
|
||||
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
|
||||
+
|
||||
+// Registers an FFI call handler for a specific platform.
|
||||
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
||||
+ PJRT_FFI_Register_Handler_Args* args);
|
||||
+
|
||||
typedef struct PJRT_FFI_Extension {
|
||||
PJRT_Extension_Base base;
|
||||
PJRT_FFI_TypeID_Register* type_id_register;
|
||||
PJRT_FFI_UserData_Add* user_data_add;
|
||||
+ PJRT_FFI_Register_Handler* register_handler;
|
||||
} PJRT_FFI;
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
||||
|
||||
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||
index 5fa88eab33..763270331b 100644
|
||||
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
||||
+#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
+#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
+#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/type_id_registry.h"
|
||||
+#include "xla/ffi/ffi_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
||||
+#include "xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace pjrt {
|
||||
|
||||
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
||||
+ PJRT_FFI_Register_Handler_Args* args) {
|
||||
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||
+ "PJRT_FFI_Register_Handler_Args",
|
||||
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
||||
+ std::string target_name(args->target_name, args->target_name_size);
|
||||
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
||||
+ switch (args->api_version) {
|
||||
+ case 0:
|
||||
+ xla::CustomCallTargetRegistry::Global()->Register(
|
||||
+ target_name, args->handler, platform_name);
|
||||
+ return nullptr;
|
||||
+ case 1:
|
||||
+ xla::ffi::Ffi::RegisterStaticHandler(
|
||||
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
||||
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
||||
+ return nullptr;
|
||||
+ default:
|
||||
+ return new PJRT_Error{absl::UnimplementedError(
|
||||
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
||||
+ "Supported versions are 0 and 1.",
|
||||
+ args->api_version))};
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||
return {
|
||||
PJRT_Extension_Base{
|
||||
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||
},
|
||||
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
||||
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
||||
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
||||
};
|
||||
}
|
||||
|
||||
--
|
||||
2.39.5 (Apple Git-154)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
17
third_party/modules/xla/20250612.0-6e48cbb/source.json
vendored
Normal file
17
third_party/modules/xla/20250612.0-6e48cbb/source.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"strip_prefix": "xla-6e48cbb8d33d771c964697e39bfaf678bcc6de31",
|
||||
"url": "https://github.com/openxla/xla/archive/6e48cbb8d33d771c964697e39bfaf678bcc6de31.tar.gz",
|
||||
"integrity": "sha256-i9lYvZ2MkzfyVW2Iu3qIucXIgGEhkbwsYXCrUZ6Yze8=",
|
||||
"overlay": {
|
||||
"llvm.bzl": "",
|
||||
"MODULE.bazel": "",
|
||||
"workspace_private.bzl": "",
|
||||
"xla.bzl": ""
|
||||
},
|
||||
"patch_strip": 1,
|
||||
"patches": {
|
||||
"0001-bazel-migration-to-bazel-8.1.1.patch": "",
|
||||
"0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "",
|
||||
"0003-Revert-Add-optional-allowOtherDialects-field-to-stab.patch": ""
|
||||
}
|
||||
}
|
||||
3
third_party/modules/xla/metadata.json
vendored
3
third_party/modules/xla/metadata.json
vendored
@ -21,7 +21,8 @@
|
||||
"20250317.0-71c67e2",
|
||||
"20250317.1-71c67e2",
|
||||
"20250317.2-71c67e2",
|
||||
"20250527.0-cb67f2f"
|
||||
"20250527.0-cb67f2f",
|
||||
"20250612.0-6e48cbb"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user