Add initial Bazel build configuration, async runtime implementation, and core MLIR dialect definitions for ZML.

This commit is contained in:
Tarry Singh 2023-01-02 14:28:25 +00:00
commit 266da6d4be
173 changed files with 40644 additions and 0 deletions

0
BUILD.bazel Normal file
View File

79
MODULE.bazel Normal file
View File

@ -0,0 +1,79 @@
module(
name = "zml",
)
new_git_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.0")
bazel_dep(name = "patchelf", version = "0.18.0")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_pkg", version = "1.0.1")
bazel_dep(name = "rules_proto", version = "6.0.2")
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1.1")
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains", dev_dependency = True)
use_repo(bazel_lib_toolchains, "jq_toolchains")
toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolchains")
use_repo(toolchains, "zig_sdk")
bazel_dep(name = "rules_zig", version = "20240913.0-1957d05")
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
zig.index(file = "//bazel:zig_index.json")
zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf")
zig.mirrors(urls = [
"https://mirror.zml.ai/zig",
])
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
register_toolchains(
"@zig_sdk//toolchain:linux_amd64_gnu.2.31",
"@zig_sdk//toolchain:linux_arm64_gnu.2.31",
)
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
use_repo(cpu, "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
use_repo(cuda, "libpjrt_cuda")
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
use_repo(rocm, "libpjrt_rocm")
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
use_repo(tpu, "libpjrt_tpu")
zls = use_extension("//third_party/zls:zls.bzl", "repo")
use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
register_toolchains("//third_party/zls:all")
bazel_dep(name = "libxev", version = "20240910.0-a2d9b31")
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure(
targets = [
"AArch64",
"X86",
"NVPTX",
],
)
use_repo(llvm, "llvm-project")
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
bazel_dep(name = "xla", version = "20240902.0-d18cd64")
tsl = use_extension("@xla//:tsl.bzl", "tsl")
use_repo(tsl, "tsl")
bazel_dep(name = "zigcoro", version = "20240829.0-fc1db29")
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
bazel_dep(name = "zig-protobuf", version = "20240722.0-c644d11")
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")

1305
MODULE.bazel.lock Normal file

File diff suppressed because it is too large Load Diff

11
async/BUILD.bazel Normal file
View File

@ -0,0 +1,11 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "async",
main = "async.zig",
visibility = ["//visibility:public"],
deps = [
"@zigcoro//:libcoro",
"@libxev//:xev",
],
)

400
async/async.zig Normal file
View File

@ -0,0 +1,400 @@
const std = @import("std");
const xev = @import("xev");
const libcoro = @import("libcoro");
const aio = libcoro.asyncio;
/// Normalize from a real tuple to a generic tuple. This is needed because
/// real tuples are reifed tuples are not the same.
fn NormalizedTuple(comptime T: type) type {
const ti = @typeInfo(T).Struct;
var types: [ti.fields.len]type = undefined;
inline for (ti.fields, 0..) |field, i| {
types[i] = field.type;
}
return std.meta.Tuple(&types);
}
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
return struct {
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
pub const ArgsT = blk: {
if (@typeInfo(FuncT).Fn.params.len == 0) {
break :blk @TypeOf(.{});
}
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
};
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload,
else => ReturnT,
};
};
pub const ReturnErrorSet: ?type = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
};
};
}
pub fn Frame(comptime func: anytype) type {
const Signature = FnSignature(func, null);
return FrameEx(func, Signature.ArgsT);
}
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
return FrameExx(func, argsT);
}
fn FrameExx(comptime func: anytype, comptime argsT: type) type {
return struct {
const Self = @This();
const Signature = FnSignature(func, argsT);
const FrameT = libcoro.FrameT(func, .{ .ArgsT = Signature.ArgsT });
inner: FrameT,
pub fn await_(self: *Self) Signature.ReturnT {
defer {
self.inner.deinit();
self.* = undefined;
}
return libcoro.xawait(self.inner);
}
fn from(other: anytype) !Self {
return .{ .inner = FrameT.wrap(other.frame()) };
}
};
}
pub fn async_(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
const frame = try aio.xasync(func, args, null);
return FrameEx(func, @TypeOf(args)).from(frame);
}
pub fn call(comptime func: anytype, args: std.meta.ArgsTuple(@TypeOf(func))) @TypeOf(callGeneric(func, args)) {
return callGeneric(func, args);
}
pub fn callGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
const Signature = FnSignature(func, @TypeOf(args));
const TaskT = struct {
const Self = @This();
_task: xev.ThreadPool.Task = .{ .callback = &Self.run },
notif: Notification,
args: *const Signature.ArgsT,
result: Signature.ReturnT = undefined,
pub fn run(task_: *xev.ThreadPool.Task) void {
const task: *Self = @alignCast(@fieldParentPtr("_task", task_));
task.result = @call(.auto, func, task.args.*);
task.notif.notify() catch @panic("Unable to notify");
}
};
var newtask: TaskT = .{
.notif = Notification.init() catch @panic("Notification.init failed"),
.args = &args,
};
defer newtask.notif.deinit();
AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&newtask._task));
newtask.notif.wait() catch @panic("Unable to wait for notification");
return newtask.result;
}
pub fn tick() void {
AsyncThread.current.executor.exec.tick();
}
pub fn sleep(ms: u64) !void {
try aio.sleep(AsyncThread.current.executor, ms);
}
pub const Notification = struct {
inner: aio.AsyncNotification,
pub fn init() !Notification {
return .{
.inner = aio.AsyncNotification.init(AsyncThread.current.executor, try xev.Async.init()),
};
}
pub fn notify(self: *Notification) !void {
try self.inner.notif.notify();
}
pub fn wait(self: *Notification) !void {
try self.inner.wait();
}
pub fn deinit(self: *Notification) void {
self.inner.notif.deinit();
self.* = undefined;
}
};
pub const AsyncThread = struct {
threadlocal var current: AsyncThread = undefined;
executor: *aio.Executor,
loop: *xev.Loop,
thread_pool: *xev.ThreadPool,
pub fn main(allocator: std.mem.Allocator, comptime func: anytype, args: anytype) !FnSignature(func, NormalizedTuple(@TypeOf(args))).ReturnPayloadT {
const Signature = FnSignature(func, NormalizedTuple(@TypeOf(args)));
var thread_pool = xev.ThreadPool.init(.{});
defer {
thread_pool.shutdown();
thread_pool.deinit();
}
var loop = try xev.Loop.init(.{
.thread_pool = &thread_pool,
});
defer loop.deinit();
var executor = aio.Executor.init(&loop);
AsyncThread.current = .{
.executor = &executor,
.loop = &loop,
.thread_pool = &thread_pool,
};
aio.initEnv(.{
.stack_allocator = allocator,
.default_stack_size = 16 * 1024 * 1024,
});
if (Signature.ReturnErrorSet) |_| {
return try aio.run(&executor, func, args, null);
} else {
return aio.run(&executor, func, args, null);
}
}
};
pub fn StdIn() !File {
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
}
pub fn StdOut() File {
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
}
pub fn StdErr() File {
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
}
pub const File = struct {
pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?;
pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?;
pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read);
pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write);
pub const SeekableStream = std.io.SeekableStream(
File,
SeekError,
GetSeekPosError,
seekTo,
seekBy,
getPos,
getEndPos,
);
inner: aio.File,
fn asFile(self: File) std.fs.File {
return .{ .handle = self.inner.file.fd };
}
pub fn init(file_: std.fs.File) !File {
return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)) };
}
pub fn fromFd(fd: std.fs.File.Handle) !File {
return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.initFd(fd)) };
}
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
return init(try call(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags }));
}
pub fn read(self: File, buf: []u8) !usize {
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
if (buf.len == 0) return 0;
return self.inner.read(.{ .slice = buf }) catch |err| switch (err) {
// NOTE(Corentin): read shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File.
error.EOF => 0,
else => err,
};
}
pub fn pread(self: File, buf: []u8, offset: u64) !usize {
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
if (buf.len == 0) return 0;
return self.inner.pread(.{ .slice = buf }, offset) catch |err| switch (err) {
// NOTE(Corentin): pread shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File.
error.EOF => 0,
else => err,
};
}
pub fn write(self: File, buf: []const u8) !usize {
return self.inner.write(.{ .slice = buf });
}
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
return self.inner.pwrite(.{ .slice = buf }, offset);
}
pub fn close(self: File) !void {
return self.inner.close();
}
pub fn reader(self: File) Reader {
return .{ .context = self };
}
pub fn seekableStream(file: File) SeekableStream {
return .{ .context = file };
}
pub fn writer(self: File) Writer {
return .{ .context = self };
}
pub fn stat(self: File) !std.fs.File.Stat {
return try call(std.fs.File.stat, .{self.asFile()});
}
pub fn seekBy(self: File, offset: i64) !void {
try call(std.fs.File.seekBy, .{ self.asFile(), offset });
}
pub fn seekTo(self: File, offset: u64) !void {
try call(std.fs.File.seekTo, .{ self.asFile(), offset });
}
pub fn getPos(self: File) !u64 {
return try call(std.fs.File.getPos, .{self.asFile()});
}
pub fn getEndPos(self: File) !u64 {
return try call(std.fs.File.getEndPos, .{self.asFile()});
}
};
pub const Socket = struct {
pub const TCP = struct {
pub const Reader = std.io.GenericReader(TCP, FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read);
pub const Writer = std.io.GenericWriter(TCP, FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write);
inner: aio.TCP,
pub fn init(addr: std.net.Address) !TCP {
return .{ .inner = aio.TCP.init(AsyncThread.current.executor, try xev.TCP.init(addr)) };
}
pub fn deinit(self: *TCP) void {
self.inner.shutdown();
}
pub fn connect(self: *TCP, addr: std.net.Address) !void {
return self.inner.connect(addr);
}
pub fn read(self: *TCP, buf: []u8) !usize {
return self.inner.read(.{ .slice = buf });
}
pub fn write(self: *TCP, buf: []const u8) !usize {
return self.inner.write(.{ .slice = buf });
}
pub fn close(self: *TCP) !void {
defer self.* = undefined;
return self.inner.close();
}
pub fn reader(self: File) Reader {
return .{ .context = self };
}
pub fn writer(self: File) Writer {
return .{ .context = self };
}
};
pub const UDP = struct {
pub const Reader = std.io.GenericReader(UDP, FnSignature(UDP.read, null).ReturnErrorSet.?, UDP.read);
pub const WriterContext = struct {
file: UDP,
addr: std.net.Address,
};
pub const Writer = std.io.GenericWriter(WriterContext, FnSignature(UDP.write, null).ReturnErrorSet.?, struct {
fn call(self: WriterContext, buf: []const u8) !usize {
return self.file.write(self.addr, buf);
}
}.call);
inner: aio.UDP,
pub fn init(addr: std.net.Address) !UDP {
return .{ .inner = aio.UDP.init(AsyncThread.current.executor, try xev.UDP.init(addr)) };
}
pub fn read(self: UDP, buf: []u8) !usize {
return self.inner.read(.{ .slice = buf });
}
pub fn write(self: UDP, addr: std.net.Address, buf: []const u8) !usize {
return self.inner.write(addr, .{ .slice = buf });
}
pub fn close(self: *UDP) !void {
defer self.* = undefined;
return self.inner.close();
}
pub fn reader(self: File) Reader {
return .{ .context = self };
}
pub fn writer(self: File, addr: std.net.Address) Writer {
return .{
.context = .{
.file = self,
.addr = addr,
},
};
}
};
};
pub const Mutex = struct {
const VoidChannel = libcoro.Channel(void, .{ .capacity = 1 });
inner: VoidChannel,
pub fn init() Mutex {
return .{ .inner = VoidChannel.init(&AsyncThread.current.executor.exec) };
}
pub fn lock(self: *Mutex) !void {
try self.inner.send({});
}
pub fn unlock(self: *Mutex) void {
_ = self.inner.recv();
}
};

0
bazel/BUILD.bazel Normal file
View File

96
bazel/cc_import.bzl Normal file
View File

@ -0,0 +1,96 @@
load("@bazel_tools//tools/build_defs/cc:cc_import.bzl", _cc_import = "cc_import")
load(":patchelf.bzl", "patchelf")
def _cc_import_runfiles_impl(ctx):
runfiles = ctx.runfiles(files = ctx.files.data)
transitive_runfiles_list = []
if ctx.attr.static_library:
transitive_runfiles_list.append(ctx.attr.static_library[DefaultInfo].default_runfiles)
if ctx.attr.pic_static_library:
transitive_runfiles_list.append(ctx.attr.pic_static_library[DefaultInfo].default_runfiles)
if ctx.attr.shared_library:
transitive_runfiles_list.append(ctx.attr.shared_library[DefaultInfo].default_runfiles)
if ctx.attr.interface_library:
transitive_runfiles_list.append(ctx.attr.interface_library[DefaultInfo].default_runfiles)
for dep in ctx.attr.deps:
transitive_runfiles_list.append(dep[DefaultInfo].default_runfiles)
for maybe_runfiles in transitive_runfiles_list:
if maybe_runfiles:
runfiles = runfiles.merge(maybe_runfiles)
default_info = DefaultInfo(runfiles = runfiles)
return [ctx.attr.src[CcInfo], default_info]
_cc_import_runfiles = rule(
implementation = _cc_import_runfiles_impl,
attrs = {
"src": attr.label(providers = [CcInfo]),
"static_library": attr.label(allow_single_file = [".a", ".lib"]),
"pic_static_library": attr.label(allow_single_file = [".pic.a", ".pic.lib"]),
"shared_library": attr.label(allow_single_file = True),
"interface_library": attr.label(allow_single_file = [".ifso", ".tbd", ".lib", ".so", ".dylib"]),
"data": attr.label_list(allow_files = True),
"deps": attr.label_list(),
},
)
def cc_import(
name,
static_library = None,
pic_static_library = None,
shared_library = None,
interface_library = None,
data = None,
deps = None,
visibility = None,
soname = None,
add_needed = None,
remove_needed = None,
replace_needed = None,
**kwargs):
if shared_library and (soname or add_needed or remove_needed or replace_needed):
patched_name = "{}_patchelf".format(name)
patchelf(
name = patched_name,
shared_library = shared_library,
soname = soname,
add_needed = add_needed,
remove_needed = remove_needed,
replace_needed = replace_needed,
)
shared_library = ":" + patched_name
if data:
_cc_import(
name = name + "_no_runfiles",
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data,
deps = deps,
**kwargs
)
_cc_import_runfiles(
name = name,
src = ":{}_no_runfiles".format(name),
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data,
deps = deps,
visibility = visibility,
)
else:
_cc_import(
name = name,
static_library = static_library,
pic_static_library = pic_static_library,
shared_library = shared_library,
interface_library = interface_library,
data = data,
deps = deps,
visibility = visibility,
**kwargs
)

View File

@ -0,0 +1,45 @@
load(
"@bazel_tools//tools/build_defs/repo:utils.bzl",
"get_auth",
"patch",
"workspace_and_buildfile",
)
def _http_deb_archive_impl(rctx):
if rctx.attr.build_file and rctx.attr.build_file_content:
fail("Only one of build_file and build_file_content can be provided.")
download_info = rctx.download_and_extract(
url = rctx.attr.urls,
output = "tmp",
sha256 = rctx.attr.sha256,
type = "deb",
stripPrefix = "",
canonical_id = " ".join(rctx.attr.urls),
auth = get_auth(rctx, rctx.attr.urls),
)
for ext in ["gz", "xz", "zst"]:
data = "tmp/data.tar.{}".format(ext)
if rctx.path(data).exists:
rctx.extract(
archive = data,
output = "",
stripPrefix = rctx.attr.strip_prefix,
)
rctx.delete("tmp")
break
workspace_and_buildfile(rctx)
patch(rctx)
http_deb_archive = repository_rule(
_http_deb_archive_impl,
attrs = {
"urls": attr.string_list(mandatory = True),
"sha256": attr.string(mandatory = True),
"strip_prefix": attr.string(),
"build_file": attr.label(allow_single_file = True),
"build_file_content": attr.string(),
"workspace_file": attr.label(allow_single_file = True),
"workspace_file_content": attr.string(),
},
)

