From b8b4d33379985b532610534758f5ffb74cebc2ac Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 23 Dec 2025 17:24:34 +0000 Subject: [PATCH] Update XLA to latest version --- MODULE.bazel | 5 +++ pjrt/pjrt.zig | 21 +++++++-- runtimes/cpu/cpu.bzl | 12 ++--- runtimes/cuda/cuda.bzl | 4 +- runtimes/rocm/rocm.bzl | 6 +-- runtimes/rocm/zmlxrocm.zig | 2 + ...header-C-compliant-for-PJRT-FFI-exte.patch | 45 +++++++++++++++++++ third_party/xla/repo.bzl | 4 +- third_party/xla/xla.bzl | 22 ++++++++- zml/callback.zig | 3 +- 10 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch diff --git a/MODULE.bazel b/MODULE.bazel index 73468e4..22e95e7 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -11,6 +11,7 @@ bazel_dep(name = "patchelf", version = "0.18.0") bazel_dep(name = "pcre2", version = "10.45") bazel_dep(name = "platforms", version = "1.0.0") bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf") + # Needs to be added before rules_cc so that the cc toolchain declared by # apple_support wins over the one in rules_cc. bazel_dep(name = "apple_support", version = "1.24.2") @@ -111,10 +112,14 @@ use_repo( "local_config_cuda", "local_config_remote_execution", "local_config_rocm", + "local_config_sycl", "local_config_tensorrt", "python_version_repo", "rules_ml_toolchain", + "rules_shell", "stablehlo", + "sycl_configure", + "sycl_configure_ext", "triton", "tsl", ) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index cad1ce6..f9f30bd 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -1311,6 +1311,20 @@ pub const Ffi = extern struct { } }; + pub const TypeInfo = struct { + deleter: ?*const fn (*anyopaque) callconv(.c) void = null, + serialize: ?*const fn () callconv(.c) void = null, + deserialize: ?*const fn () callconv(.c) void = null, + + pub fn toCStruct(self: TypeInfo) c.PJRT_FFI_Type_Info { + return .{ + .deleter = @ptrCast(self.deleter), + .serialize = @ptrCast(self.serialize), + .deserialize = @ptrCast(self.deserialize), + }; + } + }; + // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 pub fn register( @@ -1337,13 +1351,14 @@ pub const Ffi = extern struct { } } - pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId { - var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{ + pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8, type_info: ?*const c.PJRT_FFI_Type_Info) ApiError!ffi.TypeId { + var ret = pjrtStruct(c.PJRT_FFI_Type_Register_Args{ .type_name = type_name.ptr, .type_name_size = type_name.len, .type_id = 0, // let the plugin assign a unique type ID + .type_info = @ptrCast(@constCast(type_info)), }); - const result = self.inner.type_id_register.?(&ret); + const result = self.inner.type_register.?(&ret); if (result) |pjrt_c_error| { const pjrt_error: *Error = @ptrCast(pjrt_c_error); return pjrt_error.getCode(api).toApiError(); diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index 2a095f4..675ff3b 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -23,22 +23,22 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_LINUX, - sha256 = "124dc500291a5930f910ca23533520e22c90797110b29fd2c0d8274475f4a220", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "ecc26dc792d2577474348eb48f3989aba8c3bb8d3cbd6df77ccf43357092a700", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_amd64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "6e5b59874880f4db37c53fb1d52520d410b0078f9d2606a90762c6c622693c26", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-amd64.tar.gz", + sha256 = "4a21db4ecd015fb772614ce4b491551d483ce11321c8784e3d0e07a9a425d5eb", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + _BUILD_DARWIN, - sha256 = "a6354bfed828a011e6d809eda2230e10c40c80044c67fe618b2a9615c047f092", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "e0ab4492468999ae7861a27837427846a708f4346fdae9ad1e84b80e1313566a", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index a99d460..0b1a35b 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -229,8 +229,8 @@ def _cuda_impl(mctx): http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-cuda_linux-amd64.tar.gz", - sha256 = "6cdac9bac6db904e4423c9745c61000cf3acaf3c7da8016ab0016f076869048a", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "4b618f05f9cd4cd14966717f7a521b1aa80b425999755870ce2d1caf45685578", ) return mctx.extension_metadata( diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index e71f1a9..f06daac 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -121,7 +121,7 @@ _ROCM_PACKAGES = { "dlopen": "zmlxrocm_dlopen", }, ), - packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]), + packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6", "lib/libhiprtc-builtins.so.6"]), ]), "hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]), } @@ -153,8 +153,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v13.0.0/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "945c43c68325c0e91cd41eaa594a9f9f6e78da7cc06892d83bf345b69f7bd714", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v14.0.1/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "087858044f17bc06b70d7cbffc33e7f2bf590d732f3ce2c24425e41453ea1cf4", ) return mctx.extension_metadata( diff --git a/runtimes/rocm/zmlxrocm.zig b/runtimes/rocm/zmlxrocm.zig index 19f840c..fccfbe9 100644 --- a/runtimes/rocm/zmlxrocm.zig +++ b/runtimes/rocm/zmlxrocm.zig @@ -11,7 +11,9 @@ pub export fn zmlxrocm_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque .{ "libamd_comgr.so", "libamd_comgr.so.3" }, .{ "librocprofiler-register.so", "librocprofiler-register.so.0" }, .{ "libMIOpen.so", "libMIOpen.so.1" }, + .{ "libMIOpen.so.1", "libMIOpen.so.1" }, .{ "librccl.so", "librccl.so.1" }, + .{ "librocblas.so.4", "librocblas.so.4" }, .{ "librocblas.so", "librocblas.so.4" }, .{ "libroctracer64.so", "libroctracer64.so.4" }, .{ "libroctx64.so", "libroctx64.so.4" }, diff --git a/third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch b/third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch new file mode 100644 index 0000000..bfb6580 --- /dev/null +++ b/third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch @@ -0,0 +1,45 @@ +From 0d88ac9b06c8bc78db817d85e90cd60d38e6561a Mon Sep 17 00:00:00 2001 +From: Hugo Mano +Date: Mon, 3 Nov 2025 16:54:54 +0100 +Subject: [PATCH] PjRT C API: make PJRT FFI C extension header compliant + + +XLA PR: https://github.com/openxla/xla/pull/33470 + +--- + xla/pjrt/c/pjrt_c_api_ffi_extension.h | 8 ++++---- + 1 file changed, 4 insertions(+), 4 deletions(-) + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +index e756650911..33b78238b9 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +@@ -32,13 +32,13 @@ extern "C" { + // See: https://en.wikipedia.org/wiki/Foreign_function_interface + #define PJRT_API_FFI_EXTENSION_VERSION 3 + +-struct PJRT_FFI_Type_Info { ++typedef struct PJRT_FFI_Type_Info { + void (*deleter)(void* object); + void (*serialize)(); // placeholder for future use + void (*deserialize)(); // placeholder for future use +-}; ++} PJRT_FFI_Type_Info; + +-struct PJRT_FFI_Type_Register_Args { ++typedef struct PJRT_FFI_Type_Register_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + +@@ -46,7 +46,7 @@ struct PJRT_FFI_Type_Register_Args { + size_t type_name_size; + int64_t type_id; // in-out + PJRT_FFI_Type_Info* type_info; +-}; ++} PJRT_FFI_Type_Register_Args; + PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Type_Register_Args, type_info); + + // Registers external type in a static type registry. If `type_id` is set to `0` +-- +2.50.1 (Apple Git-155) + diff --git a/third_party/xla/repo.bzl b/third_party/xla/repo.bzl index e9a8b40..4e45d2d 100644 --- a/third_party/xla/repo.bzl +++ b/third_party/xla/repo.bzl @@ -4,9 +4,9 @@ def repo(): git_repository( name = "xla", remote = "https://github.com/openxla/xla.git", - commit = "b3fbfeeb076f2b536897180f4a274680ed9d52eb", + commit = "9a77a882bb2bc75cb8c29620ff8cd0fd089bdc86", patch_args = ["-p1"], patches = [ - # patches live in the patches directory + "third_party/xla/patches/0001-PjRT-C-API-male-header-C-compliant-for-PJRT-FFI-exte.patch", ], ) diff --git a/third_party/xla/xla.bzl b/third_party/xla/xla.bzl index 3780bc5..8444ffd 100644 --- a/third_party/xla/xla.bzl +++ b/third_party/xla/xla.bzl @@ -42,6 +42,15 @@ if_rocm_newer_than = always_newer_than is_rocm_configured = always_false if_gpu_is_configured = always_if_false if_cuda_or_rocm = always_if_false +""", + }) + simple_files(name = "local_config_sycl", files = { + "BUILD.bazel": "", + "sycl/BUILD.bazel": "", + "crosstool/BUILD.bazel": "", + "sycl/build_defs.bzl": _BZL_HELPERS + """\ +if_sycl = always_if_false +if_sycl_is_configured = always_if_false """, }) simple_files(name = "local_config_remote_execution", files = { @@ -56,6 +65,17 @@ if_cuda_or_rocm = always_if_false simple_files(name = "rules_ml_toolchain", files = { "third_party/gpus/BUILD.bazel": "", "third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""", + "third_party/extensions/sycl_configure.bzl": "", + }) + simple_files(name = "sycl_configure_ext", files = {}) + simple_files(name = "sycl_configure", files = {}) + simple_files(name = "rules_shell", files = { + "BUILD.bazel": "", + "shell/BUILD.bazel": "", + "shell/sh_binary.bzl": """ +def sh_binary(**kwargs): + native.sh_binary(**kwargs) +""", }) def _xla_impl(mctx): @@ -70,7 +90,7 @@ def _xla_impl(mctx): patch_file = ["//third_party/grpc:grpc.patch"], urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"), ) - tf_vendored(name = "tsl", relpath = "third_party/tsl") + tf_vendored(name = "tsl", path = "third_party/tsl") _dummy_repos(mctx) diff --git a/zml/callback.zig b/zml/callback.zig index 709e6c8..04d09de 100644 --- a/zml/callback.zig +++ b/zml/callback.zig @@ -151,7 +151,8 @@ pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void { const target_name = "zml$" ++ @typeName(Callback); const proxy_cb = proxy(Callback); - Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback)); + const type_info: pjrt.Ffi.TypeInfo = .{}; + Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback), &type_info.toCStruct()); try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits); log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name }); }