Switch ROCm dlopen handling to patchelf's rename_dynamic_symbols for more robust dynamic symbol import.

This commit is contained in:
Tarry Singh 2023-05-03 17:33:46 +00:00
parent fefd84b1bb
commit cfe38f27ca
8 changed files with 174 additions and 172 deletions

View File

@ -36,21 +36,22 @@ _cc_import_runfiles = rule(
)
def cc_import(
name,
static_library = None,
pic_static_library = None,
shared_library = None,
interface_library = None,
data = None,
deps = None,
visibility = None,
soname = None,
add_needed = None,
remove_needed = None,
replace_needed = None,
**kwargs):
if shared_library and (soname or add_needed or remove_needed or replace_needed):
patched_name = "{}_patchelf".format(name)
name,
static_library = None,
pic_static_library = None,
shared_library = None,
interface_library = None,
data = None,
deps = None,
visibility = None,
soname = None,
add_needed = None,
remove_needed = None,
replace_needed = None,
rename_dynamic_symbols = None,
**kwargs):
if shared_library and (soname or add_needed or remove_needed or replace_needed or rename_dynamic_symbols):
patched_name = "{}.patchelf".format(name)
patchelf(
name = patched_name,
shared_library = shared_library,
@ -58,11 +59,12 @@ def cc_import(
add_needed = add_needed,
remove_needed = remove_needed,
replace_needed = replace_needed,
rename_dynamic_symbols = rename_dynamic_symbols,
)
shared_library = ":" + patched_name
if data:
_cc_import(
name = name + "_no_runfiles",
name = name + ".norunfiles",
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
@ -71,15 +73,10 @@ def cc_import(
deps = deps,
**kwargs
)
_cc_import_runfiles(
native.cc_library(
name = name,
src = ":{}_no_runfiles".format(name),
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data,
deps = deps,
deps = [name + ".norunfiles"],
visibility = visibility,
)
else:
@ -89,7 +86,6 @@ def cc_import(
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data,
deps = deps,
visibility = visibility,
**kwargs

29
bazel/dpkg.bzl Normal file
View File

@ -0,0 +1,29 @@
def _packages_to_dict(txt):
packages = {}
current_pkg = {}
for line in txt.splitlines():
if line == "":
if current_pkg:
pkg_name = current_pkg["Package"]
pkg_version = current_pkg["Version"]
if pkg_name not in packages:
packages[pkg_name] = {}
packages[pkg_name][pkg_version] = struct(**current_pkg)
current_pkg = {}
continue
if line.startswith(" "):
current_pkg[key] += line
continue
split = line.split(": ", 1)
key = split[0]
value = len(split) > 1 and split[1] or ""
current_pkg[key] = value
return packages
def _read_packages(mctx, label):
data = mctx.read(Label(label))
return _packages_to_dict(data)
dpkg = struct(
read_packages = _read_packages,
)

View File

@ -1,6 +1,3 @@
def _render_kv(e):
return e
def _patchelf_impl(ctx):
output_name = ctx.file.shared_library.basename
if ctx.attr.soname:
@ -26,8 +23,19 @@ def _patchelf_impl(ctx):
for k, v in ctx.attr.replace_needed.items():
commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v))
renamed_syms = ctx.actions.declare_file("{}.rename.txt".format(ctx.label.name))
if ctx.attr.rename_dynamic_symbols:
content = "\n".join([
"{} {}".format(k, v)
for k, v in ctx.attr.rename_dynamic_symbols.items()
])
ctx.actions.write(renamed_syms, content)
commands.append('"$1" --rename-dynamic-symbols "{}" "$3"'.format(renamed_syms.path))
else:
ctx.actions.write(renamed_syms, "")
ctx.actions.run_shell(
inputs = [ctx.file.shared_library],
inputs = [ctx.file.shared_library, renamed_syms],
outputs = [output],
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
command = "\n".join(commands),
@ -48,6 +56,7 @@ patchelf = rule(
"add_needed": attr.string_list(),
"remove_needed": attr.string_list(),
"replace_needed": attr.string_dict(),
"rename_dynamic_symbols": attr.string_dict(),
"_patchelf": attr.label(
default = "@patchelf",
allow_single_file = True,

View File

@ -1,6 +1,6 @@
filegroup(
name = "zmlrocmhooks_srcs",
srcs = ["zmlrocmhooks.cc"],
name = "zmlxrocm_srcs",
srcs = ["zmlxrocm.cc"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)

View File

@ -31,9 +31,9 @@ copy_to_directory(
)
cc_library(
name = "zmlrocmhooks_lib",
name = "zmlxrocm_lib",
data = ["@rocblas//:runfiles"],
srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"],
srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"],
linkopts = [
"-lc",
"-ldl",
@ -42,14 +42,14 @@ cc_library(
)
cc_shared_library(
name = "zmlrocmhooks_so",
shared_lib_name = "libzmlrocmhooks.so.0",
deps = [":zmlrocmhooks_lib"],
name = "zmlxrocm_so",
shared_lib_name = "libzmlxrocm.so.0",
deps = [":zmlxrocm_lib"],
)
cc_import(
name = "zmlrocmhooks",
shared_library = ":zmlrocmhooks_so",
name = "zmlxrocm",
shared_library = ":zmlxrocm_so",
)
cc_import(
@ -64,7 +64,7 @@ cc_import(
shared_library = "libpjrt_rocm.so",
visibility = ["//visibility:public"],
deps = [
":zmlrocmhooks",
":zmlxrocm",
"@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip",
"@hipblaslt",

View File

@ -1,17 +1,18 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:dpkg.bzl", "dpkg")
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ROCM_VERSION = "6.2.2"
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
STRIP_PREFIX = "opt/rocm-6.2.2"
def pkg_kwargs(pkg, packages):
return {
"name": pkg,
"urls": [BASE_URL + "/" + packages[pkg]["Filename"]],
"sha256": packages[pkg]["SHA256"],
"strip_prefix": STRIP_PREFIX,
}
def pkg_kwargs(pkg):
return dict(
name = pkg.Package,
urls = [BASE_URL + "/" + pkg.Filename],
sha256 = pkg.SHA256,
strip_prefix = STRIP_PREFIX,
)
def _ubuntu_package(path, deb_path, sha256, name, shared_library):
return {
@ -118,7 +119,10 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
cc_import(
name = "rocblas",
shared_library = "lib/librocblas.so.4",
add_needed = ["libzmlrocmhooks.so.0"],
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
@ -147,7 +151,10 @@ load("@zml//bazel:cc_import.bzl", "cc_import")
cc_import(
name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0",
add_needed = ["libzmlrocmhooks.so.0"],
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
""",
@ -193,32 +200,14 @@ filegroup(
""",
}
def _packages_to_dict(txt):
packages = {}
current_pkg = {}
for line in txt.splitlines():
if line == "":
if current_pkg:
packages[current_pkg["Package"]] = current_pkg
current_pkg = {}
continue
if line.startswith(" "):
current_pkg[key] += line
continue
split = line.split(": ", 1)
key = split[0]
value = len(split) > 1 and split[1] or ""
current_pkg[key] = value
return packages
def _rocm_impl(mctx):
data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt"))
PACKAGES = _packages_to_dict(data)
all_packages = dpkg.read_packages(mctx, "@zml//runtimes/rocm:packages.amd64.txt")
for pkg, build_file_content in _PACKAGES.items():
for pkg_name, build_file_content in _PACKAGES.items():
pkg = all_packages[pkg_name].values()[0]
http_deb_archive(
build_file_content = build_file_content,
**pkg_kwargs(pkg, PACKAGES)
**pkg_kwargs(pkg)
)
for repository, kwargs in _UBUNTU_PACKAGES.items():

View File

@ -1,103 +0,0 @@
#include <string>
#include <iostream>
#include <dlfcn.h>
#include <errno.h>
#include <fstream>
#include <stdlib.h>
#include "tools/cpp/runfiles/runfiles.h"
namespace zml
{
using bazel::tools::cpp::runfiles::Runfiles;
std::unique_ptr<Runfiles> runfiles;
std::string ROCBLAS_TENSILE_LIBPATH;
std::string HIPBLASLT_TENSILE_LIBPATH;
std::string HIPBLASLT_EXT_OP_LIBRARY_PATH;
std::string ROCM_PATH;
typedef void *(*dlopen_func)(const char *filename, int flags);
dlopen_func dlopen_orig = nullptr;
__attribute__((constructor)) static void setup(int argc, char **argv)
{
runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
static void *rocm_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.2",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen_orig(filename, flags);
}
}
extern "C"
{
zml::dlopen_func _zml_rocm_resolve_dlopen()
{
zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen");
return zml::rocm_dlopen;
}
extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen")));
}

82
runtimes/rocm/zmlxrocm.cc Normal file
View File

@ -0,0 +1,82 @@
#include <dlfcn.h>
#include <errno.h>
#include <stdlib.h>
#include <fstream>
#include <iostream>
#include <string>
#include "tools/cpp/runfiles/runfiles.h"
__attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
{
using bazel::tools::cpp::runfiles::Runfiles;
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
auto HIPBLASLT_EXT_OP_LIBRARY_PATH =
runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
auto HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
auto ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
auto ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.2",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen(filename, flags);
}