173
bazel/huggingface.bzl Normal file
View File

@ -0,0 +1,173 @@
load(
"@bazel_tools//tools/build_defs/repo:utils.bzl",
"patch",
"workspace_and_buildfile",
)
TREE_URL_TEMPLATE = "https://huggingface.co/api/models/{model}/tree/{commit}/{path}"
RAW_FILE_URL_REMPLATE = "https://huggingface.co/{model}/raw/{commit}/{path}"
LFS_FILE_URL_TEMPLATE = "https://huggingface.co/{model}/resolve/{commit}/{path}"
def _glob(rctx, str, patterns):
cmd = "\n".join([
"""[[ "{str}" = {pattern} ]] && exit 0""".format(str = str, pattern = pattern)
for pattern in patterns
] + ["exit 1"])
return rctx.execute(["bash", "-c", cmd]).return_code == 0
def _ls(rctx, headers, path):
url = TREE_URL_TEMPLATE.format(
model = rctx.attr.model,
commit = rctx.attr.commit,
path = path,
)
rctx.download(url, path + ".index.json", headers = headers)
ret = json.decode(rctx.read(path + ".index.json"))
rctx.delete(path + ".index.json")
return ret
def _get_token_via_env(rctx):
return rctx.getenv("HUGGINGFACE_TOKEN")
def _get_token_via_file(rctx):
p = rctx.path(rctx.getenv("HOME") + "/.cache/huggingface/token")
if p.exists:
return rctx.read(p)
def _get_token_via_git_credentials(rctx):
input = """\
protocol=https
host=huggingface.co
"""
res = rctx.execute(["bash", "-c", "echo '{}' | git credential fill".format(input)])
if res.return_code != 0:
return None
for line in res.stdout.split("\n"):
if line.startswith("password="):
return line[len("password="):]
return None
def _get_token(rctx):
t = _get_token_via_env(rctx) or \
_get_token_via_file(rctx) or \
_get_token_via_git_credentials(rctx)
if t:
return t.strip()
def _huggingface_repository_impl(rctx):
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip, deflate",
}
token = _get_token(rctx)
if token:
headers["Authorization"] = "Bearer " + token
includes = rctx.attr.includes
excludes = rctx.attr.excludes
stack = [""]
downloads = []
for _ in range(9999999):
if (not stack):
break
path = stack.pop()
for entry in _ls(rctx, headers, path):
if entry["type"] == "directory":
stack.append(entry["path"])
elif entry["type"] == "file":
if (excludes and _glob(rctx, entry["path"], excludes)):
continue
if (not includes or _glob(rctx, entry["path"], includes)):
tpl = RAW_FILE_URL_REMPLATE
if ("lfs" in entry):
tpl = LFS_FILE_URL_TEMPLATE
url = tpl.format(
model = rctx.attr.model,
commit = rctx.attr.commit,
path = entry["path"],
)
downloads.append(rctx.download(
url = url,
output = entry["path"],
canonical_id = entry["oid"],
headers = headers,
block = False,
))
for download in downloads:
download.wait()
workspace_and_buildfile(rctx)
patch(rctx)
huggingface_repository = repository_rule(
implementation = _huggingface_repository_impl,
attrs = {
"model": attr.string(mandatory = True),
"commit": attr.string(mandatory = True),
"includes": attr.string_list(default = []),
"excludes": attr.string_list(default = []),
"patches": attr.label_list(),
"patch_tool": attr.string(default = ""),
"patch_args": attr.string_list(default = ["-p0"]),
"patch_cmds": attr.string_list(default = []),
"patch_cmds_win": attr.string_list(default = []),
"build_file": attr.label(allow_single_file = True),
"build_file_content": attr.string(),
"workspace_file": attr.label(allow_single_file = True),
"workspace_file_content": attr.string(),
},
)
def _huggingface_impl(mctx):
for mod in mctx.modules:
for model in mod.tags.model:
huggingface_repository(
name = model.name,
model = model.model,
commit = model.commit,
includes = model.includes,
excludes = model.excludes,
patches = model.patches,
patch_tool = model.patch_tool,
patch_args = model.patch_args,
patch_cmds = model.patch_cmds,
patch_cmds_win = model.patch_cmds_win,
build_file = model.build_file,
build_file_content = model.build_file_content,
workspace_file = model.workspace_file,
workspace_file_content = model.workspace_file_content,
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
huggingface = module_extension(
implementation = _huggingface_impl,
tag_classes = {
"model": tag_class(
attrs = {
"name": attr.string(mandatory = True),
"model": attr.string(mandatory = True),
"commit": attr.string(mandatory = True),
"includes": attr.string_list(default = []),
"excludes": attr.string_list(default = []),
"patches": attr.label_list(),
"patch_tool": attr.string(default = ""),
"patch_args": attr.string_list(default = ["-p0"]),
"patch_cmds": attr.string_list(default = []),
"patch_cmds_win": attr.string_list(default = []),
"build_file": attr.label(allow_single_file = True),
"build_file_content": attr.string(),
"workspace_file": attr.label(allow_single_file = True),
"workspace_file_content": attr.string(),
},
),
},
)

58
bazel/patchelf.bzl Normal file
View File

@ -0,0 +1,58 @@
def _render_kv(e):
return e
def _patchelf_impl(ctx):
output_name = ctx.file.shared_library.basename
if ctx.attr.soname:
output_name = ctx.attr.soname
output = ctx.actions.declare_file("{}/{}".format(ctx.attr.name, output_name))
commands = [
"set -e",
'cp -f "$2" "$3"',
'chmod +w "$3"',
]
if ctx.attr.soname:
commands.append('"$1" --set-soname "{}" "$3"'.format(ctx.attr.soname))
if ctx.attr.remove_needed:
for v in ctx.attr.remove_needed:
commands.append('"$1" --remove-needed "{}" "$3"'.format(v))
if ctx.attr.add_needed:
for v in ctx.attr.add_needed:
commands.append('"$1" --add-needed "{}" "$3"'.format(v))
if ctx.attr.replace_needed:
for k, v in ctx.attr.replace_needed.items():
commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v))
ctx.actions.run_shell(
inputs = [ctx.file.shared_library],
outputs = [output],
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
command = "\n".join(commands),
tools = [ctx.executable._patchelf],
)
return [
DefaultInfo(
files = depset([output]),
),
]
patchelf = rule(
implementation = _patchelf_impl,
attrs = {
"shared_library": attr.label(allow_single_file = True, mandatory = True),
"soname": attr.string(),
"add_needed": attr.string_list(),
"remove_needed": attr.string_list(),
"replace_needed": attr.string_dict(),
"_patchelf": attr.label(
default = "@patchelf",
allow_single_file = True,
executable = True,
cfg = "exec",
),
},
)

39
bazel/zig.bzl Normal file
View File

@ -0,0 +1,39 @@
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary")
def zig_cc_binary(name, args = None, env = None, data = [], deps = [], visibility = None, **kwargs):
zig_binary(
name = "{}_lib".format(name),
kind = BINARY_KIND.static_lib,
deps = deps + [
"@rules_zig//zig/lib:libc",
],
**kwargs
)
native.cc_binary(
name = name,
args = args,
env = env,
data = data,
deps = [":{}_lib".format(name)],
visibility = visibility,
)
def zig_cc_test(name, env = None, data = [], deps = [], test_runner = None, visibility = None, **kwargs):
zig_binary(
name = "{}_test_lib".format(name),
kind = BINARY_KIND.test_lib,
test_runner = test_runner,
data = data,
deps = deps + [
"@rules_zig//zig/lib:libc",
],
**kwargs
)
native.cc_test(
name = name,
env = env,
data = data,
deps = [":{}_test_lib".format(name)],
visibility = visibility,
linkstatic = True,
)

1090
bazel/zig_index.json Normal file

File diff suppressed because it is too large Load Diff

109
bazel/zig_proto_library.bzl Normal file
View File

@ -0,0 +1,109 @@
"""Starlark implementation of zig_proto_library"""
load("@rules_proto//proto:defs.bzl", "proto_common")
load(
"@rules_zig//zig/private/providers:zig_module_info.bzl",
"ZigModuleInfo",
"zig_module_info",
)
def _zig_proto_library_impl(ctx):
if len(ctx.attr.deps) != 1:
fail("zig_proto_library '{}' requires exactly one 'deps'".format(ctx.label.name))
(dep,) = ctx.attr.deps
# the aspect already generated for us the Zig module
# we just change the import name to make it match what the user chose.
module = dep[ZigModuleInfo]
import_name = ctx.attr.import_name or ctx.label.name
if import_name:
keys = ["canonical_name", "cdeps", "copts", "deps", "extra_srcs", "linkopts", "main", "srcs"]
args = {k: getattr(module, k) for k in keys}
args["name"] = import_name
module = zig_module_info(**args)
return [module]
def zig_proto_library_aspect_impl(target, ctx):
"""
For each `.proto` in the given target dependencies,
generate a `.pb.zig` file, and a `zig_module` to import it.
"""
toolchain = ctx.attr._zig_proto_toolchain[proto_common.ProtoLangToolchainInfo]
proto_info = target[ProtoInfo]
# assert len(proto_info.direct_sources) == 1, "Can only compile .proto files one by one"
(proto_src,) = proto_info.direct_sources
pb_zig_name = proto_src.basename[:-len(".proto")] + ".pb.zig"
zig_src = ctx.actions.declare_file(pb_zig_name, sibling = proto_src)
ctx.actions.run
proto_common.compile(
ctx.actions,
proto_info = proto_info,
proto_lang_toolchain_info = toolchain,
generated_files = [zig_src],
)
zig_proto_modules = [p[ZigModuleInfo] for p in ctx.rule.attr.deps]
import_name = get_import_name(target, proto_src)
module = zig_module_info(
name = import_name,
canonical_name = str(target.label),
main = zig_src,
srcs = [],
extra_srcs = [],
copts = [],
linkopts = [],
deps = [toolchain.runtime[ZigModuleInfo]] + zig_proto_modules,
cdeps = [],
)
return [module]
def get_import_name(target, proto_src):
"""
When the Zig protoc plugin is generating .pb.zig files,
it generates import names based on the path received from protoc.
We need to create Zig modules with the same name.
"""
name = str(target.label)
# special handling of builtin types
if "com_google_protobuf//:" in name:
name = "google_protobuf_" + proto_src.basename
else:
name = name.rsplit("//")[-1]
name = name.rsplit("//")[-1]
return name.replace(".", "_").replace(":", "_").replace("/", "_")
zig_proto_library_aspect = aspect(
attrs = {
"_zig_proto_toolchain": attr.label(
default = "@zig-protobuf//:zig_toolchain",
providers = [proto_common.ProtoLangToolchainInfo],
),
},
implementation = zig_proto_library_aspect_impl,
provides = [ZigModuleInfo],
attr_aspects = ["deps"],
)
zig_proto_library = rule(
doc = """
Converts a single `proto_library` into a zig module.
""",
implementation = _zig_proto_library_impl,
attrs = {
"deps": attr.label_list(
aspects = [zig_proto_library_aspect],
providers = [ProtoInfo],
),
"import_name": attr.string(
doc = "The import name of the Zig module.",
default = "",
),
},
provides = [ZigModuleInfo],
)

0
build.zig Normal file
View File

39
mlir/BUILD.bazel Normal file
View File

@ -0,0 +1,39 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
cc_library(
name = "mlirx",
srcs = ["mlirx.cc"],
hdrs = ["mlirx.h"],
includes = ["."],
deps = [
"@llvm-project//mlir:CAPIIR",
],
)
cc_library(
name = "c",
hdrs = ["c.h"],
visibility = ["//mlir:__subpackages__"],
deps = [
"@llvm-project//mlir:CAPIArith",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:CAPIMath",
"@llvm-project//mlir:CAPITransforms",
],
)
zig_library(
name = "mlir",
main = "mlir.zig",
visibility = ["//visibility:public"],
deps = [
":c",
":mlirx",
],
)
zig_cc_test(
name = "test",
deps = [":mlir"],
)

8
mlir/c.h Normal file
View File

@ -0,0 +1,8 @@
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Dialect/Arith.h>
#include <mlir-c/Dialect/Func.h>
#include <mlir-c/Dialect/Math.h>
#include <mlir-c/IR.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Transforms.h>

41
mlir/dialects/BUILD.bazel Normal file
View File

@ -0,0 +1,41 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
zig_library(
name = "dialects",
srcs = [
"arith.zig",
"func.zig",
"math.zig",
"tensor.zig",
],
import_name = "mlir/dialects",
main = "dialects.zig",
visibility = ["//visibility:public"],
deps = [
":stablehlo",
"//mlir",
],
)
zig_cc_test(
name = "test",
deps = [":dialects"],
)
zig_library(
name = "stablehlo",
import_name = "mlir/dialects/stablehlo",
main = "stablehlo.zig",
visibility = ["//mlir/dialects:__subpackages__"],
deps = [
"//mlir",
"//mlir:c",
"@stablehlo//:stablehlo_capi",
],
)
zig_cc_test(
name = "stablehlo_test",
deps = [":stablehlo"],
)

109
mlir/dialects/arith.zig Normal file
View File

@ -0,0 +1,109 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn constant(ctx: mlir.Context, value: mlir.Attribute, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "arith.constant", .{
.attributes = &.{
.{ "value", value },
},
.result_type_inference = true,
.location = location,
});
}
fn binary_fn(comptime op_name: [:0]const u8) fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation {
return struct {
pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, op_name, .{
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.location = location,
});
}
}.call;
}
fn cast_fn(comptime op_name: [:0]const u8) fn (mlir.Context, mlir.Value, mlir.Type, mlir.Location) mlir.Operation {
return struct {
pub fn call(ctx: mlir.Context, value: mlir.Value, new_type: mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, op_name, .{
.operands = &.{value},
.results = &.{new_type},
.location = location,
});
}
}.call;
}
pub const addi = binary_fn("arith.addi");
pub const addf = binary_fn("arith.addf");
pub const subi = binary_fn("arith.subi");
pub const subf = binary_fn("arith.subf");
pub const muli = binary_fn("arith.muli");
pub const mulf = binary_fn("arith.mulf");
pub const divsi = binary_fn("arith.divsi");
pub const divui = binary_fn("arith.divui");
pub const divf = binary_fn("arith.divf");
pub const extsi = cast_fn("arith.extsi");
pub const extui = cast_fn("arith.extui");
pub const extf = cast_fn("arith.extf");
pub const trunci = cast_fn("arith.trunci");
pub const truncf = cast_fn("arith.truncf");
pub const fptosi = cast_fn("arith.fptosi");
pub const fptoui = cast_fn("arith.fptoui");
pub const sitofp = cast_fn("arith.sitofp");
pub const uitofp = cast_fn("arith.uitofp");
pub const CmpIPredicate = enum {
eq,
ne,
slt,
sle,
sgt,
sge,
ult,
ule,
ugt,
uge,
};
pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "arith.cmpi", .{
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
},
.location = location,
});
}
pub const CmpFPredicate = enum {
false,
oeq,
ogt,
oge,
olt,
ole,
one,
ord,
ueq,
ugt,
uge,
ult,
ule,
une,
uno,
true,
};
pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "arith.cmpf", .{
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
},
.location = location,
});
}

View File

@ -0,0 +1,11 @@
const std = @import("std");
pub const arith = @import("arith.zig");
pub const func = @import("func.zig");
pub const math = @import("math.zig");
pub const tensor = @import("tensor.zig");
pub const stablehlo = @import("mlir/dialects/stablehlo");
test {
std.testing.refAllDecls(@This());
}

45
mlir/dialects/func.zig Normal file
View File

