132 lines
4.8 KiB
Diff
132 lines
4.8 KiB
Diff
From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001
|
|
From: Hugo Mano <hugo@zml.ai>
|
|
Date: Wed, 5 Feb 2025 19:25:03 +0100
|
|
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
|
|
|
PR: https://github.com/openxla/xla/pull/13420
|
|
---
|
|
xla/pjrt/c/BUILD | 5 ++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++--
|
|
3 files changed, 54 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
|
index ad2ed95bce..0e7c35c30f 100644
|
|
--- a/xla/pjrt/c/BUILD
|
|
+++ b/xla/pjrt/c/BUILD
|
|
@@ -69,7 +69,12 @@ cc_library(
|
|
":pjrt_c_api_wrapper_impl",
|
|
"//xla/ffi:execution_context",
|
|
"//xla/ffi:type_id_registry",
|
|
+ "//xla/ffi:ffi_api",
|
|
+ "//xla/ffi/api:c_api",
|
|
+ "//xla/ffi/api:ffi",
|
|
+ "//xla/service:custom_call_target_registry",
|
|
"@com_google_absl//absl/status",
|
|
+ "@com_google_absl//absl/strings:str_format",
|
|
],
|
|
)
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
index a33bd4aa9c..3309194538 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
|
// Adds a user data to the execute context.
|
|
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
|
|
|
+struct PJRT_FFI_Register_Handler_Args {
|
|
+ size_t struct_size;
|
|
+ const char* target_name;
|
|
+ size_t target_name_size;
|
|
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
|
+ void* handler;
|
|
+ const char* platform_name;
|
|
+ size_t platform_name_size;
|
|
+};
|
|
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler);
|
|
+
|
|
+// Registers an FFI call handler for a specific platform.
|
|
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args);
|
|
+
|
|
typedef struct PJRT_FFI_Extension {
|
|
size_t struct_size;
|
|
PJRT_Extension_Type type;
|
|
PJRT_Extension_Base* next;
|
|
PJRT_FFI_TypeID_Register* type_id_register;
|
|
PJRT_FFI_UserData_Add* user_data_add;
|
|
+ PJRT_FFI_Register_Handler* register_handler;
|
|
} PJRT_FFI;
|
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
index 0375b39d0b..3527a0756e 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
|
+#include <string>
|
|
|
|
#include "absl/status/status.h"
|
|
+#include "absl/strings/str_format.h"
|
|
+#include "xla/ffi/api/c_api.h"
|
|
+#include "xla/ffi/api/ffi.h"
|
|
#include "xla/ffi/execution_context.h"
|
|
-#include "xla/ffi/type_id_registry.h"
|
|
+ #include "xla/ffi/type_id_registry.h"
|
|
+#include "xla/ffi/ffi_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
|
+#include "xla/service/custom_call_target_registry.h"
|
|
|
|
namespace pjrt {
|
|
|
|
@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
|
return nullptr;
|
|
}
|
|
|
|
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args) {
|
|
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
|
+ "PJRT_FFI_Register_Handler_Args",
|
|
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
|
+ std::string target_name(args->target_name, args->target_name_size);
|
|
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
|
+ switch (args->api_version) {
|
|
+ case 0:
|
|
+ xla::CustomCallTargetRegistry::Global()->Register(
|
|
+ target_name, args->handler, platform_name);
|
|
+ return nullptr;
|
|
+ case 1:
|
|
+ xla::ffi::Ffi::RegisterStaticHandler(
|
|
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
|
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
|
+ return nullptr;
|
|
+ default:
|
|
+ return new PJRT_Error{absl::UnimplementedError(
|
|
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
|
+ "Supported versions are 0 and 1.",
|
|
+ args->api_version))};
|
|
+ }
|
|
+}
|
|
+
|
|
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
return {
|
|
/*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE,
|
|
@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
/*next=*/next,
|
|
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
|
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
|
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
|
};
|
|
}
|
|
|
|
--
|
|
2.43.0
|