xla: bump to commit b3fbfee, temporarily disable libnvptxcompiler due to missing support in PjRT CUDA plugin v13.0, add nvshmem to sandbox for PjRT CUDA plugin
This commit is contained in:
parent
f35119f768
commit
01da2184fe
@ -1308,7 +1308,6 @@ pub const FFI = extern struct {
|
|||||||
options: RegisterFfiOptions,
|
options: RegisterFfiOptions,
|
||||||
) ApiError!void {
|
) ApiError!void {
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
||||||
.api_version = 1,
|
|
||||||
.target_name = target_name.ptr,
|
.target_name = target_name.ptr,
|
||||||
.target_name_size = target_name.len,
|
.target_name_size = target_name.len,
|
||||||
.handler = @ptrCast(@constCast(func)),
|
.handler = @ptrCast(@constCast(func)),
|
||||||
|
|||||||
@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_linux_amd64",
|
name = "libpjrt_cpu_linux_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
||||||
sha256 = "cf5ea44b14a6ddc320c5b1d7bb88328dda099571e106b17a214b6ec586e321b8",
|
sha256 = "124dc500291a5930f910ca23533520e22c90797110b29fd2c0d8274475f4a220",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_amd64",
|
name = "libpjrt_cpu_darwin_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "1b48c26b0d2709730df6921a86d9ae97ab71d3f2cfa17fd1a370e85235370914",
|
sha256 = "6e5b59874880f4db37c53fb1d52520d410b0078f9d2606a90762c6c622693c26",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_arm64",
|
name = "libpjrt_cpu_darwin_arm64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "66dc6c65933a6d7985b1c69837d8abe13750cd1f06704427f6989c3f952d3511",
|
sha256 = "e54bc3c3b71313c49e38fc342851bffe2cd3a55e3990839947d49496a1a71270",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -17,6 +17,10 @@ CUDNN_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redis
|
|||||||
CUDNN_VERSION = "9.8.0"
|
CUDNN_VERSION = "9.8.0"
|
||||||
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
CUDNN_REDIST_JSON_SHA256 = "a1599fa1f8dcb81235157be5de5ab7d3936e75dfc4e1e442d07970afad3c4843"
|
||||||
|
|
||||||
|
NVSHMEM_REDIST_PREFIX = "https://developer.download.nvidia.com/compute/nvshmem/redist/"
|
||||||
|
NVSHMEM_VERSION = "3.2.5"
|
||||||
|
NVSHMEM_REDIST_JSON_SHA256 = "6945425d3bfd24de23c045996f93ec720c010379bfd6f0860ac5f2716659442d"
|
||||||
|
|
||||||
_UBUNTU_PACKAGES = {
|
_UBUNTU_PACKAGES = {
|
||||||
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
|
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
|
||||||
}
|
}
|
||||||
@ -119,6 +123,17 @@ CUDNN_PACKAGES = {
|
|||||||
]),
|
]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NVSHMEM_PACKAGES = {
|
||||||
|
"libnvshmem": packages.filegroup(
|
||||||
|
name = "libnvshmem",
|
||||||
|
srcs = [
|
||||||
|
"lib/libnvshmem_host.so.3",
|
||||||
|
"lib/nvshmem_bootstrap_uid.so.3",
|
||||||
|
"lib/nvshmem_transport_ibrc.so.3",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def _read_redist_json(mctx, url, sha256):
|
def _read_redist_json(mctx, url, sha256):
|
||||||
fname = ".{}.json".format(sha256)
|
fname = ".{}.json".format(sha256)
|
||||||
mctx.download(
|
mctx.download(
|
||||||
@ -138,6 +153,12 @@ def _cuda_impl(mctx):
|
|||||||
sha256 = CUDA_REDIST_JSON_SHA256,
|
sha256 = CUDA_REDIST_JSON_SHA256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
NVSHMEM_REDIST = _read_redist_json(
|
||||||
|
mctx,
|
||||||
|
url = NVSHMEM_REDIST_PREFIX + "redistrib_{}.json".format(NVSHMEM_VERSION),
|
||||||
|
sha256 = NVSHMEM_REDIST_JSON_SHA256,
|
||||||
|
)
|
||||||
|
|
||||||
CUDNN_REDIST = _read_redist_json(
|
CUDNN_REDIST = _read_redist_json(
|
||||||
mctx,
|
mctx,
|
||||||
url = CUDNN_REDIST_PREFIX + "redistrib_{}.json".format(CUDNN_VERSION),
|
url = CUDNN_REDIST_PREFIX + "redistrib_{}.json".format(CUDNN_VERSION),
|
||||||
@ -180,6 +201,20 @@ def _cuda_impl(mctx):
|
|||||||
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for pkg, build_file_content in NVSHMEM_PACKAGES.items():
|
||||||
|
pkg_data = NVSHMEM_REDIST[pkg]
|
||||||
|
arch_data = pkg_data.get(ARCH)
|
||||||
|
if not arch_data:
|
||||||
|
continue
|
||||||
|
arch_data = arch_data.get("cuda12", arch_data)
|
||||||
|
http_archive(
|
||||||
|
name = pkg,
|
||||||
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
||||||
|
url = NVSHMEM_REDIST_PREFIX + arch_data["relative_path"],
|
||||||
|
sha256 = arch_data["sha256"],
|
||||||
|
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
||||||
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "nccl",
|
name = "nccl",
|
||||||
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
urls = ["https://files.pythonhosted.org/packages/11/0c/8c78b7603f4e685624a3ea944940f1e75f36d71bd6504330511f4a0e1557/nvidia_nccl_cu12-2.25.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"],
|
||||||
@ -194,8 +229,8 @@ def _cuda_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v11.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
||||||
sha256 = "08fa022a6067ddfb5c951bdf11ddc398e63de21fdcacc9ffd07f70b1463482c2",
|
sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
||||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
|
||||||
load("@zml//bazel:patchelf.bzl", "patchelf")
|
load("@zml//bazel:patchelf.bzl", "patchelf")
|
||||||
|
|
||||||
cc_shared_library(
|
cc_shared_library(
|
||||||
@ -30,6 +29,7 @@ copy_to_directory(
|
|||||||
"@cuda_nvrtc",
|
"@cuda_nvrtc",
|
||||||
"@cuda_nvtx",
|
"@cuda_nvtx",
|
||||||
"@cudnn",
|
"@cudnn",
|
||||||
|
"@libnvshmem",
|
||||||
"@libcublas",
|
"@libcublas",
|
||||||
"@libcufft",
|
"@libcufft",
|
||||||
"@libcusolver",
|
"@libcusolver",
|
||||||
|
|||||||
@ -153,8 +153,8 @@ def _rocm_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_rocm",
|
name = "libpjrt_rocm",
|
||||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v12.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
||||||
sha256 = "709982b959595750545a01d125adf4893c42f05c60ec290425276bba8aa49f64",
|
sha256 = "945c43c68325c0e91cd41eaa594a9f9f6e78da7cc06892d83bf345b69f7bd714",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -4,9 +4,9 @@ def _tpu_impl(mctx):
|
|||||||
# https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
# https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_tpu",
|
name = "libpjrt_tpu",
|
||||||
url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250102+nightly-py3-none-linux_x86_64.whl",
|
url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250807+nightly-py3-none-manylinux_2_31_x86_64.whl",
|
||||||
type = "zip",
|
type = "zip",
|
||||||
sha256 = "df8339b4f852bd19ad4ed380facc08f28c04e214e9dabb88863e70907b08817e",
|
sha256 = "41c19fa5ae4a32fbd05f0260527ba2d93afb6cf128e6c4de7773e9011c7b3df5",
|
||||||
build_file = "libpjrt_tpu.BUILD.bazel",
|
build_file = "libpjrt_tpu.BUILD.bazel",
|
||||||
)
|
)
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -1,41 +0,0 @@
|
|||||||
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
|
|
||||||
From: Jean-Baptiste Dalido <jb@zml.ai>
|
|
||||||
Date: Mon, 6 Jan 2025 13:33:13 +0100
|
|
||||||
Subject: [PATCH] bazel: migration to bazel 8.0.1
|
|
||||||
|
|
||||||
---
|
|
||||||
.bazelversion | 2 +-
|
|
||||||
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
|
|
||||||
2 files changed, 3 insertions(+), 3 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/.bazelversion b/.bazelversion
|
|
||||||
index f22d756da3..fa5fce04b3 100644
|
|
||||||
--- a/.bazelversion
|
|
||||||
+++ b/.bazelversion
|
|
||||||
@@ -1 +1 @@
|
|
||||||
-7.4.1
|
|
||||||
+8.1.1
|
|
||||||
\ No newline at end of file
|
|
||||||
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
|
|
||||||
index d62531152d..71d80a5a99 100644
|
|
||||||
--- a/third_party/gpus/cuda_configure.bzl
|
|
||||||
+++ b/third_party/gpus/cuda_configure.bzl
|
|
||||||
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
|
|
||||||
load(
|
|
||||||
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
|
|
||||||
"escape_string",
|
|
||||||
- "get_env_var",
|
|
||||||
)
|
|
||||||
load(
|
|
||||||
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
|
|
||||||
- "find_msvc_tool",
|
|
||||||
"find_vc_path",
|
|
||||||
"setup_vc_env_vars",
|
|
||||||
)
|
|
||||||
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
|
|
||||||
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
|
|
||||||
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
|
|
||||||
load(
|
|
||||||
"//third_party/remote_config:common.bzl",
|
|
||||||
--
|
|
||||||
2.39.3 (Apple Git-146)
|
|
||||||
@ -1,135 +0,0 @@
|
|||||||
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
|
|
||||||
From: Hugo Mano <hugo@zml.ai>
|
|
||||||
Date: Tue, 27 May 2025 11:48:17 +0200
|
|
||||||
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
|
||||||
|
|
||||||
PR: https://github.com/openxla/xla/pull/13420
|
|
||||||
---
|
|
||||||
xla/pjrt/c/BUILD | 5 +++++
|
|
||||||
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
|
|
||||||
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
|
|
||||||
3 files changed, 57 insertions(+), 1 deletion(-)
|
|
||||||
|
|
||||||
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
|
||||||
index 79f18fa0bc..0f33dd8a6e 100644
|
|
||||||
--- a/xla/pjrt/c/BUILD
|
|
||||||
+++ b/xla/pjrt/c/BUILD
|
|
||||||
@@ -69,8 +69,13 @@ cc_library(
|
|
||||||
":pjrt_c_api_wrapper_impl",
|
|
||||||
"//xla/ffi:execution_context",
|
|
||||||
"//xla/ffi:type_id_registry",
|
|
||||||
+ "//xla/ffi:ffi_api",
|
|
||||||
+ "//xla/ffi/api:c_api",
|
|
||||||
+ "//xla/ffi/api:ffi",
|
|
||||||
+ "//xla/service:custom_call_target_registry",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/strings:string_view",
|
|
||||||
+ "@com_google_absl//absl/strings:str_format",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
||||||
index 995a2c7e50..b8f10bc2f7 100644
|
|
||||||
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
||||||
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
||||||
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
|
||||||
// Adds a user data to the execute context.
|
|
||||||
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
|
||||||
|
|
||||||
+typedef enum PJRT_FFI_Handler_TraitsBits {
|
|
||||||
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
|
||||||
+} PJRT_FFI_Handler_TraitsBits;
|
|
||||||
+
|
|
||||||
+struct PJRT_FFI_Register_Handler_Args {
|
|
||||||
+ size_t struct_size;
|
|
||||||
+ const char* target_name;
|
|
||||||
+ size_t target_name_size;
|
|
||||||
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
|
||||||
+ void* handler;
|
|
||||||
+ const char* platform_name;
|
|
||||||
+ size_t platform_name_size;
|
|
||||||
+ PJRT_FFI_Handler_TraitsBits traits;
|
|
||||||
+};
|
|
||||||
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
|
|
||||||
+
|
|
||||||
+// Registers an FFI call handler for a specific platform.
|
|
||||||
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
|
||||||
+ PJRT_FFI_Register_Handler_Args* args);
|
|
||||||
+
|
|
||||||
typedef struct PJRT_FFI_Extension {
|
|
||||||
PJRT_Extension_Base base;
|
|
||||||
PJRT_FFI_TypeID_Register* type_id_register;
|
|
||||||
PJRT_FFI_UserData_Add* user_data_add;
|
|
||||||
+ PJRT_FFI_Register_Handler* register_handler;
|
|
||||||
} PJRT_FFI;
|
|
||||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
|
||||||
|
|
||||||
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
||||||
index 5fa88eab33..763270331b 100644
|
|
||||||
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
||||||
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
||||||
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
|
||||||
+#include <string>
|
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
+#include "absl/strings/str_format.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
+#include "xla/ffi/api/c_api.h"
|
|
||||||
#include "xla/ffi/execution_context.h"
|
|
||||||
#include "xla/ffi/type_id_registry.h"
|
|
||||||
+#include "xla/ffi/ffi_api.h"
|
|
||||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
|
||||||
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
|
||||||
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
|
||||||
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
|
||||||
+#include "xla/service/custom_call_target_registry.h"
|
|
||||||
|
|
||||||
namespace pjrt {
|
|
||||||
|
|
||||||
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
|
||||||
+ PJRT_FFI_Register_Handler_Args* args) {
|
|
||||||
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
|
||||||
+ "PJRT_FFI_Register_Handler_Args",
|
|
||||||
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
|
||||||
+ std::string target_name(args->target_name, args->target_name_size);
|
|
||||||
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
|
||||||
+ switch (args->api_version) {
|
|
||||||
+ case 0:
|
|
||||||
+ xla::CustomCallTargetRegistry::Global()->Register(
|
|
||||||
+ target_name, args->handler, platform_name);
|
|
||||||
+ return nullptr;
|
|
||||||
+ case 1:
|
|
||||||
+ xla::ffi::Ffi::RegisterStaticHandler(
|
|
||||||
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
|
||||||
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
|
||||||
+ return nullptr;
|
|
||||||
+ default:
|
|
||||||
+ return new PJRT_Error{absl::UnimplementedError(
|
|
||||||
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
|
||||||
+ "Supported versions are 0 and 1.",
|
|
||||||
+ args->api_version))};
|
|
||||||
+ }
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
||||||
return {
|
|
||||||
PJRT_Extension_Base{
|
|
||||||
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
||||||
},
|
|
||||||
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
|
||||||
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
|
||||||
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
--
|
|
||||||
2.39.5 (Apple Git-154)
|
|
||||||
|
|
||||||
@ -1,124 +0,0 @@
|
|||||||
From 6078da86a46b6f0d983dccb9ae4f36fc90640247 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Hugo Mano <hugo@zml.ai>
|
|
||||||
Date: Fri, 11 Jul 2025 14:05:16 +0200
|
|
||||||
Subject: [PATCH] zml patch
|
|
||||||
|
|
||||||
---
|
|
||||||
third_party/stablehlo/workspace.bzl | 1 +
|
|
||||||
third_party/stablehlo/zml.patch | 93 +++++++++++++++++++++++++++++
|
|
||||||
2 files changed, 94 insertions(+)
|
|
||||||
create mode 100644 third_party/stablehlo/zml.patch
|
|
||||||
|
|
||||||
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
|
|
||||||
index d9d5063744..44980948d0 100644
|
|
||||||
--- a/third_party/stablehlo/workspace.bzl
|
|
||||||
+++ b/third_party/stablehlo/workspace.bzl
|
|
||||||
@@ -15,5 +15,6 @@ def repo():
|
|
||||||
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
|
|
||||||
patch_file = [
|
|
||||||
"//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove.
|
|
||||||
+ "//third_party/stablehlo:zml.patch", # Autogenerated, don't remove.
|
|
||||||
],
|
|
||||||
)
|
|
||||||
diff --git a/third_party/stablehlo/zml.patch b/third_party/stablehlo/zml.patch
|
|
||||||
new file mode 100644
|
|
||||||
index 0000000000..2a09384582
|
|
||||||
--- /dev/null
|
|
||||||
+++ b/third_party/stablehlo/zml.patch
|
|
||||||
@@ -0,0 +1,93 @@
|
|
||||||
+From e38ab68376dd8a17ebf4469d2c8350f521310182 Mon Sep 17 00:00:00 2001
|
|
||||||
+From: Hugo Mano <hugo@zml.ai>
|
|
||||||
+Date: Fri, 11 Jul 2025 12:08:35 +0200
|
|
||||||
+Subject: [PATCH] zml patch
|
|
||||||
+
|
|
||||||
+---
|
|
||||||
+ stablehlo/dialect/Serialization.cpp | 5 ++---
|
|
||||||
+ stablehlo/dialect/Serialization.h | 3 +--
|
|
||||||
+ stablehlo/integrations/c/StablehloDialectApi.cpp | 3 +--
|
|
||||||
+ stablehlo/integrations/c/StablehloDialectApi.h | 2 +-
|
|
||||||
+ stablehlo/tools/StablehloTranslateMain.cpp | 2 +-
|
|
||||||
+ 5 files changed, 6 insertions(+), 9 deletions(-)
|
|
||||||
+
|
|
||||||
+diff --git a/stablehlo/dialect/Serialization.cpp b/stablehlo/dialect/Serialization.cpp
|
|
||||||
+index cb89d673..4370d588 100644
|
|
||||||
+--- a/stablehlo/dialect/Serialization.cpp
|
|
||||||
++++ b/stablehlo/dialect/Serialization.cpp
|
|
||||||
+@@ -39,8 +39,7 @@ namespace stablehlo {
|
|
||||||
+
|
|
||||||
+ LogicalResult serializePortableArtifact(ModuleOp module,
|
|
||||||
+ StringRef targetVersion,
|
|
||||||
+- raw_ostream& os,
|
|
||||||
+- bool allowOtherDialects) {
|
|
||||||
++ raw_ostream& os) {
|
|
||||||
+ MLIRContext* context = module.getContext();
|
|
||||||
+
|
|
||||||
+ // Convert StableHLO --> VHLO.
|
|
||||||
+@@ -49,7 +48,7 @@ LogicalResult serializePortableArtifact(ModuleOp module,
|
|
||||||
+ {
|
|
||||||
+ PassManager pm(context);
|
|
||||||
+ StablehloLegalizeToVhloPassOptions options;
|
|
||||||
+- options.allowOtherDialects = allowOtherDialects;
|
|
||||||
++ options.allowOtherDialects = false;
|
|
||||||
+ pm.addPass(stablehlo::createStablehloLegalizeToVhloPass(options));
|
|
||||||
+ if (!succeeded(pm.run(module))) {
|
|
||||||
+ return failure();
|
|
||||||
+diff --git a/stablehlo/dialect/Serialization.h b/stablehlo/dialect/Serialization.h
|
|
||||||
+index 811ca97b..abe95e63 100644
|
|
||||||
+--- a/stablehlo/dialect/Serialization.h
|
|
||||||
++++ b/stablehlo/dialect/Serialization.h
|
|
||||||
+@@ -34,8 +34,7 @@ namespace stablehlo {
|
|
||||||
+ // unsupported dialects.
|
|
||||||
+ LogicalResult serializePortableArtifact(ModuleOp module,
|
|
||||||
+ StringRef targetVersion,
|
|
||||||
+- raw_ostream& os,
|
|
||||||
+- bool allowOtherDialects = false);
|
|
||||||
++ raw_ostream& os);
|
|
||||||
+
|
|
||||||
+ // Read StableHLO portable artifact
|
|
||||||
+ //
|
|
||||||
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.cpp b/stablehlo/integrations/c/StablehloDialectApi.cpp
|
|
||||||
+index 343f8d0b..8f52e4d5 100644
|
|
||||||
+--- a/stablehlo/integrations/c/StablehloDialectApi.cpp
|
|
||||||
++++ b/stablehlo/integrations/c/StablehloDialectApi.cpp
|
|
||||||
+@@ -81,8 +81,7 @@ MlirLogicalResult stablehloSerializePortableArtifactFromModule(
|
|
||||||
+ MlirStringCallback callback, void *userData, bool allowOtherDialects) {
|
|
||||||
+ mlir::detail::CallbackOstream stream(callback, userData);
|
|
||||||
+ if (failed(mlir::stablehlo::serializePortableArtifact(
|
|
||||||
+- unwrap(moduleStr), unwrap(targetVersion), stream,
|
|
||||||
+- allowOtherDialects)))
|
|
||||||
++ unwrap(moduleStr), unwrap(targetVersion), stream)))
|
|
||||||
+ return mlirLogicalResultFailure();
|
|
||||||
+ return mlirLogicalResultSuccess();
|
|
||||||
+ }
|
|
||||||
+diff --git a/stablehlo/integrations/c/StablehloDialectApi.h b/stablehlo/integrations/c/StablehloDialectApi.h
|
|
||||||
+index 385156bf..24d11c1d 100644
|
|
||||||
+--- a/stablehlo/integrations/c/StablehloDialectApi.h
|
|
||||||
++++ b/stablehlo/integrations/c/StablehloDialectApi.h
|
|
||||||
+@@ -93,7 +93,7 @@ stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
|
|
||||||
+ MlirStringRef targetVersion,
|
|
||||||
+ MlirStringCallback callback,
|
|
||||||
+ void* userData,
|
|
||||||
+- bool allowOtherDialects = false);
|
|
||||||
++ bool allowOtherDialects);
|
|
||||||
+
|
|
||||||
+ // Read a StableHLO program from a portable artifact, returning the module as
|
|
||||||
+ // MLIR bytecode. Note, this bytecode returned is not a portable artifact,
|
|
||||||
+diff --git a/stablehlo/tools/StablehloTranslateMain.cpp b/stablehlo/tools/StablehloTranslateMain.cpp
|
|
||||||
+index fdf0d6a9..8d5c8752 100644
|
|
||||||
+--- a/stablehlo/tools/StablehloTranslateMain.cpp
|
|
||||||
++++ b/stablehlo/tools/StablehloTranslateMain.cpp
|
|
||||||
+@@ -323,7 +323,7 @@ TranslateFromMLIRRegistration serializeRegistration(
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ return stablehlo::serializePortableArtifact(
|
|
||||||
+- module, targetVersion, os, allowOtherDialectsOption.getValue());
|
|
||||||
++ module, targetVersion, os);
|
|
||||||
+ },
|
|
||||||
+ [](DialectRegistry ®istry) {
|
|
||||||
+ mlir::registerAllDialects(registry);
|
|
||||||
+--
|
|
||||||
+2.39.5 (Apple Git-154)
|
|
||||||
+
|
|
||||||
--
|
|
||||||
2.39.5 (Apple Git-154)
|
|
||||||
|
|
||||||
6
third_party/xla/repo.bzl
vendored
6
third_party/xla/repo.bzl
vendored
@ -4,11 +4,9 @@ def repo():
|
|||||||
git_repository(
|
git_repository(
|
||||||
name = "xla",
|
name = "xla",
|
||||||
remote = "https://github.com/openxla/xla.git",
|
remote = "https://github.com/openxla/xla.git",
|
||||||
commit = "ef07e787ea1303fa2f8d8a175d24d434bfb84107",
|
commit = "b3fbfeeb076f2b536897180f4a274680ed9d52eb",
|
||||||
patch_args = ["-p1"],
|
patch_args = ["-p1"],
|
||||||
patches = [
|
patches = [
|
||||||
"//third_party/xla:patches/0001-bazel-migration-to-bazel-8.1.1.patch",
|
# patches live in the patches directory
|
||||||
"//third_party/xla:patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch",
|
|
||||||
"//third_party/xla:patches/0003-Remove-unconventional-C-code-in-headers.patch",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
10
third_party/xla/xla.bzl
vendored
10
third_party/xla/xla.bzl
vendored
@ -65,12 +65,10 @@ def _xla_impl(mctx):
|
|||||||
|
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "com_github_grpc_grpc",
|
name = "com_github_grpc_grpc",
|
||||||
sha256 = "afbc5d78d6ba6d509cc6e264de0d49dcd7304db435cbf2d630385bacf49e066c",
|
sha256 = "dd6a2fa311ba8441bbefd2764c55b99136ff10f7ea42954be96006a2723d33fc",
|
||||||
strip_prefix = "grpc-1.68.2",
|
strip_prefix = "grpc-1.74.0",
|
||||||
patch_file = [
|
patch_file = ["//third_party/grpc:grpc.patch"],
|
||||||
"//third_party/grpc:grpc.patch",
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"),
|
||||||
],
|
|
||||||
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.68.2.tar.gz"),
|
|
||||||
)
|
)
|
||||||
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||||
|
|
||||||
|
|||||||
@ -905,7 +905,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
|
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
|
||||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena);
|
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena);
|
||||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena);
|
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena);
|
||||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_libnvptxcompiler", true, upb_arena);
|
|
||||||
},
|
},
|
||||||
.rocm => {
|
.rocm => {
|
||||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user