Switch ROCm dlopen handling to patchelf's rename_dynamic_symbols for more robust dynamic symbol import.
This commit is contained in:
parent
fefd84b1bb
commit
cfe38f27ca
@ -36,21 +36,22 @@ _cc_import_runfiles = rule(
|
||||
)
|
||||
|
||||
def cc_import(
|
||||
name,
|
||||
static_library = None,
|
||||
pic_static_library = None,
|
||||
shared_library = None,
|
||||
interface_library = None,
|
||||
data = None,
|
||||
deps = None,
|
||||
visibility = None,
|
||||
soname = None,
|
||||
add_needed = None,
|
||||
remove_needed = None,
|
||||
replace_needed = None,
|
||||
**kwargs):
|
||||
if shared_library and (soname or add_needed or remove_needed or replace_needed):
|
||||
patched_name = "{}_patchelf".format(name)
|
||||
name,
|
||||
static_library = None,
|
||||
pic_static_library = None,
|
||||
shared_library = None,
|
||||
interface_library = None,
|
||||
data = None,
|
||||
deps = None,
|
||||
visibility = None,
|
||||
soname = None,
|
||||
add_needed = None,
|
||||
remove_needed = None,
|
||||
replace_needed = None,
|
||||
rename_dynamic_symbols = None,
|
||||
**kwargs):
|
||||
if shared_library and (soname or add_needed or remove_needed or replace_needed or rename_dynamic_symbols):
|
||||
patched_name = "{}.patchelf".format(name)
|
||||
patchelf(
|
||||
name = patched_name,
|
||||
shared_library = shared_library,
|
||||
@ -58,11 +59,12 @@ def cc_import(
|
||||
add_needed = add_needed,
|
||||
remove_needed = remove_needed,
|
||||
replace_needed = replace_needed,
|
||||
rename_dynamic_symbols = rename_dynamic_symbols,
|
||||
)
|
||||
shared_library = ":" + patched_name
|
||||
if data:
|
||||
_cc_import(
|
||||
name = name + "_no_runfiles",
|
||||
name = name + ".norunfiles",
|
||||
static_library = static_library,
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
@ -71,15 +73,10 @@ def cc_import(
|
||||
deps = deps,
|
||||
**kwargs
|
||||
)
|
||||
_cc_import_runfiles(
|
||||
native.cc_library(
|
||||
name = name,
|
||||
src = ":{}_no_runfiles".format(name),
|
||||
static_library = static_library,
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
interface_library = interface_library,
|
||||
data = data,
|
||||
deps = deps,
|
||||
deps = [name + ".norunfiles"],
|
||||
visibility = visibility,
|
||||
)
|
||||
else:
|
||||
@ -89,7 +86,6 @@ def cc_import(
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
interface_library = interface_library,
|
||||
data = data,
|
||||
deps = deps,
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
|
||||
29
bazel/dpkg.bzl
Normal file
29
bazel/dpkg.bzl
Normal file
@ -0,0 +1,29 @@
|
||||
def _packages_to_dict(txt):
|
||||
packages = {}
|
||||
current_pkg = {}
|
||||
for line in txt.splitlines():
|
||||
if line == "":
|
||||
if current_pkg:
|
||||
pkg_name = current_pkg["Package"]
|
||||
pkg_version = current_pkg["Version"]
|
||||
if pkg_name not in packages:
|
||||
packages[pkg_name] = {}
|
||||
packages[pkg_name][pkg_version] = struct(**current_pkg)
|
||||
current_pkg = {}
|
||||
continue
|
||||
if line.startswith(" "):
|
||||
current_pkg[key] += line
|
||||
continue
|
||||
split = line.split(": ", 1)
|
||||
key = split[0]
|
||||
value = len(split) > 1 and split[1] or ""
|
||||
current_pkg[key] = value
|
||||
return packages
|
||||
|
||||
def _read_packages(mctx, label):
|
||||
data = mctx.read(Label(label))
|
||||
return _packages_to_dict(data)
|
||||
|
||||
dpkg = struct(
|
||||
read_packages = _read_packages,
|
||||
)
|
||||
@ -1,6 +1,3 @@
|
||||
def _render_kv(e):
|
||||
return e
|
||||
|
||||
def _patchelf_impl(ctx):
|
||||
output_name = ctx.file.shared_library.basename
|
||||
if ctx.attr.soname:
|
||||
@ -26,8 +23,19 @@ def _patchelf_impl(ctx):
|
||||
for k, v in ctx.attr.replace_needed.items():
|
||||
commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v))
|
||||
|
||||
renamed_syms = ctx.actions.declare_file("{}.rename.txt".format(ctx.label.name))
|
||||
if ctx.attr.rename_dynamic_symbols:
|
||||
content = "\n".join([
|
||||
"{} {}".format(k, v)
|
||||
for k, v in ctx.attr.rename_dynamic_symbols.items()
|
||||
])
|
||||
ctx.actions.write(renamed_syms, content)
|
||||
commands.append('"$1" --rename-dynamic-symbols "{}" "$3"'.format(renamed_syms.path))
|
||||
else:
|
||||
ctx.actions.write(renamed_syms, "")
|
||||
|
||||
ctx.actions.run_shell(
|
||||
inputs = [ctx.file.shared_library],
|
||||
inputs = [ctx.file.shared_library, renamed_syms],
|
||||
outputs = [output],
|
||||
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
|
||||
command = "\n".join(commands),
|
||||
@ -48,6 +56,7 @@ patchelf = rule(
|
||||
"add_needed": attr.string_list(),
|
||||
"remove_needed": attr.string_list(),
|
||||
"replace_needed": attr.string_dict(),
|
||||
"rename_dynamic_symbols": attr.string_dict(),
|
||||
"_patchelf": attr.label(
|
||||
default = "@patchelf",
|
||||
allow_single_file = True,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
filegroup(
|
||||
name = "zmlrocmhooks_srcs",
|
||||
srcs = ["zmlrocmhooks.cc"],
|
||||
name = "zmlxrocm_srcs",
|
||||
srcs = ["zmlxrocm.cc"],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
|
||||
|
||||
@ -31,9 +31,9 @@ copy_to_directory(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "zmlrocmhooks_lib",
|
||||
name = "zmlxrocm_lib",
|
||||
data = ["@rocblas//:runfiles"],
|
||||
srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"],
|
||||
srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"],
|
||||
linkopts = [
|
||||
"-lc",
|
||||
"-ldl",
|
||||
@ -42,14 +42,14 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_shared_library(
|
||||
name = "zmlrocmhooks_so",
|
||||
shared_lib_name = "libzmlrocmhooks.so.0",
|
||||
deps = [":zmlrocmhooks_lib"],
|
||||
name = "zmlxrocm_so",
|
||||
shared_lib_name = "libzmlxrocm.so.0",
|
||||
deps = [":zmlxrocm_lib"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "zmlrocmhooks",
|
||||
shared_library = ":zmlrocmhooks_so",
|
||||
name = "zmlxrocm",
|
||||
shared_library = ":zmlxrocm_so",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
@ -64,7 +64,7 @@ cc_import(
|
||||
shared_library = "libpjrt_rocm.so",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":zmlrocmhooks",
|
||||
":zmlxrocm",
|
||||
"@comgr//:amd_comgr",
|
||||
"@hip-runtime-amd//:amdhip",
|
||||
"@hipblaslt",
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//bazel:dpkg.bzl", "dpkg")
|
||||
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
|
||||
|
||||
ROCM_VERSION = "6.2.2"
|
||||
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
|
||||
STRIP_PREFIX = "opt/rocm-6.2.2"
|
||||
|
||||
def pkg_kwargs(pkg, packages):
|
||||
return {
|
||||
"name": pkg,
|
||||
"urls": [BASE_URL + "/" + packages[pkg]["Filename"]],
|
||||
"sha256": packages[pkg]["SHA256"],
|
||||
"strip_prefix": STRIP_PREFIX,
|
||||
}
|
||||
def pkg_kwargs(pkg):
|
||||
return dict(
|
||||
name = pkg.Package,
|
||||
urls = [BASE_URL + "/" + pkg.Filename],
|
||||
sha256 = pkg.SHA256,
|
||||
strip_prefix = STRIP_PREFIX,
|
||||
)
|
||||
|
||||
def _ubuntu_package(path, deb_path, sha256, name, shared_library):
|
||||
return {
|
||||
@ -118,7 +119,10 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
|
||||
cc_import(
|
||||
name = "rocblas",
|
||||
shared_library = "lib/librocblas.so.4",
|
||||
add_needed = ["libzmlrocmhooks.so.0"],
|
||||
add_needed = ["libzmlxrocm.so.0"],
|
||||
rename_dynamic_symbols = {
|
||||
"dlopen": "zmlxrocm_dlopen",
|
||||
},
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
|
||||
@ -147,7 +151,10 @@ load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
cc_import(
|
||||
name = "hipblaslt",
|
||||
shared_library = "lib/libhipblaslt.so.0",
|
||||
add_needed = ["libzmlrocmhooks.so.0"],
|
||||
add_needed = ["libzmlxrocm.so.0"],
|
||||
rename_dynamic_symbols = {
|
||||
"dlopen": "zmlxrocm_dlopen",
|
||||
},
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
@ -193,32 +200,14 @@ filegroup(
|
||||
""",
|
||||
}
|
||||
|
||||
def _packages_to_dict(txt):
|
||||
packages = {}
|
||||
current_pkg = {}
|
||||
for line in txt.splitlines():
|
||||
if line == "":
|
||||
if current_pkg:
|
||||
packages[current_pkg["Package"]] = current_pkg
|
||||
current_pkg = {}
|
||||
continue
|
||||
if line.startswith(" "):
|
||||
current_pkg[key] += line
|
||||
continue
|
||||
split = line.split(": ", 1)
|
||||
key = split[0]
|
||||
value = len(split) > 1 and split[1] or ""
|
||||
current_pkg[key] = value
|
||||
return packages
|
||||
|
||||
def _rocm_impl(mctx):
|
||||
data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt"))
|
||||
PACKAGES = _packages_to_dict(data)
|
||||
all_packages = dpkg.read_packages(mctx, "@zml//runtimes/rocm:packages.amd64.txt")
|
||||
|
||||
for pkg, build_file_content in _PACKAGES.items():
|
||||
for pkg_name, build_file_content in _PACKAGES.items():
|
||||
pkg = all_packages[pkg_name].values()[0]
|
||||
http_deb_archive(
|
||||
build_file_content = build_file_content,
|
||||
**pkg_kwargs(pkg, PACKAGES)
|
||||
**pkg_kwargs(pkg)
|
||||
)
|
||||
|
||||
for repository, kwargs in _UBUNTU_PACKAGES.items():
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <dlfcn.h>
|
||||
#include <errno.h>
|
||||
#include <fstream>
|
||||
#include <stdlib.h>
|
||||
#include "tools/cpp/runfiles/runfiles.h"
|
||||
|
||||
namespace zml
|
||||
{
|
||||
using bazel::tools::cpp::runfiles::Runfiles;
|
||||
|
||||
std::unique_ptr<Runfiles> runfiles;
|
||||
std::string ROCBLAS_TENSILE_LIBPATH;
|
||||
std::string HIPBLASLT_TENSILE_LIBPATH;
|
||||
std::string HIPBLASLT_EXT_OP_LIBRARY_PATH;
|
||||
std::string ROCM_PATH;
|
||||
|
||||
typedef void *(*dlopen_func)(const char *filename, int flags);
|
||||
dlopen_func dlopen_orig = nullptr;
|
||||
|
||||
__attribute__((constructor)) static void setup(int argc, char **argv)
|
||||
{
|
||||
runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
|
||||
|
||||
HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
|
||||
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
|
||||
if (HIPBLASLT_TENSILE_LIBPATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
|
||||
}
|
||||
|
||||
ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
|
||||
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
|
||||
|
||||
ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
|
||||
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
static void *rocm_dlopen(const char *filename, int flags)
|
||||
{
|
||||
if (filename != NULL)
|
||||
{
|
||||
char *replacements[] = {
|
||||
"librocm-core.so",
|
||||
"librocm-core.so.1",
|
||||
"librocm_smi64.so",
|
||||
"librocm_smi64.so.7",
|
||||
"libhsa-runtime64.so",
|
||||
"libhsa-runtime64.so.1",
|
||||
"libhsa-amd-aqlprofile64.so",
|
||||
"libhsa-amd-aqlprofile64.so.1",
|
||||
"libamd_comgr.so",
|
||||
"libamd_comgr.so.2",
|
||||
"librocprofiler-register.so",
|
||||
"librocprofiler-register.so.0",
|
||||
"libMIOpen.so",
|
||||
"libMIOpen.so.1",
|
||||
"librccl.so",
|
||||
"librccl.so.1",
|
||||
"librocblas.so",
|
||||
"librocblas.so.4",
|
||||
"libroctracer64.so",
|
||||
"libroctracer64.so.4",
|
||||
"libroctx64.so",
|
||||
"libroctx64.so.4",
|
||||
"libhipblaslt.so",
|
||||
"libhipblaslt.so.0",
|
||||
"libamdhip64.so",
|
||||
"libamdhip64.so.6",
|
||||
"libhiprtc.so",
|
||||
"libhiprtc.so.6",
|
||||
NULL,
|
||||
NULL,
|
||||
};
|
||||
for (int i = 0; replacements[i] != NULL; i += 2)
|
||||
{
|
||||
if (strcmp(filename, replacements[i]) == 0)
|
||||
{
|
||||
filename = replacements[i + 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dlopen_orig(filename, flags);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C"
|
||||
{
|
||||
zml::dlopen_func _zml_rocm_resolve_dlopen()
|
||||
{
|
||||
zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen");
|
||||
return zml::rocm_dlopen;
|
||||
}
|
||||
|
||||
extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen")));
|
||||
}
|
||||
82
runtimes/rocm/zmlxrocm.cc
Normal file
82
runtimes/rocm/zmlxrocm.cc
Normal file
@ -0,0 +1,82 @@
|
||||
#include <dlfcn.h>
|
||||
#include <errno.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "tools/cpp/runfiles/runfiles.h"
|
||||
|
||||
__attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
|
||||
{
|
||||
using bazel::tools::cpp::runfiles::Runfiles;
|
||||
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
|
||||
|
||||
auto HIPBLASLT_EXT_OP_LIBRARY_PATH =
|
||||
runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
|
||||
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
auto HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
|
||||
if (HIPBLASLT_TENSILE_LIBPATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
|
||||
}
|
||||
|
||||
auto ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
|
||||
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
|
||||
|
||||
auto ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
|
||||
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags)
|
||||
{
|
||||
if (filename != NULL)
|
||||
{
|
||||
char *replacements[] = {
|
||||
"librocm-core.so",
|
||||
"librocm-core.so.1",
|
||||
"librocm_smi64.so",
|
||||
"librocm_smi64.so.7",
|
||||
"libhsa-runtime64.so",
|
||||
"libhsa-runtime64.so.1",
|
||||
"libhsa-amd-aqlprofile64.so",
|
||||
"libhsa-amd-aqlprofile64.so.1",
|
||||
"libamd_comgr.so",
|
||||
"libamd_comgr.so.2",
|
||||
"librocprofiler-register.so",
|
||||
"librocprofiler-register.so.0",
|
||||
"libMIOpen.so",
|
||||
"libMIOpen.so.1",
|
||||
"librccl.so",
|
||||
"librccl.so.1",
|
||||
"librocblas.so",
|
||||
"librocblas.so.4",
|
||||
"libroctracer64.so",
|
||||
"libroctracer64.so.4",
|
||||
"libroctx64.so",
|
||||
"libroctx64.so.4",
|
||||
"libhipblaslt.so",
|
||||
"libhipblaslt.so.0",
|
||||
"libamdhip64.so",
|
||||
"libamdhip64.so.6",
|
||||
"libhiprtc.so",
|
||||
"libhiprtc.so.6",
|
||||
NULL,
|
||||
NULL,
|
||||
};
|
||||
for (int i = 0; replacements[i] != NULL; i += 2)
|
||||
{
|
||||
if (strcmp(filename, replacements[i]) == 0)
|
||||
{
|
||||
filename = replacements[i + 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dlopen(filename, flags);
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user