runtimes/rocm: sandbox ROCm dependencies and ensure they load on the main thread due to TLS usage in static C++ destructors.
This commit is contained in:
parent
eba0e72532
commit
7d9fdf94e7
14
MODULE.bazel
14
MODULE.bazel
@ -79,9 +79,6 @@ pip.parse(
|
|||||||
)
|
)
|
||||||
use_repo(pip, "neuron_py_deps")
|
use_repo(pip, "neuron_py_deps")
|
||||||
|
|
||||||
common_apt_packages = use_extension("//runtimes/common:packages.bzl", "common_apt_packages")
|
|
||||||
use_repo(common_apt_packages, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
|
||||||
|
|
||||||
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
||||||
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
||||||
|
|
||||||
@ -89,10 +86,7 @@ cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
|
|||||||
use_repo(cuda, "libpjrt_cuda")
|
use_repo(cuda, "libpjrt_cuda")
|
||||||
|
|
||||||
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
||||||
|
use_repo(rocm, "libpjrt_rocm")
|
||||||
inject_repo(rocm, "libdrm-amdgpu1", "libdrm2", "libelf1", "libnuma1", "libtinfo6", "libzstd1", "zlib1g")
|
|
||||||
|
|
||||||
use_repo(rocm, "hipblaslt", "libpjrt_rocm", "rocblas")
|
|
||||||
|
|
||||||
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
|
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
|
||||||
use_repo(tpu, "libpjrt_tpu")
|
use_repo(tpu, "libpjrt_tpu")
|
||||||
@ -150,12 +144,6 @@ non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_modul
|
|||||||
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
|
use_repo(non_module_deps, "com_google_sentencepiece", "org_swig_swig")
|
||||||
|
|
||||||
apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt")
|
apt = use_extension("@rules_distroless//apt:extensions.bzl", "apt")
|
||||||
apt.install(
|
|
||||||
name = "apt_common",
|
|
||||||
lock = "//runtimes/common:packages.lock.json",
|
|
||||||
manifest = "//runtimes/common:packages.yaml",
|
|
||||||
)
|
|
||||||
use_repo(apt, "apt_common")
|
|
||||||
apt.install(
|
apt.install(
|
||||||
name = "apt_cuda",
|
name = "apt_cuda",
|
||||||
lock = "//runtimes/cuda:packages.lock.json",
|
lock = "//runtimes/cuda:packages.lock.json",
|
||||||
|
|||||||
@ -1,6 +0,0 @@
|
|||||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
|
||||||
|
|
||||||
exports_files(
|
|
||||||
["packages.lock.json"],
|
|
||||||
visibility = ["//runtimes:__subpackages__"],
|
|
||||||
)
|
|
||||||
@ -27,6 +27,9 @@ cc_import(
|
|||||||
def _filegroup(**kwargs):
|
def _filegroup(**kwargs):
|
||||||
return """filegroup({})""".format(_kwargs(**kwargs))
|
return """filegroup({})""".format(_kwargs(**kwargs))
|
||||||
|
|
||||||
|
def _patchelf(**kwargs):
|
||||||
|
return """patchelf({})""".format(_kwargs(**kwargs))
|
||||||
|
|
||||||
def _load(bzl, name):
|
def _load(bzl, name):
|
||||||
return """load({}, {})""".format(repr(bzl), repr(name))
|
return """load({}, {})""".format(repr(bzl), repr(name))
|
||||||
|
|
||||||
@ -40,36 +43,6 @@ def _read(mctx, labels):
|
|||||||
})
|
})
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
_DEBIAN_PACKAGES = {
|
|
||||||
"libdrm2": _cc_import(name = "libdrm2", shared_library = "usr/lib/x86_64-linux-gnu/libdrm.so.2"),
|
|
||||||
"libelf1": _cc_import(name = "libelf1", shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1"),
|
|
||||||
"libnuma1": _cc_import(name = "libnuma1", shared_library = "usr/lib/x86_64-linux-gnu/libnuma.so.1"),
|
|
||||||
"libzstd1": _cc_import(name = "libzstd1", shared_library = "usr/lib/x86_64-linux-gnu/libzstd.so.1"),
|
|
||||||
"libdrm-amdgpu1": _cc_import(name = "libdrm-amdgpu1", shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1"),
|
|
||||||
"libtinfo6": _cc_import(name = "libtinfo6", shared_library = "lib/x86_64-linux-gnu/libtinfo.so.6"),
|
|
||||||
"zlib1g": _cc_import(name = "zlib1g", shared_library = "lib/x86_64-linux-gnu/libz.so.1"),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _common_apt_packages_impl(mctx):
|
|
||||||
loaded_packages = packages.read(mctx, ["packages.lock.json"])
|
|
||||||
for pkg_name, build_file_content in _DEBIAN_PACKAGES.items():
|
|
||||||
pkg = loaded_packages[pkg_name]
|
|
||||||
http_deb_archive(
|
|
||||||
name = pkg_name,
|
|
||||||
urls = pkg["urls"],
|
|
||||||
sha256 = pkg["sha256"],
|
|
||||||
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
|
||||||
)
|
|
||||||
return mctx.extension_metadata(
|
|
||||||
reproducible = True,
|
|
||||||
root_module_direct_deps = "all",
|
|
||||||
root_module_direct_dev_deps = [],
|
|
||||||
)
|
|
||||||
|
|
||||||
common_apt_packages = module_extension(
|
|
||||||
implementation = _common_apt_packages_impl,
|
|
||||||
)
|
|
||||||
|
|
||||||
packages = struct(
|
packages = struct(
|
||||||
read = _read,
|
read = _read,
|
||||||
cc_import = _cc_import,
|
cc_import = _cc_import,
|
||||||
@ -77,4 +50,5 @@ packages = struct(
|
|||||||
cc_library = _cc_library,
|
cc_library = _cc_library,
|
||||||
filegroup = _filegroup,
|
filegroup = _filegroup,
|
||||||
load_ = _load,
|
load_ = _load,
|
||||||
|
patchelf = _patchelf,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,258 +0,0 @@
|
|||||||
{
|
|
||||||
"packages": [
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libdrm-common_2.4.114-1_amd64",
|
|
||||||
"name": "libdrm-common",
|
|
||||||
"version": "2.4.114-1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libdrm2_2.4.114-1-p-b1_amd64",
|
|
||||||
"name": "libdrm2",
|
|
||||||
"sha256": "be18fb670797ba32da9628cf3e8acd83160d8db8c8dd842501dd8e401c3b5371",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm2_2.4.114-1+b1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "2.4.114-1+b1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [],
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"sha256": "5dc83256f10ca4d0f2a53dd6583ffd0d0e319af30074ea6c82fb0ae77bd16365",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/glibc/libc6_2.36-9+deb12u10_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [],
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"sha256": "3016e62cb4b7cd8038822870601f5ed131befe942774d0f745622cc77d8a88f7",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/gcc-12/libgcc-s1_12.2.0-14+deb12u1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [],
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"sha256": "1896a2aacf4ad681ff5eacc24a5b0ca4d5d9c9b9c9e4b6de5197bc1e116ea619",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/g/gcc-12/gcc-12-base_12.2.0-14+deb12u1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [],
|
|
||||||
"key": "libdrm-common_2.4.114-1_amd64",
|
|
||||||
"name": "libdrm-common",
|
|
||||||
"sha256": "32f9664138b38b224383c6986457d5ad2ec8efd559b1a0ce7749405f7a451aad",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm-common_2.4.114-1_all.deb"
|
|
||||||
],
|
|
||||||
"version": "2.4.114-1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
|
|
||||||
"name": "zlib1g",
|
|
||||||
"version": "1:1.2.13.dfsg-1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libelf1_0.188-2.1_amd64",
|
|
||||||
"name": "libelf1",
|
|
||||||
"sha256": "619add379c606b3ac6c1a175853b918e6939598a83d8ebadf3bdfd50d10b3c8c",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/e/elfutils/libelf1_0.188-2.1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "0.188-2.1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "zlib1g_1-1.2.13.dfsg-1_amd64",
|
|
||||||
"name": "zlib1g",
|
|
||||||
"sha256": "d7dd1d1411fedf27f5e27650a6eff20ef294077b568f4c8c5e51466dc7c08ce4",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/z/zlib/zlib1g_1.2.13.dfsg-1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "1:1.2.13.dfsg-1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libnuma1_2.0.16-1_amd64",
|
|
||||||
"name": "libnuma1",
|
|
||||||
"sha256": "639e1ab6bd66ead40db8a22c332d7199679fa22db261cac34444eb8eb4c17dda",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/n/numactl/libnuma1_2.0.16-1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "2.0.16-1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libzstd1_1.5.4-p-dfsg2-5_amd64",
|
|
||||||
"name": "libzstd1",
|
|
||||||
"sha256": "6315b5ac38b724a710fb96bf1042019398cb656718b1522279a5185ed39318fa",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libz/libzstd/libzstd1_1.5.4+dfsg2-5_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "1.5.4+dfsg2-5"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libdrm2_2.4.114-1-p-b1_amd64",
|
|
||||||
"name": "libdrm2",
|
|
||||||
"version": "2.4.114-1+b1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libdrm-common_2.4.114-1_amd64",
|
|
||||||
"name": "libdrm-common",
|
|
||||||
"version": "2.4.114-1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libdrm-amdgpu1_2.4.114-1-p-b1_amd64",
|
|
||||||
"name": "libdrm-amdgpu1",
|
|
||||||
"sha256": "b75a71e96f1faac0f131ac657e09efcbe8968eef62cc34b8abfcff2ff9f0cccd",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/libd/libdrm/libdrm-amdgpu1_2.4.114-1+b1_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "2.4.114-1+b1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"arch": "amd64",
|
|
||||||
"dependencies": [
|
|
||||||
{
|
|
||||||
"key": "libc6_2.36-9-p-deb12u10_amd64",
|
|
||||||
"name": "libc6",
|
|
||||||
"version": "2.36-9+deb12u10"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "libgcc-s1_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "libgcc-s1",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"key": "gcc-12-base_12.2.0-14-p-deb12u1_amd64",
|
|
||||||
"name": "gcc-12-base",
|
|
||||||
"version": "12.2.0-14+deb12u1"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"key": "libtinfo6_6.4-4_amd64",
|
|
||||||
"name": "libtinfo6",
|
|
||||||
"sha256": "072d908f38f51090ca28ca5afa3b46b2957dc61fe35094c0b851426859a49a51",
|
|
||||||
"urls": [
|
|
||||||
"https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/pool/main/n/ncurses/libtinfo6_6.4-4_amd64.deb"
|
|
||||||
],
|
|
||||||
"version": "6.4-4"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"version": 1
|
|
||||||
}
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
#
|
|
||||||
# bazel run @rocm_apt//:lock
|
|
||||||
#
|
|
||||||
version: 1
|
|
||||||
|
|
||||||
sources:
|
|
||||||
- channel: bookworm main
|
|
||||||
url: https://snapshot-cloudflare.debian.org/archive/debian/20250529T205323Z/
|
|
||||||
|
|
||||||
archs:
|
|
||||||
- "amd64"
|
|
||||||
|
|
||||||
packages:
|
|
||||||
- "libdrm2"
|
|
||||||
- "libelf1"
|
|
||||||
- "libnuma1"
|
|
||||||
- "libzstd1"
|
|
||||||
- "libdrm-amdgpu1"
|
|
||||||
- "libtinfo6"
|
|
||||||
- "zlib1g"
|
|
||||||
@ -8,11 +8,6 @@ cc_shared_library(
|
|||||||
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
|
deps = ["@zml//runtimes/cuda:zmlxcuda_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_import(
|
|
||||||
name = "zmlxcuda",
|
|
||||||
shared_library = ":zmlxcuda_so",
|
|
||||||
)
|
|
||||||
|
|
||||||
patchelf(
|
patchelf(
|
||||||
name = "libpjrt_cuda.patchelf",
|
name = "libpjrt_cuda.patchelf",
|
||||||
shared_library = "libpjrt_cuda.so",
|
shared_library = "libpjrt_cuda.so",
|
||||||
|
|||||||
@ -7,17 +7,6 @@ cc_library(
|
|||||||
"-lc",
|
"-lc",
|
||||||
"-ldl",
|
"-ldl",
|
||||||
],
|
],
|
||||||
)
|
|
||||||
|
|
||||||
cc_shared_library(
|
|
||||||
name = "zmlxrocm_so",
|
|
||||||
shared_lib_name = "libzmlxrocm.so.0",
|
|
||||||
deps = [":zmlxrocm_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_import(
|
|
||||||
name = "zmlxrocm",
|
|
||||||
shared_library = ":zmlxrocm_so",
|
|
||||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
||||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
|
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
|
||||||
|
load("@zml//bazel:patchelf.bzl", "patchelf")
|
||||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||||
|
|
||||||
string_list_flag(
|
string_list_flag(
|
||||||
@ -19,65 +20,58 @@ config_setting(
|
|||||||
flag_values = {":hipblaslt": "True"},
|
flag_values = {":hipblaslt": "True"},
|
||||||
)
|
)
|
||||||
|
|
||||||
copy_to_directory(
|
|
||||||
name = "sandbox",
|
|
||||||
srcs = [
|
|
||||||
"@rocm-device-libs//:runfiles",
|
|
||||||
"@rocm-llvm//:lld",
|
|
||||||
],
|
|
||||||
include_external_repositories = ["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "zmlxrocm_lib",
|
|
||||||
data = ["@rocblas//:runfiles"],
|
|
||||||
srcs = ["@zml//runtimes/rocm:zmlxrocm_srcs"],
|
|
||||||
linkopts = [
|
|
||||||
"-lc",
|
|
||||||
"-ldl",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_shared_library(
|
cc_shared_library(
|
||||||
name = "zmlxrocm_so",
|
name = "zmlxrocm_so",
|
||||||
shared_lib_name = "libzmlxrocm.so.0",
|
shared_lib_name = "lib/libzmlxrocm.so.0",
|
||||||
deps = [":zmlxrocm_lib"],
|
deps = ["@zml//runtimes/rocm:zmlxrocm_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_import(
|
patchelf(
|
||||||
name = "zmlxrocm",
|
name = "libpjrt_rocm.patchelf",
|
||||||
shared_library = ":zmlxrocm_so",
|
shared_library = "libpjrt_rocm.so",
|
||||||
)
|
add_needed = [
|
||||||
|
"libzmlxrocm.so.0",
|
||||||
cc_import(
|
# So that RPATH is taken into account.
|
||||||
name = "libpjrt_rocm",
|
"librocblas.so.4",
|
||||||
data = [
|
"libMIOpen.so.1",
|
||||||
":sandbox",
|
|
||||||
"@rocblas//:runfiles",
|
|
||||||
] + select({
|
] + select({
|
||||||
":_hipblaslt": ["@hipblaslt//:runfiles"],
|
"_hipblaslt": [
|
||||||
|
"libhipblaslt.so.0",
|
||||||
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
add_needed = ["libzmlxrocm.so.0"],
|
|
||||||
rename_dynamic_symbols = {
|
rename_dynamic_symbols = {
|
||||||
"dlopen": "zmlxrocm_dlopen",
|
"dlopen": "zmlxrocm_dlopen",
|
||||||
},
|
},
|
||||||
shared_library = "libpjrt_rocm.so",
|
set_rpath = "$ORIGIN",
|
||||||
soname = "libpjrt_rocm.so",
|
)
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
copy_to_directory(
|
||||||
|
name = "sandbox",
|
||||||
|
srcs = [
|
||||||
|
":zmlxrocm_so",
|
||||||
|
":libpjrt_rocm.patchelf",
|
||||||
"@comgr//:amd_comgr",
|
"@comgr//:amd_comgr",
|
||||||
"@hip-runtime-amd//:amdhip",
|
"@hip-runtime-amd//:amdhip",
|
||||||
"@hipblaslt",
|
"@hip-runtime-amd//:hiprtc",
|
||||||
|
"@hipblaslt//:hipblaslt",
|
||||||
|
"@hipfft",
|
||||||
|
"@hipsolver",
|
||||||
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
|
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
|
||||||
"@hsa-rocr//:hsa-runtime",
|
"@hsa-rocr//:hsa-runtime",
|
||||||
"@miopen-hip//:MIOpen",
|
"@miopen-hip//:MIOpen",
|
||||||
"@rccl",
|
"@rccl",
|
||||||
"@rocblas",
|
"@rocblas//:rocblas",
|
||||||
|
"@rocblas//:runfiles",
|
||||||
"@rocm-core",
|
"@rocm-core",
|
||||||
|
"@rocm-device-libs//:runfiles",
|
||||||
|
"@rocm-llvm//:lld",
|
||||||
"@rocm-smi-lib//:rocm_smi",
|
"@rocm-smi-lib//:rocm_smi",
|
||||||
"@rocprofiler-register",
|
"@rocprofiler-register",
|
||||||
|
"@rocfft",
|
||||||
|
"@rocsolver",
|
||||||
"@roctracer",
|
"@roctracer",
|
||||||
|
"@roctracer//:roctx",
|
||||||
"@libelf1",
|
"@libelf1",
|
||||||
"@libdrm2",
|
"@libdrm2",
|
||||||
"@libnuma1",
|
"@libnuma1",
|
||||||
@ -85,6 +79,35 @@ cc_import(
|
|||||||
"@libdrm-amdgpu1",
|
"@libdrm-amdgpu1",
|
||||||
"@libtinfo6",
|
"@libtinfo6",
|
||||||
"@zlib1g",
|
"@zlib1g",
|
||||||
"@zml//runtimes/rocm:zmlxrocm",
|
|
||||||
],
|
# lld dependencies
|
||||||
|
"@libxml2",
|
||||||
|
"@libicu70//:libicuuc70",
|
||||||
|
"@libicu70//:libicudata70",
|
||||||
|
"@liblzma5",
|
||||||
|
] + select({
|
||||||
|
":_hipblaslt": ["@hipblaslt//:runfiles"],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
replace_prefixes = {
|
||||||
|
"libpjrt_rocm.patchelf": "lib",
|
||||||
|
"lib/x86_64-linux-gnu": "lib",
|
||||||
|
"usr/lib/x86_64-linux-gnu": "lib",
|
||||||
|
"libdrm-amdgpu1": "lib",
|
||||||
|
"libelf1": "lib",
|
||||||
|
"hipblaslt": "lib",
|
||||||
|
"rocblas": "lib",
|
||||||
|
"libxml2": "lib",
|
||||||
|
"libicuuc70": "lib",
|
||||||
|
"liblzma5": "lib",
|
||||||
|
"lld": "llvm/bin",
|
||||||
|
},
|
||||||
|
add_directory_to_runfiles = True,
|
||||||
|
include_external_repositories = ["**"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libpjrt_rocm",
|
||||||
|
data = [":sandbox"],
|
||||||
|
visibility = ["@zml//runtimes/rocm:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,27 +1,24 @@
|
|||||||
#
|
#
|
||||||
# bazel run @rocm_apt//:lock
|
# bazel run @apt_rocm//:lock
|
||||||
#
|
#
|
||||||
version: 1
|
version: 1
|
||||||
|
|
||||||
sources:
|
sources:
|
||||||
- channel: jammy main
|
- channel: jammy main
|
||||||
url: https://repo.radeon.com/rocm/apt/6.3.4/
|
url: https://repo.radeon.com/rocm/apt/6.3.4/
|
||||||
# - channel: bookworm main
|
- channel: jammy main
|
||||||
# url: https://snapshot-cloudflare.debian.org/archive/debian/20241127T143620Z/
|
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
|
||||||
|
- channel: jammy-security main
|
||||||
|
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
|
||||||
|
- channel: jammy-updates main
|
||||||
|
url: https://snapshot.ubuntu.com/ubuntu/20250711T030400Z
|
||||||
|
|
||||||
archs:
|
archs:
|
||||||
- "amd64"
|
- "amd64"
|
||||||
|
|
||||||
|
# readelf -d libpjrt_rosm.so | grep NEEDED
|
||||||
packages:
|
packages:
|
||||||
# - "libdrm2"
|
# - "rocm-smi-lib"
|
||||||
# - "libelf1"
|
|
||||||
# - "libnuma1"
|
|
||||||
# - "libzstd1"
|
|
||||||
# - "libdrm-amdgpu1"
|
|
||||||
# - "libtinfo6"
|
|
||||||
# - "zlib1g"
|
|
||||||
- "rocm-core"
|
|
||||||
- "rocm-smi-lib"
|
|
||||||
- "hsa-rocr"
|
- "hsa-rocr"
|
||||||
- "hsa-amd-aqlprofile"
|
- "hsa-amd-aqlprofile"
|
||||||
- "comgr"
|
- "comgr"
|
||||||
@ -31,8 +28,13 @@ packages:
|
|||||||
- "rocm-device-libs"
|
- "rocm-device-libs"
|
||||||
- "hip-dev"
|
- "hip-dev"
|
||||||
- "rocblas"
|
- "rocblas"
|
||||||
- "roctracer"
|
- "rocsolver"
|
||||||
|
- "hipsolver"
|
||||||
|
- "hipfft"
|
||||||
|
# - "roctracer"
|
||||||
- "hipblaslt"
|
- "hipblaslt"
|
||||||
- "hipblaslt-dev"
|
# - "hipblaslt-dev"
|
||||||
- "hip-runtime-amd"
|
- "hip-runtime-amd"
|
||||||
- "rocm-llvm"
|
# - "rocm-llvm"
|
||||||
|
# rocm-llvm > ld.ldd missing dependency
|
||||||
|
- "libxml2"
|
||||||
|
|||||||
@ -8,46 +8,66 @@ package(default_visibility = ["//visibility:public"])
|
|||||||
|
|
||||||
_ROCM_STRIP_PREFIX = "opt/rocm-6.3.4"
|
_ROCM_STRIP_PREFIX = "opt/rocm-6.3.4"
|
||||||
|
|
||||||
# def _kwargs(**kwargs):
|
_UBUNTU_PACKAGES = {
|
||||||
# return repr(struct(**kwargs))[len("struct("):-1]
|
"libdrm2": packages.filegroup(name = "libdrm2", srcs = ["usr/lib/x86_64-linux-gnu/libdrm.so.2"]),
|
||||||
|
"libelf1": "\n".join([
|
||||||
# def packages.cc_import(**kwargs):
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
# return """cc_import({})""".format(_kwargs(**kwargs))
|
packages.patchelf(
|
||||||
|
name = "libelf1",
|
||||||
# def packages.filegroup(**kwargs):
|
shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1",
|
||||||
# return """filegroup({})""".format(_kwargs(**kwargs))
|
set_rpath = '$ORIGIN',
|
||||||
|
),
|
||||||
# def packages.load_(bzl, name):
|
]),
|
||||||
# return """load({}, {})""".format(repr(bzl), repr(name))
|
"libnuma1": packages.filegroup(name = "libnuma1", srcs = ["usr/lib/x86_64-linux-gnu/libnuma.so.1"]),
|
||||||
|
"libzstd1": packages.filegroup(name = "libzstd1", srcs = ["usr/lib/x86_64-linux-gnu/libzstd.so.1"]),
|
||||||
# _UBUNTU_PACKAGES = {
|
"libdrm-amdgpu1": "\n".join([
|
||||||
# "libdrm2": packages.cc_import(name = "libdrm2", shared_library = "usr/lib/x86_64-linux-gnu/libdrm.so.2"),
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
# "libelf1": packages.cc_import(name = "libelf1", shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1"),
|
packages.patchelf(
|
||||||
# "libnuma1": packages.cc_import(name = "libnuma1", shared_library = "usr/lib/x86_64-linux-gnu/libnuma.so.1"),
|
name = "libdrm-amdgpu1",
|
||||||
# "libzstd1": packages.cc_import(name = "libzstd1", shared_library = "usr/lib/x86_64-linux-gnu/libzstd.so.1"),
|
shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1",
|
||||||
# "libdrm-amdgpu1": packages.cc_import(name = "libdrm-amdgpu1", shared_library = "usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1"),
|
set_rpath = '$ORIGIN',
|
||||||
# "libtinfo6": packages.cc_import(name = "libtinfo6", shared_library = "lib/x86_64-linux-gnu/libtinfo.so.6"),
|
),
|
||||||
# "zlib1g": packages.cc_import(name = "zlib1g", shared_library = "lib/x86_64-linux-gnu/libz.so.1"),
|
]),
|
||||||
# }
|
"libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]),
|
||||||
|
"zlib1g": packages.filegroup(name = "zlib1g", srcs = ["lib/x86_64-linux-gnu/libz.so.1"]),
|
||||||
|
"liblzma5": packages.filegroup(name = "liblzma5", srcs = ["lib/x86_64-linux-gnu/liblzma.so.5"]),
|
||||||
|
"libxml2": "\n".join([
|
||||||
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
|
packages.patchelf(
|
||||||
|
name = "libxml2",
|
||||||
|
shared_library = "usr/lib/x86_64-linux-gnu/libxml2.so.2",
|
||||||
|
set_rpath = '$ORIGIN',
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
"libicu70": "\n".join([
|
||||||
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
|
packages.patchelf(
|
||||||
|
name = "libicuuc70",
|
||||||
|
shared_library = "usr/lib/x86_64-linux-gnu/libicuuc.so.70",
|
||||||
|
set_rpath = '$ORIGIN',
|
||||||
|
),
|
||||||
|
packages.filegroup(name = "libicudata70", srcs = ["usr/lib/x86_64-linux-gnu/libicudata.so.70"])
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
|
||||||
_ROCM_PACKAGES = {
|
_ROCM_PACKAGES = {
|
||||||
"rocm-core": packages.cc_import(name = "rocm-core", shared_library = "lib/librocm-core.so.1"),
|
"rocm-core": packages.filegroup(name = "rocm-core", srcs = ["lib/librocm-core.so.1"]),
|
||||||
"rocm-smi-lib": packages.cc_import(name = "rocm_smi", shared_library = "lib/librocm_smi64.so.7"),
|
"rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]),
|
||||||
"hsa-rocr": packages.cc_import(name = "hsa-runtime", shared_library = "lib/libhsa-runtime64.so.1"),
|
"hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]),
|
||||||
"hsa-amd-aqlprofile": packages.cc_import(name = "hsa-amd-aqlprofile", shared_library = "lib/libhsa-amd-aqlprofile64.so.1"),
|
"hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]),
|
||||||
"comgr": packages.cc_import(name = "amd_comgr", shared_library = "lib/libamd_comgr.so.2"),
|
"comgr": packages.filegroup(name = "amd_comgr", srcs = ["lib/libamd_comgr.so.2"]),
|
||||||
"rocprofiler-register": packages.cc_import(name = "rocprofiler-register", shared_library = "lib/librocprofiler-register.so.0"),
|
"rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]),
|
||||||
"miopen-hip": "\n".join([
|
"miopen-hip": "\n".join([
|
||||||
packages.cc_import(name = "MIOpen", shared_library = "lib/libMIOpen.so.1"),
|
packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]),
|
||||||
"""filegroup(name = "runfiles", srcs = glob(["share/miopen/**"]))""",
|
"""filegroup(name = "runfiles", srcs = glob(["share/miopen/**"]))""",
|
||||||
]),
|
]),
|
||||||
"rccl": packages.cc_import(name = "rccl", shared_library = "lib/librccl.so.1"),
|
"rccl": packages.filegroup(name = "rccl", srcs = ["lib/librccl.so.1"]),
|
||||||
"rocm-device-libs": """filegroup(name = "runfiles", srcs = glob(["amdgcn/**"]))""",
|
"rocm-device-libs": """filegroup(name = "runfiles", srcs = glob(["amdgcn/**"]))""",
|
||||||
"hip-dev": """filegroup(name = "runfiles", srcs = glob(["share/**"]))""",
|
"hip-dev": """filegroup(name = "runfiles", srcs = glob(["share/**"]))""",
|
||||||
"rocblas": "\n".join([
|
"rocblas": "\n".join([
|
||||||
packages.load_("@zml//bazel:cc_import.bzl", "cc_import"),
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
|
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
|
||||||
packages.cc_import(
|
packages.patchelf(
|
||||||
name = "rocblas",
|
name = "rocblas",
|
||||||
shared_library = "lib/librocblas.so.4",
|
shared_library = "lib/librocblas.so.4",
|
||||||
add_needed = ["libzmlxrocm.so.0"],
|
add_needed = ["libzmlxrocm.so.0"],
|
||||||
@ -69,14 +89,16 @@ _ROCM_PACKAGES = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
|
"rocfft": packages.filegroup(name = "rocfft", srcs = ["lib/librocfft.so.0"]),
|
||||||
|
"rocsolver": packages.filegroup(name = "rocsolver", srcs = ["lib/librocsolver.so.0"]),
|
||||||
"roctracer": "\n".join([
|
"roctracer": "\n".join([
|
||||||
packages.cc_import(name = "roctracer", shared_library = "lib/libroctracer64.so.4", deps = [":roctx"]),
|
packages.filegroup(name = "roctracer", srcs = ["lib/libroctracer64.so.4"]),
|
||||||
packages.cc_import(name = "roctx", shared_library = "lib/libroctx64.so.4"),
|
packages.filegroup(name = "roctx", srcs = ["lib/libroctx64.so.4"]),
|
||||||
]),
|
]),
|
||||||
"hipblaslt": "\n".join([
|
"hipblaslt": "\n".join([
|
||||||
packages.load_("@zml//bazel:cc_import.bzl", "cc_import"),
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
|
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
|
||||||
packages.cc_import(
|
packages.patchelf(
|
||||||
name = "hipblaslt",
|
name = "hipblaslt",
|
||||||
shared_library = "lib/libhipblaslt.so.0",
|
shared_library = "lib/libhipblaslt.so.0",
|
||||||
add_needed = ["libzmlxrocm.so.0"],
|
add_needed = ["libzmlxrocm.so.0"],
|
||||||
@ -102,11 +124,21 @@ _ROCM_PACKAGES = {
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
|
"hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]),
|
||||||
"hip-runtime-amd": "\n".join([
|
"hip-runtime-amd": "\n".join([
|
||||||
packages.cc_import(name = "amdhip", shared_library = "lib/libamdhip64.so.6", deps = [":hiprtc"]),
|
packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]),
|
||||||
packages.cc_import(name = "hiprtc", shared_library = "lib/libhiprtc.so.6"),
|
packages.filegroup(name = "hiprtc", srcs = ["lib/libhiprtc.so.6"]),
|
||||||
|
]),
|
||||||
|
"hipsolver": packages.filegroup(name = "hipsolver", srcs = ["lib/libhipsolver.so.0"]),
|
||||||
|
"rocm-llvm": "\n".join([
|
||||||
|
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
|
||||||
|
packages.patchelf(
|
||||||
|
name = "lld",
|
||||||
|
#TODO: Rename attr to elf_file or file ?
|
||||||
|
shared_library = "llvm/bin/ld.lld",
|
||||||
|
set_rpath = '$ORIGIN/../../lib',
|
||||||
|
),
|
||||||
]),
|
]),
|
||||||
"rocm-llvm": packages.filegroup(name = "lld", srcs = ["llvm/bin/ld.lld"], visibility = ["//visibility:public"]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _rocm_impl(mctx):
|
def _rocm_impl(mctx):
|
||||||
@ -114,6 +146,15 @@ def _rocm_impl(mctx):
|
|||||||
"@zml//runtimes/rocm:packages.lock.json",
|
"@zml//runtimes/rocm:packages.lock.json",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
for pkg_name, build_file_content in _UBUNTU_PACKAGES.items():
|
||||||
|
pkg = loaded_packages[pkg_name]
|
||||||
|
http_deb_archive(
|
||||||
|
name = pkg_name,
|
||||||
|
urls = pkg["urls"],
|
||||||
|
sha256 = pkg["sha256"],
|
||||||
|
build_file_content = _BUILD_FILE_DEFAULT_VISIBILITY + build_file_content,
|
||||||
|
)
|
||||||
|
|
||||||
for pkg_name, build_file_content in _ROCM_PACKAGES.items():
|
for pkg_name, build_file_content in _ROCM_PACKAGES.items():
|
||||||
pkg = loaded_packages[pkg_name]
|
pkg = loaded_packages[pkg_name]
|
||||||
http_deb_archive(
|
http_deb_archive(
|
||||||
|
|||||||
@ -8,6 +8,8 @@ const pjrt = @import("pjrt");
|
|||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/runtime/rocm");
|
||||||
|
|
||||||
const ROCmEnvEntry = struct {
|
const ROCmEnvEntry = struct {
|
||||||
name: [:0]const u8,
|
name: [:0]const u8,
|
||||||
rpath: []const u8,
|
rpath: []const u8,
|
||||||
@ -16,10 +18,10 @@ const ROCmEnvEntry = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const rocm_env_entries: []const ROCmEnvEntry = &.{
|
const rocm_env_entries: []const ROCmEnvEntry = &.{
|
||||||
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "hipblaslt/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
|
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
|
||||||
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "hipblaslt/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
|
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
|
||||||
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "rocblas/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
|
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
|
||||||
.{ .name = "ROCM_PATH", .rpath = "libpjrt_rocm/sandbox", .dirname = false, .mandatory = true },
|
.{ .name = "ROCM_PATH", .rpath = "/", .dirname = false, .mandatory = true },
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn isEnabled() bool {
|
pub fn isEnabled() bool {
|
||||||
@ -33,20 +35,9 @@ fn hasRocmDevices() bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setupRocmEnv() !void {
|
fn setupRocmEnv(allocator: std.mem.Allocator, rocm_data_dir: []const u8) !void {
|
||||||
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
|
||||||
defer arena.deinit();
|
|
||||||
|
|
||||||
const r = blk: {
|
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
|
||||||
stdx.debug.panic("Unable to find Runfiles directory", .{});
|
|
||||||
};
|
|
||||||
const source_repo = bazel_builtin.current_repository;
|
|
||||||
break :blk r_.withSourceRepo(source_repo);
|
|
||||||
};
|
|
||||||
|
|
||||||
for (rocm_env_entries) |entry| {
|
for (rocm_env_entries) |entry| {
|
||||||
var real_path = r.rlocationAlloc(arena.allocator(), entry.rpath) catch null orelse {
|
var real_path: []const u8 = std.fmt.allocPrintZ(allocator, "{s}/{s}", .{ rocm_data_dir, entry.rpath }) catch null orelse {
|
||||||
if (entry.mandatory) {
|
if (entry.mandatory) {
|
||||||
stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository });
|
stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository });
|
||||||
}
|
}
|
||||||
@ -59,7 +50,7 @@ fn setupRocmEnv() !void {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = c.setenv(entry.name, try arena.allocator().dupeZ(u8, real_path), 1);
|
_ = c.setenv(entry.name, try allocator.dupeZ(u8, real_path), 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +65,35 @@ pub fn load() !*const pjrt.Api {
|
|||||||
return error.Unavailable;
|
return error.Unavailable;
|
||||||
}
|
}
|
||||||
|
|
||||||
try setupRocmEnv();
|
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
|
||||||
return try asynk.callBlocking(pjrt.Api.loadFrom, .{"libpjrt_rocm.so"});
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
|
||||||
|
stdx.debug.panic("Unable to find runfiles", .{});
|
||||||
|
};
|
||||||
|
|
||||||
|
const source_repo = bazel_builtin.current_repository;
|
||||||
|
const r = r_.withSourceRepo(source_repo);
|
||||||
|
|
||||||
|
var path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||||
|
const sandbox_path = try r.rlocation("libpjrt_rocm/sandbox", &path_buf) orelse {
|
||||||
|
log.err("Failed to find sandbox path for ROCm runtime", .{});
|
||||||
|
return error.FileNotFound;
|
||||||
|
};
|
||||||
|
|
||||||
|
try setupRocmEnv(arena.allocator(), sandbox_path);
|
||||||
|
|
||||||
|
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||||
|
const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" });
|
||||||
|
|
||||||
|
// We must load the PJRT plugin from the main thread.
|
||||||
|
//
|
||||||
|
// This is because libamdhip64.so use thread local storage as part of the static destructors...
|
||||||
|
//
|
||||||
|
// This destructor accesses a thread-local variable. If the destructor is
|
||||||
|
// executed in a different thread than the one that originally called dlopen()
|
||||||
|
// on the library, the thread-local storage (TLS) offset may be resolved
|
||||||
|
// relative to the TLS base of the main thread, rather than the thread actually
|
||||||
|
// executing the destructor. Accessing this variable results in a segmentation fault...
|
||||||
|
return try pjrt.Api.loadFrom(lib_path);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user