@ -0,0 +1,45 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn func(
ctx: mlir.Context,
args: struct {
sym_name: [:0]const u8,
args: []const mlir.Type,
arg_attrs: []const mlir.Attribute = &.{},
results: []const mlir.Type,
block: mlir.Block,
location: mlir.Location,
},
) mlir.Operation {
const AttrTuple = struct { [:0]const u8, mlir.Attribute };
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 3){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
if (args.arg_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
}
return mlir.Operation.make(ctx, "func.func", .{
.blocks = &.{args.block},
.attributes = attrs_tuple_buffer.constSlice(),
.location = args.location,
});
}
pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, results: []const mlir.Type, loc: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "func.call", .{
.variadic_operands = &.{values},
.results = results,
.verify = true,
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }},
.location = loc,
});
}
pub fn return_(ctx: mlir.Context, values: []const mlir.Value, loc: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "func.return", .{
.operands = values,
.verify = false,
.location = loc,
});
}

35
mlir/dialects/math.zig Normal file
View File

@ -0,0 +1,35 @@
const std = @import("std");
const mlir = @import("mlir");
const namespace = "math";
fn unary_fn(comptime op_name: [:0]const u8) type {
return struct {
pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
.operands = &.{value},
.results = &.{},
.location = location,
});
}
};
}
fn binary_fn(comptime op_name: [:0]const u8) type {
return struct {
pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
.operands = &.{ lhs, rhs },
.results = &.{},
.location = location,
});
}
};
}
pub const ipowi = binary_fn("ipowi").call;
pub const fpowi = binary_fn("fpowi").call;
pub const tanh = unary_fn("tanh").call;
pub const sqrt = unary_fn("sqrt").call;
pub const exp = unary_fn("exp").call;
pub const log = unary_fn("log").call;

1330
mlir/dialects/stablehlo.zig Normal file

File diff suppressed because it is too large Load Diff

36
mlir/dialects/tensor.zig Normal file
View File

@ -0,0 +1,36 @@
const std = @import("std");
const mlir = @import("mlir");
pub fn empty(ctx: mlir.Context, args: struct {
result: mlir.Type,
location: mlir.Location,
}) mlir.Operation {
return mlir.Operation.make(ctx, "tensor.empty", .{
.results = &.{args.result},
.location = args.location,
});
}
pub fn splat(ctx: mlir.Context, args: struct {
value: mlir.Value,
result: mlir.Type,
location: mlir.Location,
}) mlir.Operation {
return mlir.Operation.make(ctx, "tensor.splat", .{
.operands = &.{args.value},
.results = &.{args.result},
.location = args.location,
});
}
pub fn cast(ctx: mlir.Context, args: struct {
source: mlir.Value,
dest: mlir.Type,
location: mlir.Location,
}) mlir.Operation {
return mlir.Operation.make(ctx, "tensor.cast", .{
.operands = &.{args.source},
.results = &.{args.dest},
.location = args.location,
});
}

1933
mlir/mlir.zig Normal file

File diff suppressed because it is too large Load Diff

27
mlir/mlirx.cc Normal file
View File

@ -0,0 +1,27 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlirx.h"
namespace mlirx {
static mlir::Attribute ArrayToElements(mlir::Attribute attr) {
if (auto array = attr.dyn_cast<mlir::DenseI64ArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
if (auto array = attr.dyn_cast<mlir::DenseBoolArrayAttr>()) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(array.size(), array.getElementType()),
array.asArrayRef());
}
return attr;
}
}
MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) {
return wrap(mlirx::ArrayToElements(unwrap(attr)));
}

16
mlir/mlirx.h Normal file
View File

@ -0,0 +1,16 @@
#ifndef MLIRX_CC_H
#define MLIRX_CC_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr);
#ifdef __cplusplus
}
#endif
#endif // MLIRX_CC_H

30
pjrt/BUILD.bazel Normal file
View File

@ -0,0 +1,30 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
cc_library(
name = "dlfcn",
hdrs = ["dlfcn.h"],
)
zig_library(
name = "pjrt",
srcs = ["profiler.zig"],
main = "pjrt.zig",
visibility = ["//visibility:public"],
deps = [
":profiler_options_proto",
"//runtimes",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
] + select({
"@platforms//os:linux": [":dlfcn"],
"//conditions:default": [],
}),
)
zig_proto_library(
name = "profiler_options_proto",
import_name = "//tsl:profiler_options_proto",
deps = ["@tsl//tsl/profiler/protobuf:profiler_options_proto"],
)

1
pjrt/dlfcn.h Normal file
View File

@ -0,0 +1 @@
#include <dlfcn.h>

886
pjrt/pjrt.zig Normal file
View File

