Also: - Bump XLA deps : `com_github_grpc_grpc` and `com_google_protobuf` - Inject `rules_ml_toolchain` - Fix `zig_proto_library` rule
136 lines
4.9 KiB
Diff
136 lines
4.9 KiB
Diff
From 2ae9bb9d24b569c2c6bfab3c54b428103614944d Mon Sep 17 00:00:00 2001
|
|
From: Hugo Mano <hugo@zml.ai>
|
|
Date: Tue, 27 May 2025 11:48:17 +0200
|
|
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
|
|
|
PR: https://github.com/openxla/xla/pull/13420
|
|
---
|
|
xla/pjrt/c/BUILD | 5 +++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 21 ++++++++++++++++++
|
|
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 32 ++++++++++++++++++++++++++-
|
|
3 files changed, 57 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
|
index 79f18fa0bc..0f33dd8a6e 100644
|
|
--- a/xla/pjrt/c/BUILD
|
|
+++ b/xla/pjrt/c/BUILD
|
|
@@ -69,8 +69,13 @@ cc_library(
|
|
":pjrt_c_api_wrapper_impl",
|
|
"//xla/ffi:execution_context",
|
|
"//xla/ffi:type_id_registry",
|
|
+ "//xla/ffi:ffi_api",
|
|
+ "//xla/ffi/api:c_api",
|
|
+ "//xla/ffi/api:ffi",
|
|
+ "//xla/service:custom_call_target_registry",
|
|
"@com_google_absl//absl/status",
|
|
"@com_google_absl//absl/strings:string_view",
|
|
+ "@com_google_absl//absl/strings:str_format",
|
|
],
|
|
)
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
index 995a2c7e50..b8f10bc2f7 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
|
@@ -69,10 +69,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
|
// Adds a user data to the execute context.
|
|
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
|
|
|
+typedef enum PJRT_FFI_Handler_TraitsBits {
|
|
+ PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
|
+} PJRT_FFI_Handler_TraitsBits;
|
|
+
|
|
+struct PJRT_FFI_Register_Handler_Args {
|
|
+ size_t struct_size;
|
|
+ const char* target_name;
|
|
+ size_t target_name_size;
|
|
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
|
+ void* handler;
|
|
+ const char* platform_name;
|
|
+ size_t platform_name_size;
|
|
+ PJRT_FFI_Handler_TraitsBits traits;
|
|
+};
|
|
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, traits);
|
|
+
|
|
+// Registers an FFI call handler for a specific platform.
|
|
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args);
|
|
+
|
|
typedef struct PJRT_FFI_Extension {
|
|
PJRT_Extension_Base base;
|
|
PJRT_FFI_TypeID_Register* type_id_register;
|
|
PJRT_FFI_UserData_Add* user_data_add;
|
|
+ PJRT_FFI_Register_Handler* register_handler;
|
|
} PJRT_FFI;
|
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
|
|
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
index 5fa88eab33..763270331b 100644
|
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
|
@@ -13,16 +13,20 @@ See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
|
+#include <string>
|
|
|
|
#include "absl/status/status.h"
|
|
+#include "absl/strings/str_format.h"
|
|
#include "absl/strings/string_view.h"
|
|
+#include "xla/ffi/api/c_api.h"
|
|
#include "xla/ffi/execution_context.h"
|
|
#include "xla/ffi/type_id_registry.h"
|
|
+#include "xla/ffi/ffi_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
|
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
|
+#include "xla/service/custom_call_target_registry.h"
|
|
|
|
namespace pjrt {
|
|
|
|
@@ -68,6 +72,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
|
return nullptr;
|
|
}
|
|
|
|
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
|
+ PJRT_FFI_Register_Handler_Args* args) {
|
|
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
|
+ "PJRT_FFI_Register_Handler_Args",
|
|
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
|
+ std::string target_name(args->target_name, args->target_name_size);
|
|
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
|
+ switch (args->api_version) {
|
|
+ case 0:
|
|
+ xla::CustomCallTargetRegistry::Global()->Register(
|
|
+ target_name, args->handler, platform_name);
|
|
+ return nullptr;
|
|
+ case 1:
|
|
+ xla::ffi::Ffi::RegisterStaticHandler(
|
|
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
|
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
|
+ return nullptr;
|
|
+ default:
|
|
+ return new PJRT_Error{absl::UnimplementedError(
|
|
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
|
+ "Supported versions are 0 and 1.",
|
|
+ args->api_version))};
|
|
+ }
|
|
+}
|
|
+
|
|
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
return {
|
|
PJRT_Extension_Base{
|
|
@@ -77,6 +106,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
|
},
|
|
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
|
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
|
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
|
};
|
|
}
|
|
|
|
--
|
|
2.39.5 (Apple Git-154)
|
|
|