runtimes/rocm: sandbox ROCm dependencies and ensure they load on the main thread due to TLS usage in static C++ destructors.

This commit is contained in:
Tarry Singh 2025-04-14 16:38:15 +00:00
parent eba0e72532
commit 7d9fdf94e7
12 changed files with 4155 additions and 514 deletions

View File

@ -79,9 +79,6 @@ pip.parse(
) )
use_repo(pip, "neuron_py_deps") use_repo(pip, "neuron_py_deps")
common_apt_packages = use_extension("//runtimes/common:packages.bzl", "common_apt_packages")
use_repo(common_apt_packages, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin") cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64") use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
@ -89,10 +86,7 @@ cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
use_repo(cuda, "libpjrt_cuda") use_repo(cuda, "libpjrt_cuda")
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages") rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
use_repo(rocm, "libpjrt_rocm")
inject_repo(rocm, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
use_repo(rocm, "hipblaslt", "libpjrt_rocm", "rocblas")
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages") tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
use_repo(tpu, "libpjrt_tpu") use_repo(tpu, "libpjrt_tpu")
@ -150,12 +144,6 @@ non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_modul
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig") use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt") apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt")
apt.install(
name = "apt_common",
lock = "//runtimes/common:packages.lock.json",
manifest = "//runtimes/common:packages.yaml",
)
use_repo(apt, "apt_common")
apt.install( apt.install(
name = "apt_cuda", name = "apt_cuda",
lock = "//runtimes/cuda:packages.lock.json", lock = "//runtimes/cuda:packages.lock.json",

View File

@ -1,6 +0,0 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
exports_files(
["packages.lock.json"],
visibility = ["//runtimes:__subpackages__"],
)

View File

@ -27,6 +27,9 @@ cc_import(
def _filegroup(**kwargs): def _filegroup(**kwargs):
return """filegroup({})""".format(_kwargs(**kwargs)) return """filegroup({})""".format(_kwargs(**kwargs))
def _patchelf(**kwargs):
return """patchelf({})""".format(_kwargs(**kwargs))
def _load(bzl, name): def _load(bzl, name):
return """load({}, {})""".format(repr(bzl), repr(name)) return """load({}, {})""".format(repr(bzl), repr(name))
@ -40,36 +43,6 @@ def _read(mctx, labels):
}) })
return ret return ret
_DEBIAN_PACKAGES = {
"libdrm2": _cc_import(name = "libdrm2", shared_library = "usr/lib/x86_64-linux-gnu/libdrm.so.2"),
"libelf1": _cc_import(name = "libelf1", shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1"),
"libnuma1": _cc_import(name = "libnuma1", shared_library = "usr/lib/x86_64-linux-gnu/libnuma.so.1"),
"libzstd1": _cc_import(name = "libzstd1", shared_library = "usr/lib/x86_64-linux-gnu/libzstd.so.1"),
"libdrm-amdgpu1": _cc_import(name = "libdrm-amdgpu1", shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1"),
"libtinfo6": _cc_import(name = "libtinfo6", shared_library = "lib/x86_64-linux-gnu/libtinfo.so.6"),
"zlib1g": _cc_import(name = "zlib1g", shared_library = "lib/x86_64-linux-gnu/libz.so.1"),
}
def _common_apt_packages_impl(mctx):
loaded_packages = packages.read(mctx, ["packages.lock.json"])
for pkg_name, build_file_content in _DEBIAN_PACKAGES.items():
pkg = loaded_packages[pkg_name]
http_deb_archive(
name = pkg_name,
urls = pkg["urls"],
sha256 = pkg["sha256"],
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
common_apt_packages = module_extension(
implementation = _common_apt_packages_impl,
)
packages = struct( packages = struct(
read = _read, read = _read,
cc_import = _cc_import, cc_import = _cc_import,
@ -77,4 +50,5 @@ packages = struct(
cc_library = _cc_library, cc_library = _cc_library,
filegroup = _filegroup, filegroup = _filegroup,
load_ = _load, load_ = _load,
patchelf = _patchelf,
) )

View File

@ -1,258 +0,0 @@
{
"packages": [
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
},
{
"key": "libdrm-common_2.4.114-1_amd64",
"name": "libdrm-common",
"version": "2.4.114-1"
}
],
"key": "libdrm2_2.4.114-1-p-b1_amd64",
"name": "libdrm2",
"sha256": "be18fb670797ba32da9628cf3e8acd83160d8db8c8dd842501dd8e401c3b5371",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm2_2.4.114-1+b1_amd64.deb"
],
"version": "2.4.114-1+b1"
},
{
"arch": "amd64",
"dependencies": [],
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb"
],
"version": "2.36-9+deb12u10"
},
{
"arch": "amd64",
"dependencies": [],
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb"
],
"version": "12.2.0-14+deb12u1"
},
{
"arch": "amd64",
"dependencies": [],
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb"
],
"version": "12.2.0-14+deb12u1"
},
{
"arch": "amd64",
"dependencies": [],
"key": "libdrm-common_2.4.114-1_amd64",
"name": "libdrm-common",
"sha256": "32f9664138b38b224383c6986457d5ad2ec8efd559b1a0ce7749405f7a451aad",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm-common_2.4.114-1_all.deb"
],
"version": "2.4.114-1"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
"name": "zlib1g",
"version": "1:1.2.13.dfsg-1"
},
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "libelf1_0.188-2.1_amd64",
"name": "libelf1",
"sha256": "619add379c606b3ac6c1a175853b918e6939598a83d8ebadf3bdfd50d10b3c8c",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/e/elfutils/libelf1_0.188-2.1_amd64.deb"
],
"version": "0.188-2.1"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
"name": "zlib1g",
"sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb"
],
"version": "1:1.2.13.dfsg-1"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "libnuma1_2.0.16-1_amd64",
"name": "libnuma1",
"sha256": "639e1ab6bd66ead40db8a22c332d7199679fa22db261cac34444eb8eb4c17dda",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/n/numactl/libnuma1_2.0.16-1_amd64.deb"
],
"version": "2.0.16-1"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "libzstd1_1.5.4-p-dfsg2-5_amd64",
"name": "libzstd1",
"sha256": "6315b5ac38b724a710fb96bf1042019398cb656718b1522279a5185ed39318fa",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libz/libzstd/libzstd1_1.5.4+dfsg2-5_amd64.deb"
],
"version": "1.5.4+dfsg2-5"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "libdrm2_2.4.114-1-p-b1_amd64",
"name": "libdrm2",
"version": "2.4.114-1+b1"
},
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
},
{
"key": "libdrm-common_2.4.114-1_amd64",
"name": "libdrm-common",
"version": "2.4.114-1"
}
],
"key": "libdrm-amdgpu1_2.4.114-1-p-b1_amd64",
"name": "libdrm-amdgpu1",
"sha256": "b75a71e96f1faac0f131ac657e09efcbe8968eef62cc34b8abfcff2ff9f0cccd",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm-amdgpu1_2.4.114-1+b1_amd64.deb"
],
"version": "2.4.114-1+b1"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "libc6_2.36-9-p-deb12u10_amd64",
"name": "libc6",
"version": "2.36-9+deb12u10"
},
{
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
"name": "libgcc-s1",
"version": "12.2.0-14+deb12u1"
},
{
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
"name": "gcc-12-base",
"version": "12.2.0-14+deb12u1"
}
],
"key": "libtinfo6_6.4-4_amd64",
"name": "libtinfo6",
"sha256": "072d908f38f51090ca28ca5afa3b46b2957dc61fe35094c0b851426859a49a51",
"urls": [
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/n/ncurses/libtinfo6_6.4-4_amd64.deb"
],
"version": "6.4-4"
}
],
"version": 1
}

