Radix/third_party/xla/xla.bzl

86 lines
2.8 KiB
Python
Raw Normal View History

load("@xla//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls", "tf_vendored")
load("@xla//third_party/llvm:workspace.bzl", llvm = "repo")
load("@xla//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
load("@xla//third_party/triton:workspace.bzl", triton = "repo")
_BZL_HELPERS = """\
always_newer_than = lambda wanted_ver, if_true, if_false = []: if_true
always_if_true = lambda if_true, if_false = []: if_true
always_if_false = lambda if_true, if_false = []: if_false
always_false = lambda *args, **kwargs: False
"""
def _simple_files_impl(rctx):
rctx.file("BUILD.bazel", "")
for f, content in rctx.attr.files.items():
rctx.file(f, content)
simple_files = repository_rule(
implementation = _simple_files_impl,
attrs = {
"files": attr.string_dict(),
},
)
def _dummy_repos(mctx):
simple_files(name = "local_config_cuda", files = {
"cuda/BUILD.bazel": "",
"cuda/build_defs.bzl": _BZL_HELPERS + """\
cuda_library = always_false
if_cuda = always_if_false
if_cuda_is_configured = always_if_false
if_cuda_newer_than = always_newer_than
is_cuda_configured = always_false
""",
})
simple_files(name = "local_config_rocm", files = {
"rocm/BUILD.bazel": "",
"rocm/build_defs.bzl": _BZL_HELPERS + """\
if_rocm = always_if_false
if_rocm_is_configured = always_if_false
if_rocm_newer_than = always_newer_than
is_rocm_configured = always_false
if_gpu_is_configured = always_if_false
if_cuda_or_rocm = always_if_false
""",
})
simple_files(name = "local_config_remote_execution", files = {
"remote_execution.bzl": """gpu_test_tags = lambda: []""",
})
simple_files(name = "local_config_tensorrt", files = {
"build_defs.bzl": _BZL_HELPERS + """if_tensorrt = always_if_false""",
})
simple_files(name = "python_version_repo", files = {
"py_version.bzl": """USE_PYWRAP_RULES = False""",
})
simple_files(name = "rules_ml_toolchain", files = {
"third_party/gpus/BUILD.bazel": "",
"third_party/gpus/nvidia_common_rules.bzl": """cuda_rpath_flags = lambda *args, **kwargs: []""",
})
def _xla_impl(mctx):
llvm("llvm-raw")
stablehlo()
triton()
tf_http_archive(
name = "com_github_grpc_grpc",
sha256 = "dd6a2fa311ba8441bbefd2764c55b99136ff10f7ea42954be96006a2723d33fc",
strip_prefix = "grpc-1.74.0",
patch_file = ["//third_party/grpc:grpc.patch"],
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/refs/tags/v1.74.0.tar.gz"),
)
tf_vendored(name = "tsl", relpath = "third_party/tsl")
_dummy_repos(mctx)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
xla = module_extension(
implementation = _xla_impl,
)