This PR intends to remove XLA as a bzlmod and transfer it as a non bzlmod dep. This is because this module will never be upstreamed as is, so keep it private. Also, we fetch llvm-raw and stablehlo from it, which is fine. While there, dummify the various local_config XLA symbols to please the imports, as we don't use those parts in ZML itself. Closes
136 lines
4.9 KiB
Diff
136 lines
4.9 KiB
Diff
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
|
|
From: Hugo Mano <hugo@zml.ai>
|
|
Date: Tue, 27 May 2025 11:48:17 +0200
|
|
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
|
|
|
PR: https://github.com/openxla/xla/pull/13420
|
|
---
|
|
xla/pjrt/c/BUILD | 5 +++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
|
|
3 files changed, 57 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
|
index 79f18fa0bc..0f33dd8a6e 100644
|
|
--- a/xla/pjrt/c/BUILD
|
|
+++ b/xla/pjrt/c/BUILD
|
|
@@ -69,8 +69,13 @@ cc_library(
|
|
":pjrt_c_api_wrapper_impl",
|
|
"//xla/ffi:execution_context",
|
|
"//xla/ffi:type_id_registry",
|
|
+ "//xla/ffi:ffi_api",
|
|
+ "//xla/ffi/api:c_api",
|
|
+ "//xla/ffi/api:ffi",
|
|
+ "//xla/service:custom_call_target_registry",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
+ "@com_google_absl//absl/strings:str_format",
|
|
],
|
|
)
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
index 995a2c7e50..b8f10bc2f7 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
|
// Adds a user data to the execute context.
|
|
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
|
|
|
+typedef enum PJRT_FFI_Handler_TraitsBits {
|
|
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
|
+} PJRT_FFI_Handler_TraitsBits;
|
|
+
|
|
+struct PJRT_FFI_Register_Handler_Args {
|
|
+ size_t struct_size;
|
|
+ const char* target_name;
|
|
+ size_t target_name_size;
|
|
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
|
+ void* handler;
|
|
+ const char* platform_name;
|
|
+ size_t platform_name_size;
|
|
+ PJRT_FFI_Handler_TraitsBits traits;
|
|
+};
|
|
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
|
|
+
|
|
+// Registers an FFI call handler for a specific platform.
|
|
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args);
|
|
+
|
|
typedef struct PJRT_FFI_Extension {
|
|
PJRT_Extension_Base base;
|
|
PJRT_FFI_TypeID_Register* type_id_register;
|
|
PJRT_FFI_UserData_Add* user_data_add;
|
|
+ PJRT_FFI_Register_Handler* register_handler;
|
|
} PJRT_FFI;
|
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
index 5fa88eab33..763270331b 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
|
+#include <string>
|
|
|
|
#include "absl/status/status.h"
|
|
+#include "absl/strings/str_format.h"
|
|
#include "absl/strings/string_view.h"
|
|
+#include "xla/ffi/api/c_api.h"
|
|
#include "xla/ffi/execution_context.h"
|
|
#include "xla/ffi/type_id_registry.h"
|
|
+#include "xla/ffi/ffi_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
|
+#include "xla/service/custom_call_target_registry.h"
|
|
|
|
namespace pjrt {
|
|
|
|
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
|
return nullptr;
|
|
}
|
|
|
|
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args) {
|
|
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
|
+ "PJRT_FFI_Register_Handler_Args",
|
|
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
|
+ std::string target_name(args->target_name, args->target_name_size);
|
|
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
|
+ switch (args->api_version) {
|
|
+ case 0:
|
|
+ xla::CustomCallTargetRegistry::Global()->Register(
|
|
+ target_name, args->handler, platform_name);
|
|
+ return nullptr;
|
|
+ case 1:
|
|
+ xla::ffi::Ffi::RegisterStaticHandler(
|
|
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
|
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
|
+ return nullptr;
|
|
+ default:
|
|
+ return new PJRT_Error{absl::UnimplementedError(
|
|
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
|
+ "Supported versions are 0 and 1.",
|
|
+ args->api_version))};
|
|
+ }
|
|
+}
|
|
+
|
|
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
return {
|
|
PJRT_Extension_Base{
|
|
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
},
|
|
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
|
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
|
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
|
};
|
|
}
|
|
|
|
--
|
|
2.39.5 (Apple Git-154)
|
|
|