Update Bazel workspace and runtime BUILD files to newer XLA, StableHLO, and LLVM versions, enabling batching‑dims support for the gather operator.
This commit is contained in:
parent
897786e440
commit
0606ea1d7c
@ -55,7 +55,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
|
|||||||
register_toolchains("//third_party/zls:all")
|
register_toolchains("//third_party/zls:all")
|
||||||
|
|
||||||
bazel_dep(name = "libxev", version = "20240910.0-a2d9b31")
|
bazel_dep(name = "libxev", version = "20240910.0-a2d9b31")
|
||||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a")
|
||||||
|
|
||||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
llvm.configure(
|
llvm.configure(
|
||||||
@ -67,8 +67,8 @@ llvm.configure(
|
|||||||
)
|
)
|
||||||
use_repo(llvm, "llvm-project")
|
use_repo(llvm, "llvm-project")
|
||||||
|
|
||||||
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
|
bazel_dep(name = "stablehlo", version = "20240917.0-78c753a")
|
||||||
bazel_dep(name = "xla", version = "20240902.0-d18cd64")
|
bazel_dep(name = "xla", version = "20240919.0-1b18dd6")
|
||||||
|
|
||||||
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
||||||
use_repo(tsl, "tsl")
|
use_repo(tsl, "tsl")
|
||||||
|
|||||||
@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_linux_amd64",
|
name = "libpjrt_cpu_linux_amd64",
|
||||||
build_file_content = _BUILD.format(ext = "so"),
|
build_file_content = _BUILD.format(ext = "so"),
|
||||||
sha256 = "14317143acd6a38656e97280e8010c0b8d8c0863dff2ae82834b6f2fe747427b",
|
sha256 = "2058c999a4866716f1dae0c42476c09da0f6deff7e77e34c5223b61f5e0027fb",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_linux-amd64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cpu_darwin_arm64",
|
name = "libpjrt_cpu_darwin_arm64",
|
||||||
build_file_content = _BUILD.format(ext = "dylib"),
|
build_file_content = _BUILD.format(ext = "dylib"),
|
||||||
sha256 = "3a26e1372f68fc11028c4ec22a0c72693f08e7690ba8c5f28b17f5baa9c9dc77",
|
sha256 = "727b0380a577b2759468a4e0b3574e1d81e1b4348c3942d23284d590c7ca91a5",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_darwin-arm64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_darwin-arm64.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -182,8 +182,8 @@ cc_import(
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_cuda",
|
name = "libpjrt_cuda",
|
||||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cuda_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cuda_linux-amd64.tar.gz",
|
||||||
sha256 = "b705f761e24d85ecd750df992a88715d9c461b7561c31722b9f878eeab32f39e",
|
sha256 = "45e91e8649bcccc43900f90d6dcbf0cfe87d3d2ee76f1763f41263d2ed44d31b",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
@ -227,8 +227,8 @@ def _rocm_impl(mctx):
|
|||||||
http_archive(
|
http_archive(
|
||||||
name = "libpjrt_rocm",
|
name = "libpjrt_rocm",
|
||||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-rocm_linux-amd64.tar.gz",
|
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-rocm_linux-amd64.tar.gz",
|
||||||
sha256 = "5900cec41274e80ab799bc13f31cdc87202f8e168d7e753b1c10796912f5ebef",
|
sha256 = "dcb2f8e1fd29e3d7ba8d3018d97a060888e5bcf4847a683cb11686caa6ad9fa2",
|
||||||
)
|
)
|
||||||
|
|
||||||
return mctx.extension_metadata(
|
return mctx.extension_metadata(
|
||||||
|
|||||||
10
third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module(
|
||||||
|
name = "llvm-raw",
|
||||||
|
version = "20240919.0-94c024a",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||||
|
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||||
0
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/BUILD.bazel
vendored
Normal file
0
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/BUILD.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module(
|
||||||
|
name = "llvm-raw",
|
||||||
|
version = "20240919.0-94c024a",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||||
|
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||||
28
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl
vendored
Normal file
28
third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||||
|
|
||||||
|
def _llvm_impl(mctx):
|
||||||
|
_targets = {}
|
||||||
|
for mod in mctx.modules:
|
||||||
|
for conf in mod.tags.configure:
|
||||||
|
for target in conf.targets:
|
||||||
|
_targets[target] = True
|
||||||
|
_llvm_configure(
|
||||||
|
name = "llvm-project",
|
||||||
|
targets = _targets.keys(),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm = module_extension(
|
||||||
|
implementation = _llvm_impl,
|
||||||
|
tag_classes = {
|
||||||
|
"configure": tag_class(
|
||||||
|
attrs = {
|
||||||
|
"targets": attr.string_list(mandatory = True),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
10
third_party/modules/llvm-raw/20240919.0-94c024a/source.json
vendored
Normal file
10
third_party/modules/llvm-raw/20240919.0-94c024a/source.json
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "llvm-project-94c024adedcb53059c29d7c2d62982053b60e86a",
|
||||||
|
"url": "https://github.com/llvm/llvm-project/archive/94c024adedcb53059c29d7c2d62982053b60e86a.tar.gz",
|
||||||
|
"integrity": "sha256-IEzt6quG8GXvZMs4id0ukt3UqPXVtrwctLJ2aU+2p5g=",
|
||||||
|
"overlay": {
|
||||||
|
"BUILD.bazel": "",
|
||||||
|
"MODULE.bazel": "",
|
||||||
|
"utils/bazel/extension.bzl": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
3
third_party/modules/llvm-raw/metadata.json
vendored
3
third_party/modules/llvm-raw/metadata.json
vendored
@ -11,7 +11,8 @@
|
|||||||
"github:llvm/llvm-project"
|
"github:llvm/llvm-project"
|
||||||
],
|
],
|
||||||
"versions": [
|
"versions": [
|
||||||
"20240823.0-f142f8a"
|
"20240823.0-f142f8a",
|
||||||
|
"20240919.0-94c024a",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,7 @@ module(
|
|||||||
|
|
||||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a")
|
||||||
|
|
||||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
llvm.configure(
|
llvm.configure(
|
||||||
|
|||||||
15
third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module(
|
||||||
|
name = "stablehlo",
|
||||||
|
version = "20240917.0-78c753a",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a")
|
||||||
|
|
||||||
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = ["AArch64", "X86", "NVPTX"],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
15
third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module(
|
||||||
|
name = "stablehlo",
|
||||||
|
version = "20240917.0-78c753a",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a")
|
||||||
|
|
||||||
|
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||||
|
llvm.configure(
|
||||||
|
targets = ["AArch64", "X86", "NVPTX"],
|
||||||
|
)
|
||||||
|
use_repo(llvm, "llvm-project")
|
||||||
8
third_party/modules/stablehlo/20240917.0-78c753a/source.json
vendored
Normal file
8
third_party/modules/stablehlo/20240917.0-78c753a/source.json
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "stablehlo-78c753ad13ad8205cacc5fcc12418c1ac97276c7",
|
||||||
|
"url": "https://github.com/openxla/stablehlo/archive/78c753ad13ad8205cacc5fcc12418c1ac97276c7.tar.gz",
|
||||||
|
"integrity": "sha256-2qFcPLXs7glf7RFDUiqbKXiBBamW7n8a30egR4Yu/bc=",
|
||||||
|
"overlay": {
|
||||||
|
"MODULE.bazel": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
3
third_party/modules/stablehlo/metadata.json
vendored
3
third_party/modules/stablehlo/metadata.json
vendored
@ -11,7 +11,8 @@
|
|||||||
"github:openxla/stablehlo"
|
"github:openxla/stablehlo"
|
||||||
],
|
],
|
||||||
"versions": [
|
"versions": [
|
||||||
"20240829.0-54aa1a5"
|
"20240829.0-54aa1a5",
|
||||||
|
"20240917.0-78c753a",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
34
third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20240919.0-1b18dd6",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
bazel_dep(name = "stablehlo", version = "20240917.0-78c753a")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
)
|
||||||
34
third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20240919.0-1b18dd6",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
bazel_dep(name = "stablehlo", version = "20240917.0-78c753a")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
)
|
||||||
13
third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl
vendored
Normal file
13
third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
load("//third_party:repo.bzl", "tf_vendored")
|
||||||
|
|
||||||
|
def _tsl_impl(mctx):
|
||||||
|
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
tsl = module_extension(
|
||||||
|
implementation = _tsl_impl,
|
||||||
|
)
|
||||||
52
third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl
vendored
Normal file
52
third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||||
|
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||||
|
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||||
|
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||||
|
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||||
|
|
||||||
|
def _xla_workspace_impl(mctx):
|
||||||
|
cuda_configure(name = "local_config_cuda")
|
||||||
|
remote_execution_configure(name = "local_config_remote_execution")
|
||||||
|
rocm_configure(name = "local_config_rocm")
|
||||||
|
tensorrt_configure(name = "local_config_tensorrt")
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_github_grpc_grpc",
|
||||||
|
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||||
|
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||||
|
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
|
||||||
|
patch_file = [
|
||||||
|
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
|
||||||
|
"@tsl//third_party/grpc:register_go_toolchain.patch",
|
||||||
|
],
|
||||||
|
system_link_files = {
|
||||||
|
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||||
|
)
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_google_protobuf",
|
||||||
|
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
|
||||||
|
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||||
|
strip_prefix = "protobuf-3.21.9",
|
||||||
|
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
|
||||||
|
system_link_files = {
|
||||||
|
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||||
|
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla_workspace = module_extension(
|
||||||
|
implementation = _xla_workspace_impl,
|
||||||
|
)
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Steeve Morin <steeve.morin@gmail.com>
|
||||||
|
Date: Mon, 18 Mar 2024 17:17:34 +0100
|
||||||
|
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
|
||||||
|
|
||||||
|
---
|
||||||
|
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
|
||||||
|
1 file changed, 2 insertions(+)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
index c821916ad..89a596123 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||||
|
@@ -16,8 +16,10 @@ limitations under the License.
|
||||||
|
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||||
|
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||||
|
|
||||||
|
+#ifdef __cplusplus
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
+#endif
|
||||||
|
|
||||||
|
#include "xla/backends/profiler/plugin/profiler_c_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-146)
|
||||||
|
|
||||||
14
third_party/modules/xla/20240919.0-1b18dd6/source.json
vendored
Normal file
14
third_party/modules/xla/20240919.0-1b18dd6/source.json
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "xla-d391119197eab771a84c1f8a59a7f50b7da4b43d",
|
||||||
|
"url": "https://github.com/openxla/xla/archive/d391119197eab771a84c1f8a59a7f50b7da4b43d.tar.gz",
|
||||||
|
"integrity": "sha256-ToPcvPhoGX6Ny3yLx6kJlQSwj7k5Xy4T+BjIp3DAj/s=",
|
||||||
|
"overlay": {
|
||||||
|
"tsl.bzl": "",
|
||||||
|
"workspace.bzl": "",
|
||||||
|
"MODULE.bazel": ""
|
||||||
|
},
|
||||||
|
"patch_strip": 1,
|
||||||
|
"patches": {
|
||||||
|
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
3
third_party/modules/xla/metadata.json
vendored
3
third_party/modules/xla/metadata.json
vendored
@ -11,7 +11,8 @@
|
|||||||
"github:openxla/xla"
|
"github:openxla/xla"
|
||||||
],
|
],
|
||||||
"versions": [
|
"versions": [
|
||||||
"20240902.0-d18cd64"
|
"20240902.0-d18cd64",
|
||||||
|
"20240919.0-1b18dd6",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2052,7 +2052,6 @@ pub const Tensor = struct {
|
|||||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
|
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||||
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
|
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
|
||||||
// batching axes are implicits.
|
// batching axes are implicits.
|
||||||
// TODO: batched gather don't compile https://github.com/zml/zml/issues/400
|
|
||||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
||||||
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
||||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||||
@ -2206,20 +2205,13 @@ pub const Tensor = struct {
|
|||||||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||||||
try std.testing.expect(y.value().owner().verify());
|
try std.testing.expect(y.value().owner().verify());
|
||||||
|
|
||||||
// The batching dims test case doesn't pass.
|
const mod = try zml.compileFn(
|
||||||
// The weird part is that the MLIR seems valid, but pjrt doesn't accept it.
|
std.testing.allocator,
|
||||||
// TODO: https://github.com/zml/zml/issues/400
|
gatherSlices,
|
||||||
const mod = zml.compileFn(std.testing.allocator, gatherSlices, .{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } }, platform);
|
.{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } },
|
||||||
|
platform,
|
||||||
if (mod) |m| {
|
);
|
||||||
m.deinit();
|
defer mod.deinit();
|
||||||
} else |err| {
|
|
||||||
if (@hasField(@TypeOf(idx_shape), "a")) {
|
|
||||||
scoped_log.warn("Skipping compilation test of gather with batching dims: https://github.com/zml/zml/issues/400", .{});
|
|
||||||
} else {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user