@ -0,0 +1,886 @@
const builtin = @import("builtin");
const std = @import("std");
const c = @import("c");
const log = std.log.scoped(.pjrt);
pub const Profiler = @import("profiler.zig").Profiler;
test {
std.testing.refAllDecls(@This());
}
// We could calculate it like PJRT does, but it turns out that some of those
// were wrong in PJRT itself [1], which gets propagated to binary plugins. In
// order to mirror that, we just the value as computed by PJRT itself, through
// comptime reflection. We could make the argument to remove that one day since
// [1] has been fixed. The problem is that this problem could happen again in
// as the way PJRT does it is not very robust.
//
// 1. https://github.com/openxla/xla/issues/10032
fn pjrtStructSize(comptime T: type) usize {
// unsafe on purpose, we want this to fail if that ever changes
const typedef_name = comptime blk: {
const needle = ".struct_";
const idx = std.mem.indexOf(u8, @typeName(T), needle).?;
break :blk @typeName(T)[idx + needle.len ..];
};
return @field(c, typedef_name ++ "_STRUCT_SIZE");
}
inline fn pjrtStruct(v: anytype) @TypeOf(v) {
var ret = v;
ret.struct_size = pjrtStructSize(@TypeOf(v));
return ret;
}
pub const ApiError = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
fn InnerMixin(comptime innerT: type) type {
return struct {
fn inner(self: anytype) *innerT {
return @ptrCast(@constCast(@alignCast(self)));
}
};
}
pub const Api = struct {
pub const Version = struct {
major: i64,
minor: i64,
};
const Funcs = std.meta.FieldEnum(c.PJRT_Api);
inner: c.PJRT_Api,
pub fn loadFrom(library: []const u8) !*const Api {
var lib: std.DynLib = switch (builtin.os.tag) {
.linux => blk: {
const library_c = try std.posix.toPosixPath(library);
break :blk .{
.inner = .{
.handle = c.dlopen(&library_c, c.RTLD_LAZY | c.RTLD_LOCAL | c.RTLD_NODELETE) orelse {
return error.FileNotFound;
},
},
};
},
else => try std.DynLib.open(library),
};
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse {
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});
};
const api = DynGetPjrtApi();
log.info("Loaded library: {s}", .{library});
_ = api.call(.PJRT_Plugin_Initialize, .{}) catch unreachable;
return api;
}
fn CallFnArgType(comptime func: Funcs) type {
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func));
const fn_ptr = @typeInfo(fti.Optional.child);
const fn_type_info = @typeInfo(fn_ptr.Pointer.child);
const arg_array_type_info = @typeInfo(fn_type_info.Fn.params[0].type.?);
return arg_array_type_info.Pointer.child;
}
inline fn call(self: *const Api, comptime method: Funcs, arg: CallFnArgType(method)) ApiError!@TypeOf(arg) {
var ret = pjrtStruct(arg);
const fn_ptr = @field(&self.inner, @tagName(method)).?;
const result = fn_ptr(&ret);
if (@TypeOf(result) == void) {
return ret;
}
if (result) |pjrt_c_error| {
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
log.err("[{s}] {s}", .{ @tagName(method), pjrt_error.getMessage(self) });
return pjrt_error.getCode(self).toApiError();
}
return ret;
}
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start));
while (cur != null) : (cur = cur.*.next) {
if (cur.*.type == ext_id) {
return @alignCast(@ptrCast(cur));
}
}
return null;
}
pub inline fn version(self: *const Api) Version {
return .{
.major = @intCast(self.inner.pjrt_api_version.major_version),
.minor = @intCast(self.inner.pjrt_api_version.minor_version),
};
}
};
pub const ErrorCode = enum(c.PJRT_Error_Code) {
cancelled = c.PJRT_Error_Code_CANCELLED,
unknown = c.PJRT_Error_Code_UNKNOWN,
invalid_argument = c.PJRT_Error_Code_INVALID_ARGUMENT,
deadline_exceeded = c.PJRT_Error_Code_DEADLINE_EXCEEDED,
not_found = c.PJRT_Error_Code_NOT_FOUND,
already_exists = c.PJRT_Error_Code_ALREADY_EXISTS,
permission_denied = c.PJRT_Error_Code_PERMISSION_DENIED,
resource_exhausted = c.PJRT_Error_Code_RESOURCE_EXHAUSTED,
failed_precondition = c.PJRT_Error_Code_FAILED_PRECONDITION,
aborted = c.PJRT_Error_Code_ABORTED,
out_of_range = c.PJRT_Error_Code_OUT_OF_RANGE,
unimplemented = c.PJRT_Error_Code_UNIMPLEMENTED,
internal = c.PJRT_Error_Code_INTERNAL,
unavailable = c.PJRT_Error_Code_UNAVAILABLE,
data_loss = c.PJRT_Error_Code_DATA_LOSS,
unauthenticated = c.PJRT_Error_Code_UNAUTHENTICATED,
pub fn toApiError(code: ErrorCode) ApiError {
return switch (code) {
.cancelled => error.Cancelled,
.unknown => error.Unknown,
.invalid_argument => error.InvalidArgument,
.deadline_exceeded => error.DeadlineExceeded,
.not_found => error.NotFound,
.already_exists => error.AlreadyExists,
.permission_denied => error.PermissionDenied,
.resource_exhausted => error.ResourceExhausted,
.failed_precondition => error.FailedPrecondition,
.aborted => error.Aborted,
.out_of_range => error.OutOfRange,
.unimplemented => error.Unimplemented,
.internal => error.Internal,
.unavailable => error.Unavailable,
.data_loss => error.DataLoss,
.unauthenticated => error.Unauthenticated,
};
}
};
pub const Error = opaque {
pub fn deinit(self: *Error, api: *const Api) void {
_ = api.call(.PJRT_Error_Destroy, .{
.@"error" = @ptrCast(self),
}) catch unreachable;
}
pub fn getCode(self: *Error, api: *const Api) ErrorCode {
const ret = api.call(.PJRT_Error_GetCode, .{
.@"error" = @ptrCast(self),
}) catch unreachable;
return @enumFromInt(ret.code);
}
pub fn getMessage(self: *Error, api: *const Api) []const u8 {
const ret = api.call(.PJRT_Error_Message, .{
.@"error" = @ptrCast(self),
}) catch unreachable;
return ret.message[0..ret.message_size];
}
};
pub const ClientInitError = error{LoadingFailed} || ApiError;
pub const Client = opaque {
const inner = InnerMixin(c.PJRT_Client).inner;
pub const ProgramFormat = enum {
hlo,
mlir,
};
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
// log.info("Loaded PJRT runtime plugin: {s}", .{api.Platform});
const ret = try api.call(.PJRT_Client_Create, .{
.create_options = @ptrCast(create_options.ptr),
.num_options = create_options.len,
.kv_get_callback = null,
.kv_put_callback = null,
.kv_put_user_arg = null,
.kv_get_user_arg = null,
});
return @ptrCast(ret.client.?);
}
pub fn deinit(self: *Client, api: *const Api) void {
_ = api.call(.PJRT_Client_Destroy, .{
.client = self.inner(),
}) catch {};
}
pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 {
const ret = api.call(.PJRT_Client_PlatformName, .{
.client = self.inner(),
}) catch unreachable;
return ret.platform_name[0..ret.platform_name_size];
}
pub fn getDevices(self: *const Client, api: *const Api) []const *Device {
const ret = api.call(.PJRT_Client_Devices, .{
.client = self.inner(),
}) catch unreachable;
return @ptrCast(ret.devices[0..ret.num_devices]);
}
pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *Device {
const ret = api.call(.PJRT_Client_AddressableDevices, .{
.client = self.inner(),
}) catch unreachable;
return @ptrCast(ret.addressable_devices[0..ret.num_addressable_devices]);
}
pub const CompileArgs = struct {
bytecode: []const u8,
bytecode_format: ProgramFormat,
compile_options_pb: []const u8,
};
pub fn compile(self: *const Client, api: *const Api, args: CompileArgs) ApiError!*LoadedExecutable {
const bytecode_format_ = @tagName(args.bytecode_format);
const ret = try api.call(.PJRT_Client_Compile, .{
.program = &pjrtStruct(c.PJRT_Program{
.code = @ptrCast(@constCast(args.bytecode.ptr)),
.code_size = args.bytecode.len,
.format = @ptrCast(@constCast(bytecode_format_.ptr)),
.format_size = bytecode_format_.len,
}),
.compile_options = @ptrCast(@constCast(args.compile_options_pb.ptr)),
.compile_options_size = args.compile_options_pb.len,
.client = self.inner(),
});
return @ptrCast(ret.executable.?);
}
pub const BufferFromHostBufferArgs = struct {
data: []const u8,
buffer_type: BufferType,
dims: []const i64,
byte_strides: ?[]const i64,
device: *const Device,
host_buffer_semantics: HostBufferSemantics,
};
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } {
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
.client = self.inner(),
.data = @ptrCast(@constCast(args.data.ptr)),
.type = @intFromEnum(args.buffer_type),
.dims = @ptrCast(@constCast(args.dims.ptr)),
.num_dims = args.dims.len,
.byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null,
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
.device = @ptrCast(@constCast(args.device)),
.memory = null, // TODO
.device_layout = null, // TODO
.done_with_host_buffer = null,
.buffer = null,
});
return .{
@ptrCast(ret.buffer.?),
@ptrCast(ret.done_with_host_buffer.?),
};
}
/// Returns the Profiler for this API.
/// Not all platform have a profiling api, for those the profiler object will do nothing.
/// Platforms with known profiler extensions: cuda, xpu
pub fn getProfiler(self: *const Client, api: *const Api, options: Profiler.Options) Profiler {
if (api.version().minor >= 45) {
if (api.lookupExtension(c.PJRT_Profiler_Extension, c.PJRT_Extension_Type_Profiler)) |ext| {
return Profiler.init(ext.profiler_api.*, options);
}
}
log.warn("No profiler found for platform: {}", .{self});
return Profiler.init(null, options);
}
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
// return .{ .custom_call_register = ext.custom_call.? };
// }
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
// return null;
// }
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{
.client = self.inner(),
.serialized_executable = bytes.ptr,
.serialized_executable_size = bytes.len,
});
return @ptrCast(ret.loaded_executable.?);
}
pub const CreateViewOfDeviceBufferArgs = struct {
data: []const u8,
dims: []const i64,
element_type: BufferType,
layout: MemoryLayout,
device: *const Device,
on_delete_callback: ?*const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = null,
on_delete_callback_arg: ?*anyopaque = null,
stream: ?isize = null,
};
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
const layout = args.layout.toCStruct();
const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{
.client = self.inner(),
.device_buffer_ptr = @ptrCast(@constCast(args.data.ptr)),
.dims = args.dims.ptr,
.num_dims = args.dims.len,
.element_type = @intFromEnum(args.element_type),
.layout = @ptrCast(@constCast(&layout)),
.device = @ptrCast(@constCast(args.device)),
.on_delete_callback = args.on_delete_callback,
.on_delete_callback_arg = args.on_delete_callback_arg,
.stream = if (args.stream) |stream| stream else 0,
});
return @ptrCast(ret.buffer.?);
}
};
// // pub const CustomCallSignature = *const fn (*anyopaque, **anyopaque, [*c]const u8, usize) callconv(.C) void;
// // pub const GpuCustomCallRegistry = struct {
// // custom_call_register: *const c.PJRT_Gpu_Register_Custom_Call,
// // pub fn registerCustomCall(self: GpuCustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: CustomCallSignature) ApiError!void {
// // var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
// // .function_name = name.ptr,
// // .function_name_size = name.len,
// // .api_version = @intCast(api_version),
// // .custom_call_function = @ptrCast(@constCast(func)),
// // });
// // const result = self.custom_call_register(&ret);
// // if (result) |pjrt_c_error| {
// // const pjrt_error = .{ .inner = pjrt_c_error };
// // log.err("{s}", .{pjrt_error.getMessage(api)});
// // return pjrt_error.getCode().toApiError();
// // }
// // }
// // };
// // const OldPjrtExtension = extern struct {
// // type: c.PJRT_Extension_Type,
// // next: [*]OldPjrtExtension,
// // };
pub const Device = opaque {
const inner = InnerMixin(c.PJRT_Device).inner;
pub fn getDescription(self: *const Device, api: *const Api) *const DeviceDescription {
const ret = api.call(.PJRT_Device_GetDescription, .{
.device = self.inner(),
}) catch unreachable;
return @ptrCast(ret.device_description.?);
}
pub fn isAddressable(self: *const Device, api: *const Api) bool {
const ret = api.call(.PJRT_Device_IsAddressable, .{
.device = self.inner(),
}) catch unreachable;
return ret.is_addressable;
}
pub fn getLocalHardwareId(self: *const Device, api: *const Api) usize {
const ret = api.call(.PJRT_Device_LocalHardwareId, .{
.device = self.inner(),
}) catch unreachable;
return @intCast(ret.local_hardware_id);
}
};
pub const DeviceDescription = opaque {
const inner = InnerMixin(c.PJRT_DeviceDescription).inner;
pub fn getId(self: *const DeviceDescription, api: *const Api) usize {
const ret = api.call(.PJRT_DeviceDescription_Id, .{
.device_description = self.inner(),
}) catch unreachable;
return @intCast(ret.id);
}
pub fn getProcessIndex(self: *const DeviceDescription, api: *const Api) usize {
const ret = api.call(.PJRT_DeviceDescription_ProcessIndex, .{
.device_description = self.inner(),
}) catch unreachable;
return @intCast(ret.process_index);
}
pub fn getKind(self: *const DeviceDescription, api: *const Api) []const u8 {
const ret = api.call(.PJRT_DeviceDescription_Kind, .{
.device_description = self.inner(),
}) catch unreachable;
return ret.device_kind[0..ret.device_kind_size];
}
pub fn debugString(self: *const DeviceDescription, api: *const Api) []const u8 {
const ret = api.call(.PJRT_DeviceDescription_DebugString, .{
.device_description = self.inner(),
}) catch unreachable;
return ret.debug_string[0..ret.debug_string_size];
}
pub fn toString(self: *const DeviceDescription, api: *const Api) []const u8 {
const ret = api.call(.PJRT_DeviceDescription_ToString, .{
.device_description = self.inner(),
}) catch unreachable;
return ret.to_string[0..ret.to_string_size];
}
};
pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError;
pub const SerializeResult = struct {
bytes: []const u8,
handle: *anyopaque,
deleter: *const fn (?*anyopaque) callconv(.C) void,
pub fn deinit(self: *SerializeResult) void {
self.deleter(self.handle);
self.bytes = &.{};
self.* = undefined;
}
};
pub const Executable = opaque {
const inner = InnerMixin(c.PJRT_Executable).inner;
pub fn deinit(self: *Executable, api: *const Api) void {
_ = api.call(.PJRT_Executable_Destroy, .{
.executable = self.inner(),
}) catch unreachable;
}
pub fn getCostAnalysis(self: *const Executable, api: *const Api) GetCostAnalysisError![]*const NamedValue {
const ret = try api.call(.PJRT_Executable_GetCostAnalysis, .{
.executable = self.inner(),
});
const values: [*]*const NamedValue = @ptrCast(ret.properties);
return values[0..ret.num_properties];
}
pub fn serialize(self: *const Executable, api: *const Api) ApiError!SerializeResult {
const ret = try api.call(.PJRT_Executable_Serialize, .{
.executable = self.inner(),
});
return .{
.bytes = ret.serialized_bytes[0..ret.serialized_bytes_size],
.handle = ret.serialized_executable.?,
.deleter = @ptrCast(ret.serialized_executable_deleter.?),
};
}
};
pub const LoadedExecutable = opaque {
const inner = InnerMixin(c.PJRT_LoadedExecutable).inner;
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
_ = api.call(.PJRT_LoadedExecutable_Destroy, .{
.executable = self.inner(),
}) catch {};
self.* = undefined;
}
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
_ = api.call(.PJRT_LoadedExecutable_Delete, .{
.executable = self.inner(),
}) catch unreachable;
}
pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool {
const ret = api.call(.PJRT_LoadedExecutable_IsDeleted, .{
.executable = self.inner(),
}) catch unreachable;
return ret.is_deleted;
}
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []Device {
const ret = api.call(.PJRT_LoadedExecutable_AddressableDevices, .{
.executable = self.inner(),
}) catch unreachable;
return @ptrCast(ret.addressable_devices);
}
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
num_args: usize,
arguments: []const [*]const *const Buffer,
results: []const [*]*Buffer,
events: []*Event,
non_donatable_input_indices: []const i64 = &.{},
}) ApiError!void {
var options = pjrtStruct(c.PJRT_ExecuteOptions{
.send_callbacks = null,
.recv_callbacks = null,
.num_send_ops = 0,
.num_recv_ops = 0,
.launch_id = 0,
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
.num_non_donatable_input_indices = args.non_donatable_input_indices.len,
});
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{
.executable = self.inner(),
.options = @ptrCast(&options),
.argument_lists = @ptrCast(args.arguments.ptr),
.num_devices = @intCast(args.arguments.len),
.num_args = args.num_args,
.output_lists = @ptrCast(args.results.ptr),
.device_complete_events = @ptrCast(args.events.ptr),
.execute_device = null,
});
}
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
.loaded_executable = self.inner(),
});
return @ptrCast(ret.executable.?);
}
};
pub const BufferType = enum(c.PJRT_Buffer_Type) {
INVALID = c.PJRT_Buffer_Type_INVALID,
PRED = c.PJRT_Buffer_Type_PRED,
S8 = c.PJRT_Buffer_Type_S8,
S16 = c.PJRT_Buffer_Type_S16,
S32 = c.PJRT_Buffer_Type_S32,
S64 = c.PJRT_Buffer_Type_S64,
U8 = c.PJRT_Buffer_Type_U8,
U16 = c.PJRT_Buffer_Type_U16,
U32 = c.PJRT_Buffer_Type_U32,
U64 = c.PJRT_Buffer_Type_U64,
F16 = c.PJRT_Buffer_Type_F16,
F32 = c.PJRT_Buffer_Type_F32,
F64 = c.PJRT_Buffer_Type_F64,
BF16 = c.PJRT_Buffer_Type_BF16,
C64 = c.PJRT_Buffer_Type_C64,
C128 = c.PJRT_Buffer_Type_C128,
F8E5M2 = c.PJRT_Buffer_Type_F8E5M2,
F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN,
F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ,
F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ,
S4 = c.PJRT_Buffer_Type_S4,
U4 = c.PJRT_Buffer_Type_U4,
};
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
Tiled = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
Strides = c.PJRT_Buffer_MemoryLayout_Type_Strides,
};
pub const MemoryLayout = union(MemoryLayoutType) {
pub const Type = MemoryLayoutType;
pub const Tiled = struct {
minor_to_major: []const i64,
tile_dims: []const i64,
tile_dims_sizes: []const usize,
};
pub const Strides = struct {
byte_strides: []const i64,
};
Tiled: Tiled,
Strides: Strides,
fn toCStruct(self: MemoryLayout) c.PJRT_Buffer_MemoryLayout {
return pjrtStruct(switch (self) {
.Tiled => |v| c.PJRT_Buffer_MemoryLayout{
.type = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
.unnamed_0 = .{
.tiled = c.PJRT_Buffer_MemoryLayout_Tiled{
.minor_to_major = v.minor_to_major.ptr,
.minor_to_major_size = v.minor_to_major.len,
.tile_dims = v.tile_dims.ptr,
.tile_dim_sizes = v.tile_dims_sizes.ptr,
.num_tiles = v.tile_dims_sizes.len,
},
},
},
.Strides => |v| c.PJRT_Buffer_MemoryLayout{
.type = c.PJRT_Buffer_MemoryLayout_Type_Strides,
.unnamed_0 = .{
.strides = c.PJRT_Buffer_MemoryLayout_Strides{
.byte_strides = v.byte_strides.ptr,
.num_byte_strides = v.byte_strides.len,
},
},
},
});
}
};
pub const HostBufferSemantics = enum(c.PJRT_HostBufferSemantics) {
ImmutableOnlyDuringCall = c.PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
ImmutableUntilTransferCompletes = c.PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
ImmutableZeroCopy = c.PJRT_HostBufferSemantics_kImmutableZeroCopy,
MutableZeroCopy = c.PJRT_HostBufferSemantics_kMutableZeroCopy,
};
pub const Buffer = opaque {
const inner = InnerMixin(c.PJRT_Buffer).inner;
pub fn deinit(self: *Buffer, api: *const Api) void {
_ = api.call(.PJRT_Buffer_Destroy, .{
.buffer = self.inner(),
}) catch unreachable;
}
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
const ret = try api.call(.PJRT_Buffer_Device, .{
.buffer = self.inner(),
});
return @ptrCast(ret.device.?);
}
pub fn delete(self: *Buffer, api: *const Api) void {
_ = api.call(.PJRT_Buffer_Delete, .{
.buffer = self.inner(),
}) catch unreachable;
}
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
const ret = api.call(.PJRT_Buffer_IsDeleted, .{
.buffer = self.inner(),
}) catch unreachable;
return ret.is_deleted;
}
pub fn isOnCpu(self: *const Buffer, api: *const Api) bool {
const ret = api.call(.PJRT_Buffer_IsOnCpu, .{
.buffer = self.inner(),
}) catch unreachable;
return ret.is_on_cpu;
}
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event {
const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{
.src = self.inner(),
.dst = @ptrCast(dst.ptr),
.dst_size = dst.len,
});
return @ptrCast(ret.event.?);
}
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
const ret = api.call(.PJRT_Buffer_ElementType, .{
.buffer = self.inner(),
}) catch unreachable;
return @enumFromInt(ret.type);
}
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
const ret = api.call(.PJRT_Buffer_Dimensions, .{
.buffer = self.inner(),
}) catch unreachable;
return ret.dims[0..ret.num_dims];
}
pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 {
const ret = try api.call(.PJRT_Buffer_UnpaddedDimensions, .{
.buffer = self.inner(),
});
return ret.dims[0..ret.num_dims];
}
pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize {
const ret = try api.call(.PJRT_Buffer_OnDeviceSizeInBytes, .{
.buffer = self.inner(),
});
return ret.on_device_size_in_bytes;
}
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer {
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
.buffer = self.inner(),
.dst_device = device.inner,
});
return @ptrCast(ret.dst_buffer.?);
}
pub fn getReadyEvent(self: *const Buffer, api: *const Api) *Event {
const ret = api.call(.PJRT_Buffer_ReadyEvent, .{
.buffer = self.inner(),
}) catch unreachable;
return @ptrCast(ret.event.?);
}
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
const ret = try api.call(.PJRT_Buffer_OpaqueDeviceMemoryDataPointer, .{
.buffer = self.inner(),
});
return ret.device_memory_ptr.?;
}
};
pub const Event = opaque {
const inner = InnerMixin(c.PJRT_Event).inner;
pub fn deinit(self: *Event, api: *const Api) void {
_ = api.call(.PJRT_Event_Destroy, .{
.event = self.inner(),
}) catch unreachable;
}
pub fn isReady(self: *const Event, api: *const Api) bool {
const ret = api.call(.PJRT_Event_IsReady, .{
.event = self.inner(),
}) catch unreachable;
return ret.is_ready;
}
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
const ret = try api.call(.PJRT_Event_Error, .{
.event = self.inner(),
});
return @ptrCast(ret);
}
pub fn await_(self: *const Event, api: *const Api) ApiError!void {
_ = try api.call(.PJRT_Event_Await, .{
.event = self.inner(),
});
}
pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.C) void, user_arg: ?*anyopaque) ApiError!void {
_ = try api.call(.PJRT_Event_OnReady, .{
.event = self.inner(),
.callback = @ptrCast(func),
.user_arg = user_arg,
});
}
};
pub const NamedValue = extern struct {
comptime {
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
}
inner: c.PJRT_NamedValue,
pub const Kind = enum(c.PJRT_NamedValue_Type) {
string = c.PJRT_NamedValue_kString,
int64 = c.PJRT_NamedValue_kInt64,
int64list = c.PJRT_NamedValue_kInt64List,
float = c.PJRT_NamedValue_kFloat,
bool = c.PJRT_NamedValue_kBool,
};
pub fn kind(self: NamedValue) Kind {
return @enumFromInt(self.inner.type);
}
pub fn name(self: NamedValue) []const u8 {
return self.inner.name[0..self.inner.name_size];
}
pub fn from(name_: []const u8, value: anytype) NamedValue {
return switch (@TypeOf(value)) {
[]u8, []const u8 => fromString(name_, value),
i64 => fromInt64(name_, value),
[]i64, []const i64 => fromInt64List(name_, value),
f32 => fromFloat(name_, value),
bool => fromBool(name_, value),
else => unreachable,
};
}
pub fn fromString(name_: []const u8, value: []const u8) NamedValue {
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
.name = @ptrCast(@constCast(name_.ptr)),
.name_size = name_.len,
.type = c.PJRT_NamedValue_kString,
.unnamed_0 = .{ .string_value = @ptrCast(@constCast(value.ptr)) },
.value_size = value.len,
}) };
}
pub fn fromInt64(name_: []const u8, value: i64) NamedValue {
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
.name = @ptrCast(@constCast(name_.ptr)),
.name_size = name_.len,
.type = c.PJRT_NamedValue_kInt64,
.unnamed_0 = .{ .int64_value = value },
.value_size = 1,
}) };
}
pub fn fromInt64List(name_: []const u8, value: []const i64) NamedValue {
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
.name = @ptrCast(@constCast(name_.ptr)),
.name_size = name_.len,
.type = c.PJRT_NamedValue_kInt64List,
.unnamed_0 = .{ .int64_array_value = @ptrCast(@constCast(value.ptr)) },
.value_size = value.len,
}) };
}
pub fn fromFloat(name_: []const u8, value: f32) NamedValue {
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
.name = @ptrCast(@constCast(name_.ptr)),
.name_size = name_.len,
.type = c.PJRT_NamedValue_kFloat,
.unnamed_0 = .{ .float_value = value },
.value_size = 1,
}) };
}
pub fn fromBool(name_: []const u8, value: bool) NamedValue {
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
.name = @ptrCast(@constCast(name_.ptr)),
.name_size = name_.len,
.type = c.PJRT_NamedValue_kBool,
.unnamed_0 = .{ .bool_value = value },
.value_size = 1,
}) };
}
pub fn format(
self: NamedValue,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] });
const u = self.inner.unnamed_0;
switch (self.kind()) {
.string => try writer.print(" .string = {s} ", .{u.string_value[0..self.inner.value_size]}),
.int64 => try writer.print(" .int64 = {d} ", .{u.int64_value}),
.int64list => try writer.print(" .int64list = {d} ", .{u.int64_array_value[0..self.inner.value_size]}),
.float => try writer.print(" .float = {d} ", .{u.float_value}),
.bool => try writer.print(" .bool = {} ", .{u.bool_value}),
}
try writer.writeAll("}");
}
};

205
pjrt/profiler.zig Normal file
View File

