From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001 From: Hugo Mano Date: Wed, 5 Feb 2025 19:25:03 +0100 Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt PR: https://github.com/openxla/xla/pull/13420 --- xla/pjrt/c/BUILD | 5 ++++ xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++ xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index ad2ed95bce..0e7c35c30f 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -69,7 +69,12 @@ cc_library( ":pjrt_c_api_wrapper_impl", "//xla/ffi:execution_context", "//xla/ffi:type_id_registry", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/ffi/api:ffi", + "//xla/service:custom_call_target_registry", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h index a33bd4aa9c..3309194538 100644 --- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h +++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h @@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); // Adds a user data to the execute context. typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); +struct PJRT_FFI_Register_Handler_Args { + size_t struct_size; + const char* target_name; + size_t target_name_size; + int api_version; // 0 for an untyped call, 1 -- for typed + void* handler; + const char* platform_name; + size_t platform_name_size; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler); + +// Registers an FFI call handler for a specific platform. +typedef PJRT_Error* PJRT_FFI_Register_Handler( + PJRT_FFI_Register_Handler_Args* args); + typedef struct PJRT_FFI_Extension { size_t struct_size; PJRT_Extension_Type type; PJRT_Extension_Base* next; PJRT_FFI_TypeID_Register* type_id_register; PJRT_FFI_UserData_Add* user_data_add; + PJRT_FFI_Register_Handler* register_handler; } PJRT_FFI; PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc index 0375b39d0b..3527a0756e 100644 --- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc @@ -13,15 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" +#include #include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/ffi/execution_context.h" -#include "xla/ffi/type_id_registry.h" + #include "xla/ffi/type_id_registry.h" +#include "xla/ffi/ffi_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/service/custom_call_target_registry.h" namespace pjrt { @@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { return nullptr; } +static PJRT_Error* PJRT_FFI_Register_Handler( + PJRT_FFI_Register_Handler_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_FFI_Register_Handler_Args", + PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); + std::string target_name(args->target_name, args->target_name_size); + std::string platform_name(args->platform_name, args->platform_name_size); + switch (args->api_version) { + case 0: + xla::CustomCallTargetRegistry::Global()->Register( + target_name, args->handler, platform_name); + return nullptr; + case 1: + xla::ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), target_name, platform_name, + reinterpret_cast(args->handler)); + return nullptr; + default: + return new PJRT_Error{absl::UnimplementedError( + absl::StrFormat("API version %d not supported for PJRT GPU plugin. " + "Supported versions are 0 and 1.", + args->api_version))}; + } +} + PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { return { /*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE, @@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { /*next=*/next, /*type_id_register=*/PJRT_FFI_TypeID_Register, /*user_data_add=*/PJRT_FFI_UserData_Add, + /*register_handler=*/PJRT_FFI_Register_Handler, }; } -- 2.43.0