Radix/third_party/modules/xla/20250317.2-71c67e2/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch

132 lines
4.8 KiB
Diff

From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Wed, 5 Feb 2025 19:25:03 +0100
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
PR: https://github.com/openxla/xla/pull/13420
---
xla/pjrt/c/BUILD | 5 ++++
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++--
3 files changed, 54 insertions(+), 2 deletions(-)
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
index ad2ed95bce..0e7c35c30f 100644
--- a/xla/pjrt/c/BUILD
+++ b/xla/pjrt/c/BUILD
@@ -69,7 +69,12 @@ cc_library(
":pjrt_c_api_wrapper_impl",
"//xla/ffi:execution_context",
"//xla/ffi:type_id_registry",
+ "//xla/ffi:ffi_api",
+ "//xla/ffi/api:c_api",
+ "//xla/ffi/api:ffi",
+ "//xla/service:custom_call_target_registry",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
index a33bd4aa9c..3309194538 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
// Adds a user data to the execute context.
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
+struct PJRT_FFI_Register_Handler_Args {
+ size_t struct_size;
+ const char* target_name;
+ size_t target_name_size;
+ int api_version; // 0 for an untyped call, 1 -- for typed
+ void* handler;
+ const char* platform_name;
+ size_t platform_name_size;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler);
+
+// Registers an FFI call handler for a specific platform.
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args);
+
typedef struct PJRT_FFI_Extension {
size_t struct_size;
PJRT_Extension_Type type;
PJRT_Extension_Base* next;
PJRT_FFI_TypeID_Register* type_id_register;
PJRT_FFI_UserData_Add* user_data_add;
+ PJRT_FFI_Register_Handler* register_handler;
} PJRT_FFI;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
index 0375b39d0b..3527a0756e 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
+#include <string>
#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
-#include "xla/ffi/type_id_registry.h"
+ #include "xla/ffi/type_id_registry.h"
+#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
+#include "xla/service/custom_call_target_registry.h"
namespace pjrt {
@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
return nullptr;
}
+static PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args) {
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+ "PJRT_FFI_Register_Handler_Args",
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
+ std::string target_name(args->target_name, args->target_name_size);
+ std::string platform_name(args->platform_name, args->platform_name_size);
+ switch (args->api_version) {
+ case 0:
+ xla::CustomCallTargetRegistry::Global()->Register(
+ target_name, args->handler, platform_name);
+ return nullptr;
+ case 1:
+ xla::ffi::Ffi::RegisterStaticHandler(
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
+ return nullptr;
+ default:
+ return new PJRT_Error{absl::UnimplementedError(
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
+ "Supported versions are 0 and 1.",
+ args->api_version))};
+ }
+}
+
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
return {
/*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE,
@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
/*next=*/next,
/*type_id_register=*/PJRT_FFI_TypeID_Register,
/*user_data_add=*/PJRT_FFI_UserData_Add,
+ /*register_handler=*/PJRT_FFI_Register_Handler,
};
}
--
2.43.0