View File

@ -1,20 +0,0 @@
#
# bazel run @rocm_apt//:lock
#
version: 1
sources:
- channel: bookworm main
url: https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/
archs:
- "amd64"
packages:
- "libdrm2"
- "libelf1"
- "libnuma1"
- "libzstd1"
- "libdrm-amdgpu1"
- "libtinfo6"
- "zlib1g"

View File

@ -8,11 +8,6 @@ cc_shared_library(
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"], deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
) )
cc_import(
name = "zmlxcuda",
shared_library = ":zmlxcuda_so",
)
patchelf( patchelf(
name = "libpjrt_cuda.patchelf", name = "libpjrt_cuda.patchelf",
shared_library = "libpjrt_cuda.so", shared_library = "libpjrt_cuda.so",

View File

@ -7,17 +7,6 @@ cc_library(
"-lc", "-lc",
"-ldl", "-ldl",
], ],
)
cc_shared_library(
name = "zmlxrocm_so",
shared_lib_name = "libzmlxrocm.so.0",
deps = [":zmlxrocm_lib"],
)
cc_import(
name = "zmlxrocm",
shared_library = ":zmlxrocm_so",
visibility = ["@libpjrt_rocm//:__subpackages__"], visibility = ["@libpjrt_rocm//:__subpackages__"],
) )