@ -0,0 +1,205 @@
const std = @import("std");
const c = @import("c");
const tsl_proto = @import("//tsl:profiler_options_proto");
const log = std.log.scoped(.zml_profiler);
/// Pjrt Profiler extension
pub const Profiler = struct {
api: ?c.PLUGIN_Profiler_Api,
inner: *c.PLUGIN_Profiler,
last_error: ?*Error = null,
status: Status = .ready,
pub const Status = enum { ready, started, stopped, done };
pub const Error = c.PLUGIN_Profiler_Error;
pub const Options = tsl_proto.ProfileOptions;
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: Options) Profiler {
if (api == null) {
return .{ .api = null, .inner = undefined };
}
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buffer);
const byte_options = options.encode(fba.allocator()) catch unreachable;
var res: Profiler = .{ .api = api, .inner = undefined };
var args: c.PLUGIN_Profiler_Create_Args = .{
.options = byte_options.ptr,
.options_size = byte_options.len,
.profiler = undefined, // out
};
res.check(api.?.create.?(&args)) catch unreachable;
res.inner = args.profiler.?;
return res;
}
fn transition(self: *Profiler, fn_name: []const u8, expected: Status, next: Status) void {
if (self.status == expected) {
self.status = next;
return;
}
std.debug.panic("Profiler can't `{s}()`. Current status: {}, expected: {}", .{ fn_name, self.status, expected });
}
pub fn start(self: *Profiler) void {
self.transition("start", .ready, .started);
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Start_Args = .{ .profiler = self.inner };
self.check(self.api.?.start.?(&args)) catch unreachable;
}
pub fn stop(self: *Profiler) void {
self.transition("stop", .started, .stopped);
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Stop_Args = .{ .profiler = self.inner };
self.check(self.api.?.stop.?(&args)) catch unreachable;
}
pub fn collectData(self: *Profiler, allocator: std.mem.Allocator) !ProfilingData {
self.transition("collect_data", .stopped, .done);
if (self.api == null) return .{ .external = &.{} };
var args: c.PLUGIN_Profiler_CollectData_Args = .{
.struct_size = c.PLUGIN_Profiler_CollectData_Args_STRUCT_SIZE,
.profiler = self.inner,
.buffer = null,
.buffer_size_in_bytes = 0,
};
try self.check(self.api.?.collect_data.?(&args));
std.debug.assert(args.buffer_size_in_bytes > 0);
const buffer: ProfilingData = if (args.buffer == null) blk: {
std.log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
// The plugin want us to allocate memory for it:
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
args.buffer = buffer.ptr;
try self.check(self.api.?.collect_data.?(&args));
break :blk .{ .owned = buffer };
} else blk: {
std.log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
// Drop sentinel. The profiler plugin returns a null terminated string.
// But this is creating issues if we save the sentinel on disk,
// because it will trip up protobuf readers.
var data = args.buffer[0..args.buffer_size_in_bytes];
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
break :blk .{ .external = data };
};
// printDataAsXSpace(allocator, buffer.items());
return buffer;
}
pub fn dumpDataTo(
self: *Profiler,
allocator: std.mem.Allocator,
dir: std.fs.Dir,
file_name: []const u8,
) !void {
const profile_data = try self.collectData(allocator);
defer profile_data.free(allocator);
if (profile_data.items().len == 0) return;
const file = try dir.createFile(file_name, .{ .truncate = true });
defer file.close();
log.info("Writing profiling data to {s} ({} bytes)", .{ file_name, profile_data.items().len });
return try file.writeAll(profile_data.items());
}
fn check(self: *Profiler, c_error: ?*Error) !void {
if (c_error) |err| {
self.last_error = err;
return error.PjrtProfilerError;
}
}
pub fn deinit(self: Profiler) void {
switch (self.status) {
.started => log.warn("Profiler was never stopped", .{}),
.stopped => log.warn("Profiler data was never collected", .{}),
else => {},
}
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Destroy_Args = .{ .profiler = self.inner };
_ = self.api.?.destroy.?(&args);
}
};
// If this was working it would be a good alternative to xspace_to_json.cc
// const xspace = @import("xspace.pb.zig");
// pub fn printDataAsXSpace(allocator: std.mem.Allocator, data: []const u8) void {
// var arena = std.heap.ArenaAllocator.init(allocator);
// defer arena.deinit();
//
// const space = xspace.XSpace.decode(data, arena.allocator()) catch |e| {
// std.log.err("Couldn't load profiling data: {}", .{e});
// return;
// };
//
// for (space.errors.items) |err| {
// std.log.err("{s}", .{err.getSlice()});
// }
// for (space.warnings.items) |warning| {
// std.log.warn("{s}", .{warning.getSlice()});
// }
// for (space.hostnames.items) |host| {
// std.log.info("Profiled host {s}", .{host.getSlice()});
// }
// for (space.planes.items) |plane| {
// var event_metadata = std.hash_map.AutoHashMap(i64, xspace.XEventMetadata).init(arena.allocator());
// event_metadata.ensureTotalCapacity(@intCast(plane.event_metadata.items.len)) catch return;
// defer event_metadata.deinit();
// for (plane.event_metadata.items) |event_meta_entry| {
// if (event_meta_entry.value) |event_meta| {
// event_metadata.putAssumeCapacity(event_meta.id, event_meta);
// }
// }
// std.log.info("Profiled device {s}", .{plane.name.getSlice()});
// for (plane.lines.items) |line| {
// std.log.info(
// "{d} -> {d} xline {s} ({d} events)",
// .{ line.timestamp_ns, line.duration_ps, line.name.getSlice(), line.events.items.len },
// );
// const ps_per_ns: i64 = 1000;
// var duration_ns: i64 = 0;
// var last_metadata_id: i64 = 0;
// for (line.events.items) |event| {
// if (event.metadata_id != last_metadata_id and duration_ns != 0) {
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
// const meta = event_metadata.get(event.metadata_id).?;
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
// last_metadata_id = event.metadata_id;
// duration_ns = 0;
// }
// duration_ns += @divFloor(event.duration_ps, ps_per_ns);
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
// const meta = event_metadata.get(event.metadata_id).?;
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
// }
// }
// }
// }
const ProfilingData = union(enum) {
owned: []const u8,
external: []const u8,
pub fn items(self: ProfilingData) []const u8 {
return switch (self) {
inline else => |x| x,
};
}
pub fn free(self: ProfilingData, allocator: std.mem.Allocator) void {
switch (self) {
.owned => |data| allocator.free(data),
.external => {},
}
}
};

21
platform_mappings Normal file
View File

@ -0,0 +1,21 @@
platforms:
@zml//platforms:linux_amd64
--cpu=k8
@zml//platforms:linux_arm64
--cpu=aarch64
@zml//platforms:macos_arm64
--cpu=darwin_arm64
--apple_platform_type=macos
flags:
--cpu=darwin_arm64
--apple_platform_type=macos
@zml//platforms:macos_arm64
--cpu=k8
@zml//platforms:linux_amd64
--cpu=aarch64
@zml//platforms:linux_arm64

26
platforms/BUILD.bazel Normal file
View File

@ -0,0 +1,26 @@
platform(
name = "linux_amd64",
constraint_values = [
"@platforms//cpu:x86_64",
"@platforms//os:linux",
],
visibility = ["//visibility:public"],
)
platform(
name = "linux_arm64",
constraint_values = [
"@platforms//cpu:aarch64",
"@platforms//os:linux",
],
visibility = ["//visibility:public"],
)
platform(
name = "macos_arm64",
constraint_values = [
"@platforms//cpu:aarch64",
"@platforms//os:macos",
],
visibility = ["//visibility:public"],
)

42
runtimes/BUILD.bazel Normal file
View File

@ -0,0 +1,42 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
RUNTIMES = {
"cpu": True,
"cuda": False,
"rocm": False,
"tpu": False,
}
[
bool_flag(
name = runtime,
build_setting_default = default,
)
for runtime, default in RUNTIMES.items()
]
[
config_setting(
name = "_{}".format(runtime),
flag_values = {":{}".format(runtime): "True"},
)
for runtime in RUNTIMES.keys()
]
cc_library(
name = "runtimes",
visibility = ["//visibility:public"],
deps = select({
":_cpu": ["//runtimes/cpu"],
"//conditions:default": [],
}) + select({
":_cuda": ["//runtimes/cuda"],
"//conditions:default": [],
}) + select({
":_rocm": ["//runtimes/rocm"],
"//conditions:default": [],
}) + select({
":_tpu": ["//runtimes/tpu"],
"//conditions:default": [],
}),
)

8
runtimes/cpu/BUILD.bazel Normal file
View File

@ -0,0 +1,8 @@
alias(
name = "cpu",
actual = select({
"@platforms//os:macos": "@libpjrt_cpu_darwin_arm64//:libpjrt_cpu",
"@platforms//os:linux": "@libpjrt_cpu_linux_amd64//:libpjrt_cpu",
}),
visibility = ["//visibility:public"],
)

34
runtimes/cpu/cpu.bzl Normal file
View File

@ -0,0 +1,34 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
_BUILD = """\
cc_import(
name = "libpjrt_cpu",
shared_library = "libpjrt_cpu.{ext}",
visibility = ["//visibility:public"],
)
"""
def _cpu_pjrt_plugin_impl(mctx):
http_archive(
name = "libpjrt_cpu_linux_amd64",
build_file_content = _BUILD.format(ext = "so"),
sha256 = "14317143acd6a38656e97280e8010c0b8d8c0863dff2ae82834b6f2fe747427b",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_linux-amd64.tar.gz",
)
http_archive(
name = "libpjrt_cpu_darwin_arm64",
build_file_content = _BUILD.format(ext = "dylib"),
sha256 = "3a26e1372f68fc11028c4ec22a0c72693f08e7690ba8c5f28b17f5baa9c9dc77",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_darwin-arm64.tar.gz",
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
cpu_pjrt_plugin = module_extension(
implementation = _cpu_pjrt_plugin_impl,
)

View File

@ -0,0 +1,5 @@
alias(
name = "cuda",
actual = "@libpjrt_cuda",
visibility = ["//visibility:public"],
)

197
runtimes/cuda/cuda.bzl Normal file
View File

@ -0,0 +1,197 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ARCH = "linux-x86_64"
CUDA_VERSION = "12.6.1"
CUDNN_VERSION = "9.3.0"
_CC_IMPORT_TPL = """\
cc_import(
name = "{name}",
shared_library = "lib/{shared_library}",
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
"""
CUDA_PACKAGES = {
"cuda_cudart": _CC_IMPORT_TPL.format(name = "cudart", shared_library = "libcudart.so.12"),
"cuda_cupti": _CC_IMPORT_TPL.format(name = "cupti", shared_library = "libcupti.so.12"),
"libcufft": _CC_IMPORT_TPL.format(name = "cufft", shared_library = "libcufft.so.11"),
"libcusolver": _CC_IMPORT_TPL.format(name = "cusolver", shared_library = "libcusolver.so.11"),
"libcusparse": _CC_IMPORT_TPL.format(name = "cusparse", shared_library = "libcusparse.so.12"),
"libnvjitlink": _CC_IMPORT_TPL.format(name = "nvjitlink", shared_library = "libnvJitLink.so.12"),
"cuda_nvcc": """\
filegroup(
name = "ptxas",
srcs = ["bin/ptxas"],
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
filegroup(
name = "libdevice",
srcs = ["nvvm/libdevice/libdevice.10.bc"],
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
cc_import(
name = "nvvm",
shared_library = "nvvm/lib64/libnvvm.so.4",
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
""",
"cuda_nvrtc": """\
cc_import(
name = "nvrtc",
shared_library = "lib/libnvrtc.so.12",
visibility = ["@libpjrt_cuda//:__subpackages__"],
deps = [":nvrtc_builtins"],
)
cc_import(
name = "nvrtc_builtins",
shared_library = "lib/libnvrtc-builtins.so.12.6",
)
""",
"libcublas": """\
cc_import(
name = "cublasLt",
shared_library = "lib/libcublasLt.so.12",
)
cc_import(
name = "cublas",
shared_library = "lib/libcublas.so.12",
visibility = ["@libpjrt_cuda//:__subpackages__"],
deps = [":cublasLt"],
)
""",
}
CUDNN_PACKAGES = {
"cudnn": """\
cc_import(
name = "cudnn",
shared_library = "lib/libcudnn.so.9",
visibility = ["@libpjrt_cuda//:__subpackages__"],
deps = [
":cudnn_adv",
":cudnn_ops",
":cudnn_cnn",
":cudnn_graph",
":cudnn_engines_precompiled",
":cudnn_engines_runtime_compiled",
":cudnn_heuristic",
],
)
cc_import(
name = "cudnn_adv",
shared_library = "lib/libcudnn_adv.so.9",
)
cc_import(
name = "cudnn_ops",
shared_library = "lib/libcudnn_ops.so.9",
)
cc_import(
name = "cudnn_cnn",
shared_library = "lib/libcudnn_cnn.so.9",
deps = [":cudnn_ops"],
)
cc_import(
name = "cudnn_graph",
shared_library = "lib/libcudnn_graph.so.9",
)
cc_import(
name = "cudnn_engines_precompiled",
shared_library = "lib/libcudnn_engines_precompiled.so.9",
)
cc_import(
name = "cudnn_engines_runtime_compiled",
shared_library = "lib/libcudnn_engines_runtime_compiled.so.9",
)
cc_import(
name = "cudnn_heuristic",
shared_library = "lib/libcudnn_heuristic.so.9",
)
""",
}
def _cuda_impl(mctx):
CUDA_REDIST = json.decode(mctx.read(Label("@zml//runtimes/cuda:cuda.redistrib_{}.json".format(CUDA_VERSION))))
CUDNN_REDIST = json.decode(mctx.read(Label("@zml//runtimes/cuda:cudnn.redistrib_{}.json".format(CUDNN_VERSION))))
for pkg, build_file_content in CUDA_PACKAGES.items():
pkg_data = CUDA_REDIST[pkg]
arch_data = pkg_data.get(ARCH)
if not arch_data:
continue
http_archive(
name = pkg,
build_file_content = build_file_content,
url = "https://developer.download.nvidia.com/compute/cuda/redist/" + arch_data["relative_path"],
sha256 = arch_data["sha256"],
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
)
for pkg, build_file_content in CUDNN_PACKAGES.items():
pkg_data = CUDNN_REDIST[pkg]
arch_data = pkg_data.get(ARCH)
if not arch_data:
continue
arch_data = arch_data.get("cuda12", arch_data)
http_archive(
name = pkg,
build_file_content = build_file_content,
url = "https://developer.download.nvidia.com/compute/cudnn/redist/" + arch_data["relative_path"],
sha256 = arch_data["sha256"],
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
)
http_deb_archive(
name = "libnccl",
urls = ["https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libnccl2_2.22.3-1+cuda12.6_amd64.deb"],
sha256 = "2f64685bcd503150ab45d00503236a56da58a15eac5fd36508045a74f4e10678",
build_file_content = """\
cc_import(
name = "nccl",
shared_library = "usr/lib/x86_64-linux-gnu/libnccl.so.2",
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
""",
)
http_deb_archive(
name = "zlib",
urls = ["http://archive.ubuntu.com/ubuntu/pool/main/z/zlib/zlib1g_1.3.dfsg-3.1ubuntu2.1_amd64.deb"],
sha256 = "7074b6a2f6367a10d280c00a1cb02e74277709180bab4f2491a2f355ab2d6c20",
build_file_content = """\
cc_import(
name = "zlib",
shared_library = "usr/lib/x86_64-linux-gnu/libz.so.1",
visibility = ["@libpjrt_cuda//:__subpackages__"],
)
""",
)
http_archive(
name = "libpjrt_cuda",
build_file = "libpjrt_cuda.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cuda_linux-amd64.tar.gz",
sha256 = "b705f761e24d85ecd750df992a88715d9c461b7561c31722b9f878eeab32f39e",
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = ["libpjrt_cuda"],
root_module_direct_dev_deps = [],
)
cuda_packages = module_extension(
implementation = _cuda_impl,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,77 @@
{
"release_date": "2024-08-01",
"release_label": "9.3.0",
"release_product": "cudnn",
"cudnn": {
"name": "NVIDIA CUDA Deep Neural Network library",
"license": "cudnn",
"license_path": "cudnn/LICENSE.txt",
"version": "9.3.0.75",
"linux-x86_64": {
"cuda11": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.3.0.75_cuda11-archive.tar.xz",
"sha256": "069da084cd368f39fb830d4c4e931803064fb65e766f4cf7df2da4a346a2ba9f",
"md5": "4ab5cfb90cfe16c02cbdb88788165de5",
"size": "748177916"
},
"cuda12": {
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.3.0.75_cuda12-archive.tar.xz",
"sha256": "3d6ef10aa06dc9339a477e2b057e085ff8500bbdee79e42c7e13655c9eff2c26",
"md5": "2fa73268de8bbdab5560f4aa1a5a73ab",
"size": "756509380"
}
},
"cuda_variant": [
"11",
"12"
],
"linux-sbsa": {
"cuda11": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.3.0.75_cuda11-archive.tar.xz",
"sha256": "04b56fbf7bee15c24e339c2ba94d17aa88b9e334d0cd19e75853dc5452794bf7",
"md5": "eb3d809ff9c853721d342aa68564ed77",
"size": "746866932"
},
"cuda12": {
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.3.0.75_cuda12-archive.tar.xz",
"sha256": "1226dd9b989c898638552963d000a3cfbb8a0a0cb9baf8cb09e6abd77ed7d639",
"md5": "bdf5c7ba6ae34cc0dbcdfb98f3da746d",
"size": "754450192"
}
},
"windows-x86_64": {
"cuda11": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.3.0.75_cuda11-archive.zip",
"sha256": "0e6f9d39343b88208b1daea280d62d6f7a90355395999806de6624c8361c36fc",
"md5": "d7414555bd6c9e8a91139f00864fa2fa",
"size": "562755549"
},
"cuda12": {
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.3.0.75_cuda12-archive.zip",
"sha256": "864a85dc67c7f92b9a8639f323acb4af63ad65de2ca82dccdf2c0b6a701c27c0",
"md5": "5f82d9233dd22a6664abff13cd22a224",
"size": "566118754"
}
},
"linux-aarch64": {
"cuda12": {
"relative_path": "cudnn/linux-aarch64/cudnn-linux-aarch64-9.3.0.75_cuda12-archive.tar.xz",
"sha256": "1aae4bfced63f930b4677f9e928e197878dece2edd51ca8fa7ab8363d0e6ed60",
"md5": "e5a94b1b0bbd313c0c11235d969ca028",
"size": "796357916"
}
}
},
"cudnn_samples": {
"name": "NVIDIA cuDNN samples",
"license": "cudnn",
"license_path": "cudnn_samples/LICENSE.txt",
"version": "9.3.0.75",
"source": {
"relative_path": "cudnn_samples/source/cudnn_samples-source-9.3.0.75-archive.tar.xz",
"sha256": "5da5536bba749158b245ddf69b4a0e2f4bf122e9b3c0e7406d8c2ddd5dd97518",
"md5": "d855a9694631213fb6e78da90ae60fbe",
"size": "1666860"
}
}
}

