From 9b7eea8ac2286b270b14757a566c06f641ce0571 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 21 Jun 2023 14:45:14 +0000 Subject: [PATCH] Add stdx utilities and rework async signature inference; tidy executable logging. --- MODULE.bazel.lock | 1172 +++++++++++++++++++------------------- async/meta.zig | 94 ++- async/threaded.zig | 57 +- pjrt/pjrt.zig | 5 +- pjrt/profiler.zig | 2 +- stdx/BUILD.bazel | 13 + stdx/debug.zig | 33 ++ stdx/math.zig | 25 + stdx/meta.zig | 158 +++++ stdx/signature.zig | 65 +++ stdx/stdx.zig | 3 + zml/BUILD.bazel | 1 + zml/aio.zig | 18 +- zml/aio/gguf.zig | 2 +- zml/aio/gguf/core.zig | 2 +- zml/aio/nemo.zig | 2 +- zml/aio/safetensors.zig | 2 +- zml/aio/tinyllama.zig | 6 +- zml/aio/torch.zig | 2 +- zml/aio/torch/eval.zig | 5 +- zml/aio/torch/file.zig | 13 +- zml/aio/torch/pickle.zig | 2 +- zml/aio/torch/py.zig | 2 +- zml/buffer.zig | 22 +- zml/context.zig | 46 +- zml/helpers.zig | 2 +- zml/hostbuffer.zig | 16 +- zml/meta.zig | 249 +------- zml/mlir.zig | 10 +- zml/module.zig | 132 +++-- zml/nn.zig | 50 +- zml/ops.zig | 28 +- zml/pjrtx.zig | 25 +- zml/platform.zig | 20 +- zml/shape.zig | 113 ++-- zml/tensor.zig | 281 ++++----- zml/test_runner.zig | 10 +- zml/testing.zig | 11 +- zml/tokenizer.zig | 10 +- zml/torch.zig | 22 +- 40 files changed, 1490 insertions(+), 1241 deletions(-) create mode 100644 stdx/BUILD.bazel create mode 100644 stdx/debug.zig create mode 100644 stdx/math.zig create mode 100644 stdx/meta.zig create mode 100644 stdx/signature.zig create mode 100644 stdx/stdx.zig diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index ee6947a..d433354 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -207,20 +207,20 @@ "@@apple_support~//crosstool:setup.bzl%apple_cc_configure_extension": { "general": { "bzlTransitiveDigest": "hDHJiBbKme6a+N8oiSQcVnU1v5B7tHMjJaAzS6GFfPc=", - "usagesDigest": "atH8xayh8CVhGZG9cm/kh7fV7XwOYgQy6Zhe1AzBH3g=", + "usagesDigest": "8MpdqvE6998JkKyqx+txaHjL8YIQGCDb8inyLInqs4w=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, "generatedRepoSpecs": { - "local_config_apple_cc": { - "bzlFile": "@@apple_support~//crosstool:setup.bzl", - "ruleClassName": "_apple_cc_autoconf", - "attributes": {} - }, "local_config_apple_cc_toolchains": { "bzlFile": "@@apple_support~//crosstool:setup.bzl", "ruleClassName": "_apple_cc_autoconf_toolchains", "attributes": {} + }, + "local_config_apple_cc": { + "bzlFile": "@@apple_support~//crosstool:setup.bzl", + "ruleClassName": "_apple_cc_autoconf", + "attributes": {} } }, "recordedRepoMappingEntries": [ @@ -234,19 +234,96 @@ }, "@@aspect_bazel_lib~//lib:extensions.bzl%toolchains": { "general": { - "bzlTransitiveDigest": "p8GR5EeSAP5RvoQ0821nQQ9nelAyTrNoUULFFBUdGlU=", - "usagesDigest": "JTRHwJCDIFLoz4veHZ+G6HnnCkU60EauNsYpX/JQlYc=", + "bzlTransitiveDigest": "NXGl3qDbDYRSfhwBauRAtJl0BrupptFIJpGJvKIaZZw=", + "usagesDigest": "+nDx/9K+hc6aq2Kz9+r5P5pxJwJUEaQ7gWkUFfbAKw0=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, "generatedRepoSpecs": { - "expand_template_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", + "copy_directory_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "copy_directory_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "copy_directory_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "copy_directory_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "copy_directory_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "copy_directory_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", "attributes": { "platform": "windows_amd64" } }, + "copy_directory_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_toolchains_repo", + "attributes": { + "user_repository_name": "copy_directory" + } + }, + "copy_to_directory_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "copy_to_directory_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "copy_to_directory_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "copy_to_directory_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "copy_to_directory_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, "copy_to_directory_windows_amd64": { "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", "ruleClassName": "copy_to_directory_platform_repo", @@ -254,6 +331,13 @@ "platform": "windows_amd64" } }, + "copy_to_directory_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_toolchains_repo", + "attributes": { + "user_repository_name": "copy_to_directory" + } + }, "jq_darwin_amd64": { "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", "ruleClassName": "jq_platform_repo", @@ -262,18 +346,20 @@ "version": "1.7" } }, - "copy_to_directory_freebsd_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_platform_repo", + "jq_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", "attributes": { - "platform": "freebsd_amd64" + "platform": "darwin_arm64", + "version": "1.7" } }, - "expand_template_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", + "jq_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", "attributes": { - "platform": "linux_amd64" + "platform": "linux_amd64", + "version": "1.7" } }, "jq_linux_arm64": { @@ -284,6 +370,102 @@ "version": "1.7" } }, + "jq_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", + "attributes": { + "platform": "windows_amd64", + "version": "1.7" + } + }, + "jq": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_host_alias_repo", + "attributes": {} + }, + "jq_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_toolchains_repo", + "attributes": { + "user_repository_name": "jq" + } + }, + "yq_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "darwin_amd64", + "version": "4.25.2" + } + }, + "yq_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "darwin_arm64", + "version": "4.25.2" + } + }, + "yq_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_amd64", + "version": "4.25.2" + } + }, + "yq_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_arm64", + "version": "4.25.2" + } + }, + "yq_linux_s390x": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_s390x", + "version": "4.25.2" + } + }, + "yq_linux_ppc64le": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_ppc64le", + "version": "4.25.2" + } + }, + "yq_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "windows_amd64", + "version": "4.25.2" + } + }, + "yq": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_host_alias_repo", + "attributes": {} + }, + "yq_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_toolchains_repo", + "attributes": { + "user_repository_name": "yq" + } + }, + "coreutils_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "darwin_amd64", + "version": "0.0.26" + } + }, "coreutils_darwin_arm64": { "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", "ruleClassName": "coreutils_platform_repo", @@ -292,32 +474,11 @@ "version": "0.0.26" } }, - "copy_to_directory_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_platform_repo", - "attributes": { - "platform": "linux_arm64" - } - }, - "bsd_tar_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", - "ruleClassName": "bsdtar_binary_repo", - "attributes": { - "platform": "linux_arm64" - } - }, - "copy_directory_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "darwin_amd64" - } - }, - "coreutils_darwin_amd64": { + "coreutils_linux_amd64": { "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", "ruleClassName": "coreutils_platform_repo", "attributes": { - "platform": "darwin_amd64", + "platform": "linux_amd64", "version": "0.0.26" } }, @@ -329,92 +490,31 @@ "version": "0.0.26" } }, - "zstd_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", - "ruleClassName": "zstd_binary_repo", + "coreutils_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", "attributes": { - "platform": "linux_arm64" + "platform": "windows_amd64", + "version": "0.0.26" } }, - "yq_linux_s390x": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", + "coreutils_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_toolchains_repo", "attributes": { - "platform": "linux_s390x", - "version": "4.25.2" + "user_repository_name": "coreutils" } }, - "yq": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_host_alias_repo", - "attributes": {} - }, - "expand_template_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", + "bsd_tar_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", + "ruleClassName": "bsdtar_binary_repo", "attributes": { "platform": "darwin_amd64" } }, - "copy_directory_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "linux_amd64" - } - }, - "jq_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", - "ruleClassName": "jq_platform_repo", - "attributes": { - "platform": "darwin_arm64", - "version": "1.7" - } - }, - "yq_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "darwin_amd64", - "version": "4.25.2" - } - }, - "copy_directory_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "linux_arm64" - } - }, - "expand_template_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_toolchains_repo", - "attributes": { - "user_repository_name": "expand_template" - } - }, - "bats_assert": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "sha256": "98ca3b685f8b8993e48ec057565e6e2abcc541034ed5b0e81f191505682037fd", - "urls": [ - "https://github.com/bats-core/bats-assert/archive/v2.1.0.tar.gz" - ], - "strip_prefix": "bats-assert-2.1.0", - "build_file_content": "load(\"@aspect_bazel_lib//lib:copy_to_directory.bzl\", \"copy_to_directory\")\n\ncopy_to_directory(\n name = \"assert\",\n hardlink = \"on\",\n srcs = glob([\n \"src/**\",\n \"load.bash\",\n ]),\n out = \"bats-assert\",\n visibility = [\"//visibility:public\"]\n)\n" - } - }, - "copy_to_directory_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_platform_repo", - "attributes": { - "platform": "darwin_amd64" - } - }, - "zstd_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", - "ruleClassName": "zstd_binary_repo", + "bsd_tar_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", + "ruleClassName": "bsdtar_binary_repo", "attributes": { "platform": "darwin_arm64" } @@ -426,11 +526,39 @@ "platform": "linux_amd64" } }, - "yq_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_toolchains_repo", + "bsd_tar_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", + "ruleClassName": "bsdtar_binary_repo", "attributes": { - "user_repository_name": "yq" + "platform": "linux_arm64" + } + }, + "bsd_tar_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", + "ruleClassName": "bsdtar_binary_repo", + "attributes": { + "platform": "windows_amd64" + } + }, + "bsd_tar_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", + "ruleClassName": "tar_toolchains_repo", + "attributes": { + "user_repository_name": "bsd_tar" + } + }, + "zstd_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", + "ruleClassName": "zstd_binary_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "zstd_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", + "ruleClassName": "zstd_binary_repo", + "attributes": { + "platform": "darwin_arm64" } }, "zstd_linux_amd64": { @@ -440,6 +568,69 @@ "platform": "linux_amd64" } }, + "zstd_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", + "ruleClassName": "zstd_binary_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "zstd_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", + "ruleClassName": "zstd_toolchains_repo", + "attributes": { + "user_repository_name": "zstd" + } + }, + "expand_template_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "expand_template_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "expand_template_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "expand_template_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "expand_template_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "expand_template_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "windows_amd64" + } + }, + "expand_template_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_toolchains_repo", + "attributes": { + "user_repository_name": "expand_template" + } + }, "bats_support": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", @@ -452,95 +643,16 @@ "build_file_content": "load(\"@aspect_bazel_lib//lib:copy_to_directory.bzl\", \"copy_to_directory\")\n\ncopy_to_directory(\n name = \"support\",\n hardlink = \"on\",\n srcs = glob([\n \"src/**\",\n \"load.bash\",\n ]),\n out = \"bats-support\",\n visibility = [\"//visibility:public\"]\n)\n" } }, - "bsd_tar_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", - "ruleClassName": "bsdtar_binary_repo", + "bats_assert": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", "attributes": { - "platform": "windows_amd64" - } - }, - "jq": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", - "ruleClassName": "jq_host_alias_repo", - "attributes": {} - }, - "expand_template_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", - "attributes": { - "platform": "darwin_arm64" - } - }, - "bsd_tar_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", - "ruleClassName": "bsdtar_binary_repo", - "attributes": { - "platform": "darwin_arm64" - } - }, - "copy_to_directory_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_platform_repo", - "attributes": { - "platform": "linux_amd64" - } - }, - "coreutils_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", - "ruleClassName": "coreutils_platform_repo", - "attributes": { - "platform": "linux_amd64", - "version": "0.0.26" - } - }, - "copy_directory_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_toolchains_repo", - "attributes": { - "user_repository_name": "copy_directory" - } - }, - "yq_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "linux_amd64", - "version": "4.25.2" - } - }, - "copy_to_directory_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_platform_repo", - "attributes": { - "platform": "darwin_arm64" - } - }, - "coreutils_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", - "ruleClassName": "coreutils_toolchains_repo", - "attributes": { - "user_repository_name": "coreutils" - } - }, - "copy_directory_freebsd_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "freebsd_amd64" - } - }, - "zstd_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", - "ruleClassName": "zstd_binary_repo", - "attributes": { - "platform": "darwin_amd64" - } - }, - "zstd_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:zstd_toolchain.bzl", - "ruleClassName": "zstd_toolchains_repo", - "attributes": { - "user_repository_name": "zstd" + "sha256": "98ca3b685f8b8993e48ec057565e6e2abcc541034ed5b0e81f191505682037fd", + "urls": [ + "https://github.com/bats-core/bats-assert/archive/v2.1.0.tar.gz" + ], + "strip_prefix": "bats-assert-2.1.0", + "build_file_content": "load(\"@aspect_bazel_lib//lib:copy_to_directory.bzl\", \"copy_to_directory\")\n\ncopy_to_directory(\n name = \"assert\",\n hardlink = \"on\",\n srcs = glob([\n \"src/**\",\n \"load.bash\",\n ]),\n out = \"bats-assert\",\n visibility = [\"//visibility:public\"]\n)\n" } }, "bats_file": { @@ -555,35 +667,6 @@ "build_file_content": "load(\"@aspect_bazel_lib//lib:copy_to_directory.bzl\", \"copy_to_directory\")\n\ncopy_to_directory(\n name = \"file\",\n hardlink = \"on\",\n srcs = glob([\n \"src/**\",\n \"load.bash\",\n ]),\n out = \"bats-file\",\n visibility = [\"//visibility:public\"]\n)\n" } }, - "expand_template_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", - "attributes": { - "platform": "linux_arm64" - } - }, - "jq_linux_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", - "ruleClassName": "jq_platform_repo", - "attributes": { - "platform": "linux_amd64", - "version": "1.7" - } - }, - "bsd_tar_darwin_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", - "ruleClassName": "bsdtar_binary_repo", - "attributes": { - "platform": "darwin_amd64" - } - }, - "bsd_tar_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:tar_toolchain.bzl", - "ruleClassName": "tar_toolchains_repo", - "attributes": { - "user_repository_name": "bsd_tar" - } - }, "bats_toolchains": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", @@ -595,89 +678,6 @@ "strip_prefix": "bats-core-1.10.0", "build_file_content": "load(\"@local_config_platform//:constraints.bzl\", \"HOST_CONSTRAINTS\")\nload(\"@aspect_bazel_lib//lib/private:bats_toolchain.bzl\", \"bats_toolchain\")\nload(\"@aspect_bazel_lib//lib:copy_to_directory.bzl\", \"copy_to_directory\")\n\ncopy_to_directory(\n name = \"core\",\n hardlink = \"on\",\n srcs = glob([\n \"lib/**\",\n \"libexec/**\"\n ]) + [\"bin/bats\"],\n out = \"bats-core\",\n)\n\nbats_toolchain(\n name = \"toolchain\",\n core = \":core\",\n libraries = [\"@bats_support//:support\", \"@bats_assert//:assert\", \"@bats_file//:file\"]\n)\n\ntoolchain(\n name = \"bats_toolchain\",\n exec_compatible_with = HOST_CONSTRAINTS,\n toolchain = \":toolchain\",\n toolchain_type = \"@aspect_bazel_lib//lib:bats_toolchain_type\",\n)\n" } - }, - "yq_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "windows_amd64", - "version": "4.25.2" - } - }, - "jq_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", - "ruleClassName": "jq_platform_repo", - "attributes": { - "platform": "windows_amd64", - "version": "1.7" - } - }, - "expand_template_freebsd_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", - "ruleClassName": "expand_template_platform_repo", - "attributes": { - "platform": "freebsd_amd64" - } - }, - "yq_linux_ppc64le": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "linux_ppc64le", - "version": "4.25.2" - } - }, - "copy_to_directory_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", - "ruleClassName": "copy_to_directory_toolchains_repo", - "attributes": { - "user_repository_name": "copy_to_directory" - } - }, - "jq_toolchains": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", - "ruleClassName": "jq_toolchains_repo", - "attributes": { - "user_repository_name": "jq" - } - }, - "copy_directory_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "darwin_arm64" - } - }, - "copy_directory_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", - "ruleClassName": "copy_directory_platform_repo", - "attributes": { - "platform": "windows_amd64" - } - }, - "yq_darwin_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "darwin_arm64", - "version": "4.25.2" - } - }, - "coreutils_windows_amd64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", - "ruleClassName": "coreutils_platform_repo", - "attributes": { - "platform": "windows_amd64", - "version": "0.0.26" - } - }, - "yq_linux_arm64": { - "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", - "ruleClassName": "yq_platform_repo", - "attributes": { - "platform": "linux_arm64", - "version": "4.25.2" - } } }, "recordedRepoMappingEntries": [ @@ -701,22 +701,34 @@ }, "@@buildifier_prebuilt~//:defs.bzl%buildifier_prebuilt_deps_extension": { "general": { - "bzlTransitiveDigest": "RaNT6gZicQa6HpONHOm8ejwF7zVAk4fIUHrbuHme7z4=", - "usagesDigest": "nThSTPRdiQbhDFl8FRM2nsKJftWMtPBQHrp/mdk716w=", + "bzlTransitiveDigest": "beIcQnY+rGQj1WVqpFpKHQLJEY1e8GEbQAtEYPYZRoc=", + "usagesDigest": "MbIuhDGRTlw07fpxjzM2N+5FUBehV3EnCmO7eEN86tc=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, "generatedRepoSpecs": { - "buildozer_darwin_amd64": { + "buildifier_darwin_amd64": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_file", "attributes": { "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildozer-darwin-amd64" + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-darwin-amd64" ], - "downloaded_file_path": "buildozer", + "downloaded_file_path": "buildifier", "executable": true, - "sha256": "d29e347ecd6b5673d72cb1a8de05bf1b06178dd229ff5eb67fad5100c840cc8e" + "sha256": "eeb47b2de27f60efe549348b183fac24eae80f1479e8b06cac0799c486df5bed" + } + }, + "buildifier_darwin_arm64": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_file", + "attributes": { + "urls": [ + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-darwin-arm64" + ], + "downloaded_file_path": "buildifier", + "executable": true, + "sha256": "fa07ba0d20165917ca4cc7609f9b19a8a4392898148b7babdf6bb2a7dd963f05" } }, "buildifier_linux_amd64": { @@ -731,6 +743,42 @@ "sha256": "be63db12899f48600bad94051123b1fd7b5251e7661b9168582ce52396132e92" } }, + "buildifier_linux_arm64": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_file", + "attributes": { + "urls": [ + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-linux-arm64" + ], + "downloaded_file_path": "buildifier", + "executable": true, + "sha256": "18540fc10f86190f87485eb86963e603e41fa022f88a2d1b0cf52ff252b5e1dd" + } + }, + "buildifier_windows_amd64": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_file", + "attributes": { + "urls": [ + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-windows-amd64.exe" + ], + "downloaded_file_path": "buildifier.exe", + "executable": true, + "sha256": "da8372f35e34b65fb6d997844d041013bb841e55f58b54d596d35e49680fe13c" + } + }, + "buildozer_darwin_amd64": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_file", + "attributes": { + "urls": [ + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildozer-darwin-amd64" + ], + "downloaded_file_path": "buildozer", + "executable": true, + "sha256": "d29e347ecd6b5673d72cb1a8de05bf1b06178dd229ff5eb67fad5100c840cc8e" + } + }, "buildozer_darwin_arm64": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_file", @@ -755,18 +803,6 @@ "sha256": "8dfd6345da4e9042daa738d7fdf34f699c5dfce4632f7207956fceedd8494119" } }, - "buildozer_windows_amd64": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_file", - "attributes": { - "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildozer-windows-amd64.exe" - ], - "downloaded_file_path": "buildozer.exe", - "executable": true, - "sha256": "e7f05bf847f7c3689dd28926460ce6e1097ae97380ac8e6ae7147b7b706ba19b" - } - }, "buildozer_linux_arm64": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_file", @@ -779,16 +815,16 @@ "sha256": "6559558fded658c8fa7432a9d011f7c4dcbac6b738feae73d2d5c352e5f605fa" } }, - "buildifier_windows_amd64": { + "buildozer_windows_amd64": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_file", "attributes": { "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-windows-amd64.exe" + "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildozer-windows-amd64.exe" ], - "downloaded_file_path": "buildifier.exe", + "downloaded_file_path": "buildozer.exe", "executable": true, - "sha256": "da8372f35e34b65fb6d997844d041013bb841e55f58b54d596d35e49680fe13c" + "sha256": "e7f05bf847f7c3689dd28926460ce6e1097ae97380ac8e6ae7147b7b706ba19b" } }, "buildifier_prebuilt_toolchains": { @@ -797,42 +833,6 @@ "attributes": { "assets_json": "[{\"arch\":\"amd64\",\"name\":\"buildifier\",\"platform\":\"darwin\",\"sha256\":\"eeb47b2de27f60efe549348b183fac24eae80f1479e8b06cac0799c486df5bed\",\"version\":\"v6.4.0\"},{\"arch\":\"arm64\",\"name\":\"buildifier\",\"platform\":\"darwin\",\"sha256\":\"fa07ba0d20165917ca4cc7609f9b19a8a4392898148b7babdf6bb2a7dd963f05\",\"version\":\"v6.4.0\"},{\"arch\":\"amd64\",\"name\":\"buildifier\",\"platform\":\"linux\",\"sha256\":\"be63db12899f48600bad94051123b1fd7b5251e7661b9168582ce52396132e92\",\"version\":\"v6.4.0\"},{\"arch\":\"arm64\",\"name\":\"buildifier\",\"platform\":\"linux\",\"sha256\":\"18540fc10f86190f87485eb86963e603e41fa022f88a2d1b0cf52ff252b5e1dd\",\"version\":\"v6.4.0\"},{\"arch\":\"amd64\",\"name\":\"buildifier\",\"platform\":\"windows\",\"sha256\":\"da8372f35e34b65fb6d997844d041013bb841e55f58b54d596d35e49680fe13c\",\"version\":\"v6.4.0\"},{\"arch\":\"amd64\",\"name\":\"buildozer\",\"platform\":\"darwin\",\"sha256\":\"d29e347ecd6b5673d72cb1a8de05bf1b06178dd229ff5eb67fad5100c840cc8e\",\"version\":\"v6.4.0\"},{\"arch\":\"arm64\",\"name\":\"buildozer\",\"platform\":\"darwin\",\"sha256\":\"9b9e71bdbec5e7223871e913b65d12f6d8fa026684daf991f00e52ed36a6978d\",\"version\":\"v6.4.0\"},{\"arch\":\"amd64\",\"name\":\"buildozer\",\"platform\":\"linux\",\"sha256\":\"8dfd6345da4e9042daa738d7fdf34f699c5dfce4632f7207956fceedd8494119\",\"version\":\"v6.4.0\"},{\"arch\":\"arm64\",\"name\":\"buildozer\",\"platform\":\"linux\",\"sha256\":\"6559558fded658c8fa7432a9d011f7c4dcbac6b738feae73d2d5c352e5f605fa\",\"version\":\"v6.4.0\"},{\"arch\":\"amd64\",\"name\":\"buildozer\",\"platform\":\"windows\",\"sha256\":\"e7f05bf847f7c3689dd28926460ce6e1097ae97380ac8e6ae7147b7b706ba19b\",\"version\":\"v6.4.0\"}]" } - }, - "buildifier_darwin_amd64": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_file", - "attributes": { - "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-darwin-amd64" - ], - "downloaded_file_path": "buildifier", - "executable": true, - "sha256": "eeb47b2de27f60efe549348b183fac24eae80f1479e8b06cac0799c486df5bed" - } - }, - "buildifier_darwin_arm64": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_file", - "attributes": { - "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-darwin-arm64" - ], - "downloaded_file_path": "buildifier", - "executable": true, - "sha256": "fa07ba0d20165917ca4cc7609f9b19a8a4392898148b7babdf6bb2a7dd963f05" - } - }, - "buildifier_linux_arm64": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_file", - "attributes": { - "urls": [ - "https://github.com/bazelbuild/buildtools/releases/download/v6.4.0/buildifier-linux-arm64" - ], - "downloaded_file_path": "buildifier", - "executable": true, - "sha256": "18540fc10f86190f87485eb86963e603e41fa022f88a2d1b0cf52ff252b5e1dd" - } } }, "recordedRepoMappingEntries": [ @@ -851,8 +851,8 @@ }, "@@hermetic_cc_toolchain~//toolchain:ext.bzl%toolchains": { "general": { - "bzlTransitiveDigest": "YrUOAIuZsxXEx+Q8Bh4BooEu9qcrrOJRxaoC368wqAs=", - "usagesDigest": "EoOqmoW16SGIu0l1zbP8oSGHkhceATca/trgf094fVo=", + "bzlTransitiveDigest": "vVlpD5ojpRWqHagv1w9Ie1/gHSDqYYOw/54ltM0wVI0=", + "usagesDigest": "G/riWK/idagMqTscQCoOap3TFm1VQbLHJKM0uRq24pk=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, @@ -901,7 +901,7 @@ "@@platforms//host:extension.bzl%host_platform": { "general": { "bzlTransitiveDigest": "WewbYICdNVp22bzUQafEVMzMIpBnNjJ3zqKlUOCDIGc=", - "usagesDigest": "V1R2Y2oMxKNfx2WCWpSCaUV1WefW1o8HZGm3v1vHgY4=", + "usagesDigest": "hgylFkgWSg0ulUwWZzEM1aIftlUnbmw2ynWLdEfHnZc=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, @@ -917,14 +917,68 @@ }, "@@rules_python~//python/extensions:python.bzl%python": { "general": { - "bzlTransitiveDigest": "uFBdNKDewkX8DPVvKoPEcW0b3dXtksiE7SShRkuCA3E=", - "usagesDigest": "fQEsnAYwqRJT7/lTBAe+NllONXk6f/Tc57oiPxLG8SI=", + "bzlTransitiveDigest": "QlZZ0JYUiuYjq3DHBAxqL2QVlVN8ZVDqZ/1KU88faZ8=", + "usagesDigest": "abUgYqI1bdv/jc3Xu7C2SbT7mmtxAziRT/kUCFERO+A=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": { "RULES_PYTHON_BZLMOD_DEBUG": null }, "generatedRepoSpecs": { + "python_3_11_aarch64-apple-darwin": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "b042c966920cf8465385ca3522986b12d745151a72c060991088977ca36d3883", + "patches": [], + "platform": "aarch64-apple-darwin", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-aarch64-apple-darwin-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-aarch64-apple-darwin-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, + "python_3_11_aarch64-unknown-linux-gnu": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "b102eaf865eb715aa98a8a2ef19037b6cc3ae7dfd4a632802650f29de635aa13", + "patches": [], + "platform": "aarch64-unknown-linux-gnu", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-aarch64-unknown-linux-gnu-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-aarch64-unknown-linux-gnu-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, + "python_3_11_ppc64le-unknown-linux-gnu": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "b44e1b74afe75c7b19143413632c4386708ae229117f8f950c2094e9681d34c7", + "patches": [], + "platform": "ppc64le-unknown-linux-gnu", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-ppc64le-unknown-linux-gnu-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-ppc64le-unknown-linux-gnu-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, "python_3_11_s390x-unknown-linux-gnu": { "bzlFile": "@@rules_python~//python:repositories.bzl", "ruleClassName": "python_repository", @@ -943,6 +997,60 @@ "ignore_root_user_error": false } }, + "python_3_11_x86_64-apple-darwin": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "a0e615eef1fafdc742da0008425a9030b7ea68a4ae4e73ac557ef27b112836d4", + "patches": [], + "platform": "x86_64-apple-darwin", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-apple-darwin-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-apple-darwin-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, + "python_3_11_x86_64-pc-windows-msvc": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "67077e6fa918e4f4fd60ba169820b00be7c390c497bf9bc9cab2c255ea8e6f3e", + "patches": [], + "platform": "x86_64-pc-windows-msvc", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-pc-windows-msvc-shared-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-pc-windows-msvc-shared-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, + "python_3_11_x86_64-unknown-linux-gnu": { + "bzlFile": "@@rules_python~//python:repositories.bzl", + "ruleClassName": "python_repository", + "attributes": { + "sha256": "4a51ce60007a6facf64e5495f4cf322e311ba9f39a8cd3f3e4c026eae488e140", + "patches": [], + "platform": "x86_64-unknown-linux-gnu", + "python_version": "3.11.7", + "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-unknown-linux-gnu-install_only.tar.gz", + "urls": [ + "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-unknown-linux-gnu-install_only.tar.gz" + ], + "distutils_content": "", + "strip_prefix": "python", + "coverage_tool": "", + "ignore_root_user_error": false + } + }, "python_3_11_host": { "bzlFile": "@@rules_python~//python/private:toolchains_repo.bzl", "ruleClassName": "host_toolchain", @@ -977,78 +1085,6 @@ ] } }, - "python_3_11_aarch64-unknown-linux-gnu": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "b102eaf865eb715aa98a8a2ef19037b6cc3ae7dfd4a632802650f29de635aa13", - "patches": [], - "platform": "aarch64-unknown-linux-gnu", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-aarch64-unknown-linux-gnu-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-aarch64-unknown-linux-gnu-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } - }, - "python_3_11_aarch64-apple-darwin": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "b042c966920cf8465385ca3522986b12d745151a72c060991088977ca36d3883", - "patches": [], - "platform": "aarch64-apple-darwin", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-aarch64-apple-darwin-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-aarch64-apple-darwin-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } - }, - "python_3_11_ppc64le-unknown-linux-gnu": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "b44e1b74afe75c7b19143413632c4386708ae229117f8f950c2094e9681d34c7", - "patches": [], - "platform": "ppc64le-unknown-linux-gnu", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-ppc64le-unknown-linux-gnu-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-ppc64le-unknown-linux-gnu-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } - }, - "python_3_11_x86_64-apple-darwin": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "a0e615eef1fafdc742da0008425a9030b7ea68a4ae4e73ac557ef27b112836d4", - "patches": [], - "platform": "x86_64-apple-darwin", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-apple-darwin-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-apple-darwin-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } - }, "pythons_hub": { "bzlFile": "@@rules_python~//python/private/bzlmod:pythons_hub.bzl", "ruleClassName": "hub_repo", @@ -1076,42 +1112,6 @@ "3.11": "python_3_11" } } - }, - "python_3_11_x86_64-pc-windows-msvc": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "67077e6fa918e4f4fd60ba169820b00be7c390c497bf9bc9cab2c255ea8e6f3e", - "patches": [], - "platform": "x86_64-pc-windows-msvc", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-pc-windows-msvc-shared-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-pc-windows-msvc-shared-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } - }, - "python_3_11_x86_64-unknown-linux-gnu": { - "bzlFile": "@@rules_python~//python:repositories.bzl", - "ruleClassName": "python_repository", - "attributes": { - "sha256": "4a51ce60007a6facf64e5495f4cf322e311ba9f39a8cd3f3e4c026eae488e140", - "patches": [], - "platform": "x86_64-unknown-linux-gnu", - "python_version": "3.11.7", - "release_filename": "20240107/cpython-3.11.7+20240107-x86_64-unknown-linux-gnu-install_only.tar.gz", - "urls": [ - "https://github.com/indygreg/python-build-standalone/releases/download/20240107/cpython-3.11.7+20240107-x86_64-unknown-linux-gnu-install_only.tar.gz" - ], - "distutils_content": "", - "strip_prefix": "python", - "coverage_tool": "", - "ignore_root_user_error": false - } } }, "recordedRepoMappingEntries": [ @@ -1130,18 +1130,23 @@ }, "@@rules_python~//python/private/bzlmod:internal_deps.bzl%internal_deps": { "general": { - "bzlTransitiveDigest": "ZKHEOaFX4G/WohTzMTV688TfCsSfLvbwBC2ma87N24w=", - "usagesDigest": "4Fj9JSpEDoJSLPRSbvSTol2bTL7baZjuA3k9U7kG/1k=", + "bzlTransitiveDigest": "zuSxKukgpeqO/JtGBnEt6151zIWBas9w8S5LaqEyZNw=", + "usagesDigest": "r7vtlnQfWxEwrL+QFXux06yzeWEkq/hrcwAssoCoSLY=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, "generatedRepoSpecs": { - "pypi__wheel": { + "rules_python_internal": { + "bzlFile": "@@rules_python~//python/private:internal_config_repo.bzl", + "ruleClassName": "internal_config_repo", + "attributes": {} + }, + "pypi__build": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", "attributes": { - "url": "https://files.pythonhosted.org/packages/b8/8b/31273bf66016be6ad22bb7345c37ff350276cfd46e389a0c2ac5da9d9073/wheel-0.41.2-py3-none-any.whl", - "sha256": "75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8", + "url": "https://files.pythonhosted.org/packages/58/91/17b00d5fac63d3dca605f1b8269ba3c65e98059e1fd99d00283e42a454f0/build-0.10.0-py3-none-any.whl", + "sha256": "af266720050a66c893a6096a2f410989eeac74ff9a68ba194b3f6473e8e26171", "type": "zip", "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } @@ -1156,76 +1161,6 @@ "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } }, - "pypi__importlib_metadata": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/cc/37/db7ba97e676af155f5fcb1a35466f446eadc9104e25b83366e8088c9c926/importlib_metadata-6.8.0-py3-none-any.whl", - "sha256": "3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__pyproject_hooks": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/d5/ea/9ae603de7fbb3df820b23a70f6aff92bf8c7770043254ad8d2dc9d6bcba4/pyproject_hooks-1.0.0-py3-none-any.whl", - "sha256": "283c11acd6b928d2f6a7c73fa0d01cb2bdc5f07c57a2eeb6e83d5e56b97976f8", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__pep517": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/ee/2f/ef63e64e9429111e73d3d6cbee80591672d16f2725e648ebc52096f3d323/pep517-0.13.0-py3-none-any.whl", - "sha256": "4ba4446d80aed5b5eac6509ade100bff3e7943a8489de249654a5ae9b33ee35b", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__packaging": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-any.whl", - "sha256": "994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__pip_tools": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/e8/df/47e6267c6b5cdae867adbdd84b437393e6202ce4322de0a5e0b92960e1d6/pip_tools-7.3.0-py3-none-any.whl", - "sha256": "8717693288720a8c6ebd07149c93ab0be1fced0b5191df9e9decd3263e20d85e", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__setuptools": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/4f/ab/0bcfebdfc3bfa8554b2b2c97a555569c4c1ebc74ea288741ea8326c51906/setuptools-68.1.2-py3-none-any.whl", - "sha256": "3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "pypi__zipp": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/8c/08/d3006317aefe25ea79d3b76c9650afabaf6d63d1c8443b236e7405447503/zipp-3.16.2-py3-none-any.whl", - "sha256": "679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, "pypi__colorama": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", @@ -1236,27 +1171,12 @@ "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } }, - "pypi__build": { + "pypi__importlib_metadata": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", "attributes": { - "url": "https://files.pythonhosted.org/packages/58/91/17b00d5fac63d3dca605f1b8269ba3c65e98059e1fd99d00283e42a454f0/build-0.10.0-py3-none-any.whl", - "sha256": "af266720050a66c893a6096a2f410989eeac74ff9a68ba194b3f6473e8e26171", - "type": "zip", - "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" - } - }, - "rules_python_internal": { - "bzlFile": "@@rules_python~//python/private:internal_config_repo.bzl", - "ruleClassName": "internal_config_repo", - "attributes": {} - }, - "pypi__pip": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "url": "https://files.pythonhosted.org/packages/50/c2/e06851e8cc28dcad7c155f4753da8833ac06a5c704c109313b8d5a62968a/pip-23.2.1-py3-none-any.whl", - "sha256": "7ccf472345f20d35bdc9d1841ff5f313260c2c33fe417f48c30ac46cccabf5be", + "url": "https://files.pythonhosted.org/packages/cc/37/db7ba97e676af155f5fcb1a35466f446eadc9104e25b83366e8088c9c926/importlib_metadata-6.8.0-py3-none-any.whl", + "sha256": "3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb", "type": "zip", "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } @@ -1281,6 +1201,66 @@ "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } }, + "pypi__packaging": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-any.whl", + "sha256": "994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__pep517": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/ee/2f/ef63e64e9429111e73d3d6cbee80591672d16f2725e648ebc52096f3d323/pep517-0.13.0-py3-none-any.whl", + "sha256": "4ba4446d80aed5b5eac6509ade100bff3e7943a8489de249654a5ae9b33ee35b", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__pip": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/50/c2/e06851e8cc28dcad7c155f4753da8833ac06a5c704c109313b8d5a62968a/pip-23.2.1-py3-none-any.whl", + "sha256": "7ccf472345f20d35bdc9d1841ff5f313260c2c33fe417f48c30ac46cccabf5be", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__pip_tools": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/e8/df/47e6267c6b5cdae867adbdd84b437393e6202ce4322de0a5e0b92960e1d6/pip_tools-7.3.0-py3-none-any.whl", + "sha256": "8717693288720a8c6ebd07149c93ab0be1fced0b5191df9e9decd3263e20d85e", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__pyproject_hooks": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/d5/ea/9ae603de7fbb3df820b23a70f6aff92bf8c7770043254ad8d2dc9d6bcba4/pyproject_hooks-1.0.0-py3-none-any.whl", + "sha256": "283c11acd6b928d2f6a7c73fa0d01cb2bdc5f07c57a2eeb6e83d5e56b97976f8", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__setuptools": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/4f/ab/0bcfebdfc3bfa8554b2b2c97a555569c4c1ebc74ea288741ea8326c51906/setuptools-68.1.2-py3-none-any.whl", + "sha256": "3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, "pypi__tomli": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", @@ -1290,6 +1270,26 @@ "type": "zip", "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" } + }, + "pypi__wheel": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/b8/8b/31273bf66016be6ad22bb7345c37ff350276cfd46e389a0c2ac5da9d9073/wheel-0.41.2-py3-none-any.whl", + "sha256": "75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } + }, + "pypi__zipp": { + "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", + "ruleClassName": "http_archive", + "attributes": { + "url": "https://files.pythonhosted.org/packages/8c/08/d3006317aefe25ea79d3b76c9650afabaf6d63d1c8443b236e7405447503/zipp-3.16.2-py3-none-any.whl", + "sha256": "679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0", + "type": "zip", + "build_file_content": "package(default_visibility = [\"//visibility:public\"])\n\nload(\"@rules_python//python:defs.bzl\", \"py_library\")\n\npy_library(\n name = \"lib\",\n srcs = glob([\"**/*.py\"]),\n data = glob([\"**/*\"], exclude=[\n # These entries include those put into user-installed dependencies by\n # data_exclude in /python/pip_install/tools/bazel.py\n # to avoid non-determinism following pip install's behavior.\n \"**/*.py\",\n \"**/*.pyc\",\n \"**/*.pyc.*\", # During pyc creation, temp files named *.pyc.NNN are created\n \"**/* *\",\n \"**/*.dist-info/RECORD\",\n \"BUILD\",\n \"WORKSPACE\",\n ]),\n # This makes this directory a top-level in the python import\n # search path for anything that depends on this.\n imports = [\".\"],\n)\n" + } } }, "recordedRepoMappingEntries": [ diff --git a/async/meta.zig b/async/meta.zig index aba974b..b0f3e56 100644 --- a/async/meta.zig +++ b/async/meta.zig @@ -1,26 +1,86 @@ const std = @import("std"); +pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { + const params = @typeInfo(funcT).Fn.params; + if (params.len == 0) { + return @TypeOf(.{}); + } + + if (@typeInfo(funcT).Fn.is_generic == false) { + return std.meta.ArgsTuple(funcT); + } + + const args = std.meta.fields(argsT orelse @compileError("generic function requires an explicit ArgsTuple")); + var tuple_fields: [params.len]std.builtin.Type.StructField = undefined; + inline for (params, args, 0..) |param, arg, i| { + if (param.type == null) { + tuple_fields[i] = arg; + continue; + } + const T = param.type.?; + var num_buf: [32]u8 = undefined; + tuple_fields[i] = .{ + .name = blk: { + const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); + num_buf[s] = 0; + break :blk num_buf[0..s :0]; + }, + .type = T, + .default_value = null, + .is_comptime = false, + .alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0, + }; + } + + return @Type(.{ + .Struct = .{ + .is_tuple = true, + .layout = .auto, + .decls = &.{}, + .fields = &tuple_fields, + }, + }); +} + +pub fn TupleRange(comptime T: type, comptime start: usize, comptime end: usize) type { + const fields = std.meta.fields(T); + var new_fields: [end - start]std.builtin.Type.StructField = undefined; + inline for (start..end, 0..) |i, j| { + var new_field = fields[i]; + var num_buf: [32]u8 = undefined; + new_field.name = blk: { + const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{}); + num_buf[s] = 0; + break :blk num_buf[0..s :0]; + }; + new_fields[j] = new_field; + } + return @Type(.{ + .Struct = .{ + .is_tuple = true, + .layout = .auto, + .decls = &.{}, + .fields = &new_fields, + }, + }); +} + pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { + return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT)); +} + +pub fn FnSignatureX(comptime func: anytype, comptime argsT: type) type { return struct { - pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); - pub const ArgsT = blk: { - if (@typeInfo(FuncT).Fn.params.len == 0) { - break :blk @TypeOf(.{}); - } - break :blk argsT orelse std.meta.ArgsTuple(FuncT); - }; + pub const FuncT = @TypeOf(func); + pub const ArgsT = argsT; pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); - pub const ReturnPayloadT = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.payload, - else => ReturnT, - }; + pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.payload, + else => ReturnT, }; - pub const ReturnErrorSet: ?type = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.error_set, - else => null, - }; + pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.error_set, + else => null, }; }; } diff --git a/async/threaded.zig b/async/threaded.zig index 51065bf..5858de6 100644 --- a/async/threaded.zig +++ b/async/threaded.zig @@ -2,16 +2,22 @@ const std = @import("std"); const xev = @import("xev"); const FnSignature = @import("meta.zig").FnSignature; +const NormalizedTuple = @import("meta.zig").NormalizedTuple; pub fn Frame(comptime func: anytype) type { const Signature = FnSignature(func, null); - return FrameEx(func, Signature.ArgsT); + return FrameExx(func, Signature); } pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { + const Signature = FnSignature(func, argsT); + return FrameExx(func, Signature); +} + +pub fn FrameExx(comptime func: anytype, comptime Signature: type) type { return struct { const Self = @This(); - const Signature = FnSignature(func, argsT); + const Signature_ = Signature; const Task = struct { _task: xev.ThreadPool.Task = .{ .callback = &Self.run }, event: std.Thread.ResetEvent = .{}, @@ -27,7 +33,8 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { task.event.set(); } - pub fn await_(self: *Self) Signature.ReturnT { + pub const await_ = wait; + pub fn wait(self: *Self) Signature.ReturnT { defer { AsyncThread.current.mutex.lock(); AsyncThread.current.allocator.destroy(self._task); @@ -39,11 +46,7 @@ pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { }; } -pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) { - return asyncGeneric(func, args); -} - -pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { +pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { const FrameT = FrameEx(func, @TypeOf(args)); AsyncThread.current.mutex.lock(); @@ -58,15 +61,11 @@ pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeO return .{ ._task = task }; } -pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) { - return callBlockingGeneric(func, args); -} - -pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { +pub inline fn callBlocking(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { return @call(.auto, func, args); } -pub fn sleep(ms: u64) !void { +pub inline fn sleep(ms: u64) !void { std.time.sleep(ms * std.time.ns_per_ms); } @@ -77,7 +76,7 @@ pub const AsyncThread = struct { thread_pool: xev.ThreadPool, mutex: std.Thread.Mutex, - pub fn main(allocator_: std.mem.Allocator, comptime func: anytype, args: anytype) !void { + pub fn main(allocator_: std.mem.Allocator, comptime mainFunc: anytype) !void { current = .{ .allocator = allocator_, .thread_pool = xev.ThreadPool.init(.{}), @@ -89,7 +88,7 @@ pub const AsyncThread = struct { current.thread_pool.deinit(); } - return @call(.auto, func, args); + return try mainFunc(); } }; @@ -114,15 +113,15 @@ pub const Notification = struct { } }; -pub fn StdIn() !File { +pub fn getStdIn() !File { return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); } -pub fn StdOut() File { +pub fn getStdOut() File { return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); } -pub fn StdErr() File { +pub fn getStdErr() File { return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); } @@ -217,3 +216,23 @@ pub const File = struct { }; pub const Mutex = std.Thread.Mutex; + +pub fn logFn( + comptime message_level: std.log.Level, + comptime scope: @Type(.EnumLiteral), + comptime format: []const u8, + args: anytype, +) void { + const level_txt = comptime message_level.asText(); + const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; + const stderr = getStdErr().writer(); + var bw = std.io.bufferedWriter(stderr); + const writer = bw.writer(); + + std.debug.lockStdErr(); + defer std.debug.unlockStdErr(); + nosuspend { + writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return; + bw.flush() catch return; + } +} diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 9ecf289..23bf40f 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -504,13 +504,14 @@ pub const LoadedExecutable = opaque { return @ptrCast(ret.addressable_devices); } - pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { + pub const ExecuteArgs = struct { num_args: usize, arguments: []const [*]const *const Buffer, results: []const [*]*Buffer, events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, - }) ApiError!void { + }; + pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void { var options = pjrtStruct(c.PJRT_ExecuteOptions{ .send_callbacks = null, .recv_callbacks = null, diff --git a/pjrt/profiler.zig b/pjrt/profiler.zig index e148bb5..6265e64 100644 --- a/pjrt/profiler.zig +++ b/pjrt/profiler.zig @@ -2,7 +2,7 @@ const std = @import("std"); const c = @import("c"); const tsl_proto = @import("//tsl:profiler_options_proto"); -const log = std.log.scoped(.zml_profiler); +const log = std.log.scoped(.@"zml/profiler"); /// Pjrt Profiler extension pub const Profiler = struct { diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel new file mode 100644 index 0000000..26c190e --- /dev/null +++ b/stdx/BUILD.bazel @@ -0,0 +1,13 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + +zig_library( + name = "stdx", + srcs = [ + "debug.zig", + "math.zig", + "meta.zig", + "signature.zig", + ], + main = "stdx.zig", + visibility = ["//visibility:public"], +) diff --git a/stdx/debug.zig b/stdx/debug.zig new file mode 100644 index 0000000..4686b3b --- /dev/null +++ b/stdx/debug.zig @@ -0,0 +1,33 @@ +const std = @import("std"); + +pub inline fn guard(check: bool, src: std.builtin.SourceLocation) void { + assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line }); +} + +pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void { + assert(check, "internal error: " ++ msg, args); +} + +pub inline fn assert(check: bool, comptime msg: []const u8, args: anytype) void { + if (!check) { + panic(msg, args); + } +} + +pub inline fn panic(comptime format: []const u8, args: anytype) noreturn { + std.debug.panic(format, args); +} + +pub inline fn compileLog(comptime msg: []const u8, comptime args: anytype) void { + @compileLog(std.fmt.comptimePrint(msg, args)); +} + +pub inline fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn { + @compileError(std.fmt.comptimePrint(msg, args)); +} + +pub inline fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void { + if (check == false) { + compileError(msg, args); + } +} diff --git a/stdx/math.zig b/stdx/math.zig new file mode 100644 index 0000000..db20a5d --- /dev/null +++ b/stdx/math.zig @@ -0,0 +1,25 @@ +pub inline fn divFloor(comptime T: type, numerator: anytype, denominator: anytype) T { + return @divFloor(floatCast(T, numerator), floatCast(T, denominator)); +} + +pub inline fn divExact(comptime T: type, numerator: anytype, denominator: anytype) T { + return @divExact(floatCast(T, numerator), floatCast(T, denominator)); +} + +pub inline fn divTrunc(comptime T: type, numerator: anytype, denominator: anytype) T { + return @divTrunc(floatCast(T, numerator), floatCast(T, denominator)); +} + +pub inline fn floatCast(comptime T: type, x: anytype) T { + return switch (@typeInfo(@TypeOf(x))) { + .Float => @floatCast(x), + else => @floatFromInt(x), + }; +} + +pub inline fn intCast(comptime T: type, x: anytype) T { + return switch (@typeInfo(@TypeOf(x))) { + .Int => @intCast(x), + else => @intFromFloat(x), + }; +} diff --git a/stdx/meta.zig b/stdx/meta.zig new file mode 100644 index 0000000..9e37c04 --- /dev/null +++ b/stdx/meta.zig @@ -0,0 +1,158 @@ +const std = @import("std"); +const debug = @import("debug.zig"); + +const compileError = debug.compileError; + +pub const FnSignature = @import("signature.zig").FnSignature; + +pub fn isStruct(comptime T: type) bool { + return switch (@typeInfo(T)) { + .Struct => true, + else => false, + }; +} + +pub fn isTuple(comptime T: type) bool { + return switch (@typeInfo(T)) { + .Struct => |info| info.is_tuple, + else => false, + }; +} + +pub fn isStructOf(comptime T: type, comptime Elem: type) bool { + return switch (@typeInfo(T)) { + .Struct => |info| blk: { + inline for (info.fields) |field| { + if (field.type != Elem) { + break :blk false; + } + } + break :blk true; + }, + else => false, + }; +} + +pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { + return switch (@typeInfo(T)) { + .Struct => |info| blk: { + inline for (info.fields) |field| { + if (f(field.type) == false) { + break :blk false; + } + } + break :blk true; + }, + else => false, + }; +} + +pub fn isTupleOf(comptime T: type, comptime Elem: type) bool { + return isTuple(T) and isStructOf(T, Elem); +} + +pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { + return isTuple(T) and isStructOfAny(T, f); +} + +pub fn isSliceOf(comptime T: type, comptime Elem: type) bool { + return switch (@typeInfo(T)) { + .Pointer => |info| switch (info.size) { + .Slice => info.child == Elem, + .One => switch (@typeInfo(info.child)) { + // As Zig, convert pointer to Array as a slice. + .Array => |arr_info| arr_info.child == Elem, + else => false, + }, + else => false, + }, + else => false, + }; +} + +pub fn isInteger(comptime T: type) bool { + return switch (@typeInfo(T)) { + .Int, .ComptimeInt => true, + else => false, + }; +} + +pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { + return switch (@typeInfo(T)) { + .Pointer => |info| info.size == .Slice and f(info.child), + else => false, + }; +} + +pub fn DeclEnum(comptime T: type) type { + const field_infos = std.meta.declarations(T); + if (field_infos.len == 0) { + compileError("Struct {} has no declarations", .{T}); + } + return std.meta.DeclEnum(UnwrapPtr(T)); +} + +pub fn UnwrapPtr(comptime T: type) type { + return switch (@typeInfo(T)) { + .Pointer => |info| switch (info.size) { + .One => info.child, + else => T, + }, + else => T, + }; +} + +pub fn asSlice(comptime T: type) type { + const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice"; + return switch (@typeInfo(T)) { + .Pointer => |info| switch (info.size) { + .Slice => info.child, + .One => switch (@typeInfo(info.child)) { + // As Zig, convert pointer to Array as a slice. + .Array => |arr_info| arr_info.child, + else => compileError(err_msg), + }, + else => compileError(err_msg), + }, + else => compileError(err_msg), + }; +} + +pub fn TupleRange(comptime T: type, comptime start: ?usize, comptime end: ?usize) type { + return TupleRangeX(T, start orelse 0, end orelse std.meta.fields(T).len); +} + +pub fn TupleRangeX(comptime T: type, comptime start: usize, comptime end: usize) type { + const fields = std.meta.fields(T); + var new_fields: [end - start]std.builtin.Type.StructField = undefined; + inline for (start..end, 0..) |i, j| { + var new_field = fields[i]; + var num_buf: [32]u8 = undefined; + new_field.name = blk: { + const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{}); + num_buf[s] = 0; + break :blk num_buf[0..s :0]; + }; + new_fields[j] = new_field; + } + return @Type(.{ + .Struct = .{ + .is_tuple = true, + .layout = .auto, + .decls = &.{}, + .fields = &new_fields, + }, + }); +} + +pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type { + return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse compileError("anytype is not supported"); +} + +pub fn FnArgs(comptime func: anytype) type { + return FnSignature(func, null).ArgsT; +} + +pub fn FnResult(comptime func: anytype) type { + return FnSignature(func, null).ReturnT; +} diff --git a/stdx/signature.zig b/stdx/signature.zig new file mode 100644 index 0000000..48aa4e0 --- /dev/null +++ b/stdx/signature.zig @@ -0,0 +1,65 @@ +const std = @import("std"); + +const compileError = @import("meta.zig").compileError; + +pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { + const params = @typeInfo(funcT).Fn.params; + if (params.len == 0) { + return @TypeOf(.{}); + } + + if (@typeInfo(funcT).Fn.is_generic == false) { + return std.meta.ArgsTuple(funcT); + } + + const args = std.meta.fields(argsT orelse compileError("generic function requires an explicit ArgsTuple", .{})); + var tuple_fields: [params.len]std.builtin.Type.StructField = undefined; + inline for (params, args, 0..) |param, arg, i| { + if (param.type == null) { + tuple_fields[i] = arg; + continue; + } + const T = param.type.?; + var num_buf: [32]u8 = undefined; + tuple_fields[i] = .{ + .name = blk: { + const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); + num_buf[s] = 0; + break :blk num_buf[0..s :0]; + }, + .type = T, + .default_value = null, + .is_comptime = false, + .alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0, + }; + } + + return @Type(.{ + .Struct = .{ + .is_tuple = true, + .layout = .auto, + .decls = &.{}, + .fields = &tuple_fields, + }, + }); +} + +pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { + return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT)); +} + +fn FnSignatureX(comptime func: anytype, comptime argsT: type) type { + return struct { + pub const FuncT = @TypeOf(func); + pub const ArgsT = argsT; + pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); + pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.payload, + else => ReturnT, + }; + pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.error_set, + else => null, + }; + }; +} diff --git a/stdx/stdx.zig b/stdx/stdx.zig new file mode 100644 index 0000000..b8447fe --- /dev/null +++ b/stdx/stdx.zig @@ -0,0 +1,3 @@ +pub const math = @import("math.zig"); +pub const meta = @import("meta.zig"); +pub const debug = @import("debug.zig"); diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index 1997182..d44cbd0 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -32,6 +32,7 @@ zig_library( "//mlir/dialects", "//pjrt", "//runtimes", + "//stdx", "//zml/tools", "@rules_zig//zig/lib:libc", "@rules_zig//zig/runfiles", diff --git a/zml/aio.zig b/zml/aio.zig index 358903b..f263b82 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -1,8 +1,10 @@ -const builtin = @import("builtin"); const asynk = @import("async"); -const std = @import("std"); -const zml = @import("zml.zig"); +const builtin = @import("builtin"); const c = @import("c"); +const std = @import("std"); +const stdx = @import("stdx"); + +const zml = @import("zml.zig"); const posix = @import("posix.zig"); pub const gguf = @import("aio/gguf.zig"); @@ -13,7 +15,7 @@ pub const tinyllama = @import("aio/tinyllama.zig"); pub const torch = @import("aio/torch.zig"); pub const yaml = @import("aio/yaml.zig"); -pub const log = std.log.scoped(.zml_aio); +pub const log = std.log.scoped(.@"zml/aio"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; test { @@ -256,7 +258,11 @@ pub const MemoryMappedFile = struct { 0, }); - try asynk.callBlocking(posix.madvise, .{ data_.ptr, @intCast(data_.len), @intCast(c.MADV_SEQUENTIAL) }); + try asynk.callBlocking(posix.madvise, .{ + data_.ptr, + @as(usize, @intCast(data_.len)), + @as(u32, @intCast(c.MADV_SEQUENTIAL)), + }); return .{ .file = file, @@ -600,7 +606,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. var buf_with_metadata = host_buffer; log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); - zml.meta.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); + stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); buf_with_metadata._shape = obj._shape; obj.* = try zml.Buffer.from(platform, buf_with_metadata); } else { diff --git a/zml/aio/gguf.zig b/zml/aio/gguf.zig index 9740871..5a8e71b 100644 --- a/zml/aio/gguf.zig +++ b/zml/aio/gguf.zig @@ -8,7 +8,7 @@ const HostBuffer = @import("../hostbuffer.zig").HostBuffer; const Allocator = std.mem.Allocator; const assert = std.debug.assert; -const log = std.log.scoped(.zml_io); +const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { var file = try core.GgufFile.open(path); diff --git a/zml/aio/gguf/core.zig b/zml/aio/gguf/core.zig index c8ed6a2..27102a2 100644 --- a/zml/aio/gguf/core.zig +++ b/zml/aio/gguf/core.zig @@ -3,7 +3,7 @@ const std = @import("std"); const zml = @import("../../zml.zig"); const assert = std.debug.assert; -const log = std.log.scoped(.zml_io); +const log = std.log.scoped(.@"zml/io"); pub const GgufErrors = error{ ValueTypeMismatch, diff --git a/zml/aio/nemo.zig b/zml/aio/nemo.zig index d90de12..c48e511 100644 --- a/zml/aio/nemo.zig +++ b/zml/aio/nemo.zig @@ -1,5 +1,5 @@ const std = @import("std"); -const log = std.log.scoped(.zml_aio); +const log = std.log.scoped(.@"zml/aio"); const asynk = @import("async"); const yaml = @import("zig-yaml"); diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index b5f3c80..e293a3f 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -7,7 +7,7 @@ const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; const StringBuilder = std.ArrayListUnmanaged(u8); const Allocator = std.mem.Allocator; -const log = std.log.scoped(.zml_io); +const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { var res: zml.aio.BufferStore = .{ diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index 499aff6..f50c246 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -1,8 +1,8 @@ /// Tools to load models from https://huggingface.co/karpathy/tinyllamas/ /// Originally made to be run with https://github.com/karpathy/llama2.c -const std = @import("std"); - const asynk = @import("async"); +const std = @import("std"); +const stdx = @import("stdx"); const zml = @import("../zml.zig"); @@ -86,7 +86,7 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe const weights_size = off; std.log.info("Loaded a tinyllama file of {} bytes.\nThis is the parsed configuration of this llama model: {}", .{ weights_size, c }); if (file.stat() catch null) |stat| { - zml.meta.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c }); + stdx.debug.assert(weights_size == stat.size, "Expected to have a tinyllama file of {} bytes but file only got {} !\nThis is the parsed configuration of this llama model: {}", .{ weights_size, stat.size, c }); } { diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index cde998b..ff0d116 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -7,7 +7,7 @@ const py = @import("torch/py.zig"); const File = @import("torch/file.zig").File; const StringBuilder = std.ArrayListUnmanaged(u8); -const log = std.log.scoped(.zml_aio); +const log = std.log.scoped(.@"zml/aio"); test { std.testing.refAllDecls(@This()); diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index d62fe3c..af4605d 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -1,6 +1,5 @@ const std = @import("std"); -const zml = @import("../../zml.zig"); -const meta = zml.meta; +const stdx = @import("stdx"); const py = @import("py.zig"); const pickle = @import("pickle.zig"); @@ -228,7 +227,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo }, } }, - .proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), + .proto => |proto| stdx.debug.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .tuple1 => try stack.append(blk: { const tup_values = try arena.alloc(py.Any, 1); tup_values[0] = try pop(&stack); diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig index afe2331..b3794bb 100644 --- a/zml/aio/torch/file.zig +++ b/zml/aio/torch/file.zig @@ -1,8 +1,6 @@ -const std = @import("std"); -const testing = std.testing; -const log = std.log.scoped(.zml_aio); - const asynk = @import("async"); +const std = @import("std"); +const stdx = @import("stdx"); const zml = @import("../../zml.zig"); const pickle = @import("pickle.zig"); @@ -10,6 +8,9 @@ const py = @import("py.zig"); const eval = @import("eval.zig"); const HostBuffer = zml.HostBuffer; +const testing = std.testing; +const log = std.log.scoped(.@"zml/aio"); + // TODO(cryptodeal): use zml.aio.PrefixBuilder instead const StringBuilder = std.ArrayListUnmanaged(u8); @@ -329,7 +330,7 @@ pub const File = struct { }, .dict => { const n = @divExact(seq.values.len, 2); - log.info("found dict with {} entries", .{n}); + log.debug("found dict with {} entries", .{n}); for (0..n) |i| { const key, const val = seq.values[2 * i ..][0..2].*; switch (key) { @@ -534,7 +535,7 @@ pub const File = struct { } fn parseDims(values: []py.Any) error{InvalidInput}!zml.Shape.DimsArray { - zml.meta.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); + stdx.debug.assert(values.len <= zml.Tensor.MAX_RANK, "Found Pytorch tensor with unsupported rank {}", .{values.len}); var result: zml.Shape.DimsArray = .{}; for (values) |val| { switch (val) { diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 5aeb529..ecf3968 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -1,6 +1,6 @@ const std = @import("std"); -const log = std.log.scoped(.zml_aio); +const log = std.log.scoped(.@"zml/aio"); /// All possible pickle operators. /// Reference: https://github.com/python/cpython/blob/3.13/Lib/pickletools.py diff --git a/zml/aio/torch/py.zig b/zml/aio/torch/py.zig index 452e4d3..c523f28 100644 --- a/zml/aio/torch/py.zig +++ b/zml/aio/torch/py.zig @@ -1,6 +1,6 @@ const std = @import("std"); const math = std.math; -const log = std.log.scoped(.zml_aio); +const log = std.log.scoped(.@"zml/aio"); const pickle = @import("pickle.zig"); diff --git a/zml/buffer.zig b/zml/buffer.zig index 9205d8f..8a59aad 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -1,9 +1,11 @@ +const asynk = @import("async"); const std = @import("std"); -const testing = std.testing; +const stdx = @import("stdx"); const meta = @import("meta.zig"); const pjrt = @import("pjrtx.zig"); -const asynk = @import("async"); + +const testing = std.testing; const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; @@ -42,12 +44,12 @@ pub const Buffer = struct { // We shard only on the first axis so that the chunks are still contiguous. // TODO: support more advanced sharding specs - meta.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); + stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); const sharding_ax: ?u3 = std.simd.firstTrue(host_buffer.shape()._sharding_info); const n_partitions = platform.sharding().num_partitions; const chunk_size = if (sharding_ax) |ax| cs: { // This kind of sharding error should be detected earlier on. - meta.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); + stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); break :cs @divExact(host_buffer.dim(ax), n_partitions); } else 0; @@ -88,8 +90,8 @@ pub const Buffer = struct { /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer { - meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); - meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); + stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); + stdx.debug.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); var shards: Shards = .{}; shards.appendSliceAssumeCapacity(pjrt_buffers); return .{ @@ -190,9 +192,9 @@ pub const Buffer = struct { /// Fetches the content of the given buffer into a stack variable of the given type. pub fn getValue(self: Buffer, T: type) !T { - meta.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); + stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); var res: T = undefined; - meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); if (maybe_event) |event| { try event.await_(self._api); @@ -204,7 +206,7 @@ pub const Buffer = struct { /// and return a new `HostBuffer` object with the same shape. /// The returned `HostBuffer` doesn't own the memory. pub fn toHost(self: Buffer, output: []u8) !HostBuffer { - meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, output); if (maybe_event) |event| { try event.await_(self._api); @@ -216,7 +218,7 @@ pub const Buffer = struct { /// The returned `HostBuffer` does own the memory. pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { const output = try HostBuffer.empty(allocator, self.shape()); - meta.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); + stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); if (maybe_event) |event| { try event.await_(self._api); diff --git a/zml/context.zig b/zml/context.zig index 5997cf8..9e8c14d 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -1,22 +1,22 @@ const builtin = @import("builtin"); -const std = @import("std"); -const mlir = @import("mlir"); const c = @import("c"); +const mlir = @import("mlir"); const runfiles = @import("runfiles"); const runtimes = @import("runtimes"); +const std = @import("std"); +const stdx = @import("stdx"); const platform = @import("platform.zig"); const pjrt = @import("pjrtx.zig"); -const available_targets = @import("platform.zig").available_targets; const HostBuffer = @import("hostbuffer.zig").HostBuffer; -const Target = @import("platform.zig").Target; -const Platform = @import("platform.zig").Platform; - -const log = std.log.scoped(.zml); - const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); +const Platform = @import("platform.zig").Platform; const PlatformsMap = std.EnumArray(Target, ?Platform); +const Target = @import("platform.zig").Target; + +const available_targets = @import("platform.zig").available_targets; +const log = std.log.scoped(.@"zml/context"); /// Every program using ZML must start with a `zml.Context.init(.{});` /// The ZML context contains global state to interact with the different @@ -145,6 +145,36 @@ pub const Context = struct { return platform_ orelse @panic("No platform found !"); } + pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void { + // List available targets + log.info("Available Platforms:", .{}); + const selected_prefix = "✅"; + const not_selected_prefix = "• "; + const selected_postfix = "(AUTO-SELECTED)"; + const not_selected_postfix = ""; + + for (platform.available_targets) |target| { + log.info(" {s} {s} {s}", .{ + if (target == selected.target) selected_prefix else not_selected_prefix, + @tagName(target), + if (target == selected.target) selected_postfix else not_selected_postfix, + }); + + // now the platform's devices + if (self.platforms.get(target)) |pfm| { + for (pfm.getDevices(), 0..) |device, index| { + const deviceKind = device.getDescription(pfm.pjrt_api).getKind(pfm.pjrt_api); + log.info(" ◦ #{d}: {s}", .{ + index, + deviceKind, + }); + // we only list 1 CPU device + if (target == .cpu) break; + } + } + } + } + pub const HostCallbackCtx = struct { host: HostBuffer, mutex: std.Thread.Mutex = std.Thread.Mutex{}, diff --git a/zml/helpers.zig b/zml/helpers.zig index 2856bf2..c756825 100644 --- a/zml/helpers.zig +++ b/zml/helpers.zig @@ -6,7 +6,7 @@ const Shape = @import("shape.zig").Shape; const Tensor = @import("tensor.zig").Tensor; const EnumLiteral = @TypeOf(.enum_literal); -const log = std.log.scoped(.zml_tensor); +const log = std.log.scoped(.@"zml/tensor"); test { std.testing.refAllDecls(@This()); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 863a7c6..aed629e 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -1,6 +1,6 @@ const std = @import("std"); +const stdx = @import("stdx"); -const meta = @import("meta.zig"); const Buffer = @import("buffer.zig").Buffer; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; @@ -108,13 +108,13 @@ pub const HostBuffer = struct { /// The memory is initialized with increasing numbers. /// The caller owns the memory, and need to call `deinit()`. pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer { - meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); - meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); + stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); + stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const b = dt.sizeOf(); const res = try empty(allocator, Shape.init(.{n_steps}, dt)); - meta.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt}); + stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt}); var data_ = @constCast(res.data); switch (dt) { inline else => { @@ -201,7 +201,7 @@ pub const HostBuffer = struct { } pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { - meta.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); + stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); var res = self; res._shape = self._shape.reshape(shape_); return res; @@ -219,9 +219,9 @@ pub const HostBuffer = struct { const start: i64 = if (s.start < 0) s.start + d else s.start; var end = s.end orelse d; if (end < 0) end += d; - meta.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); - meta.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); - meta.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); + stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); + stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); + stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); // If strides weren't set it means original buffer is contiguous. // But it won't be anymore after slicing. The strides don't change though. diff --git a/zml/meta.zig b/zml/meta.zig index 826e159..0363802 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -1,4 +1,8 @@ const std = @import("std"); +const stdx = @import("stdx"); + +const FnParam = stdx.meta.FnParam; +const asSlice = stdx.meta.asSlice; const testing = std.testing; @@ -6,215 +10,6 @@ test { std.testing.refAllDecls(@This()); } -/// Computes floating point value division between two integers. -pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T { - return toFloat(T, numerator) / toFloat(T, denominator); -} - -fn toFloat(T: type, x: anytype) T { - return switch (@typeInfo(@TypeOf(x))) { - .Float => @floatCast(x), - else => @floatFromInt(x), - }; -} - -pub fn guard(check: bool, src: std.builtin.SourceLocation) void { - assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line }); -} - -pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void { - assert(check, "ZML internal error: " ++ msg, args); -} - -pub fn assert(check: bool, comptime msg: []const u8, args: anytype) void { - if (!check) panic(msg, args); -} - -pub fn panic(comptime msg: []const u8, args: anytype) noreturn { - std.log.err(msg, args); - @panic(msg); -} - -pub fn compileLog(comptime msg: []const u8, comptime args: anytype) void { - @compileLog(std.fmt.comptimePrint(msg, args)); -} - -pub fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn { - @compileError(std.fmt.comptimePrint(msg, args)); -} - -pub fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void { - if (check == false) { - compileError(msg, args); - } -} - -pub fn isStruct(comptime T: type) bool { - return switch (@typeInfo(T)) { - .Struct => true, - else => false, - }; -} - -pub fn isTuple(comptime T: type) bool { - return switch (@typeInfo(T)) { - .Struct => |info| info.is_tuple, - else => false, - }; -} - -pub fn isStructOf(comptime T: type, comptime Elem: type) bool { - return switch (@typeInfo(T)) { - .Struct => |info| blk: { - inline for (info.fields) |field| { - if (field.type != Elem) { - break :blk false; - } - } - break :blk true; - }, - else => false, - }; -} - -pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { - return switch (@typeInfo(T)) { - .Struct => |info| blk: { - inline for (info.fields) |field| { - if (f(field.type) == false) { - break :blk false; - } - } - break :blk true; - }, - else => false, - }; -} - -pub fn isTupleOf(comptime T: type, comptime Elem: type) bool { - return isTuple(T) and isStructOf(T, Elem); -} - -pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { - return isTuple(T) and isStructOfAny(T, f); -} - -pub fn isSliceOf(comptime T: type, comptime Elem: type) bool { - return switch (@typeInfo(T)) { - .Pointer => |info| switch (info.size) { - .Slice => info.child == Elem, - .One => switch (@typeInfo(info.child)) { - // As Zig, convert pointer to Array as a slice. - .Array => |arr_info| arr_info.child == Elem, - else => false, - }, - else => false, - }, - else => false, - }; -} - -pub fn asSlice(comptime T: type) type { - const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice"; - return switch (@typeInfo(T)) { - .Pointer => |info| switch (info.size) { - .Slice => info.child, - .One => switch (@typeInfo(info.child)) { - // As Zig, convert pointer to Array as a slice. - .Array => |arr_info| arr_info.child, - else => @compileError(err_msg), - }, - else => @compileError(err_msg), - }, - else => @compileError(err_msg), - }; -} - -pub fn isInteger(comptime T: type) bool { - return switch (@typeInfo(T)) { - .Int, .ComptimeInt => true, - else => false, - }; -} - -pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool { - return switch (@typeInfo(T)) { - .Pointer => |info| info.size == .Slice and f(info.child), - else => false, - }; -} - -pub fn DeclEnum(comptime T: type) type { - const field_infos = std.meta.declarations(T); - if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T}); - return std.meta.DeclEnum(UnwrapPtr(T)); -} - -pub fn UnwrapPtr(comptime T: type) type { - return switch (@typeInfo(T)) { - .Pointer => |info| switch (info.size) { - .One => info.child, - else => T, - }, - else => T, - }; -} - -pub fn FnParam(func: anytype, n: comptime_int) type { - return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse @compileError("anytype not supported in callbacks"); -} - -pub fn FnParams(func: anytype) type { - return std.meta.ArgsTuple(@TypeOf(func)); -} - -pub fn FnResult(func: anytype) type { - return @typeInfo(@TypeOf(func)).Fn.return_type.?; -} - -pub fn FnResultPayload(func: anytype) type { - const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?; - const payload_type = switch (@typeInfo(return_type)) { - .ErrorUnion => |u| u.payload, - else => return_type, - }; - return payload_type; -} - -pub fn FnResultErrorSet(func: anytype) ?type { - const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?; - const error_set = switch (@typeInfo(return_type)) { - .ErrorUnion => |u| u.error_set, - else => null, - }; - return error_set; -} - -pub fn Signature(comptime func: anytype, comptime argsT: ?type) type { - return struct { - pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); - pub const ArgsT = blk: { - if (@typeInfo(FuncT).Fn.params.len == 0) { - break :blk @TypeOf(.{}); - } - break :blk argsT orelse std.meta.ArgsTuple(FuncT); - }; - pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); - pub const ReturnPayloadT = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.payload, - else => ReturnT, - }; - }; - pub const ReturnErrorSet: ?type = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.error_set, - else => null, - }; - }; - }; -} - pub fn MapType(From: type, To: type) type { return struct { pub fn map(T: type) type { @@ -299,7 +94,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam const type_info_to_ptr = @typeInfo(@TypeOf(to)); if (type_info_to_ptr != .Pointer) { - @compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to))); + stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to))); } const ToStruct = type_info_to_ptr.Pointer.child; const type_info_to = @typeInfo(ToStruct); @@ -348,7 +143,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam } else if (field.default_value) |_| { @field(to, field.name) = null; } else { - compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name }); + stdx.meta.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name }); }, else => @field(to, field.name) = @field(from, field.name), } @@ -374,7 +169,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam } to.* = items; }, - else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), + else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), }, .Optional => if (from) |f| { to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined); @@ -383,7 +178,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam to.* = null; }, .Int, .Float => to.* = from, - else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), + else => stdx.meta.compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)), } } @@ -444,12 +239,12 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { const type_info_v = @typeInfo(T); const K = switch (@typeInfo(FnParam(cb, 1))) { .Pointer => |info| info.child, - else => @compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))), + else => stdx.meta.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))), }; if (type_info_v != .Pointer) { const Callback = @TypeOf(cb); - @compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T)); + stdx.meta.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T)); } const ptr_info = type_info_v.Pointer; if (@typeInfo(ptr_info.child) == .Fn) return; @@ -512,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { } } }, - else => @compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)), + else => stdx.meta.compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)), } } @@ -601,10 +396,8 @@ test visit { /// Only T elements of values will be looked at. /// This only works for simple types, in particular `zip` doesn't follow pointers. /// Which means that zip only allocate temp memory, and nothing need to be freed after the call. -pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) { +pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) { const sliceT = @typeInfo(FnParam(func, 0)); - assertComptime(sliceT == .Pointer and sliceT.Pointer.size == .Slice and sliceT.Pointer.child == FnResult(func), "zip requires a `fn([]const T, Args) T`, received: {}", .{@TypeOf(func)}); - const T = sliceT.Pointer.child; const V = asSlice(@TypeOf(values)); if (V == T) { @@ -613,13 +406,13 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a // const fn_args return switch (@typeInfo(V)) { - .Pointer => @compileError("zip only accept by value arguments. Received: " ++ @typeName(V)), + .Pointer => stdx.meta.compileError("zip only accept by value arguments. Received: " ++ @typeName(V)), .Struct => |struct_info| { var out: V = values[0]; inline for (struct_info.fields) |f| { if (f.is_comptime) continue; if (@typeInfo(f.type) == .Pointer) { - @compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); + stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); } var fields = try allocator.alloc(f.type, values.len); defer allocator.free(fields); @@ -632,7 +425,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a }, .Array => |arr_info| { if (@typeInfo(arr_info.child) == .Pointer) { - @compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); + stdx.meta.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V)); } var out: V = undefined; var slice = try allocator.alloc(arr_info.child, values.len); @@ -645,7 +438,7 @@ pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: a } return out; }, - .Union, .Optional => @compileError("zip doesn't yet support " ++ @typeName(V)), + .Union, .Optional => stdx.meta.compileError("zip doesn't yet support " ++ @typeName(V)), else => values[0], }; } @@ -668,11 +461,11 @@ test zip { /// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// finds all X in the given object, and write the result of func(X) into an arraylist. -pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(FnResult(func)), obj: anytype) error{OutOfMemory}!void { - assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); +pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { + stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); const LocalContext = struct { func_ctx: _CollectCtx(func), - out: *std.ArrayList(FnResult(func)), + out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), oom: bool = false, }; var context = LocalContext{ .func_ctx = func_ctx, .out = out }; @@ -691,10 +484,10 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(F fn _CollectCtx(func: anytype) type { const params = @typeInfo(@TypeOf(func)).Fn.params; if (params.len == 1) return void; - return params[0].type orelse @compileError("anytype not supported in collect"); + return params[0].type orelse stdx.meta.compileError("anytype not supported in collect"); } fn _CollectArg(func: anytype) type { const params = @typeInfo(@TypeOf(func)).Fn.params; - return params[params.len - 1].type orelse @compileError("anytype not supported in collect"); + return params[params.len - 1].type orelse stdx.meta.compileError("anytype not supported in collect"); } diff --git a/zml/mlir.zig b/zml/mlir.zig index b90763d..1b0c0a5 100644 --- a/zml/mlir.zig +++ b/zml/mlir.zig @@ -2,14 +2,14 @@ const mlir = @This(); const builtin = @import("builtin"); const std = @import("std"); +const stdx = @import("stdx"); const dtype = @import("dtype.zig"); -const meta = @import("meta.zig"); const Shape = @import("shape.zig").Shape; const Tensor = @import("tensor.zig").Tensor; -const log = std.log.scoped(.zml_mlir); +const log = std.log.scoped(.@"zml/mlir"); pub usingnamespace @import("mlir"); @@ -128,7 +128,7 @@ pub const ext = struct { } } - meta.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); + stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); } }; @@ -148,7 +148,7 @@ pub const ext = struct { const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val)); return int_attr.as(mlir.Attribute).?; }, - inline else => |_, tag| meta.panic("Unsupported data type: {any}", .{tag}), + inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), } } }; @@ -169,7 +169,7 @@ pub const ext = struct { .f16 => mlir.DenseIntOrFPElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f32 => mlir.DenseIntOrFPElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute).?, .f64 => mlir.DenseIntOrFPElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute).?, - inline else => |tag| meta.panic("Unsupported data type: {any}", .{tag}), + inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}), }; } }; diff --git a/zml/module.zig b/zml/module.zig index bf4976f..c7cbfc5 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1,32 +1,31 @@ +const asynk = @import("async"); const builtin = @import("builtin"); -const std = @import("std"); - +const dialect = @import("mlir/dialects"); +const protobuf = @import("io/protobuf"); const runfiles = @import("runfiles"); - +const std = @import("std"); +const stdx = @import("stdx"); const xla_pb = @import("//xla:xla_proto"); + const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const ops = @import("ops.zig"); const pjrt = @import("pjrtx.zig"); -const protobuf = @import("io/protobuf"); -const asynk = @import("async"); const aio = @import("aio.zig"); -const dialect = @import("mlir/dialects"); - -const assert = std.debug.assert; +const Buffer = @import("buffer.zig").Buffer; +const Bufferized = @import("tensor.zig").Bufferized; const Context = @import("context.zig").Context; const Location = mlir.Location; const Platform = @import("platform.zig").Platform; +const Shape = @import("shape.zig").Shape; +const ShapeOf = @import("tensor.zig").ShapeOf; const Target = @import("platform.zig").Target; const Tensor = @import("tensor.zig").Tensor; -const ShapeOf = @import("tensor.zig").ShapeOf; -const Shape = @import("shape.zig").Shape; -const Buffer = @import("buffer.zig").Buffer; -const Bufferized = @import("tensor.zig").Bufferized; const Tracer = @import("tools/tracer.zig").Tracer; -const log = std.log.scoped(.zml_module); +const assert = std.debug.assert; +const log = std.log.scoped(.@"zml/module"); test { std.testing.refAllDecls(@This()); @@ -101,7 +100,7 @@ pub const CompilationContext = struct { } pub fn deactivate(self: *CompilationContext) void { - std.debug.assert(_current != null and _current.? == self); + assert(_current != null and _current.? == self); _current = self._previous; self._previous = null; } @@ -163,7 +162,7 @@ pub const CompilationContext = struct { // So we create a copy of the arguments, and replace values // by the block arguments. var blk_args = args; - assert(assignBlockArguments(&blk_args, block, 0) == N); + std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N); const loc = self.mlirCtx().location(@src()); const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args)); @@ -209,9 +208,9 @@ pub const CompilationContext = struct { var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); meta.collect(Tensor.shape, {}, &input_shapes, model) catch unreachable; - meta.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{}); + stdx.debug.internalAssert(input_shapes.items.len == model_tensor_count, "model has changed ?", .{}); meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; - meta.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); + stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); const input_types = try arena.alloc(mlir.Type, tensor_count); for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); @@ -311,7 +310,7 @@ pub const CompilationContext = struct { // This will break the day we writer another attribute before donation. // When the time come, do a more fancy lookup here to check if an argument // is donated twice. - meta.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); + stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); attributes[a].appendAssumeCapacity( mlir.NamedAttribute.init( mlir.Identifier.get(self.mlirCtx(), "tf.aliasing_output"), @@ -507,7 +506,7 @@ pub const CompilationContext = struct { extractValues(args, values[function.n_model..]); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); - var res: meta.FnResult(func) = undefined; + var res: stdx.meta.FnResult(func) = undefined; assignResults(&res, function.res_shapes, op); return res; } @@ -531,7 +530,7 @@ pub const CompilationContext = struct { const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id); if (res.found_existing) { - std.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id }); + stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id }); } else { res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; } @@ -677,9 +676,9 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3 }; meta.visit((struct { fn cb(ctx: *LocalContext, buffer: *const Buffer) void { - // meta.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); + // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); const model_sharding = ctx.buffers.len; - meta.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); + stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); for (buffer._shards.constSlice(), 0..) |shard, d| { ctx.buffers[d][ctx.index] = shard; } @@ -718,7 +717,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); } }).cb, &local_ctx, v); - meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); + stdx.debug.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); } /// Visit the given struct and assign op results to each tensor found. @@ -761,6 +760,13 @@ const BaseExe = struct { num_devices: u8, /// Allocator backing result_buffer_shapes and deinit by ExeWithWeights _allocator: std.heap.ArenaAllocator, + + pub fn serialize(self: BaseExe, writer: anytype) !void { + var executable = try self.exe.getExecutable(self.pjrt_api); + var serialize_result = try executable.serialize(self.platform.pjrt_api); + defer serialize_result.deinit(); + try writer.writeAll(serialize_result.bytes); + } }; /// Represents a ZML model, compiled into a PJRT executable. @@ -779,6 +785,16 @@ pub fn Exe(comptime func: anytype) type { pub fn prepare(self: Self, allocator: std.mem.Allocator, model: Bufferized(Signature.ModelT)) !ExeWithWeights(func) { return ExeWithWeights(func).initFromModel(allocator, self.inner, model); } + + pub fn serialize(self: Self, writer: anytype) !void { + return try self.inner.serialize(writer); + } + + // pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self { + // const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size); + // defer allocator.free(bytes); + // return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes); + // } }; } @@ -906,7 +922,7 @@ fn compileInternal( var timer = std.time.Timer.start() catch null; const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callBlockingGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); + const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); context._module.getBody().appendOperation(f.mlir_fn); const sharding = context._platform.sharding(); @@ -927,7 +943,7 @@ fn compileInternal( if (timer) |*t| { const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); - if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)}); + if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloor(f32, time_ms, 1000)}); } var arena_state_exe = std.heap.ArenaAllocator.init(allocator); @@ -945,12 +961,7 @@ fn compileInternal( }; } -/// Compiles a Model struct with the given configuration and shapes, for the given platform. -/// The steps are: -/// * lookup at tensors available in the store and create a `model: Model` struct with them -/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config -/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments -pub fn compile( +pub fn load( allocator: std.mem.Allocator, comptime Model: type, init_args: anytype, @@ -973,33 +984,62 @@ pub fn compile( return compileModel(allocator, model, func, args_shapes, platform); } +/// Compiles a Model struct with the given configuration and shapes, for the given platform. +/// The steps are: +/// * lookup at tensors available in the store and create a `model: Model` struct with them +/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config +/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments +pub fn compile( + allocator: std.mem.Allocator, + comptime func: anytype, + init_args: anytype, + args_shapes: ShapeOf(ModuleSignature(func).ArgsT), + buffer_store: aio.BufferStore, + platform: Platform, +) !Exe(func) { + const ModelT = ModuleSignature(func).ModelT; + + var arena_state = std.heap.ArenaAllocator.init(allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + var model = try aio.populateModel(ModelT, arena, buffer_store); + + // If the Model has a "init" function, call it with the given parameters. + if (@hasDecl(ModelT, "init")) { + // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void. + @call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args); + } + + return compileModel(allocator, func, model, args_shapes, platform); +} + /// Compiles a Model struct with the given configuration and shapes, for the given platform. /// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments pub fn compileModel( allocator: std.mem.Allocator, - model: anytype, - comptime func: @TypeOf(.literal), - args_shapes: ShapeOf(ModuleSignature(@field(@TypeOf(model), @tagName(func))).ArgsT), + comptime func: anytype, + model: ModuleSignature(func).ModelT, + args_shapes: ShapeOf(ModuleSignature(func).ArgsT), platform: Platform, -) !Exe(@field(@TypeOf(model), @tagName(func))) { - const Model = @TypeOf(model); - const name = @typeName(Model) ++ "." ++ @tagName(func); +) !Exe(func) { + const ModelT = ModuleSignature(func).ModelT; + const name = @typeName(ModelT) ++ ".forward"; log.info("Compiling {s} with {}", .{ name, args_shapes }); var context = try CompilationContext.init(allocator, name, platform); defer context.deinit(); - const raw_module = try compileInternal(allocator, &context, @field(Model, @tagName(func)), model, args_shapes); + const raw_module = try compileInternal(allocator, &context, func, model, args_shapes); - return Exe(@field(Model, @tagName(func))){ .inner = raw_module }; + return .{ .inner = raw_module }; } /// Compiles a function with the given configuration and shapes, for the given platform. /// Generate MLIR by calling the given function with tensor of the given shapes. pub fn compileFn( allocator: std.mem.Allocator, - func: anytype, - args: ShapeOf(meta.FnParams(func)), + comptime func: anytype, + args: ShapeOf(stdx.meta.FnArgs(func)), platform: Platform, ) !ExeWithWeights(FnWithVoidArg(func)) { const name = @typeName(@TypeOf(func)); @@ -1008,7 +1048,7 @@ pub fn compileFn( const Local = struct { // This is the function we will actually compile. - pub fn forward(_: void, inner_args: meta.FnParams(func)) meta.FnResult(func) { + pub fn forward(_: void, inner_args: stdx.meta.FnArgs(func)) stdx.meta.FnResult(func) { return @call(.auto, func, inner_args); } }; @@ -1019,10 +1059,10 @@ pub fn compileFn( return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model); } -fn FnWithVoidArg(func: anytype) type { +fn FnWithVoidArg(comptime func: anytype) type { const fn_info = @typeInfo(@TypeOf(func)).Fn; const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void }; - meta.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)}); + stdx.debug.assertComptime(!fn_info.is_generic, "Can't do reflection on generic function: {}", .{@TypeOf(func)}); return @Type(.{ .Fn = .{ .calling_convention = fn_info.calling_convention, .is_generic = false, @@ -1268,10 +1308,10 @@ pub fn ModuleSignature(comptime func: anytype) Sign { const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); return .{ .FuncT = FuncT, - .ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create,ModuleSignature for function with an 'anytype' parameter"), + .ModelT = @typeInfo(FuncT).Fn.params[0].type orelse @compileError("cannot create ModuleSignature for function with an 'anytype' parameter"), .ArgsT = blk: { const function_info = @typeInfo(FuncT); - if (function_info.Fn.params[1..].len == 0) { + if (function_info.Fn.params.len < 2) { break :blk @TypeOf(.{}); } diff --git a/zml/nn.zig b/zml/nn.zig index 44a9065..0a0abd3 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -1,20 +1,20 @@ //! Common layer definition and functions for Neural Networks (NN) const std = @import("std"); -const assert = std.debug.assert; -const testing = std.testing; +const stdx = @import("stdx"); -const zml = @import("zml.zig"); -const meta = @import("meta.zig"); +const cuda = @import("nn/cuda.zig"); const helpers = @import("helpers.zig"); +const meta = @import("meta.zig"); const ops = @import("ops.zig"); +const zml = @import("zml.zig"); const DataType = @import("dtype.zig").DataType; const Shape = @import("shape.zig").Shape; const Tensor = @import("tensor.zig").Tensor; -const log = std.log.scoped(.zml_tensor); - -const cuda = @import("nn/cuda.zig"); +const assert = std.debug.assert; +const log = std.log.scoped(.@"zml/tensor"); +const testing = std.testing; test { _ = cuda; @@ -41,8 +41,8 @@ pub const TokenEmbedding = struct { weight: Tensor, pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { - meta.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); - meta.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); + stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); + stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); return self.weight.gatherValues(0, idx, .{}); } }; @@ -159,7 +159,7 @@ pub const CosSin = [2]Tensor; /// See: https://paperswithcode.com/method/rope pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor { const cos, const sin = cos_sin_cache; - meta.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin }); + stdx.debug.assert(x.dim(-1) == 2 * cos.dim(-1), "Couldn't compute rope({}, {}, {})", .{ x, cos, sin }); // broadcast cos / sin to .{ batch, .seq, .half_dim } const x_real, const x_imag = splitRealImg(x, opts.impl); const has_tags = cos.shape().tag(0) != Shape.TagUnknown; @@ -178,9 +178,9 @@ pub fn rope(x: Tensor, cos_sin_cache: CosSin, opts: RopeOpts) Tensor { pub fn ropeCosSin(sh: anytype, dtype: DataType, opts: RopeOpts) CosSin { const shape = Shape.init(sh, dtype); - meta.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape}); + stdx.debug.assert(shape.rank() == 2, "ropeCosSin({}) shape need to exactly have 2 axes", .{shape}); const seq_len, const head_dim = .{ shape.dim(0), shape.dim(1) }; - meta.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim}); + stdx.debug.assert(@mod(head_dim, 2) == 0, "ropeCosSin requires an even head_dim, got {}", .{head_dim}); // compute sin and cos in f32 before downcasting to x type. const inv_freq = invFreq(head_dim, opts.freq_base, .f32); @@ -364,8 +364,8 @@ pub fn upsample( ) Tensor { // TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here. // resize* have API which are more explicit, this assume you want to scale the N-2 last axes. - meta.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input}); - meta.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{}); + stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input}); + stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{}); return switch (opts.mode) { .nearest => { var scale_factors: [3]f64 = undefined; @@ -398,7 +398,7 @@ pub fn nearest(input: Tensor, scale_factor: []const f64) Tensor { var res = input; for (spatial_dims) |d| { const n = out_shape.dim(d); - const ratio = meta.divFloat(f32, input.dim(d), n); + const ratio = stdx.math.divFloor(f32, input.dim(d), n); const offsets = Tensor.arange(.{ .end = n }, .f32).addConstant(0.5).scale(ratio).floor().convert(.i32); res = res.gatherValues(d, offsets, .{ .indices_are_sorted = true }); } @@ -576,7 +576,7 @@ pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Te const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); - const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); + const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len)); const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const left = scaled.floor(); const right = left.addConstant(1); @@ -638,7 +638,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); - const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); + const ratio = og_len.convert(dtype).scale(stdx.math.divFloor(f32, 1, new_len)); const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const t = scaled.sub(scaled.floor()); const pos = Tensor.stack(&.{ @@ -693,11 +693,11 @@ pub fn causalAttnMask( attn_window_len: ?u32, ) Tensor { const attn_shape = Shape.init(attn_shape_, dtype); - meta.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape}); + stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape}); const qlen = attn_shape.dim(-2); - const q_idx = Tensor.iota(attn_shape, .i32, -2); + const q_idx = Tensor.iota(attn_shape, -2); const klen = attn_shape.dim(-1); - const k_idx = Tensor.iota(attn_shape, .i32, -1); + const k_idx = Tensor.iota(attn_shape, -1); // all elements > main diagonal must be 0 // (q_idx - window_len < k_idx <= q_idx) @@ -748,16 +748,16 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; const err_args = .{ q, k, v, opts.attn_mask }; - meta.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); - meta.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); - meta.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); + stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); + stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); + stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) { return cuda.sdpa(q, k, v, opts); } if (q.dim(.h) != k.dim(.h)) { - meta.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); + stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); // Note: we don't try to repeat queries. // Repeating keys is the interesting optimisation cause it reduces KV cache memory usage. const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h))); @@ -766,7 +766,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { const attn_mask = if (opts.attn_mask) |m| m else null; const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { - meta.panic(err_template ++ "Inputs have incompatible shapes.", err_args); + stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args); }; const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); diff --git a/zml/ops.zig b/zml/ops.zig index f9044d8..e171047 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -1,11 +1,13 @@ const std = @import("std"); -const mlir = @import("mlir.zig"); +const stdx = @import("stdx"); +const buffer = @import("buffer.zig"); const helpers = @import("helpers.zig"); -const module = @import("module.zig"); const meta = @import("meta.zig"); +const mlir = @import("mlir.zig"); +const module = @import("module.zig"); -const Buffer = @import("buffer.zig").Buffer; +const Buffer = buffer.Buffer; const CompilationContext = module.CompilationContext; const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; @@ -20,14 +22,14 @@ const dialect = struct { }; const assert = std.debug.assert; -const log = std.log.scoped(.zml); +const log = std.log.scoped(.@"zml/tensor"); test { std.testing.refAllDecls(@This()); } /// Generate an MLIR call to the given member function with the given tensors. -pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { +pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { // TODO: this should use `self.getContext().callFunc(self, args)` return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args); @@ -121,8 +123,8 @@ test "simple while" { pub fn reduce( comptime body_fn: anytype, - inputs: meta.FnParam(body_fn, 0), - inits: meta.FnParam(body_fn, 0), + inputs: stdx.meta.FnParam(body_fn, 0), + inits: stdx.meta.FnParam(body_fn, 0), axes: []const i64, ) BlockSignNoCtx(body_fn).Return { // TODO: actualAxes @@ -155,7 +157,7 @@ pub fn reduce( // `stablehlo.reduce` drops axes. We want to avoid that to propagate tags. // So we need to broadcast the output of `stablehlo.reduce` to the input shapes. - // To that order, we initialize `result` to `inputs`, then we use meta.visit, + // To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit, // to find the correct mlir.Value, but we first broadcast before creating the final // Tensor struct. var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; @@ -217,10 +219,10 @@ pub const ReduceWindowOpts = struct { pub fn reduceWindow( comptime body_fn: anytype, - inputs: meta.FnParam(body_fn, 0), - inits: meta.FnParam(body_fn, 0), + inputs: stdx.meta.FnParam(body_fn, 0), + inits: stdx.meta.FnParam(body_fn, 0), opts: ReduceWindowOpts, -) meta.FnResult(body_fn) { +) stdx.meta.FnResult(body_fn) { const BodyS = comptime BlockSignNoCtx(body_fn); comptime { if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); @@ -263,7 +265,7 @@ pub fn reduceWindow( pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: anytype) BlockSign(func).Return { const num_steps: u32, const step_tag = blk: { const dims, const tags = Shape.parseDimensions(num_steps_); - meta.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_}); + stdx.debug.assert(dims.len == 1, "zml.for_ only supports one num_step, Received: {any}", .{num_steps_}); break :blk .{ @intCast(dims.get(0)), tags.get(0) }; }; const S = comptime BlockSign(func); @@ -290,7 +292,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: } fn updateResBuffer(inputs: []const Tensor, idx: Tensor) Tensor { - meta.internalAssert(inputs.len == 2, "too many tensors", .{}); + stdx.debug.internalAssert(inputs.len == 2, "too many tensors", .{}); const res, const step_res = inputs[0..2].*; return res.dynamicUpdateSlice1d(step_res.insertAxes(0, .{._}), 0, idx); } diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 9be7129..d50724f 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -1,22 +1,21 @@ -const builtin = @import("builtin"); -const std = @import("std"); - const asynk = @import("async"); +const builtin = @import("builtin"); +const dialects = @import("mlir/dialects"); const mlir = @import("mlir"); - const pjrt = @import("pjrt"); +const std = @import("std"); +const stdx = @import("stdx"); + const dtype = @import("dtype.zig"); const meta = @import("meta.zig"); -const dialects = @import("mlir/dialects"); - -pub const Profiler = pjrt.Profiler; -pub const ApiError = pjrt.ApiError; -pub const ErrorCode = pjrt.ErrorCode; const Target = @import("platform.zig").Target; const log = std.log.scoped(.zml); +pub const Profiler = pjrt.Profiler; +pub const ApiError = pjrt.ApiError; +pub const ErrorCode = pjrt.ErrorCode; pub const Buffer = pjrt.Buffer; pub const BufferType = pjrt.BufferType; pub const Device = pjrt.Device; @@ -181,14 +180,16 @@ pub const LoadedExecutable = opaque { return self.inner().getAddressableDevices(api); } - pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct { + pub const ExecuteArgs = struct { arguments: []const [*]const *const Buffer, num_args: usize, results: []const [*]*Buffer, events: []?*Event, non_donatable_input_indices: []const i64 = &.{}, - }) ExecuteError!void { - try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, .{ + }; + + pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void { + try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{ .num_args = args.num_args, .arguments = @ptrCast(args.arguments), .results = @ptrCast(args.results), diff --git a/zml/platform.zig b/zml/platform.zig index d037c2b..85b253c 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -1,28 +1,18 @@ -const builtin = @import("builtin"); -const std = @import("std"); - const asynk = @import("async"); +const builtin = @import("builtin"); const runtimes = @import("runtimes"); +const std = @import("std"); +const stdx = @import("stdx"); const meta = @import("meta.zig"); const module = @import("module.zig"); const pjrt = @import("pjrtx.zig"); + const log = std.log.scoped(.zml); pub const Target = runtimes.Platform; -pub const available_targets = switch (builtin.os.tag) { - .macos => [_]Target{ - .cpu, - }, - .linux => [_]Target{ - .cpu, - .cuda, - .rocm, - .tpu, - }, - else => [_]Target{}, -}; +pub const available_targets = std.enums.values(Target); pub const CompilationOptions = struct { xla_dump_to: ?[]const u8 = null, diff --git a/zml/shape.zig b/zml/shape.zig index d4a9abc..bafc886 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -1,11 +1,12 @@ const builtin = @import("builtin"); const std = @import("std"); +const stdx = @import("stdx"); + const testing = std.testing; -const meta = @import("meta.zig"); const DataType = @import("dtype.zig").DataType; - const EnumLiteral = @TypeOf(.enum_literal); + const log = std.log.scoped(.shape); test { @@ -39,7 +40,7 @@ pub const Shape = struct { return .{ v._dims, v._tags }; } - if (comptime meta.isSliceOfAny(T, meta.isInteger)) { + if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger)) { var dims_ = DimsArray.init(0) catch unreachable; var tags_ = TagsArray.init(0) catch unreachable; for (v) |d| { @@ -49,19 +50,19 @@ pub const Shape = struct { return .{ dims_, tags_ }; } - if (comptime meta.isStruct(T)) { + if (comptime stdx.meta.isStruct(T)) { var dims_: DimsArray = .{}; var tags_: TagsArray = .{}; inline for (std.meta.fields(T)) |field| { const fv = @field(v, field.name); - if (comptime meta.isInteger(field.type)) { + if (comptime stdx.meta.isInteger(field.type)) { dims_.appendAssumeCapacity(@intCast(fv)); } else if (comptime isAutoDim(fv)) { dims_.appendAssumeCapacity(-1); } else { - meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name}); + stdx.meta.compileError("Field {s} should be an integer or an auto dimension", .{field.name}); } - if (comptime meta.isTuple(T)) { + if (comptime stdx.meta.isTuple(T)) { tags_.appendAssumeCapacity(TagUnknown); } else { tags_.appendAssumeCapacity(toTag(field)); @@ -71,7 +72,7 @@ pub const Shape = struct { return .{ dims_, tags_ }; } - meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T}); + stdx.meta.compileError("expected a dimension tuple eg '.{{ .a = 10, .b = 20}}' or '.{{ 10, 20 }}', got {}", .{T}); } test parseDimensions { @@ -92,7 +93,7 @@ pub const Shape = struct { var axes_ = AxesArray.init(0) catch unreachable; var tags_ = TagsArray.init(0) catch unreachable; - if (comptime meta.isSliceOfAny(T, isAxisConvertible)) { + if (comptime stdx.meta.isSliceOfAny(T, isAxisConvertible)) { for (v) |d| { axes_.appendAssumeCapacity(self.axis(d)); tags_.appendAssumeCapacity(self.tag(d)); @@ -100,7 +101,7 @@ pub const Shape = struct { return .{ axes_, tags_ }; } - if (comptime meta.isTupleOfAny(T, isAxisConvertible)) { + if (comptime stdx.meta.isTupleOfAny(T, isAxisConvertible)) { inline for (std.meta.fields(T)) |field| { axes_.appendAssumeCapacity(self.axis(@field(v, field.name))); tags_.appendAssumeCapacity(self.tag(@field(v, field.name))); @@ -108,12 +109,12 @@ pub const Shape = struct { return .{ axes_, tags_ }; } - meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T}); + stdx.meta.compileError("Wrong type, got {}. Expected .{{.a, .b}}", .{T}); } pub fn parseTags(v: anytype) TagsArray { const T = @TypeOf(v); - meta.assertComptime(meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T}); + stdx.debug.assertComptime(stdx.meta.isTupleOf(T, EnumLiteral), "Wrong type, got {}. Expected .{{ .a, .b }}", .{T}); var tags_ = TagsArray.init(0) catch unreachable; inline for (v) |field| { tags_.appendAssumeCapacity(toTag(field)); @@ -135,7 +136,7 @@ pub const Shape = struct { var res: Shape = .{ ._dtype = dt }; for (0..rank_) |i| { res._dims.append(@intCast(i)) catch { - meta.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ }); + stdx.debug.panic("Too many dimensions! Max: {d}, passed: {d}", .{ res._dims.capacity(), rank_ }); }; res._tags.append(TagUnknown) catch unreachable; } @@ -162,7 +163,7 @@ pub const Shape = struct { } fn isAxisConvertible(comptime T: type) bool { - return meta.isInteger(T) or isTagConvertible(T); + return stdx.meta.isInteger(T) or isTagConvertible(T); } fn isTagConvertible(comptime T: type) bool { @@ -180,12 +181,12 @@ pub const Shape = struct { EnumLiteral => @tagName(v).ptr, std.builtin.Type.StructField => v.name.ptr, Tag => v, - else => meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), + else => stdx.meta.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), }; } inline fn ensureDimsAndTagsAreSync(self: Shape) void { - meta.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len }); + stdx.debug.assert(self._dims.len == self._tags.len, "Tags and dims have diverged! dims={d} tags={d}", .{ self._dims.len, self._tags.len }); } pub fn tag(self: Shape, ax: anytype) Tag { @@ -220,7 +221,7 @@ pub const Shape = struct { pub fn hasTags(self: Shape, tagz: anytype) bool { const T = @TypeOf(tagz); - if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { + if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { for (tagz) |t| { if (self.hasTag(t) == null) { return false; @@ -229,7 +230,7 @@ pub const Shape = struct { return true; } - if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { + if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { inline for (tagz) |t| { if (self.hasTag(t) == null) { return false; @@ -238,7 +239,7 @@ pub const Shape = struct { return true; } - meta.compileError("Expected tuple of tags, got {any}", .{T}); + stdx.meta.compileError("Expected tuple of tags, got {any}", .{T}); } pub fn isFullyTagged(self: Shape) bool { @@ -252,7 +253,7 @@ pub const Shape = struct { self.ensureDimsAndTagsAreSync(); const T = @TypeOf(axis_); - if (comptime meta.isInteger(T)) { + if (comptime stdx.meta.isInteger(T)) { return self.axisFromInt(@intCast(axis_)); } @@ -260,7 +261,7 @@ pub const Shape = struct { return self.axisFromTag(toTag(axis_)); } - meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T}); + stdx.meta.compileError("Wrong axis type, expected .literal, or an integer, got: {any}", .{T}); } pub fn axes(self: Shape, axes_: anytype) AxesArray { @@ -274,27 +275,27 @@ pub const Shape = struct { var res = AxesArray.init(0) catch unreachable; - if (comptime meta.isSliceOfAny(T, meta.isInteger) or meta.isSliceOf(T, Tag)) { + if (comptime stdx.meta.isSliceOfAny(T, stdx.meta.isInteger) or stdx.meta.isSliceOf(T, Tag)) { for (axes_) |ax| { res.appendAssumeCapacity(self.axis(ax)); } return res; } - if (comptime meta.isStruct(T)) { + if (comptime stdx.meta.isStruct(T)) { inline for (std.meta.fields(T)) |field| { res.appendAssumeCapacity(self.axis(@field(axes_, field.name))); } return res; } - meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); + stdx.meta.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); } fn axisFromInt(self: Shape, d: isize) u3 { const rk: i8 = self.rank(); if (d < -rk or d > rk) { - meta.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); + stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); } return if (d < 0) @intCast(d + rk) @@ -323,9 +324,9 @@ pub const Shape = struct { } fn axisFromTag(self: Shape, d: Tag) u3 { - meta.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self }); + stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self }); return self.axisFromTagMaybe(d) orelse { - meta.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d }); + stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d }); }; } @@ -339,7 +340,7 @@ pub const Shape = struct { pub fn count(self: Shape) usize { var res: i64 = 1; for (self.dims()) |d| { - meta.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self}); + stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self}); res *= d; } return @intCast(res); @@ -398,12 +399,12 @@ pub const Shape = struct { var new_shape: Shape = .{ ._dtype = self.dtype() }; new_shape._dims, new_shape._tags = parseDimensions(new_shape_); new_shape.inferMissingAxis(self.count()); - meta.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() }); + stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() }); return new_shape; } fn inferMissingAxis(self: *Shape, n_: usize) void { - meta.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*}); + stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*}); const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return; // We can't use `self.count()` yet cause we have negative dims. @@ -481,7 +482,7 @@ pub const Shape = struct { } pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape { - meta.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self}); + stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self}); const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last) self.rank() @@ -573,23 +574,23 @@ pub const Shape = struct { var res = self; - if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { - meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz }); + if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { + stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz }); for (tagz, 0..) |tag_, i| { res._tags.set(i, toTag(tag_)); } return res; } - if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { - meta.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz }); + if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { + stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz }); inline for (tagz, 0..) |tag_, i| { res._tags.set(i, toTag(tag_)); } return res; } - meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); + stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); } test withTags { @@ -620,23 +621,23 @@ pub const Shape = struct { var res = self; - if (comptime meta.isSliceOf(T, Tag) or meta.isSliceOf(T, EnumLiteral)) { - meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz }); + if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { + stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz }); for (tagz, self.rank() - tagz.len..) |tag_, i| { res._tags.set(i, toTag(tag_)); } return res; } - if (comptime meta.isTupleOf(T, Tag) or meta.isTupleOf(T, EnumLiteral)) { - meta.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz }); + if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { + stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz }); inline for (tagz, self.rank() - tagz.len..) |tag_, i| { res._tags.set(i, toTag(tag_)); } return res; } - meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); + stdx.meta.compileError("Expected a tuple of enum literals eg: .{ .a, .b, .c } got: {any}", .{@TypeOf(tagz)}); } test withPartialTags { @@ -683,7 +684,7 @@ pub const Shape = struct { /// Shape.init(.{ .a = 10, .b = 20 }).rename(.{ .b = .batch }); // .{ .a = 10, .batch = 20 }; pub fn rename(self: Shape, renames: anytype) Shape { const T = @TypeOf(renames); - meta.assertComptime(meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T}); + stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T}); var res = self; inline for (std.meta.fields(T)) |field| { res._tags.set(self.axis(field), toTag(@field(renames, field.name))); @@ -789,7 +790,7 @@ pub const Shape = struct { pub fn splitAxes(self: Shape, axes_: anytype) Shape { const T = @TypeOf(axes_); - meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); + stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); var res = self; inline for (std.meta.fields(T)) |field| { @@ -822,7 +823,7 @@ pub const Shape = struct { var new_dim: i64 = 1; for (axes__.constSlice(), first_axis..) |ax, counter| { new_dim *= self.dim(ax); - meta.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ }); + stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ }); } var new_shape = self; @@ -863,11 +864,11 @@ pub const Shape = struct { pub fn mergeAxes(self: Shape, axes_: anytype) Shape { const T = @TypeOf(axes_); - meta.assertComptime(meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); + stdx.debug.assertComptime(stdx.meta.isStruct(T), "Must pass struct of enum literals like .{ .a = .{ .a1, .a2 } }. Passed: {any}", .{T}); var res = self; inline for (std.meta.fields(T)) |field| { - meta.assertComptime(meta.isTupleOfAny(field.type, isAxisConvertible) or meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type}); + stdx.debug.assertComptime(stdx.meta.isTupleOfAny(field.type, isAxisConvertible) or stdx.meta.isSliceOfAny(field.type, isAxisConvertible), "Must pass struct of axes. Passed: {any}", .{field.type}); res = res.mergeAxis(field, @field(axes_, field.name)); } return res; @@ -912,28 +913,28 @@ pub const Shape = struct { var vals_: std.BoundedArray(T, MAX_RANK) = .{}; var tags_: TagsArray = .{}; - if (comptime meta.isSliceOf(V, T)) { + if (comptime stdx.meta.isSliceOf(V, T)) { for (v) |d| { vals_.appendAssumeCapacity(d); } return .{ vals_, tags_ }; } - if (comptime meta.isStruct(V)) { + if (comptime stdx.meta.isStruct(V)) { const fields = std.meta.fields(V); - meta.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK }); + stdx.debug.assertComptime(fields.len <= MAX_RANK, "Too many fields in struct {} ({d}). Max supported is {d}.", .{ V, fields.len, MAX_RANK }); inline for (fields) |field| { const fv = @field(v, field.name); vals_.appendAssumeCapacity(fv); - if (!comptime meta.isTuple(V)) { + if (!comptime stdx.meta.isTuple(V)) { tags_.appendAssumeCapacity(toTag(field)); } } return .{ vals_, tags_ }; } - meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); + stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); } test parseStruct { @@ -948,17 +949,17 @@ pub const Shape = struct { const V = @TypeOf(options); var res: std.BoundedArray(T, MAX_RANK) = .{}; - if (comptime meta.isSliceOf(V, T)) { - meta.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); + if (comptime stdx.meta.isSliceOf(V, T)) { + stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); for (options) |d| { res.appendAssumeCapacity(d); } } - if (comptime meta.isStruct(V)) { + if (comptime stdx.meta.isStruct(V)) { for (0..self.rank()) |_| res.appendAssumeCapacity(default); const fields = std.meta.fields(V); - meta.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len }); + stdx.debug.assertComptime(fields.len <= MAX_RANK, "expects up to {} options struct literal, got {}", .{ V, MAX_RANK, fields.len }); inline for (fields) |field| { const a = self.axis(field); res.buffer[a] = @field(options, field.name); @@ -966,7 +967,7 @@ pub const Shape = struct { return res; } - meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); + stdx.meta.compileError("parseStruct expects struct or tuple, got {}", .{V}); } test parseAxesOptions { diff --git a/zml/tensor.zig b/zml/tensor.zig index 8be051f..0af4559 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1,7 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); -const assert = std.debug.assert; -const testing = std.testing; +const stdx = @import("stdx"); const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); @@ -22,7 +21,9 @@ const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; -const scoped_log = std.log.scoped(.zml_tensor); +const assert = std.debug.assert; +const testing = std.testing; +const scoped_log = std.log.scoped(.@"zml/tensor"); test { std.testing.refAllDecls(Tensor); @@ -99,7 +100,7 @@ pub const Tensor = struct { if (builtin.mode == .Debug) { // Check that the MLIR value actually have the same shape. const other = fromMlirValue(val); - meta.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape }); + stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape }); } return res; @@ -112,7 +113,7 @@ pub const Tensor = struct { const ranked_tensor = val.getType().as(mlir.RankedTensorType).?; const n = ranked_tensor.getRank(); - meta.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); + stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) }; for (0..n) |i| { @@ -213,7 +214,7 @@ pub const Tensor = struct { /// For `reuseBuffer` to be effective, it needs to propagate all the way through the output. pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor { // Note: check donation docs, this may be too permissive. - meta.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin }); + stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin }); // TODO: should we store all donations inside the context ? var res = self; @@ -262,7 +263,7 @@ pub const Tensor = struct { break :gt res; } else lt: { // several contiguous elements of self maps to one element of the result - meta.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size }); + stdx.debug.assert(self.dim(-1) * src_bit_size == tgt_bit_size, "bitcast expects elements of the input tensor last dimension to map to one element of the target datatype, got {0} elements (bitsize of {0}x{1}={2}) and {3} (bitsize of {4})", .{ self.dim(-1), src_bit_size, self.dim(-1) * src_bit_size, dt, tgt_bit_size }); break :lt self._shape.remove(-1); }; @@ -295,7 +296,7 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise number of bits set in the input Tensor. pub fn popcnt(self: Tensor) Tensor { - meta.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()}); + stdx.debug.assert(self.dtype().isInteger(), "popcnt expects tensor type to be an integer, got {}", .{self.dtype()}); const loc = self.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.popcnt(self.getContext().mlirCtx(), self.value(), loc); return _result(self._shape, op.result(0)); @@ -358,7 +359,7 @@ pub const Tensor = struct { /// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true /// and upper-triangular otherwise. pub fn cholesky(self: Tensor, lower: bool) Tensor { - meta.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()}); + stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower}); const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc); @@ -367,8 +368,8 @@ pub const Tensor = struct { /// Solves the system of linear equations formed by the input tensors. pub fn triangularSolve(self: Tensor, other: Tensor, opts: dialect.stablehlo.TriangularSolveOpts) Tensor { - meta.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); - meta.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); + stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); + stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts}); const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts); @@ -377,7 +378,7 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties away from zero, of the input Tensor. pub fn roundNearestAfz(self: Tensor) Tensor { - meta.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()}); + stdx.debug.assert(self.dtype().isFloat(), "roundNearestAfz expects tensor type to be a float, got {}", .{self.dtype()}); const loc = self.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.round_nearest_afz(self.getContext().mlirCtx(), self.value(), loc); @@ -386,7 +387,7 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise rounding towards the nearest integer, breaking ties towards the even integer, of the input Tensor. pub fn roundNearestEven(self: Tensor) Tensor { - meta.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()}); + stdx.debug.assert(self.dtype().isFloat(), "roundNearestEven expects tensor type to be a float, got {}", .{self.dtype()}); const loc = self.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.round_nearest_even(self.getContext().mlirCtx(), self.value(), loc); @@ -395,8 +396,8 @@ pub const Tensor = struct { /// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors. pub fn complex(re: Tensor, im: Tensor) Tensor { - meta.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape }); - meta.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()}); + stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape }); + stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()}); const loc = re.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.complex(re.getContext().mlirCtx(), re.value(), im.value(), loc); @@ -408,7 +409,7 @@ pub const Tensor = struct { /// /// Tensor type can float or complex. pub fn real(self: Tensor) Tensor { - meta.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()}); + stdx.debug.assert(self.dtype().isComplex() or self.dtype().isFloat(), "real expects tensor type to be a float or a complex, got {}", .{self.dtype()}); if (self.dtype().isFloat()) { return self; @@ -428,7 +429,7 @@ pub const Tensor = struct { /// /// Tensor type can float or complex. pub fn imag(self: Tensor) Tensor { - meta.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()}); + stdx.debug.assert(self.dtype().isFloat() or self.dtype().isComplex(), "imag expects tensor type to be a float or a complex, got {}", .{self.dtype()}); // Real tensors don't have imaginary part. if (self.dtype().isFloat()) { @@ -450,18 +451,18 @@ pub const Tensor = struct { pub fn fft(self: Tensor, opts: dialect.stablehlo.FftOpts) Tensor { // TODO: support tagged API. - meta.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len}); - meta.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() }); + stdx.debug.assert(1 <= opts.length.len and opts.length.len <= 3, "fft expects 'opts.length' length to be between 1 and 3 (inclusive), got {}", .{opts.length.len}); + stdx.debug.assert(opts.length.len <= self.rank(), "fft expects 'opts.length' length to be less than tensor rank, got {} and {}", .{ opts.length.len, self.rank() }); const sh = switch (opts.kind) { .FFT, .IFFT => blk: { - meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); + stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); break :blk self._shape; }, .RFFT => blk: { - meta.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() }); - meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); + stdx.debug.assert(self.dtype() == .f32 or self.dtype() == .f64, "fft({}) expects tensor type to be f32 or f64, got {}", .{ opts, self.dtype() }); + stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); const dt: DataType = switch (self.dtype()) { .f32 => .c64, @@ -471,8 +472,8 @@ pub const Tensor = struct { break :blk shape_.withDtype(dt); }, .IRFFT => blk: { - meta.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); - meta.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); + stdx.debug.assert(self.dtype().isComplex(), "fft({any}) expects tensor type to be complex, got {}", .{ opts, self.dtype() }); + stdx.debug.assert(std.mem.eql(i64, self.dims()[self.rank() - opts.length.len ..], opts.length), "fft({any}) expects tensor last dimensions to match given lengths, got {} and {}", .{ opts, self.dims()[self.rank() - opts.length.len ..].len, opts.length.len }); const dt: DataType = switch (self.dtype()) { .c64 => .f32, @@ -551,7 +552,7 @@ pub const Tensor = struct { 16 => .u16, 32 => .u32, 64 => .u64, - else => meta.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}), + else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}), }; const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype)); @@ -635,7 +636,7 @@ pub const Tensor = struct { /// Note: this uses stablehlo.rng which is deprecated. /// https://github.com/openxla/stablehlo/blob/main/rfcs/20240503-opset-deprecations.md pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor { - meta.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); + stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); const ctx = CompilationContext.current().mlirCtx(); const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts }); @@ -731,9 +732,9 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise conversion to another floating point type. pub fn reducePrecision(self: Tensor, exponent_bits: i32, mantissa_bits: i32) Tensor { - meta.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()}); - meta.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits}); - meta.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits}); + stdx.debug.assert(self.dtype().isFloat(), "reducePrecision expects tensor type to be a float, got {}", .{self.dtype()}); + stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits}); + stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits}); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits }); const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc); @@ -741,56 +742,56 @@ pub const Tensor = struct { } inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor { - meta.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() }); + stdx.debug.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() }); const N = self.rank(); - meta.guard(opts.window_strides.len == N - 2, @src()); - for (opts.window_strides) |s| meta.guard(0 < s, @src()); - meta.guard(opts.lhs_dilation.len == N - 2, @src()); - for (opts.lhs_dilation) |d| meta.guard(0 < d, @src()); - meta.guard(opts.rhs_dilation.len == N - 2, @src()); - for (opts.rhs_dilation) |d| meta.guard(0 < d, @src()); - meta.guard(opts.window_reversal.len == N - 2, @src()); - meta.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src()); - meta.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src()); - meta.guard(opts.input_spatial_dimensions.len == N - 2, @src()); - meta.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src()); - meta.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src()); - meta.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src()); + stdx.debug.guard(opts.window_strides.len == N - 2, @src()); + for (opts.window_strides) |s| stdx.debug.guard(0 < s, @src()); + stdx.debug.guard(opts.lhs_dilation.len == N - 2, @src()); + for (opts.lhs_dilation) |d| stdx.debug.guard(0 < d, @src()); + stdx.debug.guard(opts.rhs_dilation.len == N - 2, @src()); + for (opts.rhs_dilation) |d| stdx.debug.guard(0 < d, @src()); + stdx.debug.guard(opts.window_reversal.len == N - 2, @src()); + stdx.debug.guard(@rem(self.dim(opts.input_batch_dimension), opts.batch_group_count) == 0, @src()); + stdx.debug.guard(@rem(self.dim(opts.input_feature_dimension), opts.feature_group_count) == 0, @src()); + stdx.debug.guard(opts.input_spatial_dimensions.len == N - 2, @src()); + stdx.debug.guard(opts.input_batch_dimension != opts.input_feature_dimension, @src()); + stdx.debug.guard(0 <= opts.input_batch_dimension and opts.input_batch_dimension < N, @src()); + stdx.debug.guard(0 <= opts.input_feature_dimension and opts.input_feature_dimension < N, @src()); for (opts.input_spatial_dimensions, 0..) |d, i| { - meta.guard(d != opts.input_batch_dimension, @src()); - meta.guard(d != opts.input_feature_dimension, @src()); - meta.guard(0 <= d and d < N, @src()); + stdx.debug.guard(d != opts.input_batch_dimension, @src()); + stdx.debug.guard(d != opts.input_feature_dimension, @src()); + stdx.debug.guard(0 <= d and d < N, @src()); if (i < opts.input_spatial_dimensions.len - 1) continue; - meta.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src()); + stdx.debug.guard(std.mem.indexOfScalar(i64, opts.input_spatial_dimensions[i + 1 ..], d) == null, @src()); } - meta.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src()); - meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src()); - meta.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src()); - meta.guard(opts.kernel_spatial_dimensions.len == N - 2, @src()); - meta.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src()); - meta.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src()); - meta.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src()); + stdx.debug.guard(other.dim(opts.kernel_input_feature_dimension) == @divTrunc(self.dim(opts.input_feature_dimension), opts.feature_group_count), @src()); + stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.batch_group_count) == 0, @src()); + stdx.debug.guard(@rem(other.dim(opts.kernel_output_feature_dimension), opts.feature_group_count) == 0, @src()); + stdx.debug.guard(opts.kernel_spatial_dimensions.len == N - 2, @src()); + stdx.debug.guard(opts.kernel_input_feature_dimension != opts.kernel_output_feature_dimension, @src()); + stdx.debug.guard(0 <= opts.kernel_input_feature_dimension and opts.kernel_input_feature_dimension < N, @src()); + stdx.debug.guard(0 <= opts.kernel_output_feature_dimension and opts.kernel_output_feature_dimension < N, @src()); for (opts.kernel_spatial_dimensions, 0..) |d, i| { - meta.guard(d != opts.kernel_input_feature_dimension, @src()); - meta.guard(d != opts.kernel_output_feature_dimension, @src()); - meta.guard(0 <= d and d < N, @src()); + stdx.debug.guard(d != opts.kernel_input_feature_dimension, @src()); + stdx.debug.guard(d != opts.kernel_output_feature_dimension, @src()); + stdx.debug.guard(0 <= d and d < N, @src()); if (i < opts.kernel_spatial_dimensions.len - 1) continue; - meta.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src()); + stdx.debug.guard(std.mem.indexOfScalar(i64, opts.kernel_spatial_dimensions[i + 1 ..], d) == null, @src()); } - meta.guard(opts.output_spatial_dimensions.len == N - 2, @src()); - meta.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src()); - meta.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src()); - meta.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src()); + stdx.debug.guard(opts.output_spatial_dimensions.len == N - 2, @src()); + stdx.debug.guard(opts.output_batch_dimension != opts.output_feature_dimension, @src()); + stdx.debug.guard(0 <= opts.output_batch_dimension and opts.output_batch_dimension < N, @src()); + stdx.debug.guard(0 <= opts.output_feature_dimension and opts.output_feature_dimension < N, @src()); for (opts.output_spatial_dimensions, 0..) |d, i| { - meta.guard(d != opts.output_batch_dimension, @src()); - meta.guard(d != opts.output_feature_dimension, @src()); - meta.guard(0 <= d and d < N, @src()); + stdx.debug.guard(d != opts.output_batch_dimension, @src()); + stdx.debug.guard(d != opts.output_feature_dimension, @src()); + stdx.debug.guard(0 <= d and d < N, @src()); if (i < opts.output_spatial_dimensions.len - 1) continue; - meta.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src()); + stdx.debug.guard(std.mem.indexOfScalar(i64, opts.output_spatial_dimensions[i + 1 ..], d) == null, @src()); } - meta.guard(0 < opts.feature_group_count, @src()); - meta.guard(0 < opts.batch_group_count, @src()); - meta.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src()); + stdx.debug.guard(0 < opts.feature_group_count, @src()); + stdx.debug.guard(0 < opts.batch_group_count, @src()); + stdx.debug.guard(opts.feature_group_count == 1 or opts.batch_group_count == 1, @src()); var used_opts = opts; used_opts.pad_shape = &.{ @intCast(N - 2), 2 }; used_opts.precision_config = &.{ .DEFAULT, .DEFAULT }; @@ -1042,10 +1043,10 @@ pub const Tensor = struct { var batching_axes: [MAX_RANK][2]i8 = undefined; var n_batching: u8 = 0; for (lhs._shape.tags(), 0..) |l, li| { - meta.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs }); + stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs }); for (rhs._shape.tags(), 0..) |r, ri| { - meta.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs }); + stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs }); if (l == r) { for (contracting_axes) |ct| { @@ -1114,7 +1115,7 @@ pub const Tensor = struct { contracting_axes: []const [2]i8, batching_axes: []const [2]i8, ) Tensor { - meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); + stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); const Axes = std.BoundedArray(i64, MAX_RANK); @@ -1124,7 +1125,7 @@ pub const Tensor = struct { var rhs_batching_axes: Axes = .{}; for (batching_axes) |b_axes| { const l, const r = b_axes; - meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); + stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); var t = lhs._shape.tag(l); if (t == Shape.TagUnknown) t = rhs._shape.tag(r); res_shape = res_shape.appendDim(lhs._shape.dim(l), t); @@ -1137,7 +1138,7 @@ pub const Tensor = struct { var rhs_contracting_axes: Axes = .{}; for (contracting_axes) |c_axes| { const l, const r = c_axes; - meta.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); + stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l)); rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r)); } @@ -1353,7 +1354,7 @@ pub const Tensor = struct { else toI64(axes__); - meta.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len }); + stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len }); if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) { return self; @@ -1386,7 +1387,7 @@ pub const Tensor = struct { /// /// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3) pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor { - meta.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() }); + stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() }); const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1; const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n }); @@ -1443,7 +1444,7 @@ pub const Tensor = struct { pub fn flatten(self: Tensor, axis_: anytype) Tensor { const old_shape = self._shape; const a = self.axis(axis_); - // meta.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); + // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_}); @@ -1483,7 +1484,7 @@ pub const Tensor = struct { var res_shape: Shape = self._shape; for (slices, 0..) |s, a| { - meta.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a }); + stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a }); const args: Slice = .{ .start = self.wrapIndex(a, s.start), @@ -1549,7 +1550,7 @@ pub const Tensor = struct { /// Concatenates the input Tensors along the given axis. pub fn concatenate(tensors: []const Tensor, axis_: anytype) Tensor { - meta.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); + stdx.debug.assert(tensors.len <= 32, "concatenate only supports up to 32 tensors, got {}", .{tensors.len}); var buffer: [32]mlir.Value = undefined; std.debug.assert(tensors.len <= buffer.len); std.debug.assert(tensors.len > 0); @@ -1576,13 +1577,13 @@ pub const Tensor = struct { /// - Tensor.stack(&.{x, y, z}, .last, .layers) -> .{ .a, .b, .c, .layers } pub fn stack(tensors: []const Tensor, axis_: anytype, tag: anytype) Tensor { // Note: we could ask the compilation context for some memory instead of stack allocating - meta.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len}); + stdx.debug.assert(tensors.len <= 32, "stack only supports up to 32 tensors, got {}", .{tensors.len}); const shape0 = tensors[0]._shape; const res_shape = shape0.insertTag(axis_, 1, tag); for (tensors[1..]) |tensor| { - meta.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 }); + stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 }); } var reshaped: [32]Tensor = undefined; @@ -1623,7 +1624,7 @@ pub const Tensor = struct { /// Repeats a Tensor several times along the given axes. pub fn repeat(self: Tensor, n_reps: []const u63) Tensor { // TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2}); - meta.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); + stdx.debug.assert(n_reps.len == self.rank(), "repeat expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); var res = self; for (n_reps, 0..) |n_rep, a| { @@ -1647,7 +1648,7 @@ pub const Tensor = struct { /// Repeats in line each value along the given axes. pub fn stutter(self: Tensor, n_reps: []const u63) Tensor { - meta.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); + stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len }); var res = self; for (n_reps, 0..) |n_rep, a| { @@ -1724,8 +1725,8 @@ pub const Tensor = struct { /// Returns a Tensor containing evenly spaced values within a given interval. pub fn arange(args: ArangeArgs, dt: DataType) Tensor { - meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); - meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); + stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); + stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); const ctx = CompilationContext.current(); const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt }); @@ -1770,9 +1771,9 @@ pub const Tensor = struct { /// Returns a Tensor containing 'args.steps' values evenly spaced from 'args.start' to 'args.end', inclusive. pub fn linspace(args: LinspaceArgs, dt: DataType) Tensor { - meta.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); - meta.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps}); - meta.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt}); + stdx.debug.assert(args.start < args.end, "linspace expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); + stdx.debug.assert(args.steps > 0, "linspace expects 'args.steps' to be positive, got {}", .{args.steps}); + stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt}); const ctx = CompilationContext.current(); const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt }); @@ -1824,7 +1825,7 @@ pub const Tensor = struct { /// Returns a Tensor containing the result of the outer product between the input Tensors. pub fn outer(self: Tensor, other: Tensor) Tensor { - meta.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() }); + stdx.debug.assert(self.rank() < 2 and other.rank() < 2 and self.rank() + other.rank() != 0, "outer expects tensor ranks to be at most 1, got {} and {}", .{ self.rank(), other.rank() }); if (self.rank() + other.rank() == 1) { return self.mul(other); @@ -1856,7 +1857,7 @@ pub const Tensor = struct { /// Broadcasts a Tensor to the given shape, adding axes at the beginning. pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor { - meta.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); + stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); const a = output_shape.rank() - self.rank(); if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) { @@ -1868,7 +1869,7 @@ pub const Tensor = struct { /// Broadcasts a Tensor to the given shape, adding axes at the end. pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor { - meta.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); + stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) { return self; @@ -1967,7 +1968,7 @@ pub const Tensor = struct { /// Appends a 1-dim axis, with the given tag. pub fn appendAxes(self: Tensor, t: anytype) Tensor { - meta.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK }); + stdx.debug.assert(self.rank() < Tensor.MAX_RANK - t.len, "appendAxis expects tensor rank to be small enough in order to extend it, got {} and {} (max is {})", .{ self.rank(), t.len, Tensor.MAX_RANK }); return self.insertAxes(.last, t); } @@ -1975,7 +1976,7 @@ pub const Tensor = struct { /// Drops a 1-dim axis at the given index pub fn squeeze(self: Tensor, axis_: anytype) Tensor { const a = self.axis(axis_); - meta.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)}); + stdx.debug.assert(self.dim(a) == 1, "squeeze expects axis to be squeezed to have a dimension of 1, got {}", .{self.dim(a)}); const new_shape = self._shape.remove(a); // log.debug("squeeze({}, {d}={d}) -> ({})", .{ self, axis, a, new_shape }); @@ -2023,10 +2024,10 @@ pub const Tensor = struct { // scoped_log.debug("gatherValues({}, {any}, {})", .{ self, coord_axes, indices }); const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); - meta.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); + stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); for (coord_axes_.constSlice(), 0..) |a, i| { if (i > 0) { - meta.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); + stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); } } @@ -2040,7 +2041,7 @@ pub const Tensor = struct { // Note: tags are required for batching. self_kind.appendAssumeCapacity(.batching); indices_batch_axes.appendAssumeCapacity(id_ax); - meta.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); + stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); } else if (maybe_coord_ax) |_| { // for gatherValues we collapsed all gathered axes // (contrary to gatherSlices where we collapse none) @@ -2057,7 +2058,7 @@ pub const Tensor = struct { indices.rank() else blk: { const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - meta.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); + stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); break :blk ax; }; @@ -2124,7 +2125,7 @@ pub const Tensor = struct { ); const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); - meta.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); + stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); return _result(res_shape, gather_op.result(0)); } @@ -2201,16 +2202,16 @@ pub const Tensor = struct { const tagged_api = slice_shape.isFullyTagged(); if (tagged_api) { for (slice_shape.tags()) |t| { - meta.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self }); + stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self }); } } else { // For untagged api, we require all slices to be specified. // Note: we could relax this and right align the slice. - meta.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); + stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); } const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - meta.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices }); + stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices }); // Compute result shape var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype()); @@ -2228,12 +2229,12 @@ pub const Tensor = struct { self_batch_axes.appendAssumeCapacity(@intCast(self_ax)); indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t)); slice_dims.set(self_ax, 1); - meta.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices }); + stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices }); } else if (maybe_slice_ax) |slice_ax| { // Specified axes contains the start offset of the slices, // and are collected in `start_index_map`. const slice_dim = slice_shape.dim(slice_ax); - meta.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape }); + stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape }); slice_dims.set(self_ax, slice_dim); res_shape = res_shape.appendDim(slice_dim, t); start_index_map.appendAssumeCapacity(@intCast(self_ax)); @@ -2395,7 +2396,7 @@ pub const Tensor = struct { const loc = @src(); // scoped_log.debug("scatterSlices({}, {any}, {}, {})", .{ self, coord_axes, indices, updates }); - meta.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); + stdx.debug.assert(self.dtype() == updates.dtype(), "scatterSlices expects input and 'updates' tensors to be of the same type, got {} and {}", .{ self.dtype(), updates.dtype() }); const single_coord, const coord_axes_ = _parseGatherCoord(self, coord_axes); const AxisKind = enum { batching, update_window, inserted_window, window_id }; @@ -2420,7 +2421,7 @@ pub const Tensor = struct { indices.rank() else blk: { const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - meta.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices }); + stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "scatterSlices({}, coord_axes={any}, indices, updates) expects 'indices' to be a tensor [..., {}], got {}", .{ self, coord_axes, coord_axes_.len, indices }); break :blk ax; }; @@ -2435,7 +2436,7 @@ pub const Tensor = struct { if (self_kind.get(self_ax) == .batching) { up_kind.appendAssumeCapacity(.batching); } else { - meta.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates }); + stdx.debug.assert(updates.dim(up_ax) <= self.dim(self_ax), "scatterSlices expects the slices described in 'updates' to fit inside 'self', but along axis .{s} it doesn't. Got self={}, updates={}.", .{ t, self, updates }); up_kind.appendAssumeCapacity(.update_window); } } else if (t == Shape.TagUnknown or indices._shape.hasTag(t) != null) { @@ -2446,9 +2447,9 @@ pub const Tensor = struct { } const n_indices_axes = updates.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; if (single_coord) { - meta.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); + stdx.debug.assert(n_indices_axes == indices.rank(), "scatterSlices({}, {any}) expects 'updates' to contain all axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); } else { - meta.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); + stdx.debug.assert(n_indices_axes == indices.rank() - 1, "scatterSlices({}, {any}) expects 'updates' to contain all-but-last axes from 'indices', got indices={}, updates={}", .{ self, coord_axes, indices, updates }); } const ctx = self.getContext(); @@ -2671,7 +2672,7 @@ pub const Tensor = struct { /// * bubbles up Nan /// * in case of equality the smallest index matching the maximum pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes { - meta.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype}); + stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype}); const a = x.axis(axis_); @@ -2870,7 +2871,7 @@ pub const Tensor = struct { padding: [2][2]i64 = .{ .{ 0, 0 }, .{ 0, 0 } }, }) MaxPoolRes { // TODO: rewrite using modern ZML - meta.guard(self.rank() == 3 or self.rank() == 4, @src()); + stdx.debug.guard(self.rank() == 3 or self.rank() == 4, @src()); // TODO: support maxPool on non last axis // Note: the problem is initPoolArg assuming last axis @@ -3004,14 +3005,14 @@ pub const Tensor = struct { } pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor { - meta.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len}); + stdx.debug.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len}); const a = self.axis(axis_); const length = self.dim(a); if (split_size_or_sections.len != 1) { var split_sum: i64 = 0; for (split_size_or_sections) |n| split_sum += n; - meta.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); + stdx.debug.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length }); } const res = try allocator.alloc(Tensor, split_size_or_sections.len); @@ -3029,7 +3030,7 @@ pub const Tensor = struct { /// Note: this doesn't support tagging, if you have tags, /// you should use `dynamicSlice` directly. pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { - meta.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); + stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); const a = self.axis(axis_); const new_shape = self._shape.set(a, len); @@ -3087,17 +3088,17 @@ pub const Tensor = struct { const offset = slice_.start; const len = slice_.len; if (slices_tags.len == 0) { - meta.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len }); + stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len }); offset_values[i] = offset.value(); res_shape._dims.set(i, len); - meta.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i }); + stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i }); } else { const t = slices_tags.get(i); - const a = res_shape.hasTag(t) orelse meta.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape }); + const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape }); - meta.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t }); + stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t }); offset_values[a] = offset.value(); res_shape._dims.set(a, len); @@ -3149,12 +3150,12 @@ pub const Tensor = struct { /// ``` pub fn dynamicUpdateSlice(self: Tensor, offset_: anytype, update_: Tensor) Tensor { // TODO: add updateSlice for when the offset isn't dynamic - meta.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() }); + stdx.debug.assert(self.dtype() == update_.dtype(), "dynamicUpdateSlice expects input and 'update_' tensors to be of the same type, got {} and {}", .{ self.dtype(), update_.dtype() }); const offset, const offset_tags = Shape.parseStruct(Tensor, offset_); // log.debug("offset: {any}, offset_tags: {any}", .{ offset, offset_tags }); for (offset.constSlice(), 0..) |start_idx, i| { - meta.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i }); + stdx.debug.assert(start_idx.rank() == 0, "dynamicUpdateSlice expects 'offset_' tensor ranks to be equal to 0, got {} at index {}", .{ start_idx.rank(), i }); } const tagged_api = update_._shape.isFullyTagged() and self._shape.isFullyTagged() and offset_tags.len > 0; @@ -3164,14 +3165,14 @@ pub const Tensor = struct { if (tagged_api) { // Check that all update tags are known. for (update._shape._tags.constSlice()) |t| { - meta.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape }); + stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape }); } var update_shape = self._shape; var prev_ax: i8 = -1; for (self._shape.tags(), 0..) |t, self_ax| { if (update._shape.hasTag(t)) |up_ax| { - meta.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape }); + stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape }); update_shape._dims.set(self_ax, update.dim(up_ax)); prev_ax = up_ax; @@ -3182,16 +3183,16 @@ pub const Tensor = struct { update = update.reshape(update_shape); } - meta.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() }); + stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() }); for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| { const t = self._shape.debugTag(ax); - meta.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t }); + stdx.debug.assert(up_d <= self_d, "dynamicUpdateSlice expects 'update_' dimensions to be less than or equal to their corresponding dimension in input tensor, got {} and {} for axis .{s}", .{ up_d, self_d, t }); if (tagged_api and up_d < self_d) { const axis_has_offset = std.mem.indexOfScalar(Shape.Tag, offset_tags.constSlice(), self._shape._tags.get(ax)) != null; - meta.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t }); + stdx.debug.assert(axis_has_offset, "dynamicUpdateSlice expects 'update_' dimensions to be equal to their corresponding dimension in input tensor, got {} and {} for axis .{s} (hint: you need to provide an offset)", .{ up_d, self_d, t }); } } @@ -3200,7 +3201,7 @@ pub const Tensor = struct { var offset_values: [MAX_RANK]mlir.Value = undefined; if (offset_tags.len == 0) { // Without offset tags we need the same number of offset than rank. - meta.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len }); + stdx.debug.assert(self.rank() == offset.len, "dynamicUpdateSlice expects input tensor rank and 'offset_' length to be equal, got {} and {}", .{ self.rank(), offset.len }); for (offset.constSlice(), 0..) |idx, i| { offset_values[i] = idx.value(); @@ -3210,7 +3211,7 @@ pub const Tensor = struct { // This is only allowed when using tagged sliced. offset_values = .{zero} ** MAX_RANK; for (offset.constSlice(), offset_tags.constSlice()) |start, t| { - const a = self._shape.hasTag(t) orelse meta.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape }); + const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape }); offset_values[a] = start.value(); } } @@ -3329,12 +3330,12 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors. pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor { - meta.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); + stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other); if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{})); - meta.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); + stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)}); const op = dialect.stablehlo.compare( @@ -3352,7 +3353,7 @@ pub const Tensor = struct { /// For each vector in the input tensor, /// creates a diagonal-matrix where diagonal values are set to the vector values. pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor { - meta.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); + stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); const a = self.axis(axis_); const d = self.dim(a); var res_shape = self._shape; @@ -3409,7 +3410,7 @@ pub const Tensor = struct { /// To get the upper triangular part, swap the order of axes: /// `.{ .b = 32, .w = 20, .h = 20 }.triangular(.{ .h, .w }, 0);` pub fn triangular(self: Tensor, axes_: anytype, num_diagonals: i32) Tensor { - meta.assertComptime(meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); + stdx.debug.assertComptime(stdx.meta.isTuple(@TypeOf(axes_)) and axes_.len == 2, "triangular expects exactly two axes to work on.", .{}); const _axes = self.axes(axes_); const x = Tensor.iota(self.shape(), _axes.get(0)); @@ -3472,8 +3473,8 @@ pub const Tensor = struct { /// For each element at index `i`, if `bool_tensor[i] == true`, `output[i] = on_true[i]` /// otherwise, if `bool_tensor[i] == false`, `output[i] = on_false[i]` pub fn select(bool_tensor: Tensor, on_true: Tensor, on_false: Tensor) Tensor { - meta.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()}); - meta.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() }); + stdx.debug.assert(bool_tensor.dtype() == .bool, "select expects input tensor type to be a boolean, got {}", .{bool_tensor.dtype()}); + stdx.debug.assert(on_true.dtype() == on_false.dtype(), "select expects 'on_true' and 'on_false' tensor types to be equal, got {} and {}", .{ on_true.dtype(), on_false.dtype() }); if (bool_tensor.rank() != 0 and on_true.rank() == 0) { return bool_tensor.select(on_true.broad(bool_tensor.shape()), on_false); @@ -3482,8 +3483,8 @@ pub const Tensor = struct { return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape())); } - meta.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); - meta.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); + stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); + stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); const loc = bool_tensor.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.select( @@ -3538,11 +3539,11 @@ pub const Tensor = struct { } fn _cartesianProduct(vectors: []const Tensor, out: []Tensor) void { - meta.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{}); - meta.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len }); + stdx.debug.assert(vectors.len >= 1, "cartesianProduct expects at least one input.", .{}); + stdx.debug.assert(vectors.len < Tensor.MAX_RANK, "cartesianProduct expects at most {} input vectors, received {} !", .{ Tensor.MAX_RANK - 1, vectors.len }); for (vectors) |x| { - meta.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors}); - meta.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors}); + stdx.debug.assert(x.rank() <= 1, "cartesianProduct expects 0 or 1 rank input vectors. Got: {any}", .{vectors}); + stdx.debug.assert(vectors[0].dtype() == x.dtype(), "cartesianProduct expects input vectors to have all the same dtype. Got: {any}", .{vectors}); } var res_shape = Shape.init(.{}, vectors[0].dtype()); @@ -3645,7 +3646,7 @@ pub const Tensor = struct { ) fn (Tensor, Tensor) Tensor { return struct { pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor { - meta.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); + stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); if (self.rank() == 0 and other.rank() != 0) { return binaryOpHelper(self.broad(other._shape), other); @@ -3655,7 +3656,7 @@ pub const Tensor = struct { return binaryOpHelper(self, other.broad(self._shape)); } - meta.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); + stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); const mlirCtx = self.getContext().mlirCtx(); const location = mlirCtx.location(@src()); diff --git a/zml/test_runner.zig b/zml/test_runner.zig index 7c0fe69..fde6de9 100644 --- a/zml/test_runner.zig +++ b/zml/test_runner.zig @@ -1,8 +1,8 @@ //! Test runner for unit test based on https://github.com/ziglang/zig/blob/master/lib/compiler/test_runner.zig with async -const builtin = @import("builtin"); - -const std = @import("std"); const asynk = @import("async"); +const builtin = @import("builtin"); +const std = @import("std"); + const io = std.io; const testing = std.testing; const assert = std.debug.assert; @@ -21,10 +21,10 @@ var fba = std.heap.FixedBufferAllocator.init(&fba_buffer); pub fn main() anyerror!void { testing.log_level = log_level; - try asynk.AsyncThread.main(testing.allocator, asyncMain, .{}); + try asynk.AsyncThread.main(testing.allocator, asyncMain); } -pub fn asyncMain() void { +pub fn asyncMain() !void { const test_fn_list: []const std.builtin.TestFn = builtin.test_functions; var ok_count: usize = 0; var skip_count: usize = 0; diff --git a/zml/testing.zig b/zml/testing.zig index 53d0f02..52b1b87 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -1,11 +1,12 @@ -const std = @import("std"); const builtin = @import("builtin"); +const std = @import("std"); +const stdx = @import("stdx"); const zml = @import("zml.zig"); const meta = @import("meta.zig"); const shapesOf = @import("tensor.zig").shapesOf; -const log = std.log.scoped(.zml_testing); +const log = std.log.scoped(.@"zml/testing"); var _ctx: ?zml.Context = null; @@ -128,7 +129,7 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec /// Compile a function and immediatly call it with the given buffers. /// The compiled module is discarded after the call. /// Useful during testing when a module is typically called only once. -pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) { +pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) { // This simplify test API and also ensure this fn isn't used outside of tests. const allocator = std.testing.allocator; var arena = std.heap.ArenaAllocator.init(allocator); @@ -139,7 +140,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu return x.shape(); } }; - var shape_args: zml.ShapeOf(meta.FnParams(func)) = undefined; + var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined; try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args); const mod = try zml.compileFn(allocator, func, shape_args, platform); @@ -151,7 +152,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu /// Compile a function and immediatly call it with the given buffers. /// The compiled module is discarded after the call. /// Useful during testing when a module is typically called only once. -pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(meta.FnParams(func)), buffer_args: zml.Bufferized(meta.FnParams(func))) !zml.Bufferized(zml.meta.FnResult(func)) { +pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) { // This simplify test API and also ensure this fn isn't used outside of tests. const allocator = std.testing.allocator; var arena = std.heap.ArenaAllocator.init(allocator); diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index 2e211ea..07c684a 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -1,13 +1,15 @@ //! Text tokenizer implementations -const std = @import("std"); const builtin = @import("builtin"); -const testing = std.testing; +const std = @import("std"); +const stdx = @import("stdx"); -const log = std.log.scoped(.zml_tokenizer); +const testing = std.testing; const helpers = @import("helpers.zig"); const meta = @import("meta.zig"); +const log = std.log.scoped(.@"zml/tokenizer"); + test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(Normalizer); @@ -202,7 +204,7 @@ pub const Tokenizer = struct { // Detects memory corruption of tokens. if (cur_tok.len == 0 or cur_tok.len > self.max_token_len) @panic("Token looks corrupted !"); - meta.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] }); + stdx.debug.assert(std.mem.eql(u8, cur_tok, input[input_off..][0..cur_tok.len]), "current token '{s}' not found in input string '{s}' !", .{ cur_tok, input[input_off..] }); } const next_tok = self.tokens[tok_buff[i + 1]]; // if `next_tok` is `.unk`, length is 1; otherwise, it's the length of the token. diff --git a/zml/torch.zig b/zml/torch.zig index c85cd23..0fe69c0 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -1,9 +1,11 @@ const std = @import("std"); -const log = std.log.scoped(.zml_torch); +const stdx = @import("stdx"); const zml = @import("zml.zig"); + const Tensor = zml.Tensor; -const meta = zml.meta; + +const log = std.log.scoped(.zml_torch); /// Multiplies a matrix or a vector with a tensor, /// following the semantic of pytorch `@` operator. @@ -14,7 +16,7 @@ const meta = zml.meta; /// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}` pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { - meta.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); + stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }}; if (lhs.rank() == 1 or rhs.rank() <= 2) { @@ -22,7 +24,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { return lhs.dotGeneral(rhs, &contracting, &.{}); } - meta.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); + stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); // Pytorch treats the extra dimensions of rhs has batching dimensions, // and implicitly broadcast lhs along those. @@ -91,7 +93,7 @@ pub fn unsqueeze( self: Tensor, axis_: anytype, ) Tensor { - meta.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self}); + stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self}); const a = switch (@typeInfo(@TypeOf(axis_))) { .Int, .ComptimeInt => if (axis_ < 0) @as(i8, self.rank()) + 1 + axis_ @@ -125,9 +127,9 @@ test unsqueeze { /// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor { const shape = tensor.shape(); - meta.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); + stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); - meta.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); + stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor }); const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w }); @@ -173,7 +175,7 @@ test pixelShuffle { /// ref: https://pytorch.org/docs/stable/generated/torch.roll.html pub fn roll(self: Tensor, shifts: []const i64, axes_: []const u8) Tensor { // TODO(hugo) accept following syntax: x.roll(.{ .a = 5, .b = 8 }) - meta.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len }); + stdx.debug.assert(self.rank() > 0 and shifts.len == axes_.len, "Shifts length ({d}) and dims length ({d}) are not equal, we expect the same length.", .{ shifts.len, axes_.len }); if (shifts.len != 1 or axes_.len != 1) { const tail_shifts = shifts[1..shifts.len]; @@ -219,8 +221,8 @@ pub const MeshgridIndexing = enum { xy, ij }; /// * for ‘ij’ indexing, outputs are of shape (M, N, P) /// * for ‘xy’ indexing, outputs are of shape (N, M, P) pub fn meshgrid(comptime N: u3, vectors: [N]Tensor, indexing: MeshgridIndexing) [N]Tensor { - meta.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{}); - meta.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len}); + stdx.debug.assertComptime(vectors.len >= 1, "Invalid meshgrid. No input.", .{}); + stdx.debug.assertComptime(vectors.len <= Tensor.MAX_RANK, "Invalid meshgrid(...). Too many inputs: {}", .{vectors.len}); if (vectors.len == 1) return vectors;