Update workspace dependencies to newer LLVM, XLA, StableHLO, and PJRT versions and expose new pjrt plugin attribute APIs and stablehlo version APIs in build and runtime configurations.
This commit is contained in:
parent
726a2d0691
commit
01eff33fa0
10
MODULE.bazel
10
MODULE.bazel
@ -15,6 +15,7 @@ bazel_dep(name = "rules_proto", version = "6.0.2")
|
||||
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1.1")
|
||||
|
||||
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains", dev_dependency = True)
|
||||
use_repo(bazel_lib_toolchains, "jq_toolchains")
|
||||
|
||||
@ -22,6 +23,7 @@ toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolcha
|
||||
use_repo(toolchains, "zig_sdk")
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240913.0-1957d05")
|
||||
|
||||
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//bazel:zig_index.json")
|
||||
zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf")
|
||||
@ -31,7 +33,9 @@ zig.mirrors(urls = [
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
|
||||
register_toolchains(
|
||||
"@zig_sdk//toolchain:linux_amd64_gnu.2.31",
|
||||
"@zig_sdk//toolchain:linux_arm64_gnu.2.31",
|
||||
@ -55,7 +59,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
|
||||
register_toolchains("//third_party/zls:all")
|
||||
|
||||
bazel_dep(name = "libxev", version = "20241119.0-6afcde9")
|
||||
bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a")
|
||||
bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
@ -67,8 +71,8 @@ llvm.configure(
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20240917.0-78c753a")
|
||||
bazel_dep(name = "xla", version = "20240919.0-1b18dd6")
|
||||
bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
|
||||
bazel_dep(name = "xla", version = "20241025.0-4663f04")
|
||||
|
||||
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
1563
MODULE.bazel.lock
1563
MODULE.bazel.lock
File diff suppressed because it is too large
Load Diff
@ -1239,19 +1239,52 @@ pub const RngAlgorithm = struct {
|
||||
};
|
||||
|
||||
pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 {
|
||||
const Context = struct {
|
||||
str: []const u8 = &.{},
|
||||
};
|
||||
var context = Context{};
|
||||
const state = struct {
|
||||
var buf: [32]u8 = undefined;
|
||||
|
||||
c.stablehloVersionFromCompatibilityRequirement(requirement, (struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *Context = @ptrCast(@alignCast(userdata));
|
||||
inner_ctx.str = mlir.fromStringRef(mlir_str);
|
||||
fn call(req: c.MlirStablehloCompatibilityRequirement) []u8 {
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
var context = .{ .writer = stream.writer() };
|
||||
const WriterContext = @TypeOf(context);
|
||||
|
||||
c.stablehloVersionFromCompatibilityRequirement(req, (struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||
}
|
||||
}).callback, &context);
|
||||
|
||||
return buf[0..stream.pos];
|
||||
}
|
||||
}).callback, &context);
|
||||
};
|
||||
|
||||
return context.str;
|
||||
return state.call(requirement);
|
||||
}
|
||||
|
||||
pub fn getCurrentVersion() []const u8 {
|
||||
const state = struct {
|
||||
var buf: [32]u8 = undefined;
|
||||
var str: []const u8 = undefined;
|
||||
var once = std.once(call);
|
||||
|
||||
fn call() void {
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
var writer_ = stream.writer();
|
||||
const ContextWriter = @TypeOf(writer_);
|
||||
|
||||
c.stablehloGetCurrentVersion((struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
|
||||
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||
}
|
||||
}).callback, &writer_);
|
||||
|
||||
str = buf[0..stream.pos];
|
||||
}
|
||||
};
|
||||
|
||||
state.once.call();
|
||||
return state.str;
|
||||
}
|
||||
|
||||
pub fn getMinimumVersion() []const u8 {
|
||||
|
||||
@ -15,6 +15,7 @@ zig_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":profiler_options_proto",
|
||||
"//stdx",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const c = @import("c");
|
||||
|
||||
@ -140,6 +141,17 @@ pub const Api = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn stablehloCurrentVersion(self: *const Api, buf: []u8) ?[]u8 {
|
||||
if (self.getPluginAttribute("stablehlo_current_version")) |v| {
|
||||
stdx.debug.assert(v.kind() == .int64list, "fetched attribute \"stablehlo_current_version\" from the plugin with type `{}`, expected `.int64list`", .{v.kind()});
|
||||
stdx.debug.assert(v.inner.value_size == 3, "expect version format to have 3 elements representing `major.minor.patch` format, got {} elements", .{v.inner.value_size});
|
||||
const value = v.inner.unnamed_0.int64_array_value[0..v.inner.value_size];
|
||||
return std.fmt.bufPrint(buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
||||
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
||||
return .{ .inner = ext.custom_call.? };
|
||||
@ -147,6 +159,27 @@ pub const Api = struct {
|
||||
// log.warn("No Custom Call registry found for platform: {}", .{self});
|
||||
return null;
|
||||
}
|
||||
|
||||
fn getPluginAttribute(api: *const Api, key: []const u8) ?NamedValue {
|
||||
const attributes = api.getPluginAttributes();
|
||||
for (attributes) |attr| {
|
||||
if (std.mem.eql(u8, attr.name(), key)) {
|
||||
return attr;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
fn getPluginAttributes(api: *const Api) []const NamedValue {
|
||||
const ret = api.call(.PJRT_Plugin_Attributes, .{
|
||||
.extension_start = null,
|
||||
}) catch unreachable;
|
||||
|
||||
if (ret.attributes == null) return &.{};
|
||||
|
||||
return @ptrCast(ret.attributes[0..ret.num_attributes]);
|
||||
}
|
||||
};
|
||||
|
||||
pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
||||
@ -870,6 +903,8 @@ pub const NamedValue = extern struct {
|
||||
/// * a context struct passed as a slice of bytes
|
||||
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
|
||||
|
||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||
pub const CustomCallRegistry = extern struct {
|
||||
inner: *const c.PJRT_Gpu_Register_Custom_Call,
|
||||
|
||||
@ -878,7 +913,7 @@ pub const CustomCallRegistry = extern struct {
|
||||
.function_name = name.ptr,
|
||||
.function_name_size = name.len,
|
||||
.api_version = @intCast(api_version),
|
||||
.custom_call_function = @ptrCast(@constCast(func)),
|
||||
.handler_execute = @ptrCast(@constCast(func)),
|
||||
});
|
||||
const result = self.inner(&ret);
|
||||
if (result) |pjrt_c_error| {
|
||||
|
||||
@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_linux_amd64",
|
||||
build_file_content = _BUILD.format(ext = "so"),
|
||||
sha256 = "2058c999a4866716f1dae0c42476c09da0f6deff7e77e34c5223b61f5e0027fb",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_linux-amd64.tar.gz",
|
||||
sha256 = "646b8ea61e690af0e4133637343674fb072e7d5e3a29694e6f84bb66ea75a6f0",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cpu_linux-amd64.tar.gz",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_darwin_arm64",
|
||||
build_file_content = _BUILD.format(ext = "dylib"),
|
||||
sha256 = "727b0380a577b2759468a4e0b3574e1d81e1b4348c3942d23284d590c7ca91a5",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_darwin-arm64.tar.gz",
|
||||
sha256 = "f166ee5ba1d50383731aa79831d4bd2ef3338c5948ae92c2442105d20280506c",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cpu_darwin-arm64.tar.gz",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
@ -179,11 +179,12 @@ cc_import(
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_cuda",
|
||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.3/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "14f39ffef0c9ac529b1a8957750b0b5f5d2f6d310c0d997436051c53a9eb1618",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "1af968c5357b0b78e43416e2b583512d203aa67a770c6b7e616006e7dd63aecc",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
@ -216,8 +216,8 @@ def _rocm_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_rocm",
|
||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-rocm_linux-amd64.tar.gz",
|
||||
sha256 = "dcb2f8e1fd29e3d7ba8d3018d97a060888e5bcf4847a683cb11686caa6ad9fa2",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-rocm_linux-amd64.tar.gz",
|
||||
sha256 = "a7da45dfca820d3defa6de8e782cc334a3f6bdffe65fa972c048994923c2e110",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
|
||||
11
third_party/modules/llvm-raw/20241022.0-6c4267f/MODULE.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20241022.0-6c4267f/MODULE.bazel
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
module(
|
||||
name = "llvm-raw",
|
||||
version = "20241022.0-6c4267f",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
0
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/BUILD.bazel
vendored
Normal file
0
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/BUILD.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/MODULE.bazel
vendored
Normal file
11
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
module(
|
||||
name = "llvm-raw",
|
||||
version = "20241022.0-6c4267f",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
28
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/utils/bazel/extension.bzl
vendored
Normal file
28
third_party/modules/llvm-raw/20241022.0-6c4267f/overlay/utils/bazel/extension.bzl
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||
|
||||
def _llvm_impl(mctx):
|
||||
_targets = {}
|
||||
for mod in mctx.modules:
|
||||
for conf in mod.tags.configure:
|
||||
for target in conf.targets:
|
||||
_targets[target] = True
|
||||
_llvm_configure(
|
||||
name = "llvm-project",
|
||||
targets = _targets.keys(),
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
llvm = module_extension(
|
||||
implementation = _llvm_impl,
|
||||
tag_classes = {
|
||||
"configure": tag_class(
|
||||
attrs = {
|
||||
"targets": attr.string_list(mandatory = True),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
10
third_party/modules/llvm-raw/20241022.0-6c4267f/source.json
vendored
Normal file
10
third_party/modules/llvm-raw/20241022.0-6c4267f/source.json
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"strip_prefix": "llvm-project-6c4267fb1779bc5550bb413f33250f9365acfbc6",
|
||||
"url": "https://github.com/llvm/llvm-project/archive/6c4267fb1779bc5550bb413f33250f9365acfbc6.tar.gz",
|
||||
"integrity": "sha256-cBDuj+hiRvq8rtvtIfqawr0lQuDSrWFypEgeApT981Q=",
|
||||
"overlay": {
|
||||
"BUILD.bazel": "",
|
||||
"MODULE.bazel": "",
|
||||
"utils/bazel/extension.bzl": ""
|
||||
}
|
||||
}
|
||||
1
third_party/modules/llvm-raw/metadata.json
vendored
1
third_party/modules/llvm-raw/metadata.json
vendored
@ -13,6 +13,7 @@
|
||||
"versions": [
|
||||
"20240823.0-f142f8a",
|
||||
"20240919.0-94c024a",
|
||||
"20241022.0-6c4267f",
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
|
||||
15
third_party/modules/stablehlo/20241021.0-1c0b606/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20241021.0-1c0b606/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
module(
|
||||
name = "stablehlo",
|
||||
version = "20241021.0-1c0b606",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = ["AArch64", "X86", "NVPTX"],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
15
third_party/modules/stablehlo/20241021.0-1c0b606/overlay/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20241021.0-1c0b606/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
module(
|
||||
name = "stablehlo",
|
||||
version = "20241021.0-1c0b606",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = ["AArch64", "X86", "NVPTX"],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
8
third_party/modules/stablehlo/20241021.0-1c0b606/source.json
vendored
Normal file
8
third_party/modules/stablehlo/20241021.0-1c0b606/source.json
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"strip_prefix": "stablehlo-f7f8e4e35296deeff2e12e39421ac8d9599ba340",
|
||||
"url": "https://github.com/openxla/stablehlo/archive/f7f8e4e35296deeff2e12e39421ac8d9599ba340.tar.gz",
|
||||
"integrity": "sha256-/chY6BYV7KcrXIuqXTNIUknISCiA1JYhWE8gG5cEeV4=",
|
||||
"overlay": {
|
||||
"MODULE.bazel": ""
|
||||
}
|
||||
}
|
||||
1
third_party/modules/stablehlo/metadata.json
vendored
1
third_party/modules/stablehlo/metadata.json
vendored
@ -13,6 +13,7 @@
|
||||
"versions": [
|
||||
"20240829.0-54aa1a5",
|
||||
"20240917.0-78c753a",
|
||||
"20241021.0-1c0b606",
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
|
||||
34
third_party/modules/xla/20241025.0-4663f04/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20241025.0-4663f04/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20241025.0-4663f04",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
|
||||
|
||||
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||
use_repo(
|
||||
xla_workspace,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
)
|
||||
34
third_party/modules/xla/20241025.0-4663f04/overlay/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20241025.0-4663f04/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20241025.0-4663f04",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
|
||||
|
||||
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||
use_repo(
|
||||
xla_workspace,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
)
|
||||
19
third_party/modules/xla/20241025.0-4663f04/overlay/tsl.bzl
vendored
Normal file
19
third_party/modules/xla/20241025.0-4663f04/overlay/tsl.bzl
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
load("//third_party:repo.bzl", "tf_vendored")
|
||||
load("//third_party/py:python_init_repositories.bzl", "python_init_repositories")
|
||||
|
||||
def _tsl_impl(mctx):
|
||||
python_init_repositories(
|
||||
requirements = {
|
||||
"3.11": "//:requirements_lock_3_11.txt",
|
||||
},
|
||||
)
|
||||
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = ["tsl"],
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
tsl = module_extension(
|
||||
implementation = _tsl_impl,
|
||||
)
|
||||
52
third_party/modules/xla/20241025.0-4663f04/overlay/workspace.bzl
vendored
Normal file
52
third_party/modules/xla/20241025.0-4663f04/overlay/workspace.bzl
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||
|
||||
def _xla_workspace_impl(mctx):
|
||||
cuda_configure(name = "local_config_cuda")
|
||||
remote_execution_configure(name = "local_config_remote_execution")
|
||||
rocm_configure(name = "local_config_rocm")
|
||||
tensorrt_configure(name = "local_config_tensorrt")
|
||||
tf_http_archive(
|
||||
name = "com_github_grpc_grpc",
|
||||
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
|
||||
patch_file = [
|
||||
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
|
||||
"@tsl//third_party/grpc:register_go_toolchain.patch",
|
||||
],
|
||||
system_link_files = {
|
||||
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
|
||||
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||
)
|
||||
tf_http_archive(
|
||||
name = "com_google_protobuf",
|
||||
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
|
||||
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||
strip_prefix = "protobuf-3.21.9",
|
||||
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
|
||||
system_link_files = {
|
||||
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
xla_workspace = module_extension(
|
||||
implementation = _xla_workspace_impl,
|
||||
)
|
||||
@ -0,0 +1,27 @@
|
||||
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
|
||||
From: Steeve Morin <steeve.morin@gmail.com>
|
||||
Date: Mon, 18 Mar 2024 17:17:34 +0100
|
||||
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
|
||||
|
||||
---
|
||||
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
|
||||
1 file changed, 2 insertions(+)
|
||||
|
||||
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
index c821916ad..89a596123 100644
|
||||
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
@@ -16,8 +16,10 @@ limitations under the License.
|
||||
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||
|
||||
+#ifdef __cplusplus
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
+#endif
|
||||
|
||||
#include "xla/backends/profiler/plugin/profiler_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
|
||||
14
third_party/modules/xla/20241025.0-4663f04/source.json
vendored
Normal file
14
third_party/modules/xla/20241025.0-4663f04/source.json
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"strip_prefix": "xla-4663f04d2653d686833f9c306bd0d899c3127358",
|
||||
"url": "https://github.com/openxla/xla/archive/4663f04d2653d686833f9c306bd0d899c3127358.tar.gz",
|
||||
"integrity": "sha256-tQRn0TXkPcmKpCcIuTYM10XNB1qC/d6++4rA1mlnnKw=",
|
||||
"overlay": {
|
||||
"tsl.bzl": "",
|
||||
"workspace.bzl": "",
|
||||
"MODULE.bazel": ""
|
||||
},
|
||||
"patch_strip": 1,
|
||||
"patches": {
|
||||
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
|
||||
}
|
||||
}
|
||||
1
third_party/modules/xla/metadata.json
vendored
1
third_party/modules/xla/metadata.json
vendored
@ -13,6 +13,7 @@
|
||||
"versions": [
|
||||
"20240902.0-d18cd64",
|
||||
"20240919.0-1b18dd6",
|
||||
"20241025.0-4663f04",
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
|
||||
@ -1236,6 +1236,10 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
.key = .{ .Const = "xla_gpu_enable_triton_gemm" },
|
||||
.value = .{ .value = .{ .bool_field = false } },
|
||||
});
|
||||
// try options.env_option_overrides.append(arena, .{
|
||||
// .key = .{ .Const = "xla_gpu_enable_latency_hiding_scheduler" },
|
||||
// .value = .{ .value = .{ .bool_field = true } },
|
||||
// });
|
||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
||||
log.warn("Bazel runfile not found !", .{});
|
||||
break :cuda_dir;
|
||||
|
||||
@ -5,6 +5,7 @@ const mlir = @import("mlir");
|
||||
const pjrt = @import("pjrt");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const c = @import("c");
|
||||
|
||||
const dtype = @import("dtype.zig");
|
||||
const meta = @import("meta.zig");
|
||||
@ -95,7 +96,12 @@ pub const Client = opaque {
|
||||
|
||||
var serialized_buffer = std.ArrayList(u8).init(allocator);
|
||||
defer serialized_buffer.deinit();
|
||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, dialects.stablehlo.getMinimumVersion(), serialized_buffer.writer()) catch |err| {
|
||||
|
||||
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
|
||||
var stablehlo_version_buf: [32]u8 = undefined;
|
||||
const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
||||
|
||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
||||
log.err("failed to serialize to portable artifact: {}", .{err});
|
||||
return err;
|
||||
};
|
||||
|
||||
@ -2152,10 +2152,9 @@ pub const Tensor = struct {
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
|
||||
// batching axes are implicits.
|
||||
// TODO: support of batching is broken atm
|
||||
// .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
||||
// .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
||||
// .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
|
||||
// stablehlo.gather is biased toward indices shape (like gatherSlice).
|
||||
// This make it awkward to use when you have both batching dimension and new indices dimensions.
|
||||
// For now we reject those, and let user explicitly transpose self or indices if needed.
|
||||
@ -2295,9 +2294,8 @@ pub const Tensor = struct {
|
||||
.{ .{ .a = 10, .b = 20 }, .{ .b = 17, .a = 7 }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 7, .b = 17 } },
|
||||
.{ .{ .a = 10, .b = 20, .c = 20 }, .{ .b = 17 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 10, .b = 17, .c = 20 } },
|
||||
// batching dims
|
||||
// TODO: support of batching is broken atm
|
||||
// .{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .a = 10, ._ = 1 }, .{ .a = 10, .b = 17 } },
|
||||
// .{ .{ .b = 200, .a = 100, .c = 300 }, .{ .c = 300 }, .{ .a = 100, .b = 200, ._ = 1 }, .{ .a = 100, .b = 200, .c = 300 } },
|
||||
.{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .a = 10, ._ = 1 }, .{ .a = 10, .b = 17 } },
|
||||
.{ .{ .b = 200, .a = 100, .c = 300 }, .{ .c = 300 }, .{ .a = 100, .b = 200, ._ = 1 }, .{ .a = 100, .b = 200, .c = 300 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const slice_dims, const idx_shape, const res_shape = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
@ -2591,8 +2589,7 @@ pub const Tensor = struct {
|
||||
);
|
||||
defer values.deinit();
|
||||
|
||||
// TODO: support of batching is broken atm
|
||||
const result = zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values }) catch return error.SkipZigTest;
|
||||
const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values });
|
||||
|
||||
const expected = [2][3][4][2]u16{
|
||||
.{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user