Update workspace dependencies to newer LLVM, XLA, StableHLO, and PJRT versions and expose new pjrt plugin attribute APIs and stablehlo version APIs in build and runtime configurations.

This commit is contained in:
Tarry Singh 2023-08-07 12:28:36 +00:00
parent 726a2d0691
commit 01eff33fa0
28 changed files with 1957 additions and 32 deletions

View File

@ -15,6 +15,7 @@ bazel_dep(name = "rules_proto", version = "6.0.2")
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True) bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1.1") bazel_dep(name = "aspect_bazel_lib", version = "2.8.1.1")
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains", dev_dependency = True) bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains", dev_dependency = True)
use_repo(bazel_lib_toolchains, "jq_toolchains") use_repo(bazel_lib_toolchains, "jq_toolchains")
@ -22,6 +23,7 @@ toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolcha
use_repo(toolchains, "zig_sdk") use_repo(toolchains, "zig_sdk")
bazel_dep(name = "rules_zig", version = "20240913.0-1957d05") bazel_dep(name = "rules_zig", version = "20240913.0-1957d05")
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig") zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
zig.index(file = "//bazel:zig_index.json") zig.index(file = "//bazel:zig_index.json")
zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf") zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf")
@ -31,7 +33,9 @@ zig.mirrors(urls = [
use_repo(zig, "zig_toolchains") use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all") register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all") register_toolchains("@zig_toolchains//:all")
register_toolchains( register_toolchains(
"@zig_sdk//toolchain:linux_amd64_gnu.2.31", "@zig_sdk//toolchain:linux_amd64_gnu.2.31",
"@zig_sdk//toolchain:linux_arm64_gnu.2.31", "@zig_sdk//toolchain:linux_arm64_gnu.2.31",
@ -55,7 +59,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
register_toolchains("//third_party/zls:all") register_toolchains("//third_party/zls:all")
bazel_dep(name = "libxev", version = "20241119.0-6afcde9") bazel_dep(name = "libxev", version = "20241119.0-6afcde9")
bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a") bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure( llvm.configure(
@ -67,8 +71,8 @@ llvm.configure(
) )
use_repo(llvm, "llvm-project") use_repo(llvm, "llvm-project")
bazel_dep(name = "stablehlo", version = "20240917.0-78c753a") bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
bazel_dep(name = "xla", version = "20240919.0-1b18dd6") bazel_dep(name = "xla", version = "20241025.0-4663f04")
tsl = use_extension("@xla//:tsl.bzl", "tsl") tsl = use_extension("@xla//:tsl.bzl", "tsl")
use_repo(tsl, "tsl") use_repo(tsl, "tsl")

File diff suppressed because it is too large Load Diff

View File

@ -1239,19 +1239,52 @@ pub const RngAlgorithm = struct {
}; };
pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 { pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 {
const Context = struct { const state = struct {
str: []const u8 = &.{}, var buf: [32]u8 = undefined;
};
var context = Context{};
c.stablehloVersionFromCompatibilityRequirement(requirement, (struct { fn call(req: c.MlirStablehloCompatibilityRequirement) []u8 {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { var stream = std.io.fixedBufferStream(&buf);
const inner_ctx: *Context = @ptrCast(@alignCast(userdata)); var context = .{ .writer = stream.writer() };
inner_ctx.str = mlir.fromStringRef(mlir_str); const WriterContext = @TypeOf(context);
c.stablehloVersionFromCompatibilityRequirement(req, (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &context);
return buf[0..stream.pos];
} }
}).callback, &context); };
return context.str; return state.call(requirement);
}
pub fn getCurrentVersion() []const u8 {
const state = struct {
var buf: [32]u8 = undefined;
var str: []const u8 = undefined;
var once = std.once(call);
fn call() void {
var stream = std.io.fixedBufferStream(&buf);
var writer_ = stream.writer();
const ContextWriter = @TypeOf(writer_);
c.stablehloGetCurrentVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &writer_);
str = buf[0..stream.pos];
}
};
state.once.call();
return state.str;
} }
pub fn getMinimumVersion() []const u8 { pub fn getMinimumVersion() []const u8 {

View File

@ -15,6 +15,7 @@ zig_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":profiler_options_proto", ":profiler_options_proto",
"//stdx",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",

View File

@ -1,5 +1,6 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const c = @import("c"); const c = @import("c");
@ -140,6 +141,17 @@ pub const Api = struct {
}; };
} }
pub fn stablehloCurrentVersion(self: *const Api, buf: []u8) ?[]u8 {
if (self.getPluginAttribute("stablehlo_current_version")) |v| {
stdx.debug.assert(v.kind() == .int64list, "fetched attribute \"stablehlo_current_version\" from the plugin with type `{}`, expected `.int64list`", .{v.kind()});
stdx.debug.assert(v.inner.value_size == 3, "expect version format to have 3 elements representing `major.minor.patch` format, got {} elements", .{v.inner.value_size});
const value = v.inner.unnamed_0.int64_array_value[0..v.inner.value_size];
return std.fmt.bufPrint(buf, "{d}.{d}.{d}", .{ value[0], value[1], value[2] }) catch unreachable;
}
return null;
}
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
return .{ .inner = ext.custom_call.? }; return .{ .inner = ext.custom_call.? };
@ -147,6 +159,27 @@ pub const Api = struct {
// log.warn("No Custom Call registry found for platform: {}", .{self}); // log.warn("No Custom Call registry found for platform: {}", .{self});
return null; return null;
} }
fn getPluginAttribute(api: *const Api, key: []const u8) ?NamedValue {
const attributes = api.getPluginAttributes();
for (attributes) |attr| {
if (std.mem.eql(u8, attr.name(), key)) {
return attr;
}
}
return null;
}
fn getPluginAttributes(api: *const Api) []const NamedValue {
const ret = api.call(.PJRT_Plugin_Attributes, .{
.extension_start = null,
}) catch unreachable;
if (ret.attributes == null) return &.{};
return @ptrCast(ret.attributes[0..ret.num_attributes]);
}
}; };
pub const ErrorCode = enum(c.PJRT_Error_Code) { pub const ErrorCode = enum(c.PJRT_Error_Code) {
@ -870,6 +903,8 @@ pub const NamedValue = extern struct {
/// * a context struct passed as a slice of bytes /// * a context struct passed as a slice of bytes
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void; pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub const CustomCallRegistry = extern struct { pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_Gpu_Register_Custom_Call, inner: *const c.PJRT_Gpu_Register_Custom_Call,
@ -878,7 +913,7 @@ pub const CustomCallRegistry = extern struct {
.function_name = name.ptr, .function_name = name.ptr,
.function_name_size = name.len, .function_name_size = name.len,
.api_version = @intCast(api_version), .api_version = @intCast(api_version),
.custom_call_function = @ptrCast(@constCast(func)), .handler_execute = @ptrCast(@constCast(func)),
}); });
const result = self.inner(&ret); const result = self.inner(&ret);
if (result) |pjrt_c_error| { if (result) |pjrt_c_error| {

View File

@ -12,15 +12,15 @@ def _cpu_pjrt_plugin_impl(mctx):
http_archive( http_archive(
name = "libpjrt_cpu_linux_amd64", name = "libpjrt_cpu_linux_amd64",
build_file_content = _BUILD.format(ext = "so"), build_file_content = _BUILD.format(ext = "so"),
sha256 = "2058c999a4866716f1dae0c42476c09da0f6deff7e77e34c5223b61f5e0027fb", sha256 = "646b8ea61e690af0e4133637343674fb072e7d5e3a29694e6f84bb66ea75a6f0",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cpu_linux-amd64.tar.gz",
) )
http_archive( http_archive(
name = "libpjrt_cpu_darwin_arm64", name = "libpjrt_cpu_darwin_arm64",
build_file_content = _BUILD.format(ext = "dylib"), build_file_content = _BUILD.format(ext = "dylib"),
sha256 = "727b0380a577b2759468a4e0b3574e1d81e1b4348c3942d23284d590c7ca91a5", sha256 = "f166ee5ba1d50383731aa79831d4bd2ef3338c5948ae92c2442105d20280506c",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-cpu_darwin-arm64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cpu_darwin-arm64.tar.gz",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -179,11 +179,12 @@ cc_import(
) )
""", """,
) )
http_archive( http_archive(
name = "libpjrt_cuda", name = "libpjrt_cuda",
build_file = "libpjrt_cuda.BUILD.bazel", build_file = "libpjrt_cuda.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.3/pjrt-cuda_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-cuda_linux-amd64.tar.gz",
sha256 = "14f39ffef0c9ac529b1a8957750b0b5f5d2f6d310c0d997436051c53a9eb1618", sha256 = "1af968c5357b0b78e43416e2b583512d203aa67a770c6b7e616006e7dd63aecc",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -216,8 +216,8 @@ def _rocm_impl(mctx):
http_archive( http_archive(
name = "libpjrt_rocm", name = "libpjrt_rocm",
build_file = "libpjrt_rocm.BUILD.bazel", build_file = "libpjrt_rocm.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.2.2/pjrt-rocm_linux-amd64.tar.gz", url = "https://github.com/zml/pjrt-artifacts/releases/download/v3.0.0/pjrt-rocm_linux-amd64.tar.gz",
sha256 = "dcb2f8e1fd29e3d7ba8d3018d97a060888e5bcf4847a683cb11686caa6ad9fa2", sha256 = "a7da45dfca820d3defa6de8e782cc334a3f6bdffe65fa972c048994923c2e110",
) )
return mctx.extension_metadata( return mctx.extension_metadata(

View File

@ -0,0 +1,11 @@
module(
name = "llvm-raw",
version = "20241022.0-6c4267f",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
bazel_dep(name = "rules_python", version = "0.29.0")

View File

@ -0,0 +1,11 @@
module(
name = "llvm-raw",
version = "20241022.0-6c4267f",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
bazel_dep(name = "rules_python", version = "0.29.0")

View File

@ -0,0 +1,28 @@
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
def _llvm_impl(mctx):
_targets = {}
for mod in mctx.modules:
for conf in mod.tags.configure:
for target in conf.targets:
_targets[target] = True
_llvm_configure(
name = "llvm-project",
targets = _targets.keys(),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
llvm = module_extension(
implementation = _llvm_impl,
tag_classes = {
"configure": tag_class(
attrs = {
"targets": attr.string_list(mandatory = True),
},
),
},
)

View File

@ -0,0 +1,10 @@
{
"strip_prefix": "llvm-project-6c4267fb1779bc5550bb413f33250f9365acfbc6",
"url": "https://github.com/llvm/llvm-project/archive/6c4267fb1779bc5550bb413f33250f9365acfbc6.tar.gz",
"integrity": "sha256-cBDuj+hiRvq8rtvtIfqawr0lQuDSrWFypEgeApT981Q=",
"overlay": {
"BUILD.bazel": "",
"MODULE.bazel": "",
"utils/bazel/extension.bzl": ""
}
}

View File

@ -13,6 +13,7 @@
"versions": [ "versions": [
"20240823.0-f142f8a", "20240823.0-f142f8a",
"20240919.0-94c024a", "20240919.0-94c024a",
"20241022.0-6c4267f",
], ],
"yanked_versions": {} "yanked_versions": {}
} }

View File

@ -0,0 +1,15 @@
module(
name = "stablehlo",
version = "20241021.0-1c0b606",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure(
targets = ["AArch64", "X86", "NVPTX"],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,15 @@
module(
name = "stablehlo",
version = "20241021.0-1c0b606",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "llvm-raw", version = "20241022.0-6c4267f")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure(
targets = ["AArch64", "X86", "NVPTX"],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,8 @@
{
"strip_prefix": "stablehlo-f7f8e4e35296deeff2e12e39421ac8d9599ba340",
"url": "https://github.com/openxla/stablehlo/archive/f7f8e4e35296deeff2e12e39421ac8d9599ba340.tar.gz",
"integrity": "sha256-/chY6BYV7KcrXIuqXTNIUknISCiA1JYhWE8gG5cEeV4=",
"overlay": {
"MODULE.bazel": ""
}
}

View File

@ -13,6 +13,7 @@
"versions": [ "versions": [
"20240829.0-54aa1a5", "20240829.0-54aa1a5",
"20240917.0-78c753a", "20240917.0-78c753a",
"20241021.0-1c0b606",
], ],
"yanked_versions": {} "yanked_versions": {}
} }

View File

@ -0,0 +1,34 @@
module(
name = "xla",
version = "20241025.0-4663f04",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
)

View File

@ -0,0 +1,34 @@
module(
name = "xla",
version = "20241025.0-4663f04",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "stablehlo", version = "20241021.0-1c0b606")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
)

View File

@ -0,0 +1,19 @@
load("//third_party:repo.bzl", "tf_vendored")
load("//third_party/py:python_init_repositories.bzl", "python_init_repositories")
def _tsl_impl(mctx):
python_init_repositories(
requirements = {
"3.11": "//:requirements_lock_3_11.txt",
},
)
tf_vendored(name = "tsl", relpath = "third_party/tsl")
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = ["tsl"],
root_module_direct_dev_deps = [],
)
tsl = module_extension(
implementation = _tsl_impl,
)

View File

@ -0,0 +1,52 @@
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
def _xla_workspace_impl(mctx):
cuda_configure(name = "local_config_cuda")
remote_execution_configure(name = "local_config_remote_execution")
rocm_configure(name = "local_config_rocm")
tensorrt_configure(name = "local_config_tensorrt")
tf_http_archive(
name = "com_github_grpc_grpc",
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
patch_file = [
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
"@tsl//third_party/grpc:register_go_toolchain.patch",
],
system_link_files = {
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
},
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
)
tf_http_archive(
name = "com_google_protobuf",
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
strip_prefix = "protobuf-3.21.9",
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
system_link_files = {
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
},
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
xla_workspace = module_extension(
implementation = _xla_workspace_impl,
)

View File

@ -0,0 +1,27 @@
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
From: Steeve Morin <steeve.morin@gmail.com>
Date: Mon, 18 Mar 2024 17:17:34 +0100
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
---
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
index c821916ad..89a596123 100644
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
+#ifdef __cplusplus
#include <cstddef>
#include <cstdint>
+#endif
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,14 @@
{
"strip_prefix": "xla-4663f04d2653d686833f9c306bd0d899c3127358",
"url": "https://github.com/openxla/xla/archive/4663f04d2653d686833f9c306bd0d899c3127358.tar.gz",
"integrity": "sha256-tQRn0TXkPcmKpCcIuTYM10XNB1qC/d6++4rA1mlnnKw=",
"overlay": {
"tsl.bzl": "",
"workspace.bzl": "",
"MODULE.bazel": ""
},
"patch_strip": 1,
"patches": {
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
}
}

View File

@ -13,6 +13,7 @@
"versions": [ "versions": [
"20240902.0-d18cd64", "20240902.0-d18cd64",
"20240919.0-1b18dd6", "20240919.0-1b18dd6",
"20241025.0-4663f04",
], ],
"yanked_versions": {} "yanked_versions": {}
} }

View File

@ -1236,6 +1236,10 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
.key = .{ .Const = "xla_gpu_enable_triton_gemm" }, .key = .{ .Const = "xla_gpu_enable_triton_gemm" },
.value = .{ .value = .{ .bool_field = false } }, .value = .{ .value = .{ .bool_field = false } },
}); });
// try options.env_option_overrides.append(arena, .{
// .key = .{ .Const = "xla_gpu_enable_latency_hiding_scheduler" },
// .value = .{ .value = .{ .bool_field = true } },
// });
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse { var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
log.warn("Bazel runfile not found !", .{}); log.warn("Bazel runfile not found !", .{});
break :cuda_dir; break :cuda_dir;

View File

@ -5,6 +5,7 @@ const mlir = @import("mlir");
const pjrt = @import("pjrt"); const pjrt = @import("pjrt");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const c = @import("c");
const dtype = @import("dtype.zig"); const dtype = @import("dtype.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
@ -95,7 +96,12 @@ pub const Client = opaque {
var serialized_buffer = std.ArrayList(u8).init(allocator); var serialized_buffer = std.ArrayList(u8).init(allocator);
defer serialized_buffer.deinit(); defer serialized_buffer.deinit();
dialects.stablehlo.serializePortableArtifact(bytecode.items, dialects.stablehlo.getMinimumVersion(), serialized_buffer.writer()) catch |err| {
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
var stablehlo_version_buf: [32]u8 = undefined;
const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
log.err("failed to serialize to portable artifact: {}", .{err}); log.err("failed to serialize to portable artifact: {}", .{err});
return err; return err;
}; };

View File

@ -2152,10 +2152,9 @@ pub const Tensor = struct {
.{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } }, .{ .{ .a = 10, .b = 20 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8 } },
.{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } }, .{ .{ .a = 10, .b = 20, .c = 30 }, .b, .{ .n = 8 }, .{ .a = 10, .n = 8, .c = 30 } },
// batching axes are implicits. // batching axes are implicits.
// TODO: support of batching is broken atm .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } },
// .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10 }, .{ .a = 10 } }, .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } },
// .{ .{ .a = 10, .b = 20 }, .a, .{ .b = 20 }, .{ .b = 20 } }, .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
// .{ .{ .a = 10, .b = 20 }, .b, .{ .a = 10, .n = 8 }, .{ .a = 10, .n = 8 } },
// stablehlo.gather is biased toward indices shape (like gatherSlice). // stablehlo.gather is biased toward indices shape (like gatherSlice).
// This make it awkward to use when you have both batching dimension and new indices dimensions. // This make it awkward to use when you have both batching dimension and new indices dimensions.
// For now we reject those, and let user explicitly transpose self or indices if needed. // For now we reject those, and let user explicitly transpose self or indices if needed.
@ -2295,9 +2294,8 @@ pub const Tensor = struct {
.{ .{ .a = 10, .b = 20 }, .{ .b = 17, .a = 7 }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 7, .b = 17 } }, .{ .{ .a = 10, .b = 20 }, .{ .b = 17, .a = 7 }, .{ .n = 8, ._ = 2 }, .{ .n = 8, .a = 7, .b = 17 } },
.{ .{ .a = 10, .b = 20, .c = 20 }, .{ .b = 17 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 10, .b = 17, .c = 20 } }, .{ .{ .a = 10, .b = 20, .c = 20 }, .{ .b = 17 }, .{ .n = 8, ._ = 1 }, .{ .n = 8, .a = 10, .b = 17, .c = 20 } },
// batching dims // batching dims
// TODO: support of batching is broken atm .{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .a = 10, ._ = 1 }, .{ .a = 10, .b = 17 } },
// .{ .{ .a = 10, .b = 20 }, .{ .b = 17 }, .{ .a = 10, ._ = 1 }, .{ .a = 10, .b = 17 } }, .{ .{ .b = 200, .a = 100, .c = 300 }, .{ .c = 300 }, .{ .a = 100, .b = 200, ._ = 1 }, .{ .a = 100, .b = 200, .c = 300 } },
// .{ .{ .b = 200, .a = 100, .c = 300 }, .{ .c = 300 }, .{ .a = 100, .b = 200, ._ = 1 }, .{ .a = 100, .b = 200, .c = 300 } },
}) |testcase| { }) |testcase| {
const x_shape, const slice_dims, const idx_shape, const res_shape = testcase; const x_shape, const slice_dims, const idx_shape, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 }); const x = Tensor.constant(x_shape, .{ .f16 = 0 });
@ -2591,8 +2589,7 @@ pub const Tensor = struct {
); );
defer values.deinit(); defer values.deinit();
// TODO: support of batching is broken atm const result = try zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values });
const result = zml.testing.compileAndCall(platform, Local.scatter, .{ operand, operand.shape().axes(.{ .c, .b }), start_indices, values }) catch return error.SkipZigTest;
const expected = [2][3][4][2]u16{ const expected = [2][3][4][2]u16{
.{ .{