Update XLA to latest version

This commit is contained in:
Tarry Singh 2025-12-23 17:24:34 +00:00
parent 57bef8d66c
commit b8b4d33379
10 changed files with 106 additions and 18 deletions

View File

@ -11,6 +11,7 @@ bazel_dep(name = "patchelf", version = "0.18.0")
bazel_dep(name = "pcre2", version = "10.45") bazel_dep(name = "pcre2", version = "10.45")
bazel_dep(name = "platforms", version = "1.0.0") bazel_dep(name = "platforms", version = "1.0.0")
bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf") bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf")
# Needs to be added before rules_cc so that the cc toolchain declared by # Needs to be added before rules_cc so that the cc toolchain declared by
# apple_support wins over the one in rules_cc. # apple_support wins over the one in rules_cc.
bazel_dep(name = "apple_support", version = "1.24.2") bazel_dep(name = "apple_support", version = "1.24.2")
@ -111,10 +112,14 @@ use_repo(
"local_config_cuda", "local_config_cuda",
"local_config_remote_execution", "local_config_remote_execution",
"local_config_rocm", "local_config_rocm",
"local_config_sycl",
"local_config_tensorrt", "local_config_tensorrt",
"python_version_repo", "python_version_repo",
"rules_ml_toolchain", "rules_ml_toolchain",
"rules_shell",
"stablehlo", "stablehlo",
"sycl_configure",
"sycl_configure_ext",
"triton", "triton",
"tsl", "tsl",
) )

View File

@ -1311,6 +1311,20 @@ pub const Ffi = extern struct {
} }
}; };
pub const TypeInfo = struct {
deleter: ?*const fn (*anyopaque) callconv(.c) void = null,
serialize: ?*const fn () callconv(.c) void = null,
deserialize: ?*const fn () callconv(.c) void = null,
pub fn toCStruct(self: TypeInfo) c.PJRT_FFI_Type_Info {
return .{
.deleter = @ptrCast(self.deleter),
.serialize = @ptrCast(self.serialize),
.deserialize = @ptrCast(self.deserialize),
};
}
};
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub fn register( pub fn register(
@ -1337,13 +1351,14 @@ pub const Ffi = extern struct {
} }
} }
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId { pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8, type_info: ?*const c.PJRT_FFI_Type_Info) ApiError!ffi.TypeId {
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{ var ret = pjrtStruct(c.PJRT_FFI_Type_Register_Args{
.type_name = type_name.ptr, .type_name = type_name.ptr,
.type_name_size = type_name.len, .type_name_size = type_name.len,
.type_id = 0, // let the plugin assign a unique type ID .type_id = 0, // let the plugin assign a unique type ID
.type_info = @ptrCast(@constCast(type_info)),
}); });
const result = self.inner.type_id_register.?(&ret); const result = self.inner.type_register.?(&ret);
if (result) |pjrt_c_error| { if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error); const pjrt_error: *Error = @ptrCast(pjrt_c_error);
return pjrt_error.getCode(api).toApiError(); return pjrt_error.getCode(api).toApiError();

View File

