Update XLA to latest version
This commit is contained in:
parent
57bef8d66c
commit
b8b4d33379
@ -11,6 +11,7 @@ bazel_dep(name = "patchelf", version = "0.18.0")
|
|||||||
bazel_dep(name = "pcre2", version = "10.45")
|
bazel_dep(name = "pcre2", version = "10.45")
|
||||||
bazel_dep(name = "platforms", version = "1.0.0")
|
bazel_dep(name = "platforms", version = "1.0.0")
|
||||||
bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf")
|
bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf")
|
||||||
|
|
||||||
# Needs to be added before rules_cc so that the cc toolchain declared by
|
# Needs to be added before rules_cc so that the cc toolchain declared by
|
||||||
# apple_support wins over the one in rules_cc.
|
# apple_support wins over the one in rules_cc.
|
||||||
bazel_dep(name = "apple_support", version = "1.24.2")
|
bazel_dep(name = "apple_support", version = "1.24.2")
|
||||||
@ -111,10 +112,14 @@ use_repo(
|
|||||||
"local_config_cuda",
|
"local_config_cuda",
|
||||||
"local_config_remote_execution",
|
"local_config_remote_execution",
|
||||||
"local_config_rocm",
|
"local_config_rocm",
|
||||||
|
"local_config_sycl",
|
||||||
"local_config_tensorrt",
|
"local_config_tensorrt",
|
||||||
"python_version_repo",
|
"python_version_repo",
|
||||||
"rules_ml_toolchain",
|
"rules_ml_toolchain",
|
||||||
|
"rules_shell",
|
||||||
"stablehlo",
|
"stablehlo",
|
||||||
|
"sycl_configure",
|
||||||
|
"sycl_configure_ext",
|
||||||
"triton",
|
"triton",
|
||||||
"tsl",
|
"tsl",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1311,6 +1311,20 @@ pub const Ffi = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const TypeInfo = struct {
|
||||||
|
deleter: ?*const fn (*anyopaque) callconv(.c) void = null,
|
||||||
|
serialize: ?*const fn () callconv(.c) void = null,
|
||||||
|
deserialize: ?*const fn () callconv(.c) void = null,
|
||||||
|
|
||||||
|
pub fn toCStruct(self: TypeInfo) c.PJRT_FFI_Type_Info {
|
||||||
|
return .{
|
||||||
|
.deleter = @ptrCast(self.deleter),
|
||||||
|
.serialize = @ptrCast(self.serialize),
|
||||||
|
.deserialize = @ptrCast(self.deserialize),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||||
pub fn register(
|
pub fn register(
|
||||||
@ -1337,13 +1351,14 @@ pub const Ffi = extern struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId {
|
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8, type_info: ?*const c.PJRT_FFI_Type_Info) ApiError!ffi.TypeId {
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_Type_Register_Args{
|
||||||
.type_name = type_name.ptr,
|
.type_name = type_name.ptr,
|
||||||
.type_name_size = type_name.len,
|
.type_name_size = type_name.len,
|
||||||
.type_id = 0, // let the plugin assign a unique type ID
|
.type_id = 0, // let the plugin assign a unique type ID
|
||||||
|
.type_info = @ptrCast(@constCast(type_info)),
|
||||||
});
|
});
|
||||||
const result = self.inner.type_id_register.?(&ret);
|
const result = self.inner.type_register.?(&ret);
|
||||||
if (result) |pjrt_c_error| {
|
if (result) |pjrt_c_error| {
|
||||||
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||||
return pjrt_error.getCode(api).toApiError();
|
return pjrt_error.getCode(api).toApiError();
|
||||||
|
|||||||
@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_linux_amd64",
|
name = "libpjrt_cpu_linux_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX,
|
||||||
sha256 = "124dc500291a5930f910ca23533520e22c90797110b29fd2c0d8274475f4a220",
|
sha256 = "ecc26dc792d2577474348eb48f3989aba8c3bb8d3cbd6df77ccf43357092a700",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_linux-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_amd64",
|
name = "libpjrt_cpu_darwin_amd64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "6e5b59874880f4db37c53fb1d52520d410b0078f9d2606a90762c6c622693c26",
|
sha256 = "4a21db4ecd015fb772614ce4b491551d483ce11321c8784e3d0e07a9a425d5eb",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_arm64",
|
name = "libpjrt_cpu_darwin_arm64",
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN,
|
||||||
sha256 = "a6354bfed828a011e6d809eda2230e10c40c80044c67fe618b2a9615c047f092",
|
sha256 = "e0ab4492468999ae7861a27837427846a708f4346fdae9ad1e84b80e1313566a",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-arm64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -229,8 +229,8 @@ def _cuda_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cuda_linux-amd64.tar.gz",
|
||||||
sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a",
|
sha256 = "4b618f05f9cd4cd14966717f7a521b1aa80b425999755870ce2d1caf45685578",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -121,7 +121,7 @@ _ROCM_PACKAGES = {
|
|||||||
"dlopen": "zmlxrocm_dlopen",
|
"dlopen": "zmlxrocm_dlopen",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]),
|
packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6", "lib/libhiprtc-builtins.so.6"]),
|
||||||
]),
|
]),
|
||||||
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
|
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
|
||||||
}
|
}
|
||||||
@ -153,8 +153,8 @@ def _rocm_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_rocm",
|
name = "libpjrt_rocm",
|
||||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-rocm_linux-amd64.tar.gz",
|
||||||
sha256 = "945c43c68325c0e91cd41eaa594a9f9f6e78da7cc06892d83bf345b69f7bd714",
|
sha256 = "087858044f17bc06b70d7cbffc33e7f2bf590d732f3ce2c24425e41453ea1cf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -11,7 +11,9 @@ pub export fn zmlxrocm_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque
|
|||||||
.{ "libamd_comgr.so", "libamd_comgr.so.3" },
|
.{ "libamd_comgr.so", "libamd_comgr.so.3" },
|
||||||
.{ "librocprofiler-register.so", "librocprofiler-register.so.0" },
|
.{ "librocprofiler-register.so", "librocprofiler-register.so.0" },
|
||||||
.{ "libMIOpen.so", "libMIOpen.so.1" },
|
.{ "libMIOpen.so", "libMIOpen.so.1" },
|
||||||
|
.{ "libMIOpen.so.1", "libMIOpen.so.1" },
|
||||||
.{ "librccl.so", "librccl.so.1" },
|
.{ "librccl.so", "librccl.so.1" },
|
||||||
|
.{ "librocblas.so.4", "librocblas.so.4" },
|
||||||
.{ "librocblas.so", "librocblas.so.4" },
|
.{ "librocblas.so", "librocblas.so.4" },
|
||||||
.{ "libroctracer64.so", "libroctracer64.so.4" },
|
.{ "libroctracer64.so", "libroctracer64.so.4" },
|
||||||
.{ "libroctx64.so", "libroctx64.so.4" },
|
.{ "libroctx64.so", "libroctx64.so.4" },
|
||||||
|
|||||||
45
third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch
vendored
Normal file
45
third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch
vendored
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
From 0d88ac9b06c8bc78db817d85e90cd60d38e6561a Mon Sep 17 00:00:00 2001
|
||||||
|
From: Hugo Mano <hugo@zml.ai>
|
||||||
|
Date: Mon, 3 Nov 2025 16:54:54 +0100
|
||||||
|
Subject: [PATCH] PjRT C API: make PJRT FFI C extension header compliant
|
||||||
|
|
||||||
|
|
||||||
|
XLA PR: https://github.com/openxla/xla/pull/33470
|
||||||
|
|
||||||
|
---
|
||||||
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 8 ++++----
|
||||||
|
1 file changed, 4 insertions(+), 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
index e756650911..33b78238b9 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
@@ -32,13 +32,13 @@ extern "C" {
|
||||||
|
// See: https://en.wikipedia.org/wiki/Foreign_function_interface
|
||||||
|
#define PJRT_API_FFI_EXTENSION_VERSION 3
|
||||||
|
|
||||||
|
-struct PJRT_FFI_Type_Info {
|
||||||
|
+typedef struct PJRT_FFI_Type_Info {
|
||||||
|
void (*deleter)(void* object);
|
||||||
|
void (*serialize)(); // placeholder for future use
|
||||||
|
void (*deserialize)(); // placeholder for future use
|
||||||
|
-};
|
||||||
|
+} PJRT_FFI_Type_Info;
|
||||||
|
|
||||||
|
-struct PJRT_FFI_Type_Register_Args {
|
||||||
|
+typedef struct PJRT_FFI_Type_Register_Args {
|
||||||
|
size_t struct_size;
|
||||||
|
PJRT_Extension_Base* extension_start;
|
||||||
|
|
||||||
|
@@ -46,7 +46,7 @@ struct PJRT_FFI_Type_Register_Args {
|
||||||
|
size_t type_name_size;
|
||||||
|
int64_t type_id; // in-out
|
||||||
|
PJRT_FFI_Type_Info* type_info;
|
||||||
|
-};
|
||||||
|
+} PJRT_FFI_Type_Register_Args;
|
||||||
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info);
|
||||||
|
|
||||||
|
// Registers external type in a static type registry. If `type_id` is set to `0`
|
||||||
|
--
|
||||||
|
2.50.1 (Apple Git-155)
|
||||||
|
|
||||||
4
third_party/xla/repo.bzl
vendored
4
third_party/xla/repo.bzl
vendored
@ -4,9 +4,9 @@ def repo():
|
|||||||
git_repository(
|
git_repository(
|
||||||
name = "xla",
|
name = "xla",
|
||||||
remote = "https://github.com/openxla/xla.git",
|
remote = "https://github.com/openxla/xla.git",
|
||||||
commit = "b3fbfeeb076f2b536897180f4a274680ed9d52eb",
|
commit = "9a77a882bb2bc75cb8c29620ff8cd0fd089bdc86",
|
||||||
patch_args = ["-p1"],
|
patch_args = ["-p1"],
|
||||||
patches = [
|
patches = [
|
||||||
# patches live in the patches directory
|
"third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
22
third_party/xla/xla.bzl
vendored
22
third_party/xla/xla.bzl
vendored
@ -42,6 +42,15 @@ if_rocm_newer_than = always_newer_than
|
|||||||
is_rocm_configured = always_false
|
is_rocm_configured = always_false
|
||||||
if_gpu_is_configured = always_if_false
|
if_gpu_is_configured = always_if_false
|
||||||
if_cuda_or_rocm = always_if_false
|
if_cuda_or_rocm = always_if_false
|
||||||
|
""",
|
||||||
|
})
|
||||||
|
simple_files(name = "local_config_sycl", files = {
|
||||||
|
"BUILD.bazel": "",
|
||||||
|
"sycl/BUILD.bazel": "",
|
||||||
|
"crosstool/BUILD.bazel": "",
|
||||||
|
"sycl/build_defs.bzl": _BZL_HELPERS + """\
|
||||||
|
if_sycl = always_if_false
|
||||||
|
if_sycl_is_configured = always_if_false
|
||||||
""",
|
""",
|
||||||
})
|
})
|
||||||
simple_files(name = "local_config_remote_execution", files = {
|
simple_files(name = "local_config_remote_execution", files = {
|
||||||
@ -56,6 +65,17 @@ if_cuda_or_rocm = always_if_false
|
|||||||
simple_files(name = "rules_ml_toolchain", files = {
|
simple_files(name = "rules_ml_toolchain", files = {
|
||||||
"third_party/gpus/BUILD.bazel": "",
|
"third_party/gpus/BUILD.bazel": "",
|
||||||
"third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""",
|
"third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""",
|
||||||
|
"third_party/extensions/sycl_configure.bzl": "",
|
||||||
|
})
|
||||||
|
simple_files(name = "sycl_configure_ext", files = {})
|
||||||
|
simple_files(name = "sycl_configure", files = {})
|
||||||
|
simple_files(name = "rules_shell", files = {
|
||||||
|
"BUILD.bazel": "",
|
||||||
|
"shell/BUILD.bazel": "",
|
||||||
|
"shell/sh_binary.bzl": """
|
||||||
|
def sh_binary(**kwargs):
|
||||||
|
native.sh_binary(**kwargs)
|
||||||
|
""",
|
||||||
})
|
})
|
||||||
|
|
||||||
def _xla_impl(mctx):
|
def _xla_impl(mctx):
|
||||||
@ -70,7 +90,7 @@ def _xla_impl(mctx):
|
|||||||
patch_file = ["//third_party/grpc:grpc.patch"],
|
patch_file = ["//third_party/grpc:grpc.patch"],
|
||||||
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"),
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"),
|
||||||
)
|
)
|
||||||
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
tf_vendored(name = "tsl", path = "third_party/tsl")
|
||||||
|
|
||||||
_dummy_repos(mctx)
|
_dummy_repos(mctx)
|
||||||
|
|
||||||
|
|||||||
@ -151,7 +151,8 @@ pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
|
|||||||
const target_name = "zml$" ++ @typeName(Callback);
|
const target_name = "zml$" ++ @typeName(Callback);
|
||||||
|
|
||||||
const proxy_cb = proxy(Callback);
|
const proxy_cb = proxy(Callback);
|
||||||
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback));
|
const type_info: pjrt.Ffi.TypeInfo = .{};
|
||||||
|
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback), &type_info.toCStruct());
|
||||||
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
|
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
|
||||||
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
|
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user