Switch ROCm dlopen handling to patchelf's rename_dynamic_symbols for more robust dynamic symbol import.

This commit is contained in:
Tarry Singh 2023-05-03 17:33:46 +00:00
parent fefd84b1bb
commit cfe38f27ca
8 changed files with 174 additions and 172 deletions

View File

@ -36,21 +36,22 @@ _cc_import_runfiles = rule(
) )
def cc_import( def cc_import(
name, name,
static_library = None, static_library = None,
pic_static_library = None, pic_static_library = None,
shared_library = None, shared_library = None,
interface_library = None, interface_library = None,
data = None, data = None,
deps = None, deps = None,
visibility = None, visibility = None,
soname = None, soname = None,
add_needed = None, add_needed = None,
remove_needed = None, remove_needed = None,
replace_needed = None, replace_needed = None,
**kwargs): rename_dynamic_symbols = None,
if shared_library and (soname or add_needed or remove_needed or replace_needed): **kwargs):
patched_name = "{}_patchelf".format(name) if shared_library and (soname or add_needed or remove_needed or replace_needed or rename_dynamic_symbols):
patched_name = "{}.patchelf".format(name)
patchelf( patchelf(
name = patched_name, name = patched_name,
shared_library = shared_library, shared_library = shared_library,
@ -58,11 +59,12 @@ def cc_import(
add_needed = add_needed, add_needed = add_needed,
remove_needed = remove_needed, remove_needed = remove_needed,
replace_needed = replace_needed, replace_needed = replace_needed,
rename_dynamic_symbols = rename_dynamic_symbols,
) )
shared_library = ":" + patched_name shared_library = ":" + patched_name
if data: if data:
_cc_import( _cc_import(
name = name + "_no_runfiles", name = name + ".norunfiles",
static_library = static_library, static_library = static_library,
pic_static_library = pic_static_library, pic_static_library = pic_static_library,
shared_library = shared_library, shared_library = shared_library,
@ -71,15 +73,10 @@ def cc_import(
deps = deps, deps = deps,
**kwargs **kwargs
) )
_cc_import_runfiles( native.cc_library(
name = name, name = name,
src = ":{}_no_runfiles".format(name),
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data, data = data,
deps = deps, deps = [name + ".norunfiles"],
visibility = visibility, visibility = visibility,
) )
else: else:
@ -89,7 +86,6 @@ def cc_import(
pic_static_library = pic_static_library, pic_static_library = pic_static_library,
shared_library = shared_library, shared_library = shared_library,
interface_library = interface_library, interface_library = interface_library,
data = data,
deps = deps, deps = deps,
visibility = visibility, visibility = visibility,
**kwargs **kwargs

29
bazel/dpkg.bzl Normal file
View File

@ -0,0 +1,29 @@
def _packages_to_dict(txt):
packages = {}
current_pkg = {}
for line in txt.splitlines():
if line == "":
if current_pkg:
pkg_name = current_pkg["Package"]
pkg_version = current_pkg["Version"]
if pkg_name not in packages:
packages[pkg_name] = {}
packages[pkg_name][pkg_version] = struct(**current_pkg)
current_pkg = {}
continue
if line.startswith(" "):
current_pkg[key] += line
continue
split = line.split(": ", 1)
key = split[0]
value = len(split) > 1 and split[1] or ""
current_pkg[key] = value
return packages
def _read_packages(mctx, label):
data = mctx.read(Label(label))
return _packages_to_dict(data)
dpkg = struct(
read_packages = _read_packages,
)

View File

@ -1,6 +1,3 @@
def _render_kv(e):
return e
def _patchelf_impl(ctx): def _patchelf_impl(ctx):
output_name = ctx.file.shared_library.basename output_name = ctx.file.shared_library.basename
if ctx.attr.soname: if ctx.attr.soname:
@ -26,8 +23,19 @@ def _patchelf_impl(ctx):
for k, v in ctx.attr.replace_needed.items(): for k, v in ctx.attr.replace_needed.items():
commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v)) commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v))
renamed_syms = ctx.actions.declare_file("{}.rename.txt".format(ctx.label.name))
if ctx.attr.rename_dynamic_symbols:
content = "\n".join([
"{} {}".format(k, v)
for k, v in ctx.attr.rename_dynamic_symbols.items()
])
ctx.actions.write(renamed_syms, content)
commands.append('"$1" --rename-dynamic-symbols "{}" "$3"'.format(renamed_syms.path))
else:
ctx.actions.write(renamed_syms, "")
ctx.actions.run_shell( ctx.actions.run_shell(
inputs = [ctx.file.shared_library], inputs = [ctx.file.shared_library, renamed_syms],
outputs = [output], outputs = [output],
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path], arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
command = "\n".join(commands), command = "\n".join(commands),
@ -48,6 +56,7 @@ patchelf = rule(
"add_needed": attr.string_list(), "add_needed": attr.string_list(),
"remove_needed": attr.string_list(), "remove_needed": attr.string_list(),
"replace_needed": attr.string_dict(), "replace_needed": attr.string_dict(),
"rename_dynamic_symbols": attr.string_dict(),
"_patchelf": attr.label( "_patchelf": attr.label(
default = "@patchelf", default = "@patchelf",
allow_single_file = True, allow_single_file = True,

View File

@ -1,6 +1,6 @@
filegroup( filegroup(
name = "zmlrocmhooks_srcs", name = "zmlxrocm_srcs",
srcs = ["zmlrocmhooks.cc"], srcs = ["zmlxrocm.cc"],
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )

View File

@ -31,9 +31,9 @@ copy_to_directory(
) )
cc_library( cc_library(
name = "zmlrocmhooks_lib", name = "zmlxrocm_lib",
data = ["@rocblas//:runfiles"], data = ["@rocblas//:runfiles"],
srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"], srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"],
linkopts = [ linkopts = [
"-lc", "-lc",
"-ldl", "-ldl",
@ -42,14 +42,14 @@ cc_library(
) )
cc_shared_library( cc_shared_library(
name = "zmlrocmhooks_so", name = "zmlxrocm_so",
shared_lib_name = "libzmlrocmhooks.so.0", shared_lib_name = "libzmlxrocm.so.0",
deps = [":zmlrocmhooks_lib"], deps = [":zmlxrocm_lib"],
) )
cc_import( cc_import(
name = "zmlrocmhooks", name = "zmlxrocm",
shared_library = ":zmlrocmhooks_so", shared_library = ":zmlxrocm_so",
) )
cc_import( cc_import(
@ -64,7 +64,7 @@ cc_import(
shared_library = "libpjrt_rocm.so", shared_library = "libpjrt_rocm.so",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":zmlrocmhooks", ":zmlxrocm",
"@comgr//:amd_comgr", "@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip", "@hip-runtime-amd//:amdhip",
"@hipblaslt", "@hipblaslt",

View File

@ -1,17 +1,18 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:dpkg.bzl", "dpkg")
load("//bazel:http_deb_archive.bzl", "http_deb_archive") load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ROCM_VERSION = "6.2.2" ROCM_VERSION = "6.2.2"
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION) BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
STRIP_PREFIX = "opt/rocm-6.2.2" STRIP_PREFIX = "opt/rocm-6.2.2"
def pkg_kwargs(pkg, packages): def pkg_kwargs(pkg):
return { return dict(
"name": pkg, name = pkg.Package,
"urls": [BASE_URL + "/" + packages[pkg]["Filename"]], urls = [BASE_URL + "/" + pkg.Filename],
"sha256": packages[pkg]["SHA256"], sha256 = pkg.SHA256,
"strip_prefix": STRIP_PREFIX, strip_prefix = STRIP_PREFIX,
} )
def _ubuntu_package(path, deb_path, sha256, name, shared_library): def _ubuntu_package(path, deb_path, sha256, name, shared_library):
return { return {
@ -118,7 +119,10 @@ load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
cc_import( cc_import(
name = "rocblas", name = "rocblas",
shared_library = "lib/librocblas.so.4", shared_library = "lib/librocblas.so.4",
add_needed = ["libzmlrocmhooks.so.0"], add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )
@ -147,7 +151,10 @@ load("@zml//bazel:cc_import.bzl", "cc_import")
cc_import( cc_import(
name = "hipblaslt", name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0", shared_library = "lib/libhipblaslt.so.0",
add_needed = ["libzmlrocmhooks.so.0"], add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )
""", """,
@ -193,32 +200,14 @@ filegroup(
""", """,
} }
def _packages_to_dict(txt):
packages = {}
current_pkg = {}
for line in txt.splitlines():
if line == "":
if current_pkg:
packages[current_pkg["Package"]] = current_pkg
current_pkg = {}
continue
if line.startswith(" "):
current_pkg[key] += line
continue
split = line.split(": ", 1)
key = split[0]
value = len(split) > 1 and split[1] or ""
current_pkg[key] = value
return packages
def _rocm_impl(mctx): def _rocm_impl(mctx):
data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt")) all_packages = dpkg.read_packages(mctx, "@zml//runtimes/rocm:packages.amd64.txt")
PACKAGES = _packages_to_dict(data)
for pkg, build_file_content in _PACKAGES.items(): for pkg_name, build_file_content in _PACKAGES.items():
pkg = all_packages[pkg_name].values()[0]
http_deb_archive( http_deb_archive(
build_file_content = build_file_content, build_file_content = build_file_content,
**pkg_kwargs(pkg, PACKAGES) **pkg_kwargs(pkg)
) )
for repository, kwargs in _UBUNTU_PACKAGES.items(): for repository, kwargs in _UBUNTU_PACKAGES.items():

View File

@ -1,103 +0,0 @@
#include <string>
#include <iostream>
#include <dlfcn.h>
#include <errno.h>
#include <fstream>
#include <stdlib.h>
#include "tools/cpp/runfiles/runfiles.h"
namespace zml
{
using bazel::tools::cpp::runfiles::Runfiles;
std::unique_ptr<Runfiles> runfiles;
std::string ROCBLAS_TENSILE_LIBPATH;
std::string HIPBLASLT_TENSILE_LIBPATH;
std::string HIPBLASLT_EXT_OP_LIBRARY_PATH;
std::string ROCM_PATH;
typedef void *(*dlopen_func)(const char *filename, int flags);
dlopen_func dlopen_orig = nullptr;
__attribute__((constructor)) static void setup(int argc, char **argv)
{
runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
static void *rocm_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.2",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen_orig(filename, flags);
}
}
extern "C"
{
zml::dlopen_func _zml_rocm_resolve_dlopen()
{
zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen");
return zml::rocm_dlopen;
}
extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen")));
}

82
runtimes/rocm/zmlxrocm.cc Normal file
View File

@ -0,0 +1,82 @@
#include <dlfcn.h>
#include <errno.h>
#include <stdlib.h>
#include <fstream>
#include <iostream>
#include <string>
#include "tools/cpp/runfiles/runfiles.h"
__attribute__((constructor)) static void setup_runfiles(int argc, char **argv)
{
using bazel::tools::cpp::runfiles::Runfiles;
auto runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
auto HIPBLASLT_EXT_OP_LIBRARY_PATH =
runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
auto HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
auto ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
auto ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
extern "C" void *zmlxrocm_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.2",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen(filename, flags);
}