Radix/zml/BUILD.bazel

94 lines
2.0 KiB
Python

load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test")
upb_c_proto_library(
name = "xla_data_upb",
deps = ["@xla//xla:xla_data_proto"],
)
upb_c_proto_library(
name = "xla_compile_options_upb",
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
)
cc_library(
name = "posix",
hdrs = ["posix.h"],
)
zig_library(
name = "zml",
srcs = [
"aio.zig",
"aio/json.zig",
"aio/safetensors.zig",
"aio/tinyllama.zig",
"aio/torch.zig",
"aio/torch/eval.zig",
"aio/torch/file.zig",
"aio/torch/pickle.zig",
"aio/torch/py.zig",
"buffer.zig",
"context.zig",
"callback.zig",
"dtype.zig",
"exe.zig",
"floats.zig",
"helpers.zig",
"hostbuffer.zig",
"meta.zig",
"mlirx.zig",
"module.zig",
"nn.zig",
"nn/cuda.zig",
"ops.zig",
"pjrtx.zig",
"platform.zig",
"posix.zig",
"quantization.zig",
"shape.zig",
"tensor.zig",
"test_runner.zig",
"testing.zig",
"torch.zig",
"zml.zig",
],
zigopts = ["-lc", "-freference-trace=20"],
main = "zml.zig",
visibility = ["//visibility:public"],
deps = [
":posix",
":xla_compile_options_upb",
":xla_data_upb",
"//async",
"//mlir",
"//mlir/dialects",
"//pjrt",
"//runtimes",
"//stdx",
"//upb",
"//zml/tokenizer",
"//zml/tools",
"@rules_zig//zig/runfiles",
],
)
# All ZML Tests
zig_test(
name = "test",
data = [
"aio/torch/simple.pt",
"aio/torch/simple_test_4.pickle",
],
test_runner = ":test_runner",
deps = [":zml"],
)
filegroup(
name = "test_runner",
srcs = ["test_runner.zig"],
visibility = ["//visibility:public"],
)