@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx):
http_archive( http_archive(
name = "libpjrt_cpu_linux_amd64", name = "libpjrt_cpu_linux_amd64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
sha256 = "124dc500291a5930f910ca23533520e22c90797110b29fd2c0d8274475f4a220", sha256 = "ecc26dc792d2577474348eb48f3989aba8c3bb8d3cbd6df77ccf43357092a700",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_linux-amd64.tar.gz",
) )
http_archive( http_archive(
name = "libpjrt_cpu_darwin_amd64", name = "libpjrt_cpu_darwin_amd64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
sha256 = "6e5b59874880f4db37c53fb1d52520d410b0078f9d2606a90762c6c622693c26", sha256 = "4a21db4ecd015fb772614ce4b491551d483ce11321c8784e3d0e07a9a425d5eb",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-amd64.tar.gz",
) )
http_archive( http_archive(
name = "libpjrt_cpu_darwin_arm64", name = "libpjrt_cpu_darwin_arm64",
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
sha256 = "a6354bfed828a011e6d809eda2230e10c40c80044c67fe618b2a9615c047f092", sha256 = "e0ab4492468999ae7861a27837427846a708f4346fdae9ad1e84b80e1313566a",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-arm64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-arm64.tar.gz",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -229,8 +229,8 @@ def _cuda_impl(mctx):
http_archive( http_archive(
name = "libpjrt_cuda", name = "libpjrt_cuda",
build_file = "libpjrt_cuda.BUILD.bazel", build_file = "libpjrt_cuda.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cuda_linux-amd64.tar.gz",
sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a", sha256 = "4b618f05f9cd4cd14966717f7a521b1aa80b425999755870ce2d1caf45685578",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -121,7 +121,7 @@ _ROCM_PACKAGES = {
"dlopen": "zmlxrocm_dlopen", "dlopen": "zmlxrocm_dlopen",
}, },
), ),
packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]), packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6", "lib/libhiprtc-builtins.so.6"]),
]), ]),
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]), "hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
} }
@ -153,8 +153,8 @@ def _rocm_impl(mctx):
http_archive( http_archive(
name = "libpjrt_rocm", name = "libpjrt_rocm",
build_file = "libpjrt_rocm.BUILD.bazel", build_file = "libpjrt_rocm.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-rocm_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-rocm_linux-amd64.tar.gz",
sha256 = "945c43c68325c0e91cd41eaa594a9f9f6e78da7cc06892d83bf345b69f7bd714", sha256 = "087858044f17bc06b70d7cbffc33e7f2bf590d732f3ce2c24425e41453ea1cf4",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -11,7 +11,9 @@ pub export fn zmlxrocm_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque
.{ "libamd_comgr.so", "libamd_comgr.so.3" }, .{ "libamd_comgr.so", "libamd_comgr.so.3" },
.{ "librocprofiler-register.so", "librocprofiler-register.so.0" }, .{ "librocprofiler-register.so", "librocprofiler-register.so.0" },
.{ "libMIOpen.so", "libMIOpen.so.1" }, .{ "libMIOpen.so", "libMIOpen.so.1" },
.{ "libMIOpen.so.1", "libMIOpen.so.1" },
.{ "librccl.so", "librccl.so.1" }, .{ "librccl.so", "librccl.so.1" },
.{ "librocblas.so.4", "librocblas.so.4" },
.{ "librocblas.so", "librocblas.so.4" }, .{ "librocblas.so", "librocblas.so.4" },
.{ "libroctracer64.so", "libroctracer64.so.4" }, .{ "libroctracer64.so", "libroctracer64.so.4" },
.{ "libroctx64.so", "libroctx64.so.4" }, .{ "libroctx64.so", "libroctx64.so.4" },

View File

@ -0,0 +1,45 @@
From 0d88ac9b06c8bc78db817d85e90cd60d38e6561a Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Mon, 3 Nov 2025 16:54:54 +0100
Subject: [PATCH] PjRT C API: make PJRT FFI C extension header compliant
XLA PR: https://github.com/openxla/xla/pull/33470
---
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
index e756650911..33b78238b9 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
@@ -32,13 +32,13 @@ extern "C" {
// See: https://en.wikipedia.org/wiki/Foreign_function_interface
#define PJRT_API_FFI_EXTENSION_VERSION 3
-struct PJRT_FFI_Type_Info {
+typedef struct PJRT_FFI_Type_Info {
void (*deleter)(void* object);
void (*serialize)(); // placeholder for future use
void (*deserialize)(); // placeholder for future use
-};
+} PJRT_FFI_Type_Info;
-struct PJRT_FFI_Type_Register_Args {
+typedef struct PJRT_FFI_Type_Register_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
@@ -46,7 +46,7 @@ struct PJRT_FFI_Type_Register_Args {
size_t type_name_size;
int64_t type_id; // in-out
PJRT_FFI_Type_Info* type_info;
-};
+} PJRT_FFI_Type_Register_Args;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info);
// Registers external type in a static type registry. If `type_id` is set to `0`
--
2.50.1 (Apple Git-155)

View File

@ -4,9 +4,9 @@ def repo():
git_repository( git_repository(
name = "xla", name = "xla",
remote = "https://github.com/openxla/xla.git", remote = "https://github.com/openxla/xla.git",
commit = "b3fbfeeb076f2b536897180f4a274680ed9d52eb", commit = "9a77a882bb2bc75cb8c29620ff8cd0fd089bdc86",
patch_args = ["-p1"], patch_args = ["-p1"],
patches = [ patches = [
# patches live in the patches directory "third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch",
], ],
) )

View File

@ -42,6 +42,15 @@ if_rocm_newer_than = always_newer_than
is_rocm_configured = always_false is_rocm_configured = always_false
if_gpu_is_configured = always_if_false if_gpu_is_configured = always_if_false
if_cuda_or_rocm = always_if_false if_cuda_or_rocm = always_if_false
""",
})
simple_files(name = "local_config_sycl", files = {
"BUILD.bazel": "",
"sycl/BUILD.bazel": "",
"crosstool/BUILD.bazel": "",
"sycl/build_defs.bzl": _BZL_HELPERS + """\
if_sycl = always_if_false
if_sycl_is_configured = always_if_false
""", """,
}) })
simple_files(name = "local_config_remote_execution", files = { simple_files(name = "local_config_remote_execution", files = {
@ -56,6 +65,17 @@ if_cuda_or_rocm = always_if_false
simple_files(name = "rules_ml_toolchain", files = { simple_files(name = "rules_ml_toolchain", files = {
"third_party/gpus/BUILD.bazel": "", "third_party/gpus/BUILD.bazel": "",
"third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""", "third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""",
"third_party/extensions/sycl_configure.bzl": "",
})
simple_files(name = "sycl_configure_ext", files = {})
simple_files(name = "sycl_configure", files = {})
simple_files(name = "rules_shell", files = {
"BUILD.bazel": "",
"shell/BUILD.bazel": "",
"shell/sh_binary.bzl": """
def sh_binary(**kwargs):
native.sh_binary(**kwargs)
""",
}) })
def _xla_impl(mctx): def _xla_impl(mctx):
@ -70,7 +90,7 @@ def _xla_impl(mctx):
patch_file = ["//third_party/grpc:grpc.patch"], patch_file = ["//third_party/grpc:grpc.patch"],
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"), urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"),
) )
tf_vendored(name = "tsl", relpath = "third_party/tsl") tf_vendored(name = "tsl", path = "third_party/tsl")
_dummy_repos(mctx) _dummy_repos(mctx)

View File

@ -151,7 +151,8 @@ pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
const target_name = "zml$" ++ @typeName(Callback); const target_name = "zml$" ++ @typeName(Callback);
const proxy_cb = proxy(Callback); const proxy_cb = proxy(Callback);
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback)); const type_info: pjrt.Ffi.TypeInfo = .{};
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback), &type_info.toCStruct());
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits); try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name }); log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
} }