View File

@ -1,5 +1,6 @@
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory") load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
load("@zml//bazel:patchelf.bzl", "patchelf")
load("@zml//bazel:cc_import.bzl", "cc_import") load("@zml//bazel:cc_import.bzl", "cc_import")
string_list_flag( string_list_flag(
@ -19,65 +20,58 @@ config_setting(
flag_values = {":hipblaslt": "True"}, flag_values = {":hipblaslt": "True"},
) )
copy_to_directory(
name = "sandbox",
srcs = [
"@rocm-device-libs//:runfiles",
"@rocm-llvm//:lld",
],
include_external_repositories = ["*"],
)
cc_library(
name = "zmlxrocm_lib",
data = ["@rocblas//:runfiles"],
srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"],
linkopts = [
"-lc",
"-ldl",
],
)
cc_shared_library( cc_shared_library(
name = "zmlxrocm_so", name = "zmlxrocm_so",
shared_lib_name = "libzmlxrocm.so.0", shared_lib_name = "lib/libzmlxrocm.so.0",
deps = [":zmlxrocm_lib"], deps = ["@zml//runtimes/rocm:zmlxrocm_lib"],
) )
cc_import( patchelf(
name = "zmlxrocm", name = "libpjrt_rocm.patchelf",
shared_library = ":zmlxrocm_so", shared_library = "libpjrt_rocm.so",
) add_needed = [
"libzmlxrocm.so.0",
cc_import( # So that RPATH is taken into account.
name = "libpjrt_rocm", "librocblas.so.4",
data = [ "libMIOpen.so.1",
":sandbox",
"@rocblas//:runfiles",
] + select({ ] + select({
":_hipblaslt": ["@hipblaslt//:runfiles"], "_hipblaslt": [
"libhipblaslt.so.0",
],
"//conditions:default": [], "//conditions:default": [],
}), }),
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = { rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen", "dlopen": "zmlxrocm_dlopen",
}, },
shared_library = "libpjrt_rocm.so", set_rpath = "$ORIGIN",
soname = "libpjrt_rocm.so", )
visibility = ["//visibility:public"],
deps = [ copy_to_directory(
name = "sandbox",
srcs = [
":zmlxrocm_so",
":libpjrt_rocm.patchelf",
"@comgr//:amd_comgr", "@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip", "@hip-runtime-amd//:amdhip",
"@hipblaslt", "@hip-runtime-amd//:hiprtc",
"@hipblaslt//:hipblaslt",
"@hipfft",
"@hipsolver",
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile", "@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
"@hsa-rocr//:hsa-runtime", "@hsa-rocr//:hsa-runtime",
"@miopen-hip//:MIOpen", "@miopen-hip//:MIOpen",
"@rccl", "@rccl",
"@rocblas", "@rocblas//:rocblas",
"@rocblas//:runfiles",
"@rocm-core", "@rocm-core",
"@rocm-device-libs//:runfiles",
"@rocm-llvm//:lld",
"@rocm-smi-lib//:rocm_smi", "@rocm-smi-lib//:rocm_smi",
"@rocprofiler-register", "@rocprofiler-register",
"@rocfft",
"@rocsolver",
"@roctracer", "@roctracer",
"@roctracer//:roctx",
"@libelf1", "@libelf1",
"@libdrm2", "@libdrm2",
"@libnuma1", "@libnuma1",
@ -85,6 +79,35 @@ cc_import(
"@libdrm-amdgpu1", "@libdrm-amdgpu1",
"@libtinfo6", "@libtinfo6",
"@zlib1g", "@zlib1g",
"@zml//runtimes/rocm:zmlxrocm",
], # lld dependencies
"@libxml2",
"@libicu70//:libicuuc70",
"@libicu70//:libicudata70",
"@liblzma5",
] + select({
":_hipblaslt": ["@hipblaslt//:runfiles"],
"//conditions:default": [],
}),
replace_prefixes = {
"libpjrt_rocm.patchelf": "lib",
"lib/x86_64-linux-gnu": "lib",
"usr/lib/x86_64-linux-gnu": "lib",
"libdrm-amdgpu1": "lib",
"libelf1": "lib",
"hipblaslt": "lib",
"rocblas": "lib",
"libxml2": "lib",
"libicuuc70": "lib",
"liblzma5": "lib",
"lld": "llvm/bin",
},
add_directory_to_runfiles = True,
include_external_repositories = ["**"],
)
cc_library(
name = "libpjrt_rocm",
data = [":sandbox"],
visibility = ["@zml//runtimes/rocm:__subpackages__"],
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,24 @@
# #
# bazel run @rocm_apt//:lock # bazel run @apt_rocm//:lock
# #
version: 1 version: 1
sources: sources:
- channel: jammy main - channel: jammy main
url: https://repo.radeon.com/rocm/apt/6.3.4/ url: https://repo.radeon.com/rocm/apt/6.3.4/
# - channel: bookworm main - channel: jammy main
# url: https://snapshot-cloudflare.debian.org/archive/debian/20241127T143620Z/ url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
- channel: jammy-security main
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
- channel: jammy-updates main
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
archs: archs:
- "amd64" - "amd64"
# readelf -d libpjrt_rosm.so | grep NEEDED
packages: packages:
# - "libdrm2" # - "rocm-smi-lib"
# - "libelf1"
# - "libnuma1"
# - "libzstd1"
# - "libdrm-amdgpu1"
# - "libtinfo6"
# - "zlib1g"
- "rocm-core"
- "rocm-smi-lib"
- "hsa-rocr" - "hsa-rocr"
- "hsa-amd-aqlprofile" - "hsa-amd-aqlprofile"
- "comgr" - "comgr"
@ -31,8 +28,13 @@ packages:
- "rocm-device-libs" - "rocm-device-libs"
- "hip-dev" - "hip-dev"
- "rocblas" - "rocblas"
- "roctracer" - "rocsolver"
- "hipsolver"
- "hipfft"
# - "roctracer"
- "hipblaslt" - "hipblaslt"
- "hipblaslt-dev" # - "hipblaslt-dev"
- "hip-runtime-amd" - "hip-runtime-amd"
- "rocm-llvm" # - "rocm-llvm"
# rocm-llvm > ld.ldd missing dependency
- "libxml2"

View File

@ -8,46 +8,66 @@ package(default_visibility = ["//visibility:public"])
_ROCM_STRIP_PREFIX = "opt/rocm-6.3.4" _ROCM_STRIP_PREFIX = "opt/rocm-6.3.4"
# def _kwargs(**kwargs): _UBUNTU_PACKAGES = {
# return repr(struct(**kwargs))[len("struct("):-1] "libdrm2": packages.filegroup(name = "libdrm2", srcs = ["usr/lib/x86_64-linux-gnu/libdrm.so.2"]),
"libelf1": "\n".join([
# def packages.cc_import(**kwargs): packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
# return """cc_import({})""".format(_kwargs(**kwargs)) packages.patchelf(
name = "libelf1",
# def packages.filegroup(**kwargs): shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1",
# return """filegroup({})""".format(_kwargs(**kwargs)) set_rpath = '$ORIGIN',
),
# def packages.load_(bzl, name): ]),
# return """load({}, {})""".format(repr(bzl), repr(name)) "libnuma1": packages.filegroup(name = "libnuma1", srcs = ["usr/lib/x86_64-linux-gnu/libnuma.so.1"]),
"libzstd1": packages.filegroup(name = "libzstd1", srcs = ["usr/lib/x86_64-linux-gnu/libzstd.so.1"]),
# _UBUNTU_PACKAGES = { "libdrm-amdgpu1": "\n".join([
# "libdrm2": packages.cc_import(name = "libdrm2", shared_library = "usr/lib/x86_64-linux-gnu/libdrm.so.2"), packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
# "libelf1": packages.cc_import(name = "libelf1", shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1"), packages.patchelf(
# "libnuma1": packages.cc_import(name = "libnuma1", shared_library = "usr/lib/x86_64-linux-gnu/libnuma.so.1"), name = "libdrm-amdgpu1",
# "libzstd1": packages.cc_import(name = "libzstd1", shared_library = "usr/lib/x86_64-linux-gnu/libzstd.so.1"), shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1",
# "libdrm-amdgpu1": packages.cc_import(name = "libdrm-amdgpu1", shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1"), set_rpath = '$ORIGIN',
# "libtinfo6": packages.cc_import(name = "libtinfo6", shared_library = "lib/x86_64-linux-gnu/libtinfo.so.6"), ),
# "zlib1g": packages.cc_import(name = "zlib1g", shared_library = "lib/x86_64-linux-gnu/libz.so.1"), ]),
# } "libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]),
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
"liblzma5": packages.filegroup(name = "liblzma5", srcs = ["lib/x86_64-linux-gnu/liblzma.so.5"]),
"libxml2": "\n".join([
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.patchelf(
name = "libxml2",
shared_library = "usr/lib/x86_64-linux-gnu/libxml2.so.2",
set_rpath = '$ORIGIN',
),
]),
"libicu70": "\n".join([
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.patchelf(
name = "libicuuc70",
shared_library = "usr/lib/x86_64-linux-gnu/libicuuc.so.70",
set_rpath = '$ORIGIN',
),
packages.filegroup(name = "libicudata70", srcs = ["usr/lib/x86_64-linux-gnu/libicudata.so.70"])
]),
}
_ROCM_PACKAGES = { _ROCM_PACKAGES = {
"rocm-core": packages.cc_import(name = "rocm-core", shared_library = "lib/librocm-core.so.1"), "rocm-core": packages.filegroup(name = "rocm-core", srcs = ["lib/librocm-core.so.1"]),
"rocm-smi-lib": packages.cc_import(name = "rocm_smi", shared_library = "lib/librocm_smi64.so.7"), "rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]),
"hsa-rocr": packages.cc_import(name = "hsa-runtime", shared_library = "lib/libhsa-runtime64.so.1"), "hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]),
"hsa-amd-aqlprofile": packages.cc_import(name = "hsa-amd-aqlprofile", shared_library = "lib/libhsa-amd-aqlprofile64.so.1"), "hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]),
"comgr": packages.cc_import(name = "amd_comgr", shared_library = "lib/libamd_comgr.so.2"), "comgr": packages.filegroup(name = "amd_comgr", srcs = ["lib/libamd_comgr.so.2"]),
"rocprofiler-register": packages.cc_import(name = "rocprofiler-register", shared_library = "lib/librocprofiler-register.so.0"), "rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]),
"miopen-hip": "\n".join([ "miopen-hip": "\n".join([
packages.cc_import(name = "MIOpen", shared_library = "lib/libMIOpen.so.1"), packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]),
"""filegroup(name = "runfiles", srcs = glob(["share/miopen/**"]))""", """filegroup(name = "runfiles", srcs = glob(["share/miopen/**"]))""",
]), ]),
"rccl": packages.cc_import(name = "rccl", shared_library = "lib/librccl.so.1"), "rccl": packages.filegroup(name = "rccl", srcs = ["lib/librccl.so.1"]),
"rocm-device-libs": """filegroup(name = "runfiles", srcs = glob(["amdgcn/**"]))""", "rocm-device-libs": """filegroup(name = "runfiles", srcs = glob(["amdgcn/**"]))""",
"hip-dev": """filegroup(name = "runfiles", srcs = glob(["share/**"]))""", "hip-dev": """filegroup(name = "runfiles", srcs = glob(["share/**"]))""",
"rocblas": "\n".join([ "rocblas": "\n".join([
packages.load_("@zml//bazel:cc_import.bzl", "cc_import"), packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"), packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
packages.cc_import( packages.patchelf(
name = "rocblas", name = "rocblas",
shared_library = "lib/librocblas.so.4", shared_library = "lib/librocblas.so.4",
add_needed = ["libzmlxrocm.so.0"], add_needed = ["libzmlxrocm.so.0"],
@ -69,14 +89,16 @@ _ROCM_PACKAGES = {
], ],
), ),
]), ]),
"rocfft": packages.filegroup(name = "rocfft", srcs = ["lib/librocfft.so.0"]),
"rocsolver": packages.filegroup(name = "rocsolver", srcs = ["lib/librocsolver.so.0"]),
"roctracer": "\n".join([ "roctracer": "\n".join([
packages.cc_import(name = "roctracer", shared_library = "lib/libroctracer64.so.4", deps = [":roctx"]), packages.filegroup(name = "roctracer", srcs = ["lib/libroctracer64.so.4"]),
packages.cc_import(name = "roctx", shared_library = "lib/libroctx64.so.4"), packages.filegroup(name = "roctx", srcs = ["lib/libroctx64.so.4"]),
]), ]),
"hipblaslt": "\n".join([ "hipblaslt": "\n".join([
packages.load_("@zml//bazel:cc_import.bzl", "cc_import"), packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"), packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
packages.cc_import( packages.patchelf(
name = "hipblaslt", name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0", shared_library = "lib/libhipblaslt.so.0",
add_needed = ["libzmlxrocm.so.0"], add_needed = ["libzmlxrocm.so.0"],
@ -102,11 +124,21 @@ _ROCM_PACKAGES = {
], ],
), ),
]), ]),
"hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]),
"hip-runtime-amd": "\n".join([ "hip-runtime-amd": "\n".join([
packages.cc_import(name = "amdhip", shared_library = "lib/libamdhip64.so.6", deps = [":hiprtc"]), packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]),
packages.cc_import(name = "hiprtc", shared_library = "lib/libhiprtc.so.6"), packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]),
]),
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
"rocm-llvm": "\n".join([
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.patchelf(
name = "lld",
#TODO: Rename attr to elf_file or file ?
shared_library = "llvm/bin/ld.lld",
set_rpath = '$ORIGIN/../../lib',
),
]), ]),
"rocm-llvm": packages.filegroup(name = "lld", srcs = ["llvm/bin/ld.lld"], visibility = ["//visibility:public"]),
} }
def _rocm_impl(mctx): def _rocm_impl(mctx):
@ -114,6 +146,15 @@ def _rocm_impl(mctx):
"@zml//runtimes/rocm:packages.lock.json", "@zml//runtimes/rocm:packages.lock.json",
]) ])
for pkg_name, build_file_content in _UBUNTU_PACKAGES.items():
pkg = loaded_packages[pkg_name]
http_deb_archive(
name = pkg_name,
urls = pkg["urls"],
sha256 = pkg["sha256"],
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
)
for pkg_name, build_file_content in _ROCM_PACKAGES.items(): for pkg_name, build_file_content in _ROCM_PACKAGES.items():
pkg = loaded_packages[pkg_name] pkg = loaded_packages[pkg_name]
http_deb_archive( http_deb_archive(

View File

@ -8,6 +8,8 @@ const pjrt = @import("pjrt");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const stdx = @import("stdx"); const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/runtime/rocm");
const ROCmEnvEntry = struct { const ROCmEnvEntry = struct {
name: [:0]const u8, name: [:0]const u8,
rpath: []const u8, rpath: []const u8,
@ -16,10 +18,10 @@ const ROCmEnvEntry = struct {
}; };
const rocm_env_entries: []const ROCmEnvEntry = &.{ const rocm_env_entries: []const ROCmEnvEntry = &.{
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false }, .{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false }, .{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "rocblas/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true }, .{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
.{ .name = "ROCM_PATH", .rpath = "libpjrt_rocm/sandbox", .dirname = false, .mandatory = true }, .{ .name = "ROCM_PATH", .rpath = "/", .dirname = false, .mandatory = true },
}; };
pub fn isEnabled() bool { pub fn isEnabled() bool {
@ -33,20 +35,9 @@ fn hasRocmDevices() bool {
return true; return true;
} }
fn setupRocmEnv() !void { fn setupRocmEnv(allocator: std.mem.Allocator, rocm_data_dir: []const u8) !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
const r = blk: {
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find Runfiles directory", .{});
};
const source_repo = bazel_builtin.current_repository;
break :blk r_.withSourceRepo(source_repo);
};
for (rocm_env_entries) |entry| { for (rocm_env_entries) |entry| {
var real_path = r.rlocationAlloc(arena.allocator(), entry.rpath) catch null orelse { var real_path: []const u8 = std.fmt.allocPrintZ(allocator, "{s}/{s}", .{ rocm_data_dir, entry.rpath }) catch null orelse {
if (entry.mandatory) { if (entry.mandatory) {
stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository }); stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository });
} }
@ -59,7 +50,7 @@ fn setupRocmEnv() !void {
}; };
} }
_ = c.setenv(entry.name, try arena.allocator().dupeZ(u8, real_path), 1); _ = c.setenv(entry.name, try allocator.dupeZ(u8, real_path), 1);
} }
} }
@ -74,7 +65,35 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable; return error.Unavailable;
} }
try setupRocmEnv(); var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_rocm.so"}); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find runfiles", .{});
};
const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
const sandbox_path = try r.rlocation("libpjrt_rocm/sandbox", &path_buf) orelse {
log.err("Failed to find sandbox path for ROCm runtime", .{});
return error.FileNotFound;
};
try setupRocmEnv(arena.allocator(), sandbox_path);
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" });
// We must load the PJRT plugin from the main thread.
//
// This is because libamdhip64.so use thread local storage as part of the static destructors...
//
// This destructor accesses a thread-local variable. If the destructor is
// executed in a different thread than the one that originally called dlopen()
// on the library, the thread-local storage (TLS) offset may be resolved
// relative to the TLS base of the main thread, rather than the thread actually
// executing the destructor. Accessing this variable results in a segmentation fault...
return try pjrt.Api.loadFrom(lib_path);
} }