From cfe38f27ca02540542acc200f79cd26d19631b38 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 3 May 2023 17:33:46 +0000 Subject: [PATCH] Switch ROCm dlopen handling to patchelf's rename_dynamic_symbols for more robust dynamic symbol import. --- bazel/cc_import.bzl | 44 +++++------ bazel/dpkg.bzl | 29 +++++++ bazel/patchelf.bzl | 17 +++- runtimes/rocm/BUILD.bazel | 4 +- runtimes/rocm/libpjrt_rocm.BUILD.bazel | 16 ++-- runtimes/rocm/rocm.bzl | 51 +++++------- runtimes/rocm/zmlrocmhooks.cc | 103 ------------------------- runtimes/rocm/zmlxrocm.cc | 82 ++++++++++++++++++++ 8 files changed, 174 insertions(+), 172 deletions(-) create mode 100644 bazel/dpkg.bzl delete mode 100644 runtimes/rocm/zmlrocmhooks.cc create mode 100644 runtimes/rocm/zmlxrocm.cc diff --git a/bazel/cc_import.bzl b/bazel/cc_import.bzl index 92cb014..1f78f46 100644 --- a/bazel/cc_import.bzl +++ b/bazel/cc_import.bzl @@ -36,21 +36,22 @@ _cc_import_runfiles = rule( ) def cc_import( - name, - static_library = None, - pic_static_library = None, - shared_library = None, - interface_library = None, - data = None, - deps = None, - visibility = None, - soname = None, - add_needed = None, - remove_needed = None, - replace_needed = None, - **kwargs): - if shared_library and (soname or add_needed or remove_needed or replace_needed): - patched_name = "{}_patchelf".format(name) + name, + static_library = None, + pic_static_library = None, + shared_library = None, + interface_library = None, + data = None, + deps = None, + visibility = None, + soname = None, + add_needed = None, + remove_needed = None, + replace_needed = None, + rename_dynamic_symbols = None, + **kwargs): + if shared_library and (soname or add_needed or remove_needed or replace_needed or rename_dynamic_symbols): + patched_name = "{}.patchelf".format(name) patchelf( name = patched_name, shared_library = shared_library, @@ -58,11 +59,12 @@ def cc_import( add_needed = add_needed, remove_needed = remove_needed, replace_needed = replace_needed, + rename_dynamic_symbols = rename_dynamic_symbols, ) shared_library = ":" + patched_name if data: _cc_import( - name = name + "_no_runfiles", + name = name + ".norunfiles", static_library = static_library, pic_static_library = pic_static_library, shared_library = shared_library, @@ -71,15 +73,10 @@ def cc_import( deps = deps, **kwargs ) - _cc_import_runfiles( + native.cc_library( name = name, - src = ":{}_no_runfiles".format(name), - static_library = static_library, - pic_static_library = pic_static_library, - shared_library = shared_library, - interface_library = interface_library, data = data, - deps = deps, + deps = [name + ".norunfiles"], visibility = visibility, ) else: @@ -89,7 +86,6 @@ def cc_import( pic_static_library = pic_static_library, shared_library = shared_library, interface_library = interface_library, - data = data, deps = deps, visibility = visibility, **kwargs diff --git a/bazel/dpkg.bzl b/bazel/dpkg.bzl new file mode 100644 index 0000000..343f188 --- /dev/null +++ b/bazel/dpkg.bzl @@ -0,0 +1,29 @@ +def _packages_to_dict(txt): + packages = {} + current_pkg = {} + for line in txt.splitlines(): + if line == "": + if current_pkg: + pkg_name = current_pkg["Package"] + pkg_version = current_pkg["Version"] + if pkg_name not in packages: + packages[pkg_name] = {} + packages[pkg_name][pkg_version] = struct(**current_pkg) + current_pkg = {} + continue + if line.startswith(" "): + current_pkg[key] += line + continue + split = line.split(": ", 1) + key = split[0] + value = len(split) > 1 and split[1] or "" + current_pkg[key] = value + return packages + +def _read_packages(mctx, label): + data = mctx.read(Label(label)) + return _packages_to_dict(data) + +dpkg = struct( + read_packages = _read_packages, +) diff --git a/bazel/patchelf.bzl b/bazel/patchelf.bzl index 244dbf6..bf59092 100644 --- a/bazel/patchelf.bzl +++ b/bazel/patchelf.bzl @@ -1,6 +1,3 @@ -def _render_kv(e): - return e - def _patchelf_impl(ctx): output_name = ctx.file.shared_library.basename if ctx.attr.soname: @@ -26,8 +23,19 @@ def _patchelf_impl(ctx): for k, v in ctx.attr.replace_needed.items(): commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v)) + renamed_syms = ctx.actions.declare_file("{}.rename.txt".format(ctx.label.name)) + if ctx.attr.rename_dynamic_symbols: + content = "\n".join([ + "{} {}".format(k, v) + for k, v in ctx.attr.rename_dynamic_symbols.items() + ]) + ctx.actions.write(renamed_syms, content) + commands.append('"$1" --rename-dynamic-symbols "{}" "$3"'.format(renamed_syms.path)) + else: + ctx.actions.write(renamed_syms, "") + ctx.actions.run_shell( - inputs = [ctx.file.shared_library], + inputs = [ctx.file.shared_library, renamed_syms], outputs = [output], arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path], command = "\n".join(commands), @@ -48,6 +56,7 @@ patchelf = rule( "add_needed": attr.string_list(), "remove_needed": attr.string_list(), "replace_needed": attr.string_dict(), + "rename_dynamic_symbols": attr.string_dict(), "_patchelf": attr.label( default = "@patchelf", allow_single_file = True, diff --git a/runtimes/rocm/BUILD.bazel b/runtimes/rocm/BUILD.bazel index a75fc83..0294432 100644 --- a/runtimes/rocm/BUILD.bazel +++ b/runtimes/rocm/BUILD.bazel @@ -1,6 +1,6 @@ filegroup( - name = "zmlrocmhooks_srcs", - srcs = ["zmlrocmhooks.cc"], + name = "zmlxrocm_srcs", + srcs = ["zmlxrocm.cc"], visibility = ["@libpjrt_rocm//:__subpackages__"], ) diff --git a/runtimes/rocm/libpjrt_rocm.BUILD.bazel b/runtimes/rocm/libpjrt_rocm.BUILD.bazel index e6e333a..d28bfd2 100644 --- a/runtimes/rocm/libpjrt_rocm.BUILD.bazel +++ b/runtimes/rocm/libpjrt_rocm.BUILD.bazel @@ -31,9 +31,9 @@ copy_to_directory( ) cc_library( - name = "zmlrocmhooks_lib", + name = "zmlxrocm_lib", data = ["@rocblas//:runfiles"], - srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"], + srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"], linkopts = [ "-lc", "-ldl", @@ -42,14 +42,14 @@ cc_library( ) cc_shared_library( - name = "zmlrocmhooks_so", - shared_lib_name = "libzmlrocmhooks.so.0", - deps = [":zmlrocmhooks_lib"], + name = "zmlxrocm_so", + shared_lib_name = "libzmlxrocm.so.0", + deps = [":zmlxrocm_lib"], ) cc_import( - name = "zmlrocmhooks", - shared_library = ":zmlrocmhooks_so", + name = "zmlxrocm", + shared_library = ":zmlxrocm_so", ) cc_import( @@ -64,7 +64,7 @@ cc_import( shared_library = "libpjrt_rocm.so", visibility = ["//visibility:public"], deps = [ - ":zmlrocmhooks", + ":zmlxrocm", "@comgr//:amd_comgr", "@hip-runtime-amd//:amdhip", "@hipblaslt", diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index d261be1..d93c9a8 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -1,17 +1,18 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//bazel:dpkg.bzl", "dpkg") load("//bazel:http_deb_archive.bzl", "http_deb_archive") ROCM_VERSION = "6.2.2" BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION) STRIP_PREFIX = "opt/rocm-6.2.2" -def pkg_kwargs(pkg, packages): - return { - "name": pkg, - "urls": [BASE_URL + "/" + packages[pkg]["Filename"]], - "sha256": packages[pkg]["SHA256"], - "strip_prefix": STRIP_PREFIX, - } +def pkg_kwargs(pkg): + return dict( + name = pkg.Package, + urls = [BASE_URL + "/" + pkg.Filename], + sha256 = pkg.SHA256, + strip_prefix = STRIP_PREFIX, + ) def _ubuntu_package(path, deb_path, sha256, name, shared_library): return { @@ -118,7 +119,10 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select") cc_import( name = "rocblas", shared_library = "lib/librocblas.so.4", - add_needed = ["libzmlrocmhooks.so.0"], + add_needed = ["libzmlxrocm.so.0"], + rename_dynamic_symbols = { + "dlopen": "zmlxrocm_dlopen", + }, visibility = ["@libpjrt_rocm//:__subpackages__"], ) @@ -147,7 +151,10 @@ load("@zml//bazel:cc_import.bzl", "cc_import") cc_import( name = "hipblaslt", shared_library = "lib/libhipblaslt.so.0", - add_needed = ["libzmlrocmhooks.so.0"], + add_needed = ["libzmlxrocm.so.0"], + rename_dynamic_symbols = { + "dlopen": "zmlxrocm_dlopen", + }, visibility = ["@libpjrt_rocm//:__subpackages__"], ) """, @@ -193,32 +200,14 @@ filegroup( """, } -def _packages_to_dict(txt): - packages = {} - current_pkg = {} - for line in txt.splitlines(): - if line == "": - if current_pkg: - packages[current_pkg["Package"]] = current_pkg - current_pkg = {} - continue - if line.startswith(" "): - current_pkg[key] += line - continue - split = line.split(": ", 1) - key = split[0] - value = len(split) > 1 and split[1] or "" - current_pkg[key] = value - return packages - def _rocm_impl(mctx): - data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt")) - PACKAGES = _packages_to_dict(data) + all_packages = dpkg.read_packages(mctx, "@zml//runtimes/rocm:packages.amd64.txt") - for pkg, build_file_content in _PACKAGES.items(): + for pkg_name, build_file_content in _PACKAGES.items(): + pkg = all_packages[pkg_name].values()[0] http_deb_archive( build_file_content = build_file_content, - **pkg_kwargs(pkg, PACKAGES) + **pkg_kwargs(pkg) ) for repository, kwargs in _UBUNTU_PACKAGES.items(): diff --git a/runtimes/rocm/zmlrocmhooks.cc b/runtimes/rocm/zmlrocmhooks.cc deleted file mode 100644 index 0c0fc60..0000000 --- a/runtimes/rocm/zmlrocmhooks.cc +++ /dev/null @@ -1,103 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "tools/cpp/runfiles/runfiles.h" - -namespace zml -{ - using bazel::tools::cpp::runfiles::Runfiles; - - std::unique_ptr runfiles; - std::string ROCBLAS_TENSILE_LIBPATH; - std::string HIPBLASLT_TENSILE_LIBPATH; - std::string HIPBLASLT_EXT_OP_LIBRARY_PATH; - std::string ROCM_PATH; - - typedef void *(*dlopen_func)(const char *filename, int flags); - dlopen_func dlopen_orig = nullptr; - - __attribute__((constructor)) static void setup(int argc, char **argv) - { - runfiles = std::unique_ptr(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY)); - - HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat"); - if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "") - { - setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1); - } - - HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library"); - if (HIPBLASLT_TENSILE_LIBPATH != "") - { - setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1); - } - - ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library"); - setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1); - - ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox"); - setenv("ROCM_PATH", ROCM_PATH.c_str(), 1); - } - - static void *rocm_dlopen(const char *filename, int flags) - { - if (filename != NULL) - { - char *replacements[] = { - "librocm-core.so", - "librocm-core.so.1", - "librocm_smi64.so", - "librocm_smi64.so.7", - "libhsa-runtime64.so", - "libhsa-runtime64.so.1", - "libhsa-amd-aqlprofile64.so", - "libhsa-amd-aqlprofile64.so.1", - "libamd_comgr.so", - "libamd_comgr.so.2", - "librocprofiler-register.so", - "librocprofiler-register.so.0", - "libMIOpen.so", - "libMIOpen.so.1", - "librccl.so", - "librccl.so.1", - "librocblas.so", - "librocblas.so.4", - "libroctracer64.so", - "libroctracer64.so.4", - "libroctx64.so", - "libroctx64.so.4", - "libhipblaslt.so", - "libhipblaslt.so.0", - "libamdhip64.so", - "libamdhip64.so.6", - "libhiprtc.so", - "libhiprtc.so.6", - NULL, - NULL, - }; - for (int i = 0; replacements[i] != NULL; i += 2) - { - if (strcmp(filename, replacements[i]) == 0) - { - filename = replacements[i + 1]; - break; - } - } - } - return dlopen_orig(filename, flags); - } -} - -extern "C" -{ - zml::dlopen_func _zml_rocm_resolve_dlopen() - { - zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen"); - return zml::rocm_dlopen; - } - - extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen"))); -} diff --git a/runtimes/rocm/zmlxrocm.cc b/runtimes/rocm/zmlxrocm.cc new file mode 100644 index 0000000..7a42fc4 --- /dev/null +++ b/runtimes/rocm/zmlxrocm.cc @@ -0,0 +1,82 @@ +#include +#include +#include + +#include +#include +#include + +#include "tools/cpp/runfiles/runfiles.h" + +__attribute__((constructor)) static void setup_runfiles(int argc, char **argv) +{ + using bazel::tools::cpp::runfiles::Runfiles; + auto runfiles = std::unique_ptr(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY)); + + auto HIPBLASLT_EXT_OP_LIBRARY_PATH = + runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat"); + if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "") + { + setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1); + } + + auto HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library"); + if (HIPBLASLT_TENSILE_LIBPATH != "") + { + setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1); + } + + auto ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library"); + setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1); + + auto ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox"); + setenv("ROCM_PATH", ROCM_PATH.c_str(), 1); +} + +extern "C" void *zmlxrocm_dlopen(const char *filename, int flags) +{ + if (filename != NULL) + { + char *replacements[] = { + "librocm-core.so", + "librocm-core.so.1", + "librocm_smi64.so", + "librocm_smi64.so.7", + "libhsa-runtime64.so", + "libhsa-runtime64.so.1", + "libhsa-amd-aqlprofile64.so", + "libhsa-amd-aqlprofile64.so.1", + "libamd_comgr.so", + "libamd_comgr.so.2", + "librocprofiler-register.so", + "librocprofiler-register.so.0", + "libMIOpen.so", + "libMIOpen.so.1", + "librccl.so", + "librccl.so.1", + "librocblas.so", + "librocblas.so.4", + "libroctracer64.so", + "libroctracer64.so.4", + "libroctx64.so", + "libroctx64.so.4", + "libhipblaslt.so", + "libhipblaslt.so.0", + "libamdhip64.so", + "libamdhip64.so.6", + "libhiprtc.so", + "libhiprtc.so.6", + NULL, + NULL, + }; + for (int i = 0; replacements[i] != NULL; i += 2) + { + if (strcmp(filename, replacements[i]) == 0) + { + filename = replacements[i + 1]; + break; + } + } + } + return dlopen(filename, flags); +}