From 0606ea1d7caaed59bd63c1820b1f3adf9aeffab5 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 1 Feb 2023 15:58:30 +0000 Subject: [PATCH] =?UTF-8?q?Update=20Bazel=20workspace=20and=20runtime=20BU?= =?UTF-8?q?ILD=20files=20to=20newer=20XLA,=20StableHLO,=20and=20LLVM=20ver?= =?UTF-8?q?sions,=20enabling=20batching=E2=80=91dims=20support=20for=20the?= =?UTF-8?q?=20gather=20operator.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MODULE.bazel | 6 +-- runtimes/cpu/cpu.bzl | 8 +-- runtimes/cuda/cuda.bzl | 4 +- runtimes/rocm/rocm.bzl | 4 +- .../llvm-raw/20240919.0-94c024a/MODULE.bazel | 10 ++++ .../20240919.0-94c024a/overlay/BUILD.bazel | 0 .../20240919.0-94c024a/overlay/MODULE.bazel | 10 ++++ .../overlay/utils/bazel/extension.bzl | 28 ++++++++++ .../llvm-raw/20240919.0-94c024a/source.json | 10 ++++ third_party/modules/llvm-raw/metadata.json | 3 +- .../stablehlo/20240829.0-54aa1a5/MODULE.bazel | 2 +- .../stablehlo/20240917.0-78c753a/MODULE.bazel | 15 ++++++ .../20240917.0-78c753a/overlay/MODULE.bazel | 15 ++++++ .../stablehlo/20240917.0-78c753a/source.json | 8 +++ third_party/modules/stablehlo/metadata.json | 3 +- .../xla/20240919.0-1b18dd6/MODULE.bazel | 34 ++++++++++++ .../20240919.0-1b18dd6/overlay/MODULE.bazel | 34 ++++++++++++ .../xla/20240919.0-1b18dd6/overlay/tsl.bzl | 13 +++++ .../20240919.0-1b18dd6/overlay/workspace.bzl | 52 +++++++++++++++++++ ...e-C-compliance-for-Profiler-Extensio.patch | 27 ++++++++++ .../xla/20240919.0-1b18dd6/source.json | 14 +++++ third_party/modules/xla/metadata.json | 3 +- zml/tensor.zig | 22 +++----- 23 files changed, 295 insertions(+), 30 deletions(-) create mode 100644 third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel create mode 100644 third_party/modules/llvm-raw/20240919.0-94c024a/overlay/BUILD.bazel create mode 100644 third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel create mode 100644 third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl create mode 100644 third_party/modules/llvm-raw/20240919.0-94c024a/source.json create mode 100644 third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel create mode 100644 third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel create mode 100644 third_party/modules/stablehlo/20240917.0-78c753a/source.json create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch create mode 100644 third_party/modules/xla/20240919.0-1b18dd6/source.json diff --git a/MODULE.bazel b/MODULE.bazel index f09b01b..e9a8b7c 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -55,7 +55,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux") register_toolchains("//third_party/zls:all") bazel_dep(name = "libxev", version = "20240910.0-a2d9b31") -bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a") +bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm.configure( @@ -67,8 +67,8 @@ llvm.configure( ) use_repo(llvm, "llvm-project") -bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5") -bazel_dep(name = "xla", version = "20240902.0-d18cd64") +bazel_dep(name = "stablehlo", version = "20240917.0-78c753a") +bazel_dep(name = "xla", version = "20240919.0-1b18dd6") tsl = use_extension("@xla//:tsl.bzl", "tsl") use_repo(tsl, "tsl") diff --git a/runtimes/cpu/cpu.bzl b/runtimes/cpu/cpu.bzl index b358538..8cd1599 100644 --- a/runtimes/cpu/cpu.bzl +++ b/runtimes/cpu/cpu.bzl @@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx): http_archive( name = "libpjrt_cpu_linux_amd64", build_file_content = _BUILD.format(ext = "so"), - sha256 = "14317143acd6a38656e97280e8010c0b8d8c0863dff2ae82834b6f2fe747427b", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_linux-amd64.tar.gz", + sha256 = "2058c999a4866716f1dae0c42476c09da0f6deff7e77e34c5223b61f5e0027fb", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_linux-amd64.tar.gz", ) http_archive( name = "libpjrt_cpu_darwin_arm64", build_file_content = _BUILD.format(ext = "dylib"), - sha256 = "3a26e1372f68fc11028c4ec22a0c72693f08e7690ba8c5f28b17f5baa9c9dc77", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_darwin-arm64.tar.gz", + sha256 = "727b0380a577b2759468a4e0b3574e1d81e1b4348c3942d23284d590c7ca91a5", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_darwin-arm64.tar.gz", ) return mctx.extension_metadata( diff --git a/runtimes/cuda/cuda.bzl b/runtimes/cuda/cuda.bzl index 2f3acba..f981e54 100644 --- a/runtimes/cuda/cuda.bzl +++ b/runtimes/cuda/cuda.bzl @@ -182,8 +182,8 @@ cc_import( http_archive( name = "libpjrt_cuda", build_file = "libpjrt_cuda.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cuda_linux-amd64.tar.gz", - sha256 = "b705f761e24d85ecd750df992a88715d9c461b7561c31722b9f878eeab32f39e", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cuda_linux-amd64.tar.gz", + sha256 = "45e91e8649bcccc43900f90d6dcbf0cfe87d3d2ee76f1763f41263d2ed44d31b", ) return mctx.extension_metadata( diff --git a/runtimes/rocm/rocm.bzl b/runtimes/rocm/rocm.bzl index 603095b..d23f574 100644 --- a/runtimes/rocm/rocm.bzl +++ b/runtimes/rocm/rocm.bzl @@ -227,8 +227,8 @@ def _rocm_impl(mctx): http_archive( name = "libpjrt_rocm", build_file = "libpjrt_rocm.BUILD.bazel", - url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-rocm_linux-amd64.tar.gz", - sha256 = "5900cec41274e80ab799bc13f31cdc87202f8e168d7e753b1c10796912f5ebef", + url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-rocm_linux-amd64.tar.gz", + sha256 = "dcb2f8e1fd29e3d7ba8d3018d97a060888e5bcf4847a683cb11686caa6ad9fa2", ) return mctx.extension_metadata( diff --git a/third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel b/third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel new file mode 100644 index 0000000..c9130b8 --- /dev/null +++ b/third_party/modules/llvm-raw/20240919.0-94c024a/MODULE.bazel @@ -0,0 +1,10 @@ +module( + name = "llvm-raw", + version = "20240919.0-94c024a", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd") +bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib") diff --git a/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/BUILD.bazel b/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/BUILD.bazel new file mode 100644 index 0000000..e69de29 diff --git a/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel b/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel new file mode 100644 index 0000000..c9130b8 --- /dev/null +++ b/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/MODULE.bazel @@ -0,0 +1,10 @@ +module( + name = "llvm-raw", + version = "20240919.0-94c024a", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd") +bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib") diff --git a/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl b/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl new file mode 100644 index 0000000..f247ed4 --- /dev/null +++ b/third_party/modules/llvm-raw/20240919.0-94c024a/overlay/utils/bazel/extension.bzl @@ -0,0 +1,28 @@ +load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure") + +def _llvm_impl(mctx): + _targets = {} + for mod in mctx.modules: + for conf in mod.tags.configure: + for target in conf.targets: + _targets[target] = True + _llvm_configure( + name = "llvm-project", + targets = _targets.keys(), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +llvm = module_extension( + implementation = _llvm_impl, + tag_classes = { + "configure": tag_class( + attrs = { + "targets": attr.string_list(mandatory = True), + }, + ), + }, +) diff --git a/third_party/modules/llvm-raw/20240919.0-94c024a/source.json b/third_party/modules/llvm-raw/20240919.0-94c024a/source.json new file mode 100644 index 0000000..2aec201 --- /dev/null +++ b/third_party/modules/llvm-raw/20240919.0-94c024a/source.json @@ -0,0 +1,10 @@ +{ + "strip_prefix": "llvm-project-94c024adedcb53059c29d7c2d62982053b60e86a", + "url": "https://github.com/llvm/llvm-project/archive/94c024adedcb53059c29d7c2d62982053b60e86a.tar.gz", + "integrity": "sha256-IEzt6quG8GXvZMs4id0ukt3UqPXVtrwctLJ2aU+2p5g=", + "overlay": { + "BUILD.bazel": "", + "MODULE.bazel": "", + "utils/bazel/extension.bzl": "" + } +} diff --git a/third_party/modules/llvm-raw/metadata.json b/third_party/modules/llvm-raw/metadata.json index e32ca27..077104b 100644 --- a/third_party/modules/llvm-raw/metadata.json +++ b/third_party/modules/llvm-raw/metadata.json @@ -11,7 +11,8 @@ "github:llvm/llvm-project" ], "versions": [ - "20240823.0-f142f8a" + "20240823.0-f142f8a", + "20240919.0-94c024a", ], "yanked_versions": {} } diff --git a/third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel b/third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel index 14386c0..29e45e0 100644 --- a/third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel +++ b/third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel @@ -6,7 +6,7 @@ module( bazel_dep(name = "bazel_skylib", version = "1.7.1") bazel_dep(name = "rules_cc", version = "0.0.9") -bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a") +bazel_dep(name = "llvm-raw", version = "20240823.0-94c024a") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm.configure( diff --git a/third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel b/third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel new file mode 100644 index 0000000..7cbfb93 --- /dev/null +++ b/third_party/modules/stablehlo/20240917.0-78c753a/MODULE.bazel @@ -0,0 +1,15 @@ +module( + name = "stablehlo", + version = "20240917.0-78c753a", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a") + +llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") +llvm.configure( + targets = ["AArch64", "X86", "NVPTX"], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel b/third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel new file mode 100644 index 0000000..7cbfb93 --- /dev/null +++ b/third_party/modules/stablehlo/20240917.0-78c753a/overlay/MODULE.bazel @@ -0,0 +1,15 @@ +module( + name = "stablehlo", + version = "20240917.0-78c753a", + compatibility_level = 1, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a") + +llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") +llvm.configure( + targets = ["AArch64", "X86", "NVPTX"], +) +use_repo(llvm, "llvm-project") diff --git a/third_party/modules/stablehlo/20240917.0-78c753a/source.json b/third_party/modules/stablehlo/20240917.0-78c753a/source.json new file mode 100644 index 0000000..cd677bd --- /dev/null +++ b/third_party/modules/stablehlo/20240917.0-78c753a/source.json @@ -0,0 +1,8 @@ +{ + "strip_prefix": "stablehlo-78c753ad13ad8205cacc5fcc12418c1ac97276c7", + "url": "https://github.com/openxla/stablehlo/archive/78c753ad13ad8205cacc5fcc12418c1ac97276c7.tar.gz", + "integrity": "sha256-2qFcPLXs7glf7RFDUiqbKXiBBamW7n8a30egR4Yu/bc=", + "overlay": { + "MODULE.bazel": "" + } +} diff --git a/third_party/modules/stablehlo/metadata.json b/third_party/modules/stablehlo/metadata.json index 7e0964b..1fb4617 100644 --- a/third_party/modules/stablehlo/metadata.json +++ b/third_party/modules/stablehlo/metadata.json @@ -11,7 +11,8 @@ "github:openxla/stablehlo" ], "versions": [ - "20240829.0-54aa1a5" + "20240829.0-54aa1a5", + "20240917.0-78c753a", ], "yanked_versions": {} } diff --git a/third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel b/third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel new file mode 100644 index 0000000..bd3a004 --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/MODULE.bazel @@ -0,0 +1,34 @@ +module( + name = "xla", + version = "20240919.0-1b18dd6", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +bazel_dep(name = "stablehlo", version = "20240917.0-78c753a") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", +) diff --git a/third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel b/third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel new file mode 100644 index 0000000..bd3a004 --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/overlay/MODULE.bazel @@ -0,0 +1,34 @@ +module( + name = "xla", + version = "20240919.0-1b18dd6", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +bazel_dep(name = "stablehlo", version = "20240917.0-78c753a") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", +) diff --git a/third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl b/third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl new file mode 100644 index 0000000..50c1df1 --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/overlay/tsl.bzl @@ -0,0 +1,13 @@ +load("//third_party:repo.bzl", "tf_vendored") + +def _tsl_impl(mctx): + tf_vendored(name = "tsl", relpath = "third_party/tsl") + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +tsl = module_extension( + implementation = _tsl_impl, +) diff --git a/third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl b/third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl new file mode 100644 index 0000000..87c035b --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/overlay/workspace.bzl @@ -0,0 +1,52 @@ +load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure") +load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure") + +def _xla_workspace_impl(mctx): + cuda_configure(name = "local_config_cuda") + remote_execution_configure(name = "local_config_remote_execution") + rocm_configure(name = "local_config_rocm") + tensorrt_configure(name = "local_config_tensorrt") + tf_http_archive( + name = "com_github_grpc_grpc", + sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", + strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", + system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD", + patch_file = [ + "@tsl//third_party/grpc:generate_cc_env_fix.patch", + "@tsl//third_party/grpc:register_go_toolchain.patch", + ], + system_link_files = { + "@tsl//third_party/systemlibs:BUILD": "bazel/BUILD", + "@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD", + "@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl", + "@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl", + }, + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"), + ) + tf_http_archive( + name = "com_google_protobuf", + patch_file = ["@tsl//third_party/protobuf:protobuf.patch"], + sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1", + strip_prefix = "protobuf-3.21.9", + system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD", + system_link_files = { + "@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", + "@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl", + }, + urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +xla_workspace = module_extension( + implementation = _xla_workspace_impl, +) diff --git a/third_party/modules/xla/20240919.0-1b18dd6/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch b/third_party/modules/xla/20240919.0-1b18dd6/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch new file mode 100644 index 0000000..0f58f10 --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/patches/0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch @@ -0,0 +1,27 @@ +From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001 +From: Steeve Morin +Date: Mon, 18 Mar 2024 17:17:34 +0100 +Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension + +--- + xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++ + 1 file changed, 2 insertions(+) + +diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h +index c821916ad..89a596123 100644 +--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h +@@ -16,8 +16,10 @@ limitations under the License. + #ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ + #define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_ + ++#ifdef __cplusplus + #include + #include ++#endif + + #include "xla/backends/profiler/plugin/profiler_c_api.h" + #include "xla/pjrt/c/pjrt_c_api.h" +-- +2.39.3 (Apple Git-146) + diff --git a/third_party/modules/xla/20240919.0-1b18dd6/source.json b/third_party/modules/xla/20240919.0-1b18dd6/source.json new file mode 100644 index 0000000..9853728 --- /dev/null +++ b/third_party/modules/xla/20240919.0-1b18dd6/source.json @@ -0,0 +1,14 @@ +{ + "strip_prefix": "xla-d391119197eab771a84c1f8a59a7f50b7da4b43d", + "url": "https://github.com/openxla/xla/archive/d391119197eab771a84c1f8a59a7f50b7da4b43d.tar.gz", + "integrity": "sha256-ToPcvPhoGX6Ny3yLx6kJlQSwj7k5Xy4T+BjIp3DAj/s=", + "overlay": { + "tsl.bzl": "", + "workspace.bzl": "", + "MODULE.bazel": "" + }, + "patch_strip": 1, + "patches": { + "0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": "" + } +} diff --git a/third_party/modules/xla/metadata.json b/third_party/modules/xla/metadata.json index 568c92c..e0dc325 100644 --- a/third_party/modules/xla/metadata.json +++ b/third_party/modules/xla/metadata.json @@ -11,7 +11,8 @@ "github:openxla/xla" ], "versions": [ - "20240902.0-d18cd64" + "20240902.0-d18cd64", + "20240919.0-1b18dd6", ], "yanked_versions": {} } diff --git a/zml/tensor.zig b/zml/tensor.zig index d2e6032..a1ad219 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -2052,7 +2052,6 @@ pub const Tensor = struct { .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } }, .{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } }, // batching axes are implicits. - // TODO: batched gather don't compile https://github.com/zml/zml/issues/400 .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } }, .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } }, .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } }, @@ -2206,20 +2205,13 @@ pub const Tensor = struct { try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); try std.testing.expect(y.value().owner().verify()); - // The batching dims test case doesn't pass. - // The weird part is that the MLIR seems valid, but pjrt doesn't accept it. - // TODO: https://github.com/zml/zml/issues/400 - const mod = zml.compileFn(std.testing.allocator, gatherSlices, .{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } }, platform); - - if (mod) |m| { - m.deinit(); - } else |err| { - if (@hasField(@TypeOf(idx_shape), "a")) { - scoped_log.warn("Skipping compilation test of gather with batching dims: https://github.com/zml/zml/issues/400", .{}); - } else { - return err; - } - } + const mod = try zml.compileFn( + std.testing.allocator, + gatherSlices, + .{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } }, + platform, + ); + defer mod.deinit(); } }