View File

@ -0,0 +1,32 @@
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@zml//bazel:cc_import.bzl", "cc_import")
copy_to_directory(
name = "sandbox",
srcs = [
"@cuda_nvcc//:libdevice",
"@cuda_nvcc//:ptxas",
],
include_external_repositories = ["**"],
)
cc_import(
name = "libpjrt_cuda",
data = [":sandbox"],
shared_library = "libpjrt_cuda.so",
visibility = ["@zml//runtimes/cuda:__subpackages__"],
deps = [
"@cuda_cudart//:cudart",
"@cuda_cupti//:cupti",
"@cuda_nvcc//:nvvm",
"@cuda_nvrtc//:nvrtc",
"@cudnn//:cudnn",
"@libcublas//:cublas",
"@libcufft//:cufft",
"@libcusolver//:cusolver",
"@libcusparse//:cusparse",
"@libnccl//:nccl",
"@libnvjitlink//:nvjitlink",
"@zlib",
],
)

21
runtimes/rocm/BUILD.bazel Normal file
View File

@ -0,0 +1,21 @@
filegroup(
name = "zmlrocmhooks_srcs",
srcs = ["zmlrocmhooks.cc"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
alias(
name = "hipblaslt",
actual = "@libpjrt_rocm//:hipblaslt",
)
alias(
name = "gfx",
actual = "@libpjrt_rocm//:gfx",
)
alias(
name = "rocm",
actual = "@libpjrt_rocm",
visibility = ["//visibility:public"],
)

49
runtimes/rocm/gfx.bzl Normal file
View File

@ -0,0 +1,49 @@
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
_ALL_GFX = ["gfx900", "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1010", "gfx1012", "gfx1030", "gfx1100", "gfx1101", "gfx1102"]
def _compute_enabled_gfx(values):
ret = {}
for v in values:
if (v == "all"):
ret = {gfx: True for gfx in _ALL_GFX}
elif (v == "none"):
ret = {}
else:
ret[v] = True
return ret
def _gfx_from_file(file):
return file.basename[:-len(file.extension) - 1].rpartition("_")[-1].partition("-")[0]
def _is_file_enabled(file, enabled_gfx):
gfx = _gfx_from_file(file)
return gfx in enabled_gfx or gfx == "fallback"
def _bytecode_select_impl(ctx):
enabled_gfx = _compute_enabled_gfx(ctx.attr.enabled_gfx[BuildSettingInfo].value)
return [
DefaultInfo(
files = depset([
file
for file in ctx.files.bytecodes
if _is_file_enabled(file, enabled_gfx)
]),
),
]
bytecode_select = rule(
implementation = _bytecode_select_impl,
attrs = {
"bytecodes": attr.label_list(allow_files = True),
"enabled_gfx": attr.label(mandatory = True),
},
)
def if_gfx(gfx, value):
return select({
"@zml//runtimes/rocm:_{}".format(gfx): value,
"//conditions:default": [],
})

View File

@ -0,0 +1,88 @@
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
load("@zml//bazel:cc_import.bzl", "cc_import")
string_list_flag(
name = "gfx",
build_setting_default = ["all"],
visibility = [
"@rocblas//:__subpackages__",
"@hipblaslt-dev//:__subpackages__",
],
)
bool_flag(
name = "hipblaslt",
build_setting_default = True,
)
config_setting(
name = "_hipblaslt",
flag_values = {":hipblaslt": "True"},
)
copy_to_directory(
name = "sandbox",
srcs = [
"@rocm-device-libs//:runfiles",
"@rocm-llvm//:lld",
],
include_external_repositories = ["*"],
)
cc_library(
name = "zmlrocmhooks_lib",
data = ["@rocblas//:runfiles"],
srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"],
linkopts = [
"-lc",
"-ldl",
],
deps = ["@bazel_tools//tools/cpp/runfiles"],
)
cc_shared_library(
name = "zmlrocmhooks_so",
shared_lib_name = "libzmlrocmhooks.so.0",
deps = [":zmlrocmhooks_lib"],
)
cc_import(
name = "zmlrocmhooks",
shared_library = ":zmlrocmhooks_so",
)
cc_import(
name = "libpjrt_rocm",
data = [
":sandbox",
"@rocblas//:runfiles",
] + select({
":_hipblaslt": ["@hipblaslt-dev//:runfiles"],
"//conditions:default": [],
}),
shared_library = "libpjrt_rocm.so",
visibility = ["//visibility:public"],
deps = [
":zmlrocmhooks",
"@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip",
"@hipblaslt",
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
"@hsa-rocr//:hsa-runtime",
"@miopen-hip//:MIOpen",
"@rccl",
"@rocblas",
"@rocm-core",
"@rocm-smi-lib//:rocm_smi",
"@rocprofiler-register",
"@roctracer",
"@libelf",
"@libdrm",
"@libnuma",
"@libzstd",
"@libdrm-amdgpu",
"@libtinfo",
"@zlib1g",
],
)

File diff suppressed because it is too large Load Diff

242
runtimes/rocm/rocm.bzl Normal file
View File

@ -0,0 +1,242 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
ROCM_VERSION = "6.2"
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
STRIP_PREFIX = "opt/rocm-6.2.0"
def pkg_kwargs(pkg, packages):
return {
"name": pkg,
"urls": [BASE_URL + "/" + packages[pkg]["Filename"]],
"sha256": packages[pkg]["SHA256"],
"strip_prefix": STRIP_PREFIX,
}
def _ubuntu_package(path, deb_path, sha256, name, shared_library):
return {
"urls": ["http://archive.ubuntu.com/ubuntu/pool/main/{}".format(path)],
"sha256": sha256,
"build_file_content": """\
cc_import(
name = {name},
shared_library = "{deb_path}{shared_library}",
visibility = ["//visibility:public"],
)
""".format(name = repr(name), shared_library = shared_library, deb_path = deb_path),
}
_UBUNTU_PACKAGES = {
"libdrm": _ubuntu_package(
path = "libd/libdrm/libdrm2_2.4.107-8ubuntu1~20.04.2_amd64.deb",
deb_path = "usr/lib/x86_64-linux-gnu/",
sha256 = "9b01d73313841abe8e3f24c2715edced675fbe329bbd10be912a5b135cd51fb6",
name = "libdrm",
shared_library = "libdrm.so.2",
),
"libelf": _ubuntu_package(
path = "e/elfutils/libelf1_0.176-1.1build1_amd64.deb",
deb_path = "usr/lib/x86_64-linux-gnu/",
sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3",
name = "libelf",
shared_library = "libelf.so.1",
),
"libnuma": _ubuntu_package(
path = "n/numactl/libnuma1_2.0.12-1_amd64.deb",
deb_path = "usr/lib/x86_64-linux-gnu/",
sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637",
name = "libnuma",
shared_library = "libnuma.so.1",
),
"libzstd": _ubuntu_package(
path = "libz/libzstd/libzstd1_1.4.4+dfsg-3ubuntu0.1_amd64.deb",
deb_path = "usr/lib/x86_64-linux-gnu/",
sha256 = "7a4422dadb90510dc90765c308d65e61a3e244ceb3886394335e48cff7559e69",
name = "libzstd",
shared_library = "libzstd.so.1",
),
"libdrm-amdgpu": _ubuntu_package(
path = "libd/libdrm/libdrm-amdgpu1_2.4.107-8ubuntu1~20.04.2_amd64.deb",
deb_path = "usr/lib/x86_64-linux-gnu/",
sha256 = "0d95779b581f344e3d658e0f21f6e4b57da6eb3606c0bcb8cb874c12f5754bf2",
name = "libdrm-amdgpu",
shared_library = "libdrm_amdgpu.so.1",
),
"libtinfo": _ubuntu_package(
path = "n/ncurses/libtinfo6_6.2-0ubuntu2.1_amd64.deb",
deb_path = "lib/x86_64-linux-gnu/",
sha256 = "711a3a901c3a71561565558865699efa9c07a99fdc810ffe086a5636f89c6431",
name = "libtinfo",
shared_library = "libtinfo.so.6",
),
"zlib1g": _ubuntu_package(
path = "z/zlib/zlib1g_1.2.11.dfsg-2ubuntu1.5_amd64.deb",
deb_path = "lib/x86_64-linux-gnu/",
sha256 = "bf67018f5303466eb468680b637a5d3f3bb17b9d44decf3d82d40b35babcd3e0",
name = "zlib1g",
shared_library = "libz.so.1",
),
}
_CC_IMPORT_TPL = """\
cc_import(
name = "{name}",
shared_library = "lib/{shared_library}",
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
"""
_RUNFILES_TPL = """\
filegroup(
name = "{name}",
srcs = glob({glob}),
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
"""
_PACKAGES = {
"rocm-core": _CC_IMPORT_TPL.format(name = "rocm-core", shared_library = "librocm-core.so.1"),
"rocm-smi-lib": _CC_IMPORT_TPL.format(name = "rocm_smi", shared_library = "librocm_smi64.so.7"),
"hsa-rocr": _CC_IMPORT_TPL.format(name = "hsa-runtime", shared_library = "libhsa-runtime64.so.1"),
"hsa-amd-aqlprofile": _CC_IMPORT_TPL.format(name = "hsa-amd-aqlprofile", shared_library = "libhsa-amd-aqlprofile64.so.1"),
"comgr": _CC_IMPORT_TPL.format(name = "amd_comgr", shared_library = "libamd_comgr.so.2"),
"rocprofiler-register": _CC_IMPORT_TPL.format(name = "rocprofiler-register", shared_library = "librocprofiler-register.so.0"),
"miopen-hip": "".join([
_CC_IMPORT_TPL.format(name = "MIOpen", shared_library = "libMIOpen.so.1"),
_RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/miopen/**"])),
]),
"rccl": "".join([
_CC_IMPORT_TPL.format(name = "rccl", shared_library = "librccl.so.1"),
_RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/rccl/msccl-algorithms/**"])),
]),
"rocm-device-libs": _RUNFILES_TPL.format(name = "runfiles", glob = repr(["amdgcn/**"])),
"hip-dev": _RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/**"])),
"rocblas": """\
load("@zml//bazel:cc_import.bzl", "cc_import")
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
cc_import(
name = "rocblas",
shared_library = "lib/librocblas.so.4",
add_needed = ["libzmlrocmhooks.so.0"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
bytecode_select(
name = "runfiles",
bytecodes = glob(["lib/rocblas/library/*"]),
enabled_gfx = "@libpjrt_rocm//:gfx",
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
""",
"roctracer": """\
cc_import(
name = "roctracer",
shared_library = "lib/libroctracer64.so.4",
visibility = ["@libpjrt_rocm//:__subpackages__"],
deps = [":roctx"],
)
cc_import(
name = "roctx",
shared_library = "lib/libroctx64.so.4",
)
""",
"hipblaslt": """\
load("@zml//bazel:cc_import.bzl", "cc_import")
cc_import(
name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0",
add_needed = ["libzmlrocmhooks.so.0"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
""",
"hipblaslt-dev": """\
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
bytecode_select(
name = "bytecodes",
bytecodes = glob(
include = ["lib/hipblaslt/library/*"],
exclude = ["lib/hipblaslt/library/hipblasltExtOpLibrary.dat"],
),
enabled_gfx = "@libpjrt_rocm//:gfx",
)
filegroup(
name = "runfiles",
srcs = [
"lib/hipblaslt/library/hipblasltExtOpLibrary.dat",
":bytecodes",
],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
""",
"hip-runtime-amd": """\
cc_import(
name = "amdhip",
shared_library = "lib/libamdhip64.so.6",
visibility = ["@libpjrt_rocm//:__subpackages__"],
deps = [":hiprtc"],
)
cc_import(
name = "hiprtc",
shared_library = "lib/libhiprtc.so.6",
)
""",
"rocm-llvm": """\
filegroup(
name = "lld",
srcs = ["llvm/bin/ld.lld"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
""",
}
def _packages_to_dict(txt):
packages = {}
current_pkg = {}
for line in txt.splitlines():
if line == "":
if current_pkg:
packages[current_pkg["Package"]] = current_pkg
current_pkg = {}
continue
if line.startswith(" "):
current_pkg[key] += line
continue
split = line.split(": ", 1)
key = split[0]
value = len(split) > 1 and split[1] or ""
current_pkg[key] = value
return packages
def _rocm_impl(mctx):
data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt"))
PACKAGES = _packages_to_dict(data)
for pkg, build_file_content in _PACKAGES.items():
http_deb_archive(
build_file_content = build_file_content,
**pkg_kwargs(pkg, PACKAGES)
)
for repository, kwargs in _UBUNTU_PACKAGES.items():
http_deb_archive(name = repository, **kwargs)
http_archive(
name = "libpjrt_rocm",
build_file = "libpjrt_rocm.BUILD.bazel",
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-rocm_linux-amd64.tar.gz",
sha256 = "5900cec41274e80ab799bc13f31cdc87202f8e168d7e753b1c10796912f5ebef",
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = ["libpjrt_rocm"],
root_module_direct_dev_deps = [],
)
rocm_packages = module_extension(
implementation = _rocm_impl,
)

View File

@ -0,0 +1,103 @@
#include <string>
#include <iostream>
#include <dlfcn.h>
#include <errno.h>
#include <fstream>
#include <stdlib.h>
#include "tools/cpp/runfiles/runfiles.h"
namespace zml
{
using bazel::tools::cpp::runfiles::Runfiles;
std::unique_ptr<Runfiles> runfiles;
std::string ROCBLAS_TENSILE_LIBPATH;
std::string HIPBLASLT_TENSILE_LIBPATH;
std::string HIPBLASLT_EXT_OP_LIBRARY_PATH;
std::string ROCM_PATH;
typedef void *(*dlopen_func)(const char *filename, int flags);
dlopen_func dlopen_orig = nullptr;
__attribute__((constructor)) static void setup(int argc, char **argv)
{
runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
{
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
}
HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
if (HIPBLASLT_TENSILE_LIBPATH != "")
{
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
}
ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
}
static void *rocm_dlopen(const char *filename, int flags)
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.2",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen_orig(filename, flags);
}
}
extern "C"
{
zml::dlopen_func _zml_rocm_resolve_dlopen()
{
zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen");
return zml::rocm_dlopen;
}
extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen")));
}

5
runtimes/tpu/BUILD.bazel Normal file
View File

@ -0,0 +1,5 @@
alias(
name = "tpu",
actual = "@libpjrt_tpu",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,14 @@
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
copy_file(
name = "libpjrt_tpu_so",
src = "libtpu/libtpu.so",
out = "libpjrt_tpu.so",
allow_symlink = True,
)
cc_import(
name = "libpjrt_tpu",
shared_library = ":libpjrt_tpu_so",
visibility = ["@zml//runtimes/tpu:__subpackages__"],
)

20
runtimes/tpu/tpu.bzl Normal file
View File

@ -0,0 +1,20 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def _tpu_impl(mctx):
# https://storage.googleapis.com/jax-releases/libtpu_releases.html
http_archive(
name = "libpjrt_tpu",
url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20240915+nightly-py3-none-any.whl",
type = "zip",
sha256 = "4eee6c7dc92e7e60334c5b0261308a0a07c2f5b6235c7c60f465263576c602bf",
build_file = "libpjrt_tpu.BUILD.bazel",
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = ["libpjrt_tpu"],
root_module_direct_dev_deps = [],
)
tpu_packages = module_extension(
implementation = _tpu_impl,
)

6
third_party/bazel_registry.json vendored Normal file
View File

@ -0,0 +1,6 @@
{
"mirrors": [
"https://storage.googleapis.com/mirror.tensorflow.org/"
],
"module_base_path": "."
}

View File

@ -0,0 +1,78 @@
"aspect-build/bazel-lib"
module(
name = "aspect_bazel_lib",
version = "2.8.1.1",
bazel_compatibility = [">=6.0.0"],
compatibility_level = 1,
)
# Lower-bounds (minimum) versions for direct runtime dependencies
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "platforms", version = "0.0.8")
# 0.5.4 is the first version with bzlmod support
bazel_dep(name = "stardoc", version = "0.6.2", repo_name = "io_bazel_stardoc")
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains")
bazel_lib_toolchains.copy_directory()
bazel_lib_toolchains.copy_to_directory()
bazel_lib_toolchains.jq()
bazel_lib_toolchains.yq()
bazel_lib_toolchains.coreutils()
bazel_lib_toolchains.tar()
bazel_lib_toolchains.zstd()
bazel_lib_toolchains.expand_template()
bazel_lib_toolchains.bats()
use_repo(bazel_lib_toolchains, "bats_toolchains", "bsd_tar_toolchains", "copy_directory_toolchains", "copy_to_directory_toolchains", "coreutils_toolchains", "expand_template_toolchains", "jq_toolchains", "yq_toolchains", "zstd_toolchains")
register_toolchains(
"@copy_directory_toolchains//:all",
"@copy_to_directory_toolchains//:all",
"@jq_toolchains//:all",
"@yq_toolchains//:all",
"@coreutils_toolchains//:all",
"@expand_template_toolchains//:all",
"@bats_toolchains//:all",
"@bsd_tar_toolchains//:all",
"@zstd_toolchains//:all",
)
####### Dev dependencies ########
# To allow /tools to be built from source
# NOTE: when publishing to BCR, we patch this to be dev_dependency, as we publish pre-built binaries
# along with our releases.
bazel_dep(
name = "gazelle",
version = "0.36.0",
dev_dependency = True,
)
bazel_dep(
name = "rules_go",
version = "0.46.0",
repo_name = "io_bazel_rules_go",
dev_dependency = True,
)
go_deps = use_extension(
"@gazelle//:extensions.bzl",
"go_deps",
dev_dependency = True,
)
go_deps.from_file(go_mod = "//:go.mod")
use_repo(
go_deps,
"com_github_bmatcuk_doublestar_v4",
"org_golang_x_exp",
"org_golang_x_sys",
)
host = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "host", dev_dependency = True)
host.host()
use_repo(host, "aspect_bazel_lib_host")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.5.0", dev_dependency = True)
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
bazel_dep(name = "bazel_features", version = "0.2.0", dev_dependency = True)

View File

@ -0,0 +1,50 @@
From 5dbe6a2604e440b684add7b44531898acc8631b4 Mon Sep 17 00:00:00 2001
From: Steeve Morin <steeve.morin@gmail.com>
Date: Sun, 8 Sep 2024 21:17:29 +0200
Subject: [PATCH] fix: add repo mapping to tar archive
When using bzlmod, runfiles lookup will fail without it.
---
lib/private/tar.bzl | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/lib/private/tar.bzl b/lib/private/tar.bzl
index 733ff60..29434f6 100644
--- a/lib/private/tar.bzl
+++ b/lib/private/tar.bzl
@@ -147,12 +147,19 @@ def _tar_impl(ctx):
args.add(ctx.file.mtree, format = "@%s")
inputs.append(ctx.file.mtree)
+ src_runfiles = []
+ for src in ctx.attr.srcs:
+ src_di = src[DefaultInfo]
+ if getattr(src_di.files_to_run, "repo_mapping_manifest", None) != None:
+ src_runfiles.append(depset(
+ direct = [src_di.files_to_run.repo_mapping_manifest],
+ transitive = [src_di.default_runfiles.files],
+ ))
+ else:
+ src_runfiles.append(src_di.default_runfiles.files)
ctx.actions.run(
executable = bsdtar.tarinfo.binary,
- inputs = depset(direct = inputs, transitive = [bsdtar.default.files] + [
- src[DefaultInfo].default_runfiles.files
- for src in ctx.attr.srcs
- ]),
+ inputs = depset(direct = inputs, transitive = [bsdtar.default.files] + src_runfiles),
outputs = [out],
arguments = [args],
mnemonic = "Tar",
@@ -234,6 +241,8 @@ def _mtree_impl(ctx):
workspace_name = str(ctx.workspace_name)
content.add(_mtree_line(runfiles_dir, type = "dir"))
+ if getattr(default_info.files_to_run, "repo_mapping_manifest", None) != None:
+ content.add(_mtree_line("{}/_repo_mapping".format(runfiles_dir), type = "file", content = default_info.files_to_run.repo_mapping_manifest.path))
content.add_all(
s.default_runfiles.files,
expand_directories = True,
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,27 @@
diff --git a/MODULE.bazel b/MODULE.bazel
index e63fa5b..9d78a88 100644
--- a/MODULE.bazel
+++ b/MODULE.bazel
@@ -50,19 +50,19 @@ use_repo(host, "aspect_bazel_lib_host")
bazel_dep(
name = "gazelle",
version = "0.36.0",
- # In released versions: dev_dependency = True
+ dev_dependency = True,
)
bazel_dep(
name = "rules_go",
version = "0.46.0",
repo_name = "io_bazel_rules_go",
- # In released versions: dev_dependency = True
+ dev_dependency = True,
)
go_deps = use_extension(
"@gazelle//:extensions.bzl",
"go_deps",
- # In released versions: dev_dependency = True
+ dev_dependency = True,
)
go_deps.from_file(go_mod = "//:go.mod")
use_repo(

View File

@ -0,0 +1,14 @@
===================================================================
--- a/MODULE.bazel
+++ b/MODULE.bazel
@@ -1,9 +1,9 @@
"aspect-build/bazel-lib"
module(
name = "aspect_bazel_lib",
- version = "0.0.0",
+ version = "2.8.1",
bazel_compatibility = [">=6.0.0"],
compatibility_level = 1,
)

View File

@ -0,0 +1,11 @@
{
"integrity": "sha256-aINU7mvuunGUJD1z6wmSuaEujt7u7FtlRPS1MaMRIjc=",
"strip_prefix": "bazel-lib-2.8.1",
"url": "https://github.com/aspect-build/bazel-lib/releases/download/v2.8.1/bazel-lib-v2.8.1.tar.gz",
"patches": {
"go_dev_dep.patch": "sha256-DTc/hk+etl4D50M0BLRik2vHbrgDb6rds+Dj4xphWb4=",
"module_dot_bazel_version.patch": "sha256-cEMv6bY7Sc5dERv8YG0lq45zuZCsVPNn4oAN9aOkf40=",
"0001-fix-add-repo-mapping-to-tar-archive.patch": ""
},
"patch_strip": 1
}

View File

@ -0,0 +1,25 @@
{
"homepage": "https://docs.aspect.build/rules/aspect_bazel_lib",
"maintainers": [
{
"name": "Alex Eagle",
"email": "alex@aspect.dev",
"github": "alexeagle"
},
{
"name": "Derek Cormier",
"email": "derek@aspect.dev",
"github": "kormide"
}
],
"repository": [
"github:aspect-build/bazel-lib"
],
"versions": [
"2.8.1.1"
],
"yanked_versions": {
"1.31.0": "1.31.0 has a breaking change to the default yq version",
"1.34.2": "The version field has a leading 'v' due to a release automation bug"
}
}

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20240825.0-dbe2291",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,13 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "xev",
srcs = glob([
"src/*.zig",
"src/backend/*.zig",
"src/linux/*.zig",
"src/watcher/*.zig",
]),
main = "src/main.zig",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20240825.0-dbe2291",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,9 @@
{
"strip_prefix": "libxev-dbe22910a43e9e8ec9948d3cbd73d8488a074967",
"url": "https://github.com/mitchellh/libxev/archive/dbe22910a43e9e8ec9948d3cbd73d8488a074967.tar.gz",
"integrity": "sha256-JM70bkRUyuAQXKV1piJ6I0IClzCrXV43sdLfIKVC9Lo=",
"overlay": {
"MODULE.bazel": "",
"BUILD.bazel": ""
}
}

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20240910.0-a2d9b31",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,13 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
zig_library(
name = "xev",
srcs = glob([
"src/*.zig",
"src/backend/*.zig",
"src/linux/*.zig",
"src/watcher/*.zig",
]),
main = "src/main.zig",
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,7 @@
module(
name = "libxev",
version = "20240910.0-a2d9b31",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")

View File

@ -0,0 +1,9 @@
{
"strip_prefix": "libxev-a2d9b3154f1e5ed463c9f2fb98a0d739057796ce",
"url": "https://github.com/zml/libxev/archive/a2d9b3154f1e5ed463c9f2fb98a0d739057796ce.tar.gz",
"integrity": "sha256-xdMjrGSB3telt52iRqHgem0WSRPjfGlTKBRyDMXLlaQ=",
"overlay": {
"MODULE.bazel": "",
"BUILD.bazel": ""
}
}

View File

@ -0,0 +1,18 @@
{
"homepage": "https://github.com/mitchellh/libxev",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"github:mitchellh/libxev"
],
"versions": [
"20240825.0-dbe2291",
"20240910.0-a2d9b31"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,10 @@
module(
name = "llvm-raw",
version = "20240823.0-f142f8a",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")

View File

@ -0,0 +1,10 @@
module(
name = "llvm-raw",
version = "20240823.0-f142f8a",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")

View File

@ -0,0 +1,28 @@
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
def _llvm_impl(mctx):
_targets = {}
for mod in mctx.modules:
for conf in mod.tags.configure:
for target in conf.targets:
_targets[target] = True
_llvm_configure(
name = "llvm-project",
targets = _targets.keys(),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
llvm = module_extension(
implementation = _llvm_impl,
tag_classes = {
"configure": tag_class(
attrs = {
"targets": attr.string_list(mandatory = True),
},
),
},
)

View File

@ -0,0 +1,10 @@
{
"strip_prefix": "llvm-project-f142f8afe21bceb00fb495468aa0b5043e98c419",
"url": "https://github.com/llvm/llvm-project/archive/f142f8afe21bceb00fb495468aa0b5043e98c419.tar.gz",
"integrity": "sha256-8ftYOnMGq99igt2I1n9SOZpYLq79+F7HiKL/7lOwcHs=",
"overlay": {
"BUILD.bazel": "",
"MODULE.bazel": "",
"utils/bazel/extension.bzl": ""
}
}

View File

@ -0,0 +1,17 @@
{
"homepage": "https://github.com/llvm/llvm-project",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"github:llvm/llvm-project"
],
"versions": [
"20240823.0-f142f8a"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,68 @@
module(
name = "rules_zig",
version = "20240904.0-010da15",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9", dev_dependency = True)
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)

View File

@ -0,0 +1,5 @@
{
"type": "git_repository",
"remote": "https://github.com/zml/rules_zig.git",
"commit": "010da15abb4335479778d6b4fb2ca752a0ab80e3"
}

View File

@ -0,0 +1,68 @@
module(
name = "rules_zig",
version = "20240909.0-37f17ff",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9", dev_dependency = True)
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)

View File

@ -0,0 +1,5 @@
{
"type": "git_repository",
"remote": "https://github.com/zml/rules_zig.git",
"commit": "37f17ffd2f6d04fc4bf9b583f9a1b117856a88d1"
}

View File

@ -0,0 +1,68 @@
module(
name = "rules_zig",
version = "20240912.0-41bfe84",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)

View File

@ -0,0 +1,5 @@
{
"type": "git_repository",
"remote": "https://github.com/zml/rules_zig.git",
"commit": "41bfe84e4d9a43cbe55281dddc80b683cc6fc6eb"
}

View File

@ -0,0 +1,68 @@
module(
name = "rules_zig",
version = "20240913.0-1957d05",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)

View File

@ -0,0 +1,5 @@
{
"type": "git_repository",
"remote": "https://github.com/zml/rules_zig.git",
"commit": "1957d0572193fb859e721a0fab8bd8f0fb57f3ff"
}

View File

@ -0,0 +1,20 @@
{
"homepage": "https://github.com/zml/rules_zig",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"github:zml/rules_zig"
],
"versions": [
"20240904.0-010da15",
"20240909.0-37f17ff",
"20240912.0-41bfe84",
"20240913.0-1957d05"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,7 @@
module(
name = "sentencepiece",
version = "20240618.0-d7ace0a",
compatibility_level = 1,
)
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")

View File

@ -0,0 +1,7 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
proto_library(
name = "sentencepiece_model_proto",
srcs = ["src/sentencepiece_model.proto"],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,7 @@
module(
name = "sentencepiece",
version = "20240618.0-d7ace0a",
compatibility_level = 1,
)
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")

View File

@ -0,0 +1,9 @@
{
"strip_prefix": "sentencepiece-d7ace0a4df17af17d73dc8fe6ff92cfd8d08e2b8",
"url": "https://github.com/google/sentencepiece/archive/d7ace0a4df17af17d73dc8fe6ff92cfd8d08e2b8.tar.gz",
"integrity": "sha256-bt8QFNNIpqcJoCNMUrmWfUf0uNobtk8pW+XbHL+ZRLg=",
"overlay": {
"MODULE.bazel": "",
"BUILD.bazel": ""
}
}

View File

@ -0,0 +1,17 @@
{
"homepage": "https://github.com/google/sentencepiece",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"google/sentencepiece"
],
"versions": [
"20240618.0-d7ace0a"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,15 @@
module(
name = "stablehlo",
version = "20240829.0-54aa1a5",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure(
targets = ["AArch64", "X86", "NVPTX"],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,15 @@
module(
name = "stablehlo",
version = "20240829.0-54aa1a5",
compatibility_level = 1,
)
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
llvm.configure(
targets = ["AArch64", "X86", "NVPTX"],
)
use_repo(llvm, "llvm-project")

View File

@ -0,0 +1,83 @@
From f11d9bd7639f63cba681caa39cea27af114a4a71 Mon Sep 17 00:00:00 2001
From: Steeve Morin <steeve.morin@gmail.com>
Date: Mon, 2 Sep 2024 23:02:28 +0200
Subject: [PATCH 1/2] Remove duplicated symbols in StablehloApi.h
This causes C compilers to choke.
Refs #2494
---
stablehlo/integrations/c/StablehloApi.cpp | 2 +-
stablehlo/integrations/c/StablehloApi.h | 15 +--------------
stablehlo/integrations/python/StablehloApi.cpp | 2 +-
3 files changed, 3 insertions(+), 16 deletions(-)
diff --git a/stablehlo/integrations/c/StablehloApi.cpp b/stablehlo/integrations/c/StablehloApi.cpp
index 8d922198..9b41ff21 100644
--- a/stablehlo/integrations/c/StablehloApi.cpp
+++ b/stablehlo/integrations/c/StablehloApi.cpp
@@ -98,7 +98,7 @@ MlirLogicalResult stablehloSerializePortableArtifact(
return mlirLogicalResultSuccess();
}
-MlirLogicalResult stablehloDeserializePortableArtifact(
+MlirLogicalResult stablehloDeserializePortableArtifactToBytecode(
MlirStringRef artifactStr, MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
if (failed(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr),
diff --git a/stablehlo/integrations/c/StablehloApi.h b/stablehlo/integrations/c/StablehloApi.h
index 4c542508..f94dca8d 100644
--- a/stablehlo/integrations/c/StablehloApi.h
+++ b/stablehlo/integrations/c/StablehloApi.h
@@ -76,16 +76,6 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
MlirStringRef moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void* userData);
-// Write a StableHLO program expressed as a string (either prettyprinted MLIR
-// module or MLIR bytecode) to a portable artifact.
-// Can fail if `moduleStr` cannot be parsed, or if it cannot be expressed in the
-// `targetVersion` version of StableHLO, e.g. if it's using new or removed
-// features, or if it involves unsupported dialects.
-// Returns false on failure.
-MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
- MlirModule moduleStr, MlirStringRef targetVersion,
- MlirStringCallback callback, void* userData);
-
// Read a StableHLO program from a portable artifact, returning the module as
// MLIR bytecode. Note, this bytecode returned is not a portable artifact,
// and has the stability of returning textual assembly format. Bytecode is
@@ -93,7 +83,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
// Can fail if `artifactStr` cannot be expressed in the current version of
// StableHLO, e.g. if it's using incompatible features.
// Returns false on failure.
-MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifact(
+MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifactAsBytecode(
MlirStringRef artifactStr, MlirStringCallback callback, void* userData);
// Read a StableHLO program from a portable artifact, returning the module as
@@ -109,9 +99,6 @@ MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
// Call the Interpreter, returns MlirArrayAttr of dense element
// MlirAttribute results
-MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
- MlirStringRef artifactStr, MlirContext ctx);
-
// Entrypoint for calling the StableHLO reference interpreter.
// Returns an array attribute of dense element attributes for results.
// Sets error code to non-zero on failure.
diff --git a/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/integrations/python/StablehloApi.cpp
index 46a640e1..4229ef76 100644
--- a/stablehlo/integrations/python/StablehloApi.cpp
+++ b/stablehlo/integrations/python/StablehloApi.cpp
@@ -213,7 +213,7 @@ void AddPortableApi(py::module &m) {
"deserialize_portable_artifact_str",
[](std::string_view artifact) -> py::bytes {
StringWriterHelper accumulator;
- if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
+ if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifactToBytecode(
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,401 @@
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
--- stablehlo/BUILD.bazel
+++ stablehlo/BUILD.bazel
@@ -340,6 +340,21 @@
],
)
+gentbl_cc_library(
+ name = "stablehlo_create_compatibility_expander_inc_gen",
+ tbl_outs = [
+ (
+ ["--gen-rewriters"],
+ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
+ deps = [
+ ":stablehlo_ops_td_files",
+ ],
+)
+
cc_library(
name = "interpreter_ops",
srcs = [
@@ -1086,6 +1101,7 @@
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloConvertToSignless.cpp",
+ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
@@ -1109,6 +1125,7 @@
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
+ ":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp
--- stablehlo/stablehlo/dialect/Version.cpp
+++ stablehlo/stablehlo/dialect/Version.cpp
@@ -82,7 +82,7 @@
case CompatibilityRequirement::WEEK_4:
return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024
case CompatibilityRequirement::WEEK_12:
- return Version(1, 0, 0); // v1.0.0 - May 14, 2024
+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024
case CompatibilityRequirement::MAX:
return Version::getMinimumVersion();
}
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
@@ -0,0 +1,43 @@
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
+
+// -----
+
+// CHECK-LABEL @tan_op_non_complex
+// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
+// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
+// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
+// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
+func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
+ func.return %1 : tensor<4xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @tan_op_complex
+// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
+// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
+// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
+// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
+// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
+// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
+// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
+// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
+// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
+// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
+// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
+// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
+func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
+ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
+ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
+ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
+ func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
+}
diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
--- stablehlo/stablehlo/transforms/CMakeLists.txt
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
@@ -20,6 +20,10 @@
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(ChloDecompositionPatternsIncGen)
+set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
+mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
+add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)
+
set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
@@ -27,6 +31,7 @@
set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(VhloToVersionPatterns)
+
add_mlir_dialect_library(StablehloPasses
PARTIAL_SOURCES_INTENDED
@@ -37,6 +42,7 @@
StablehloAggressiveSimplification.cpp
StablehloCanonicalizeDynamism.cpp
StablehloConvertToSignless.cpp
+ StablehloCreateCompatibilityExpander.cpp
StablehloLegalizeCompositeToCall.cpp
StablehloLegalizeDeprecatedOps.cpp
StablehloLegalizeQuantToMath.cpp
@@ -53,6 +59,7 @@
StablehloLegalizeDeprecatedOpsPatternsIncGen
PassesIncGen
VhloToVersionPatterns
+ StablehloCreateCompatibilityExpanderPatternsIncGen
LINK_LIBS PUBLIC
ChloOps
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
--- stablehlo/stablehlo/transforms/Passes.h
+++ stablehlo/stablehlo/transforms/Passes.h
@@ -25,6 +25,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/Version.h"
namespace mlir {
namespace stablehlo {
@@ -96,6 +97,12 @@
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns);
+/// Collection of patterns to create compatibility expander for StableHLO
+/// operations.
+void populateStablehloCreateCompatibilityExpanderPatterns(
+ RewritePatternSet *patterns, MLIRContext *context,
+ vhlo::Version targetVersion);
+
//// Additional pass constructors ////
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
--- stablehlo/stablehlo/transforms/Passes.td
+++ stablehlo/stablehlo/transforms/Passes.td
@@ -292,3 +292,51 @@
"mlir::stablehlo::StablehloDialect",
];
}
+
+def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
+ let summary = "Create compatibility expander for StableHLO operations.";
+
+ let description = [{
+ StableHLO ops gets updates or new op is introduced in the latest versions.
+ This opt-in pass expands backward compatibility with older StableHLO
+ versions by decomposing newer StableHLO operations into equivalent
+ operations supported by those older versions.
+
+ Why is this an opt-in pass?
+
+ Occasionally, StableHLO op enhancements are used to greatly simplify the
+ handling of certain common patterns in the OpenXLA ecosystem. This
+ includes things like TanOp, which has high framework and compiler support,
+ as well as gather/scatter batching dimensions, which can be represented
+ using slices, but makes sharding much more difficult. For this category of
+ new features, we do not offer automatic downgrade, since it may throw away
+ important information used in subsequent optimizations. This pass can be
+ used to expand these ops based on a target version to maximize compatibility
+ at the expense of potentially less optimal compilation.
+
+ ```mlir
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
+ func.return %1 : tensor<4xf64>
+ }
+ ```
+
+ will become:
+
+ ```mlir
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
+ %0 = stablehlo.sine %arg0 : tensor<4xf64>
+ %1 = stablehlo.cosine %arg0 : tensor<4xf64>
+ %2 = stablehlo.divide %0, %1 : tensor<4xf64>
+ return %2 : tensor<4xf64>
+ }
+ ```
+ }];
+ let options = [
+ Option<"targetVersionOption", "target", "std::string", "",
+ "The target version. Must be a version of the form #.#.#.">,
+ ];
+ let dependentDialects = [
+ "mlir::stablehlo::StablehloDialect",
+ ];
+}
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
@@ -0,0 +1,128 @@
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fcntl.h>
+
+#include <cassert>
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/dialect/StablehloOps.h"
+#include "stablehlo/dialect/Version.h"
+#include "stablehlo/transforms/Passes.h"
+
+namespace mlir {
+namespace stablehlo {
+#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS
+#include "stablehlo/transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helpers.
+//===----------------------------------------------------------------------===//
+
+// Creates a constant with all ones.
+static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
+ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
+ if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
+
+ mlir::DenseElementsAttr elementsAttr =
+ mlir::DenseElementsAttr::get(shapedTy, 1.0);
+
+ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
+ elementsAttr);
+}
+
+// Check user-specified target version.
+vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
+ auto failOrVersion = vhlo::Version::fromString(versionRef);
+ if (failed(failOrVersion)) {
+ assert(!versionRef.empty() &&
+ "No target version specified. Target version must be of the form "
+ "`#.#.#`.");
+ assert(versionRef.empty() &&
+ "Invalid target version argument. Target version must be of the "
+ "form `#.#.#`.");
+ }
+ vhlo::Version targetVersion = *failOrVersion;
+ assert((vhlo::Version::getMinimumVersion() <= targetVersion) &&
+ "target version is less than minimum supported.");
+ assert((targetVersion <= vhlo::Version::getCurrentVersion()) &&
+ "target version is greater than current version.");
+ return targetVersion;
+}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+struct StablehloCreateCompatibilityExpanderPass
+ : public impl::StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass> {
+ StablehloCreateCompatibilityExpanderPass()
+ : StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass>() {}
+ StablehloCreateCompatibilityExpanderPass(
+ const StablehloCreateCompatibilityExpanderPassOptions &opts)
+ : StablehloCreateCompatibilityExpanderPassBase<
+ StablehloCreateCompatibilityExpanderPass>(opts) {}
+
+ public:
+ LogicalResult initialize(MLIRContext *context) override {
+ auto targetVersion = validateTargetVersion(targetVersionOption);
+
+ config.useTopDownTraversal = true;
+ RewritePatternSet patterns_(context);
+ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context,
+ targetVersion);
+ patterns = std::move(patterns_);
+ return success();
+ }
+
+ void runOnOperation() override {
+ auto func = getOperation();
+ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
+ func.emitError(
+ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ")
+ << config.maxIterations << " iterations";
+ signalPassFailure();
+ }
+ }
+
+ private:
+ FrozenRewritePatternSet patterns;
+ GreedyRewriteConfig config;
+};
+
+#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc"
+
+} // namespace
+
+void populateStablehloCreateCompatibilityExpanderPatterns(
+ RewritePatternSet *patterns, MLIRContext *context,
+ vhlo::Version targetVersion) {
+ // StableHLO TanOp is introduced in v1.4.0.
+ if (targetVersion < vhlo::Version(1, 4, 0)) {
+ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
+ patterns->add<TanOp_CompatiblityExpander>(context);
+ }
+}
+
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
@@ -0,0 +1,47 @@
+/* Copyright 2022 The StableHLO Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/IR/OpBase.td"
+include "stablehlo/dialect/StablehloOps.td"
+
+def ComplexElementType : Type<
+ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
+ "Complex element type">;
+
+def NonComplexElementType : Type<
+ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
+ "Non-complex element type">;
+
+def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
+
+// Express `tan` as
+// sine(x) / cosine(x)
+def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input),
+ (StableHLO_DivOp
+ (StableHLO_SineOp $input),
+ (StableHLO_CosineOp $input)
+ )>;
+
+// Express `tan(a + bi)` as
+// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
+def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input),
+ (StableHLO_DivOp
+ (StableHLO_ComplexOp
+ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
+ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
+ (StableHLO_ComplexOp
+ (createConstantWithAllOnes $tan),
+ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
+ )>;

View File

@ -0,0 +1,12 @@
{
"strip_prefix": "stablehlo-54aa1a57178251981da616b877dda1a88d840d11",
"url": "https://github.com/openxla/stablehlo/archive/54aa1a57178251981da616b877dda1a88d840d11.tar.gz",
"integrity": "sha256-PFXXs3vQ+WlMiCC8twbGEFbsBEF/67Z4qtmq8CtUSUU=",
"overlay": {
"MODULE.bazel": ""
},
"patch_strip": 1,
"patches": {
"0001-Remove-duplicated-symbols-in-StablehloApi.h.patch": ""
}
}

View File

@ -0,0 +1,17 @@
{
"homepage": "https://github.com/openxla/stablehlo",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"github:openxla/stablehlo"
],
"versions": [
"20240829.0-54aa1a5"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,34 @@
module(
name = "xla",
version = "20240902.0-d18cd64",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
)

View File

@ -0,0 +1,34 @@
module(
name = "xla",
version = "20240902.0-d18cd64",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
)

View File

@ -0,0 +1,13 @@
load("//third_party:repo.bzl", "tf_vendored")
def _tsl_impl(mctx):
tf_vendored(name = "tsl", relpath = "third_party/tsl")
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
tsl = module_extension(
implementation = _tsl_impl,
)

View File

@ -0,0 +1,52 @@
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
def _xla_workspace_impl(mctx):
cuda_configure(name = "local_config_cuda")
remote_execution_configure(name = "local_config_remote_execution")
rocm_configure(name = "local_config_rocm")
tensorrt_configure(name = "local_config_tensorrt")
tf_http_archive(
name = "com_github_grpc_grpc",
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
patch_file = [
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
"@tsl//third_party/grpc:register_go_toolchain.patch",
],
system_link_files = {
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
},
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
)
tf_http_archive(
name = "com_google_protobuf",
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
strip_prefix = "protobuf-3.21.9",
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
system_link_files = {
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
},
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
xla_workspace = module_extension(
implementation = _xla_workspace_impl,
)

View File

@ -0,0 +1,27 @@
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
From: Steeve Morin <steeve.morin@gmail.com>
Date: Mon, 18 Mar 2024 17:17:34 +0100
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
---
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
index c821916ad..89a596123 100644
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
+#ifdef __cplusplus
#include <cstddef>
#include <cstdint>
+#endif
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,14 @@
{
"strip_prefix": "xla-d18cd64b7cd61a2ade10089665ac104f639101b1",
"url": "https://github.com/openxla/xla/archive/d18cd64b7cd61a2ade10089665ac104f639101b1.tar.gz",
"integrity": "sha256-EtKhjU91STBceUmg0TUE6cPeRkeSz3LI2S1i3EFMj/E=",
"overlay": {
"tsl.bzl": "",
"workspace.bzl": "",
"MODULE.bazel": ""
},
"patch_strip": 1,
"patches": {
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
}
}

17
third_party/modules/xla/metadata.json vendored Normal file
View File

@ -0,0 +1,17 @@
{
"homepage": "https://github.com/openxla/xla",
"maintainers": [
{
"email": "bzlmod@zml.ai",
"github": "zml",
"name": "ZML Engineering Team"
}
],
"repository": [
"github:openxla/xla"
],
"versions": [
"20240902.0-d18cd64"
],
"yanked_versions": {}
}

View File

@ -0,0 +1,8 @@
module(
name = "zig-protobuf",
version = "20240722.0-c644d11",
compatibility_level = 1,
)
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")

View File

@ -0,0 +1,32 @@
load("@rules_proto//proto:defs.bzl", "proto_lang_toolchain")
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary", "zig_library")
zig_library(
name = "protobuf",
import_name = "protobuf",
main = "src/protobuf.zig",
visibility = ["//visibility:public"],
)
zig_binary(
name = "generator",
srcs = [
"bootstrapped-generator/FullName.zig",
"bootstrapped-generator/google/protobuf/compiler/plugin.pb.zig",
"bootstrapped-generator/google/protobuf/descriptor.pb.zig",
],
kind = BINARY_KIND.exe,
main = "bootstrapped-generator/main.zig",
visibility = ["//visibility:public"],
deps = [":protobuf"],
)
proto_lang_toolchain(
name = "zig_toolchain",
command_line = "--zig_out=$(OUT)",
output_files = "multiple",
plugin = ":generator",
plugin_format_flag = "--plugin=protoc-gen-zig=%s",
runtime = ":protobuf",
visibility = ["//visibility:public"],
)

Some files were not shown because too many files have changed in this diff Show More