Add initial Bazel build configuration, async runtime implementation, and core MLIR dialect definitions for ZML.
This commit is contained in:
commit
266da6d4be
0
BUILD.bazel
Normal file
0
BUILD.bazel
Normal file
79
MODULE.bazel
Normal file
79
MODULE.bazel
Normal file
@ -0,0 +1,79 @@
|
||||
module(
|
||||
name = "zml",
|
||||
)
|
||||
|
||||
new_git_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.0")
|
||||
bazel_dep(name = "patchelf", version = "0.18.0")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_pkg", version = "1.0.1")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.2")
|
||||
|
||||
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1.1")
|
||||
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains", dev_dependency = True)
|
||||
use_repo(bazel_lib_toolchains, "jq_toolchains")
|
||||
|
||||
toolchains = use_extension("@hermetic_cc_toolchain//toolchain:ext.bzl", "toolchains")
|
||||
use_repo(toolchains, "zig_sdk")
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240913.0-1957d05")
|
||||
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//bazel:zig_index.json")
|
||||
zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf")
|
||||
zig.mirrors(urls = [
|
||||
"https://mirror.zml.ai/zig",
|
||||
])
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
register_toolchains(
|
||||
"@zig_sdk//toolchain:linux_amd64_gnu.2.31",
|
||||
"@zig_sdk//toolchain:linux_arm64_gnu.2.31",
|
||||
)
|
||||
|
||||
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
||||
use_repo(cpu, "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
||||
|
||||
cuda = use_extension("//runtimes/cuda:cuda.bzl", "cuda_packages")
|
||||
use_repo(cuda, "libpjrt_cuda")
|
||||
|
||||
rocm = use_extension("//runtimes/rocm:rocm.bzl", "rocm_packages")
|
||||
use_repo(rocm, "libpjrt_rocm")
|
||||
|
||||
tpu = use_extension("//runtimes/tpu:tpu.bzl", "tpu_packages")
|
||||
use_repo(tpu, "libpjrt_tpu")
|
||||
|
||||
zls = use_extension("//third_party/zls:zls.bzl", "repo")
|
||||
use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux")
|
||||
|
||||
register_toolchains("//third_party/zls:all")
|
||||
|
||||
bazel_dep(name = "libxev", version = "20240910.0-a2d9b31")
|
||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = [
|
||||
"AArch64",
|
||||
"X86",
|
||||
"NVPTX",
|
||||
],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
|
||||
bazel_dep(name = "xla", version = "20240902.0-d18cd64")
|
||||
|
||||
tsl = use_extension("@xla//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
bazel_dep(name = "zigcoro", version = "20240829.0-fc1db29")
|
||||
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
||||
bazel_dep(name = "zig-protobuf", version = "20240722.0-c644d11")
|
||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||
1305
MODULE.bazel.lock
Normal file
1305
MODULE.bazel.lock
Normal file
File diff suppressed because it is too large
Load Diff
11
async/BUILD.bazel
Normal file
11
async/BUILD.bazel
Normal file
@ -0,0 +1,11 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "async",
|
||||
main = "async.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@zigcoro//:libcoro",
|
||||
"@libxev//:xev",
|
||||
],
|
||||
)
|
||||
400
async/async.zig
Normal file
400
async/async.zig
Normal file
@ -0,0 +1,400 @@
|
||||
const std = @import("std");
|
||||
const xev = @import("xev");
|
||||
const libcoro = @import("libcoro");
|
||||
const aio = libcoro.asyncio;
|
||||
|
||||
/// Normalize from a real tuple to a generic tuple. This is needed because
|
||||
/// real tuples are reifed tuples are not the same.
|
||||
fn NormalizedTuple(comptime T: type) type {
|
||||
const ti = @typeInfo(T).Struct;
|
||||
var types: [ti.fields.len]type = undefined;
|
||||
inline for (ti.fields, 0..) |field, i| {
|
||||
types[i] = field.type;
|
||||
}
|
||||
return std.meta.Tuple(&types);
|
||||
}
|
||||
|
||||
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
|
||||
return struct {
|
||||
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
|
||||
pub const ArgsT = blk: {
|
||||
if (@typeInfo(FuncT).Fn.params.len == 0) {
|
||||
break :blk @TypeOf(.{});
|
||||
}
|
||||
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
|
||||
};
|
||||
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
|
||||
pub const ReturnPayloadT = blk: {
|
||||
break :blk switch (@typeInfo(ReturnT)) {
|
||||
.ErrorUnion => |u| u.payload,
|
||||
else => ReturnT,
|
||||
};
|
||||
};
|
||||
pub const ReturnErrorSet: ?type = blk: {
|
||||
break :blk switch (@typeInfo(ReturnT)) {
|
||||
.ErrorUnion => |u| u.error_set,
|
||||
else => null,
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
pub fn Frame(comptime func: anytype) type {
|
||||
const Signature = FnSignature(func, null);
|
||||
return FrameEx(func, Signature.ArgsT);
|
||||
}
|
||||
|
||||
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
|
||||
return FrameExx(func, argsT);
|
||||
}
|
||||
|
||||
fn FrameExx(comptime func: anytype, comptime argsT: type) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
const Signature = FnSignature(func, argsT);
|
||||
const FrameT = libcoro.FrameT(func, .{ .ArgsT = Signature.ArgsT });
|
||||
|
||||
inner: FrameT,
|
||||
|
||||
pub fn await_(self: *Self) Signature.ReturnT {
|
||||
defer {
|
||||
self.inner.deinit();
|
||||
self.* = undefined;
|
||||
}
|
||||
return libcoro.xawait(self.inner);
|
||||
}
|
||||
|
||||
fn from(other: anytype) !Self {
|
||||
return .{ .inner = FrameT.wrap(other.frame()) };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn async_(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
|
||||
const frame = try aio.xasync(func, args, null);
|
||||
return FrameEx(func, @TypeOf(args)).from(frame);
|
||||
}
|
||||
|
||||
pub fn call(comptime func: anytype, args: std.meta.ArgsTuple(@TypeOf(func))) @TypeOf(callGeneric(func, args)) {
|
||||
return callGeneric(func, args);
|
||||
}
|
||||
|
||||
pub fn callGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
|
||||
const Signature = FnSignature(func, @TypeOf(args));
|
||||
|
||||
const TaskT = struct {
|
||||
const Self = @This();
|
||||
|
||||
_task: xev.ThreadPool.Task = .{ .callback = &Self.run },
|
||||
|
||||
notif: Notification,
|
||||
args: *const Signature.ArgsT,
|
||||
result: Signature.ReturnT = undefined,
|
||||
|
||||
pub fn run(task_: *xev.ThreadPool.Task) void {
|
||||
const task: *Self = @alignCast(@fieldParentPtr("_task", task_));
|
||||
task.result = @call(.auto, func, task.args.*);
|
||||
task.notif.notify() catch @panic("Unable to notify");
|
||||
}
|
||||
};
|
||||
|
||||
var newtask: TaskT = .{
|
||||
.notif = Notification.init() catch @panic("Notification.init failed"),
|
||||
.args = &args,
|
||||
};
|
||||
defer newtask.notif.deinit();
|
||||
|
||||
AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&newtask._task));
|
||||
newtask.notif.wait() catch @panic("Unable to wait for notification");
|
||||
return newtask.result;
|
||||
}
|
||||
|
||||
pub fn tick() void {
|
||||
AsyncThread.current.executor.exec.tick();
|
||||
}
|
||||
|
||||
pub fn sleep(ms: u64) !void {
|
||||
try aio.sleep(AsyncThread.current.executor, ms);
|
||||
}
|
||||
|
||||
pub const Notification = struct {
|
||||
inner: aio.AsyncNotification,
|
||||
|
||||
pub fn init() !Notification {
|
||||
return .{
|
||||
.inner = aio.AsyncNotification.init(AsyncThread.current.executor, try xev.Async.init()),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn notify(self: *Notification) !void {
|
||||
try self.inner.notif.notify();
|
||||
}
|
||||
|
||||
pub fn wait(self: *Notification) !void {
|
||||
try self.inner.wait();
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Notification) void {
|
||||
self.inner.notif.deinit();
|
||||
self.* = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
pub const AsyncThread = struct {
|
||||
threadlocal var current: AsyncThread = undefined;
|
||||
|
||||
executor: *aio.Executor,
|
||||
loop: *xev.Loop,
|
||||
thread_pool: *xev.ThreadPool,
|
||||
|
||||
pub fn main(allocator: std.mem.Allocator, comptime func: anytype, args: anytype) !FnSignature(func, NormalizedTuple(@TypeOf(args))).ReturnPayloadT {
|
||||
const Signature = FnSignature(func, NormalizedTuple(@TypeOf(args)));
|
||||
|
||||
var thread_pool = xev.ThreadPool.init(.{});
|
||||
defer {
|
||||
thread_pool.shutdown();
|
||||
thread_pool.deinit();
|
||||
}
|
||||
|
||||
var loop = try xev.Loop.init(.{
|
||||
.thread_pool = &thread_pool,
|
||||
});
|
||||
defer loop.deinit();
|
||||
|
||||
var executor = aio.Executor.init(&loop);
|
||||
|
||||
AsyncThread.current = .{
|
||||
.executor = &executor,
|
||||
.loop = &loop,
|
||||
.thread_pool = &thread_pool,
|
||||
};
|
||||
|
||||
aio.initEnv(.{
|
||||
.stack_allocator = allocator,
|
||||
.default_stack_size = 16 * 1024 * 1024,
|
||||
});
|
||||
|
||||
if (Signature.ReturnErrorSet) |_| {
|
||||
return try aio.run(&executor, func, args, null);
|
||||
} else {
|
||||
return aio.run(&executor, func, args, null);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
pub fn StdIn() !File {
|
||||
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
||||
}
|
||||
|
||||
pub fn StdOut() File {
|
||||
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
|
||||
}
|
||||
|
||||
pub fn StdErr() File {
|
||||
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
|
||||
}
|
||||
|
||||
pub const File = struct {
|
||||
pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?;
|
||||
pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?;
|
||||
pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read);
|
||||
pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write);
|
||||
pub const SeekableStream = std.io.SeekableStream(
|
||||
File,
|
||||
SeekError,
|
||||
GetSeekPosError,
|
||||
seekTo,
|
||||
seekBy,
|
||||
getPos,
|
||||
getEndPos,
|
||||
);
|
||||
|
||||
inner: aio.File,
|
||||
|
||||
fn asFile(self: File) std.fs.File {
|
||||
return .{ .handle = self.inner.file.fd };
|
||||
}
|
||||
|
||||
pub fn init(file_: std.fs.File) !File {
|
||||
return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)) };
|
||||
}
|
||||
|
||||
pub fn fromFd(fd: std.fs.File.Handle) !File {
|
||||
return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.initFd(fd)) };
|
||||
}
|
||||
|
||||
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
|
||||
return init(try call(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags }));
|
||||
}
|
||||
|
||||
pub fn read(self: File, buf: []u8) !usize {
|
||||
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
|
||||
if (buf.len == 0) return 0;
|
||||
|
||||
return self.inner.read(.{ .slice = buf }) catch |err| switch (err) {
|
||||
// NOTE(Corentin): read shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File.
|
||||
error.EOF => 0,
|
||||
else => err,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn pread(self: File, buf: []u8, offset: u64) !usize {
|
||||
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
|
||||
if (buf.len == 0) return 0;
|
||||
|
||||
return self.inner.pread(.{ .slice = buf }, offset) catch |err| switch (err) {
|
||||
// NOTE(Corentin): pread shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File.
|
||||
error.EOF => 0,
|
||||
else => err,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn write(self: File, buf: []const u8) !usize {
|
||||
return self.inner.write(.{ .slice = buf });
|
||||
}
|
||||
|
||||
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
|
||||
return self.inner.pwrite(.{ .slice = buf }, offset);
|
||||
}
|
||||
|
||||
pub fn close(self: File) !void {
|
||||
return self.inner.close();
|
||||
}
|
||||
|
||||
pub fn reader(self: File) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn seekableStream(file: File) SeekableStream {
|
||||
return .{ .context = file };
|
||||
}
|
||||
|
||||
pub fn writer(self: File) Writer {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn stat(self: File) !std.fs.File.Stat {
|
||||
return try call(std.fs.File.stat, .{self.asFile()});
|
||||
}
|
||||
|
||||
pub fn seekBy(self: File, offset: i64) !void {
|
||||
try call(std.fs.File.seekBy, .{ self.asFile(), offset });
|
||||
}
|
||||
|
||||
pub fn seekTo(self: File, offset: u64) !void {
|
||||
try call(std.fs.File.seekTo, .{ self.asFile(), offset });
|
||||
}
|
||||
|
||||
pub fn getPos(self: File) !u64 {
|
||||
return try call(std.fs.File.getPos, .{self.asFile()});
|
||||
}
|
||||
|
||||
pub fn getEndPos(self: File) !u64 {
|
||||
return try call(std.fs.File.getEndPos, .{self.asFile()});
|
||||
}
|
||||
};
|
||||
|
||||
pub const Socket = struct {
|
||||
pub const TCP = struct {
|
||||
pub const Reader = std.io.GenericReader(TCP, FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read);
|
||||
pub const Writer = std.io.GenericWriter(TCP, FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write);
|
||||
|
||||
inner: aio.TCP,
|
||||
|
||||
pub fn init(addr: std.net.Address) !TCP {
|
||||
return .{ .inner = aio.TCP.init(AsyncThread.current.executor, try xev.TCP.init(addr)) };
|
||||
}
|
||||
|
||||
pub fn deinit(self: *TCP) void {
|
||||
self.inner.shutdown();
|
||||
}
|
||||
|
||||
pub fn connect(self: *TCP, addr: std.net.Address) !void {
|
||||
return self.inner.connect(addr);
|
||||
}
|
||||
|
||||
pub fn read(self: *TCP, buf: []u8) !usize {
|
||||
return self.inner.read(.{ .slice = buf });
|
||||
}
|
||||
|
||||
pub fn write(self: *TCP, buf: []const u8) !usize {
|
||||
return self.inner.write(.{ .slice = buf });
|
||||
}
|
||||
|
||||
pub fn close(self: *TCP) !void {
|
||||
defer self.* = undefined;
|
||||
return self.inner.close();
|
||||
}
|
||||
|
||||
pub fn reader(self: File) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn writer(self: File) Writer {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
|
||||
pub const UDP = struct {
|
||||
pub const Reader = std.io.GenericReader(UDP, FnSignature(UDP.read, null).ReturnErrorSet.?, UDP.read);
|
||||
pub const WriterContext = struct {
|
||||
file: UDP,
|
||||
addr: std.net.Address,
|
||||
};
|
||||
pub const Writer = std.io.GenericWriter(WriterContext, FnSignature(UDP.write, null).ReturnErrorSet.?, struct {
|
||||
fn call(self: WriterContext, buf: []const u8) !usize {
|
||||
return self.file.write(self.addr, buf);
|
||||
}
|
||||
}.call);
|
||||
|
||||
inner: aio.UDP,
|
||||
|
||||
pub fn init(addr: std.net.Address) !UDP {
|
||||
return .{ .inner = aio.UDP.init(AsyncThread.current.executor, try xev.UDP.init(addr)) };
|
||||
}
|
||||
|
||||
pub fn read(self: UDP, buf: []u8) !usize {
|
||||
return self.inner.read(.{ .slice = buf });
|
||||
}
|
||||
|
||||
pub fn write(self: UDP, addr: std.net.Address, buf: []const u8) !usize {
|
||||
return self.inner.write(addr, .{ .slice = buf });
|
||||
}
|
||||
|
||||
pub fn close(self: *UDP) !void {
|
||||
defer self.* = undefined;
|
||||
return self.inner.close();
|
||||
}
|
||||
|
||||
pub fn reader(self: File) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn writer(self: File, addr: std.net.Address) Writer {
|
||||
return .{
|
||||
.context = .{
|
||||
.file = self,
|
||||
.addr = addr,
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
pub const Mutex = struct {
|
||||
const VoidChannel = libcoro.Channel(void, .{ .capacity = 1 });
|
||||
|
||||
inner: VoidChannel,
|
||||
|
||||
pub fn init() Mutex {
|
||||
return .{ .inner = VoidChannel.init(&AsyncThread.current.executor.exec) };
|
||||
}
|
||||
|
||||
pub fn lock(self: *Mutex) !void {
|
||||
try self.inner.send({});
|
||||
}
|
||||
|
||||
pub fn unlock(self: *Mutex) void {
|
||||
_ = self.inner.recv();
|
||||
}
|
||||
};
|
||||
0
bazel/BUILD.bazel
Normal file
0
bazel/BUILD.bazel
Normal file
96
bazel/cc_import.bzl
Normal file
96
bazel/cc_import.bzl
Normal file
@ -0,0 +1,96 @@
|
||||
load("@bazel_tools//tools/build_defs/cc:cc_import.bzl", _cc_import = "cc_import")
|
||||
load(":patchelf.bzl", "patchelf")
|
||||
|
||||
def _cc_import_runfiles_impl(ctx):
|
||||
runfiles = ctx.runfiles(files = ctx.files.data)
|
||||
transitive_runfiles_list = []
|
||||
if ctx.attr.static_library:
|
||||
transitive_runfiles_list.append(ctx.attr.static_library[DefaultInfo].default_runfiles)
|
||||
if ctx.attr.pic_static_library:
|
||||
transitive_runfiles_list.append(ctx.attr.pic_static_library[DefaultInfo].default_runfiles)
|
||||
if ctx.attr.shared_library:
|
||||
transitive_runfiles_list.append(ctx.attr.shared_library[DefaultInfo].default_runfiles)
|
||||
if ctx.attr.interface_library:
|
||||
transitive_runfiles_list.append(ctx.attr.interface_library[DefaultInfo].default_runfiles)
|
||||
for dep in ctx.attr.deps:
|
||||
transitive_runfiles_list.append(dep[DefaultInfo].default_runfiles)
|
||||
|
||||
for maybe_runfiles in transitive_runfiles_list:
|
||||
if maybe_runfiles:
|
||||
runfiles = runfiles.merge(maybe_runfiles)
|
||||
|
||||
default_info = DefaultInfo(runfiles = runfiles)
|
||||
return [ctx.attr.src[CcInfo], default_info]
|
||||
|
||||
_cc_import_runfiles = rule(
|
||||
implementation = _cc_import_runfiles_impl,
|
||||
attrs = {
|
||||
"src": attr.label(providers = [CcInfo]),
|
||||
"static_library": attr.label(allow_single_file = [".a", ".lib"]),
|
||||
"pic_static_library": attr.label(allow_single_file = [".pic.a", ".pic.lib"]),
|
||||
"shared_library": attr.label(allow_single_file = True),
|
||||
"interface_library": attr.label(allow_single_file = [".ifso", ".tbd", ".lib", ".so", ".dylib"]),
|
||||
"data": attr.label_list(allow_files = True),
|
||||
"deps": attr.label_list(),
|
||||
},
|
||||
)
|
||||
|
||||
def cc_import(
|
||||
name,
|
||||
static_library = None,
|
||||
pic_static_library = None,
|
||||
shared_library = None,
|
||||
interface_library = None,
|
||||
data = None,
|
||||
deps = None,
|
||||
visibility = None,
|
||||
soname = None,
|
||||
add_needed = None,
|
||||
remove_needed = None,
|
||||
replace_needed = None,
|
||||
**kwargs):
|
||||
if shared_library and (soname or add_needed or remove_needed or replace_needed):
|
||||
patched_name = "{}_patchelf".format(name)
|
||||
patchelf(
|
||||
name = patched_name,
|
||||
shared_library = shared_library,
|
||||
soname = soname,
|
||||
add_needed = add_needed,
|
||||
remove_needed = remove_needed,
|
||||
replace_needed = replace_needed,
|
||||
)
|
||||
shared_library = ":" + patched_name
|
||||
if data:
|
||||
_cc_import(
|
||||
name = name + "_no_runfiles",
|
||||
static_library = static_library,
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
interface_library = interface_library,
|
||||
data = data,
|
||||
deps = deps,
|
||||
**kwargs
|
||||
)
|
||||
_cc_import_runfiles(
|
||||
name = name,
|
||||
src = ":{}_no_runfiles".format(name),
|
||||
static_library = static_library,
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
interface_library = interface_library,
|
||||
data = data,
|
||||
deps = deps,
|
||||
visibility = visibility,
|
||||
)
|
||||
else:
|
||||
_cc_import(
|
||||
name = name,
|
||||
static_library = static_library,
|
||||
pic_static_library = pic_static_library,
|
||||
shared_library = shared_library,
|
||||
interface_library = interface_library,
|
||||
data = data,
|
||||
deps = deps,
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
)
|
||||
45
bazel/http_deb_archive.bzl
Normal file
45
bazel/http_deb_archive.bzl
Normal file
@ -0,0 +1,45 @@
|
||||
load(
|
||||
"@bazel_tools//tools/build_defs/repo:utils.bzl",
|
||||
"get_auth",
|
||||
"patch",
|
||||
"workspace_and_buildfile",
|
||||
)
|
||||
|
||||
def _http_deb_archive_impl(rctx):
|
||||
if rctx.attr.build_file and rctx.attr.build_file_content:
|
||||
fail("Only one of build_file and build_file_content can be provided.")
|
||||
download_info = rctx.download_and_extract(
|
||||
url = rctx.attr.urls,
|
||||
output = "tmp",
|
||||
sha256 = rctx.attr.sha256,
|
||||
type = "deb",
|
||||
stripPrefix = "",
|
||||
canonical_id = " ".join(rctx.attr.urls),
|
||||
auth = get_auth(rctx, rctx.attr.urls),
|
||||
)
|
||||
|
||||
for ext in ["gz", "xz", "zst"]:
|
||||
data = "tmp/data.tar.{}".format(ext)
|
||||
if rctx.path(data).exists:
|
||||
rctx.extract(
|
||||
archive = data,
|
||||
output = "",
|
||||
stripPrefix = rctx.attr.strip_prefix,
|
||||
)
|
||||
rctx.delete("tmp")
|
||||
break
|
||||
workspace_and_buildfile(rctx)
|
||||
patch(rctx)
|
||||
|
||||
http_deb_archive = repository_rule(
|
||||
_http_deb_archive_impl,
|
||||
attrs = {
|
||||
"urls": attr.string_list(mandatory = True),
|
||||
"sha256": attr.string(mandatory = True),
|
||||
"strip_prefix": attr.string(),
|
||||
"build_file": attr.label(allow_single_file = True),
|
||||
"build_file_content": attr.string(),
|
||||
"workspace_file": attr.label(allow_single_file = True),
|
||||
"workspace_file_content": attr.string(),
|
||||
},
|
||||
)
|
||||
173
bazel/huggingface.bzl
Normal file
173
bazel/huggingface.bzl
Normal file
@ -0,0 +1,173 @@
|
||||
load(
|
||||
"@bazel_tools//tools/build_defs/repo:utils.bzl",
|
||||
"patch",
|
||||
"workspace_and_buildfile",
|
||||
)
|
||||
|
||||
TREE_URL_TEMPLATE = "https://huggingface.co/api/models/{model}/tree/{commit}/{path}"
|
||||
RAW_FILE_URL_REMPLATE = "https://huggingface.co/{model}/raw/{commit}/{path}"
|
||||
LFS_FILE_URL_TEMPLATE = "https://huggingface.co/{model}/resolve/{commit}/{path}"
|
||||
|
||||
def _glob(rctx, str, patterns):
|
||||
cmd = "\n".join([
|
||||
"""[[ "{str}" = {pattern} ]] && exit 0""".format(str = str, pattern = pattern)
|
||||
for pattern in patterns
|
||||
] + ["exit 1"])
|
||||
return rctx.execute(["bash", "-c", cmd]).return_code == 0
|
||||
|
||||
def _ls(rctx, headers, path):
|
||||
url = TREE_URL_TEMPLATE.format(
|
||||
model = rctx.attr.model,
|
||||
commit = rctx.attr.commit,
|
||||
path = path,
|
||||
)
|
||||
rctx.download(url, path + ".index.json", headers = headers)
|
||||
ret = json.decode(rctx.read(path + ".index.json"))
|
||||
rctx.delete(path + ".index.json")
|
||||
return ret
|
||||
|
||||
def _get_token_via_env(rctx):
|
||||
return rctx.getenv("HUGGINGFACE_TOKEN")
|
||||
|
||||
def _get_token_via_file(rctx):
|
||||
p = rctx.path(rctx.getenv("HOME") + "/.cache/huggingface/token")
|
||||
if p.exists:
|
||||
return rctx.read(p)
|
||||
|
||||
def _get_token_via_git_credentials(rctx):
|
||||
input = """\
|
||||
protocol=https
|
||||
host=huggingface.co
|
||||
|
||||
"""
|
||||
res = rctx.execute(["bash", "-c", "echo '{}' | git credential fill".format(input)])
|
||||
if res.return_code != 0:
|
||||
return None
|
||||
for line in res.stdout.split("\n"):
|
||||
if line.startswith("password="):
|
||||
return line[len("password="):]
|
||||
return None
|
||||
|
||||
def _get_token(rctx):
|
||||
t = _get_token_via_env(rctx) or \
|
||||
_get_token_via_file(rctx) or \
|
||||
_get_token_via_git_credentials(rctx)
|
||||
if t:
|
||||
return t.strip()
|
||||
|
||||
def _huggingface_repository_impl(rctx):
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
}
|
||||
|
||||
token = _get_token(rctx)
|
||||
if token:
|
||||
headers["Authorization"] = "Bearer " + token
|
||||
|
||||
includes = rctx.attr.includes
|
||||
excludes = rctx.attr.excludes
|
||||
stack = [""]
|
||||
downloads = []
|
||||
|
||||
for _ in range(9999999):
|
||||
if (not stack):
|
||||
break
|
||||
path = stack.pop()
|
||||
for entry in _ls(rctx, headers, path):
|
||||
if entry["type"] == "directory":
|
||||
stack.append(entry["path"])
|
||||
elif entry["type"] == "file":
|
||||
if (excludes and _glob(rctx, entry["path"], excludes)):
|
||||
continue
|
||||
if (not includes or _glob(rctx, entry["path"], includes)):
|
||||
tpl = RAW_FILE_URL_REMPLATE
|
||||
if ("lfs" in entry):
|
||||
tpl = LFS_FILE_URL_TEMPLATE
|
||||
url = tpl.format(
|
||||
model = rctx.attr.model,
|
||||
commit = rctx.attr.commit,
|
||||
path = entry["path"],
|
||||
)
|
||||
downloads.append(rctx.download(
|
||||
url = url,
|
||||
output = entry["path"],
|
||||
canonical_id = entry["oid"],
|
||||
headers = headers,
|
||||
block = False,
|
||||
))
|
||||
|
||||
for download in downloads:
|
||||
download.wait()
|
||||
|
||||
workspace_and_buildfile(rctx)
|
||||
patch(rctx)
|
||||
|
||||
huggingface_repository = repository_rule(
|
||||
implementation = _huggingface_repository_impl,
|
||||
attrs = {
|
||||
"model": attr.string(mandatory = True),
|
||||
"commit": attr.string(mandatory = True),
|
||||
"includes": attr.string_list(default = []),
|
||||
"excludes": attr.string_list(default = []),
|
||||
"patches": attr.label_list(),
|
||||
"patch_tool": attr.string(default = ""),
|
||||
"patch_args": attr.string_list(default = ["-p0"]),
|
||||
"patch_cmds": attr.string_list(default = []),
|
||||
"patch_cmds_win": attr.string_list(default = []),
|
||||
"build_file": attr.label(allow_single_file = True),
|
||||
"build_file_content": attr.string(),
|
||||
"workspace_file": attr.label(allow_single_file = True),
|
||||
"workspace_file_content": attr.string(),
|
||||
},
|
||||
)
|
||||
|
||||
def _huggingface_impl(mctx):
|
||||
for mod in mctx.modules:
|
||||
for model in mod.tags.model:
|
||||
huggingface_repository(
|
||||
name = model.name,
|
||||
model = model.model,
|
||||
commit = model.commit,
|
||||
includes = model.includes,
|
||||
excludes = model.excludes,
|
||||
patches = model.patches,
|
||||
patch_tool = model.patch_tool,
|
||||
patch_args = model.patch_args,
|
||||
patch_cmds = model.patch_cmds,
|
||||
patch_cmds_win = model.patch_cmds_win,
|
||||
build_file = model.build_file,
|
||||
build_file_content = model.build_file_content,
|
||||
workspace_file = model.workspace_file,
|
||||
workspace_file_content = model.workspace_file_content,
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
huggingface = module_extension(
|
||||
implementation = _huggingface_impl,
|
||||
tag_classes = {
|
||||
"model": tag_class(
|
||||
attrs = {
|
||||
"name": attr.string(mandatory = True),
|
||||
"model": attr.string(mandatory = True),
|
||||
"commit": attr.string(mandatory = True),
|
||||
"includes": attr.string_list(default = []),
|
||||
"excludes": attr.string_list(default = []),
|
||||
"patches": attr.label_list(),
|
||||
"patch_tool": attr.string(default = ""),
|
||||
"patch_args": attr.string_list(default = ["-p0"]),
|
||||
"patch_cmds": attr.string_list(default = []),
|
||||
"patch_cmds_win": attr.string_list(default = []),
|
||||
"build_file": attr.label(allow_single_file = True),
|
||||
"build_file_content": attr.string(),
|
||||
"workspace_file": attr.label(allow_single_file = True),
|
||||
"workspace_file_content": attr.string(),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
58
bazel/patchelf.bzl
Normal file
58
bazel/patchelf.bzl
Normal file
@ -0,0 +1,58 @@
|
||||
def _render_kv(e):
|
||||
return e
|
||||
|
||||
def _patchelf_impl(ctx):
|
||||
output_name = ctx.file.shared_library.basename
|
||||
if ctx.attr.soname:
|
||||
output_name = ctx.attr.soname
|
||||
output = ctx.actions.declare_file("{}/{}".format(ctx.attr.name, output_name))
|
||||
|
||||
commands = [
|
||||
"set -e",
|
||||
'cp -f "$2" "$3"',
|
||||
'chmod +w "$3"',
|
||||
]
|
||||
|
||||
if ctx.attr.soname:
|
||||
commands.append('"$1" --set-soname "{}" "$3"'.format(ctx.attr.soname))
|
||||
if ctx.attr.remove_needed:
|
||||
for v in ctx.attr.remove_needed:
|
||||
commands.append('"$1" --remove-needed "{}" "$3"'.format(v))
|
||||
if ctx.attr.add_needed:
|
||||
for v in ctx.attr.add_needed:
|
||||
commands.append('"$1" --add-needed "{}" "$3"'.format(v))
|
||||
|
||||
if ctx.attr.replace_needed:
|
||||
for k, v in ctx.attr.replace_needed.items():
|
||||
commands.append('"$1" --replace-needed "{}" "{}" "$3"'.format(k, v))
|
||||
|
||||
ctx.actions.run_shell(
|
||||
inputs = [ctx.file.shared_library],
|
||||
outputs = [output],
|
||||
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
|
||||
command = "\n".join(commands),
|
||||
tools = [ctx.executable._patchelf],
|
||||
)
|
||||
|
||||
return [
|
||||
DefaultInfo(
|
||||
files = depset([output]),
|
||||
),
|
||||
]
|
||||
|
||||
patchelf = rule(
|
||||
implementation = _patchelf_impl,
|
||||
attrs = {
|
||||
"shared_library": attr.label(allow_single_file = True, mandatory = True),
|
||||
"soname": attr.string(),
|
||||
"add_needed": attr.string_list(),
|
||||
"remove_needed": attr.string_list(),
|
||||
"replace_needed": attr.string_dict(),
|
||||
"_patchelf": attr.label(
|
||||
default = "@patchelf",
|
||||
allow_single_file = True,
|
||||
executable = True,
|
||||
cfg = "exec",
|
||||
),
|
||||
},
|
||||
)
|
||||
39
bazel/zig.bzl
Normal file
39
bazel/zig.bzl
Normal file
@ -0,0 +1,39 @@
|
||||
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary")
|
||||
|
||||
def zig_cc_binary(name, args = None, env = None, data = [], deps = [], visibility = None, **kwargs):
|
||||
zig_binary(
|
||||
name = "{}_lib".format(name),
|
||||
kind = BINARY_KIND.static_lib,
|
||||
deps = deps + [
|
||||
"@rules_zig//zig/lib:libc",
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
native.cc_binary(
|
||||
name = name,
|
||||
args = args,
|
||||
env = env,
|
||||
data = data,
|
||||
deps = [":{}_lib".format(name)],
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def zig_cc_test(name, env = None, data = [], deps = [], test_runner = None, visibility = None, **kwargs):
|
||||
zig_binary(
|
||||
name = "{}_test_lib".format(name),
|
||||
kind = BINARY_KIND.test_lib,
|
||||
test_runner = test_runner,
|
||||
data = data,
|
||||
deps = deps + [
|
||||
"@rules_zig//zig/lib:libc",
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
native.cc_test(
|
||||
name = name,
|
||||
env = env,
|
||||
data = data,
|
||||
deps = [":{}_test_lib".format(name)],
|
||||
visibility = visibility,
|
||||
linkstatic = True,
|
||||
)
|
||||
1090
bazel/zig_index.json
Normal file
1090
bazel/zig_index.json
Normal file
File diff suppressed because it is too large
Load Diff
109
bazel/zig_proto_library.bzl
Normal file
109
bazel/zig_proto_library.bzl
Normal file
@ -0,0 +1,109 @@
|
||||
"""Starlark implementation of zig_proto_library"""
|
||||
|
||||
load("@rules_proto//proto:defs.bzl", "proto_common")
|
||||
load(
|
||||
"@rules_zig//zig/private/providers:zig_module_info.bzl",
|
||||
"ZigModuleInfo",
|
||||
"zig_module_info",
|
||||
)
|
||||
|
||||
def _zig_proto_library_impl(ctx):
|
||||
if len(ctx.attr.deps) != 1:
|
||||
fail("zig_proto_library '{}' requires exactly one 'deps'".format(ctx.label.name))
|
||||
|
||||
(dep,) = ctx.attr.deps
|
||||
|
||||
# the aspect already generated for us the Zig module
|
||||
# we just change the import name to make it match what the user chose.
|
||||
module = dep[ZigModuleInfo]
|
||||
import_name = ctx.attr.import_name or ctx.label.name
|
||||
|
||||
if import_name:
|
||||
keys = ["canonical_name", "cdeps", "copts", "deps", "extra_srcs", "linkopts", "main", "srcs"]
|
||||
args = {k: getattr(module, k) for k in keys}
|
||||
args["name"] = import_name
|
||||
module = zig_module_info(**args)
|
||||
return [module]
|
||||
|
||||
def zig_proto_library_aspect_impl(target, ctx):
|
||||
"""
|
||||
For each `.proto` in the given target dependencies,
|
||||
generate a `.pb.zig` file, and a `zig_module` to import it.
|
||||
"""
|
||||
toolchain = ctx.attr._zig_proto_toolchain[proto_common.ProtoLangToolchainInfo]
|
||||
proto_info = target[ProtoInfo]
|
||||
|
||||
# assert len(proto_info.direct_sources) == 1, "Can only compile .proto files one by one"
|
||||
(proto_src,) = proto_info.direct_sources
|
||||
pb_zig_name = proto_src.basename[:-len(".proto")] + ".pb.zig"
|
||||
zig_src = ctx.actions.declare_file(pb_zig_name, sibling = proto_src)
|
||||
|
||||
ctx.actions.run
|
||||
proto_common.compile(
|
||||
ctx.actions,
|
||||
proto_info = proto_info,
|
||||
proto_lang_toolchain_info = toolchain,
|
||||
generated_files = [zig_src],
|
||||
)
|
||||
|
||||
zig_proto_modules = [p[ZigModuleInfo] for p in ctx.rule.attr.deps]
|
||||
import_name = get_import_name(target, proto_src)
|
||||
|
||||
module = zig_module_info(
|
||||
name = import_name,
|
||||
canonical_name = str(target.label),
|
||||
main = zig_src,
|
||||
srcs = [],
|
||||
extra_srcs = [],
|
||||
copts = [],
|
||||
linkopts = [],
|
||||
deps = [toolchain.runtime[ZigModuleInfo]] + zig_proto_modules,
|
||||
cdeps = [],
|
||||
)
|
||||
return [module]
|
||||
|
||||
def get_import_name(target, proto_src):
|
||||
"""
|
||||
When the Zig protoc plugin is generating .pb.zig files,
|
||||
it generates import names based on the path received from protoc.
|
||||
We need to create Zig modules with the same name.
|
||||
"""
|
||||
name = str(target.label)
|
||||
|
||||
# special handling of builtin types
|
||||
if "com_google_protobuf//:" in name:
|
||||
name = "google_protobuf_" + proto_src.basename
|
||||
else:
|
||||
name = name.rsplit("//")[-1]
|
||||
name = name.rsplit("//")[-1]
|
||||
return name.replace(".", "_").replace(":", "_").replace("/", "_")
|
||||
|
||||
zig_proto_library_aspect = aspect(
|
||||
attrs = {
|
||||
"_zig_proto_toolchain": attr.label(
|
||||
default = "@zig-protobuf//:zig_toolchain",
|
||||
providers = [proto_common.ProtoLangToolchainInfo],
|
||||
),
|
||||
},
|
||||
implementation = zig_proto_library_aspect_impl,
|
||||
provides = [ZigModuleInfo],
|
||||
attr_aspects = ["deps"],
|
||||
)
|
||||
|
||||
zig_proto_library = rule(
|
||||
doc = """
|
||||
Converts a single `proto_library` into a zig module.
|
||||
""",
|
||||
implementation = _zig_proto_library_impl,
|
||||
attrs = {
|
||||
"deps": attr.label_list(
|
||||
aspects = [zig_proto_library_aspect],
|
||||
providers = [ProtoInfo],
|
||||
),
|
||||
"import_name": attr.string(
|
||||
doc = "The import name of the Zig module.",
|
||||
default = "",
|
||||
),
|
||||
},
|
||||
provides = [ZigModuleInfo],
|
||||
)
|
||||
39
mlir/BUILD.bazel
Normal file
39
mlir/BUILD.bazel
Normal file
@ -0,0 +1,39 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
load("//bazel:zig.bzl", "zig_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "mlirx",
|
||||
srcs = ["mlirx.cc"],
|
||||
hdrs = ["mlirx.h"],
|
||||
includes = ["."],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c",
|
||||
hdrs = ["c.h"],
|
||||
visibility = ["//mlir:__subpackages__"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIArith",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:CAPIMath",
|
||||
"@llvm-project//mlir:CAPITransforms",
|
||||
],
|
||||
)
|
||||
|
||||
zig_library(
|
||||
name = "mlir",
|
||||
main = "mlir.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c",
|
||||
":mlirx",
|
||||
],
|
||||
)
|
||||
|
||||
zig_cc_test(
|
||||
name = "test",
|
||||
deps = [":mlir"],
|
||||
)
|
||||
8
mlir/c.h
Normal file
8
mlir/c.h
Normal file
@ -0,0 +1,8 @@
|
||||
#include <mlir-c/BuiltinAttributes.h>
|
||||
#include <mlir-c/BuiltinTypes.h>
|
||||
#include <mlir-c/Dialect/Arith.h>
|
||||
#include <mlir-c/Dialect/Func.h>
|
||||
#include <mlir-c/Dialect/Math.h>
|
||||
#include <mlir-c/IR.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
41
mlir/dialects/BUILD.bazel
Normal file
41
mlir/dialects/BUILD.bazel
Normal file
@ -0,0 +1,41 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
load("//bazel:zig.bzl", "zig_cc_test")
|
||||
|
||||
zig_library(
|
||||
name = "dialects",
|
||||
srcs = [
|
||||
"arith.zig",
|
||||
"func.zig",
|
||||
"math.zig",
|
||||
"tensor.zig",
|
||||
],
|
||||
import_name = "mlir/dialects",
|
||||
main = "dialects.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":stablehlo",
|
||||
"//mlir",
|
||||
],
|
||||
)
|
||||
|
||||
zig_cc_test(
|
||||
name = "test",
|
||||
deps = [":dialects"],
|
||||
)
|
||||
|
||||
zig_library(
|
||||
name = "stablehlo",
|
||||
import_name = "mlir/dialects/stablehlo",
|
||||
main = "stablehlo.zig",
|
||||
visibility = ["//mlir/dialects:__subpackages__"],
|
||||
deps = [
|
||||
"//mlir",
|
||||
"//mlir:c",
|
||||
"@stablehlo//:stablehlo_capi",
|
||||
],
|
||||
)
|
||||
|
||||
zig_cc_test(
|
||||
name = "stablehlo_test",
|
||||
deps = [":stablehlo"],
|
||||
)
|
||||
109
mlir/dialects/arith.zig
Normal file
109
mlir/dialects/arith.zig
Normal file
@ -0,0 +1,109 @@
|
||||
const std = @import("std");
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn constant(ctx: mlir.Context, value: mlir.Attribute, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "arith.constant", .{
|
||||
.attributes = &.{
|
||||
.{ "value", value },
|
||||
},
|
||||
.result_type_inference = true,
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
fn binary_fn(comptime op_name: [:0]const u8) fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation {
|
||||
return struct {
|
||||
pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, op_name, .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
}.call;
|
||||
}
|
||||
|
||||
fn cast_fn(comptime op_name: [:0]const u8) fn (mlir.Context, mlir.Value, mlir.Type, mlir.Location) mlir.Operation {
|
||||
return struct {
|
||||
pub fn call(ctx: mlir.Context, value: mlir.Value, new_type: mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, op_name, .{
|
||||
.operands = &.{value},
|
||||
.results = &.{new_type},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
}.call;
|
||||
}
|
||||
|
||||
pub const addi = binary_fn("arith.addi");
|
||||
pub const addf = binary_fn("arith.addf");
|
||||
pub const subi = binary_fn("arith.subi");
|
||||
pub const subf = binary_fn("arith.subf");
|
||||
pub const muli = binary_fn("arith.muli");
|
||||
pub const mulf = binary_fn("arith.mulf");
|
||||
pub const divsi = binary_fn("arith.divsi");
|
||||
pub const divui = binary_fn("arith.divui");
|
||||
pub const divf = binary_fn("arith.divf");
|
||||
pub const extsi = cast_fn("arith.extsi");
|
||||
pub const extui = cast_fn("arith.extui");
|
||||
pub const extf = cast_fn("arith.extf");
|
||||
pub const trunci = cast_fn("arith.trunci");
|
||||
pub const truncf = cast_fn("arith.truncf");
|
||||
pub const fptosi = cast_fn("arith.fptosi");
|
||||
pub const fptoui = cast_fn("arith.fptoui");
|
||||
pub const sitofp = cast_fn("arith.sitofp");
|
||||
pub const uitofp = cast_fn("arith.uitofp");
|
||||
|
||||
pub const CmpIPredicate = enum {
|
||||
eq,
|
||||
ne,
|
||||
slt,
|
||||
sle,
|
||||
sgt,
|
||||
sge,
|
||||
ult,
|
||||
ule,
|
||||
ugt,
|
||||
uge,
|
||||
};
|
||||
|
||||
pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "arith.cmpi", .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
pub const CmpFPredicate = enum {
|
||||
false,
|
||||
oeq,
|
||||
ogt,
|
||||
oge,
|
||||
olt,
|
||||
ole,
|
||||
one,
|
||||
ord,
|
||||
ueq,
|
||||
ugt,
|
||||
uge,
|
||||
ult,
|
||||
ule,
|
||||
une,
|
||||
uno,
|
||||
true,
|
||||
};
|
||||
|
||||
pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "arith.cmpf", .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.result_type_inference = true,
|
||||
.attributes = &.{
|
||||
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute).? },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
11
mlir/dialects/dialects.zig
Normal file
11
mlir/dialects/dialects.zig
Normal file
@ -0,0 +1,11 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const arith = @import("arith.zig");
|
||||
pub const func = @import("func.zig");
|
||||
pub const math = @import("math.zig");
|
||||
pub const tensor = @import("tensor.zig");
|
||||
pub const stablehlo = @import("mlir/dialects/stablehlo");
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
45
mlir/dialects/func.zig
Normal file
45
mlir/dialects/func.zig
Normal file
@ -0,0 +1,45 @@
|
||||
const std = @import("std");
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn func(
|
||||
ctx: mlir.Context,
|
||||
args: struct {
|
||||
sym_name: [:0]const u8,
|
||||
args: []const mlir.Type,
|
||||
arg_attrs: []const mlir.Attribute = &.{},
|
||||
results: []const mlir.Type,
|
||||
block: mlir.Block,
|
||||
location: mlir.Location,
|
||||
},
|
||||
) mlir.Operation {
|
||||
const AttrTuple = struct { [:0]const u8, mlir.Attribute };
|
||||
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 3){};
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
|
||||
if (args.arg_attrs.len > 0) {
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
|
||||
}
|
||||
return mlir.Operation.make(ctx, "func.func", .{
|
||||
.blocks = &.{args.block},
|
||||
.attributes = attrs_tuple_buffer.constSlice(),
|
||||
.location = args.location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn call(ctx: mlir.Context, name: [:0]const u8, values: []const mlir.Value, results: []const mlir.Type, loc: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "func.call", .{
|
||||
.variadic_operands = &.{values},
|
||||
.results = results,
|
||||
.verify = true,
|
||||
.attributes = &.{.{ "callee", mlir.FlatSymbolRefAttribute.init(ctx, name).as(mlir.Attribute).? }},
|
||||
.location = loc,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn return_(ctx: mlir.Context, values: []const mlir.Value, loc: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "func.return", .{
|
||||
.operands = values,
|
||||
.verify = false,
|
||||
.location = loc,
|
||||
});
|
||||
}
|
||||
35
mlir/dialects/math.zig
Normal file
35
mlir/dialects/math.zig
Normal file
@ -0,0 +1,35 @@
|
||||
const std = @import("std");
|
||||
const mlir = @import("mlir");
|
||||
|
||||
const namespace = "math";
|
||||
|
||||
fn unary_fn(comptime op_name: [:0]const u8) type {
|
||||
return struct {
|
||||
pub fn call(ctx: mlir.Context, value: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
|
||||
.operands = &.{value},
|
||||
.results = &.{},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn binary_fn(comptime op_name: [:0]const u8) type {
|
||||
return struct {
|
||||
pub fn call(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, namespace ++ "." ++ op_name, .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
.results = &.{},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub const ipowi = binary_fn("ipowi").call;
|
||||
pub const fpowi = binary_fn("fpowi").call;
|
||||
pub const tanh = unary_fn("tanh").call;
|
||||
pub const sqrt = unary_fn("sqrt").call;
|
||||
pub const exp = unary_fn("exp").call;
|
||||
pub const log = unary_fn("log").call;
|
||||
1330
mlir/dialects/stablehlo.zig
Normal file
1330
mlir/dialects/stablehlo.zig
Normal file
File diff suppressed because it is too large
Load Diff
36
mlir/dialects/tensor.zig
Normal file
36
mlir/dialects/tensor.zig
Normal file
@ -0,0 +1,36 @@
|
||||
const std = @import("std");
|
||||
const mlir = @import("mlir");
|
||||
|
||||
pub fn empty(ctx: mlir.Context, args: struct {
|
||||
result: mlir.Type,
|
||||
location: mlir.Location,
|
||||
}) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "tensor.empty", .{
|
||||
.results = &.{args.result},
|
||||
.location = args.location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn splat(ctx: mlir.Context, args: struct {
|
||||
value: mlir.Value,
|
||||
result: mlir.Type,
|
||||
location: mlir.Location,
|
||||
}) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "tensor.splat", .{
|
||||
.operands = &.{args.value},
|
||||
.results = &.{args.result},
|
||||
.location = args.location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn cast(ctx: mlir.Context, args: struct {
|
||||
source: mlir.Value,
|
||||
dest: mlir.Type,
|
||||
location: mlir.Location,
|
||||
}) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "tensor.cast", .{
|
||||
.operands = &.{args.source},
|
||||
.results = &.{args.dest},
|
||||
.location = args.location,
|
||||
});
|
||||
}
|
||||
1933
mlir/mlir.zig
Normal file
1933
mlir/mlir.zig
Normal file
File diff suppressed because it is too large
Load Diff
27
mlir/mlirx.cc
Normal file
27
mlir/mlirx.cc
Normal file
@ -0,0 +1,27 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlirx.h"
|
||||
|
||||
namespace mlirx {
|
||||
|
||||
static mlir::Attribute ArrayToElements(mlir::Attribute attr) {
|
||||
if (auto array = attr.dyn_cast<mlir::DenseI64ArrayAttr>()) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(array.size(), array.getElementType()),
|
||||
array.asArrayRef());
|
||||
}
|
||||
if (auto array = attr.dyn_cast<mlir::DenseBoolArrayAttr>()) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(array.size(), array.getElementType()),
|
||||
array.asArrayRef());
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
MlirAttribute mlirDenseArrayToElements(MlirAttribute attr) {
|
||||
return wrap(mlirx::ArrayToElements(unwrap(attr)));
|
||||
}
|
||||
16
mlir/mlirx.h
Normal file
16
mlir/mlirx.h
Normal file
@ -0,0 +1,16 @@
|
||||
#ifndef MLIRX_CC_H
|
||||
#define MLIRX_CC_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseArrayToElements(MlirAttribute attr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIRX_CC_H
|
||||
30
pjrt/BUILD.bazel
Normal file
30
pjrt/BUILD.bazel
Normal file
@ -0,0 +1,30 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
|
||||
|
||||
cc_library(
|
||||
name = "dlfcn",
|
||||
hdrs = ["dlfcn.h"],
|
||||
)
|
||||
|
||||
zig_library(
|
||||
name = "pjrt",
|
||||
srcs = ["profiler.zig"],
|
||||
main = "pjrt.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":profiler_options_proto",
|
||||
"//runtimes",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
||||
] + select({
|
||||
"@platforms//os:linux": [":dlfcn"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
zig_proto_library(
|
||||
name = "profiler_options_proto",
|
||||
import_name = "//tsl:profiler_options_proto",
|
||||
deps = ["@tsl//tsl/profiler/protobuf:profiler_options_proto"],
|
||||
)
|
||||
1
pjrt/dlfcn.h
Normal file
1
pjrt/dlfcn.h
Normal file
@ -0,0 +1 @@
|
||||
#include <dlfcn.h>
|
||||
886
pjrt/pjrt.zig
Normal file
886
pjrt/pjrt.zig
Normal file
@ -0,0 +1,886 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
|
||||
const c = @import("c");
|
||||
|
||||
const log = std.log.scoped(.pjrt);
|
||||
|
||||
pub const Profiler = @import("profiler.zig").Profiler;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
// We could calculate it like PJRT does, but it turns out that some of those
|
||||
// were wrong in PJRT itself [1], which gets propagated to binary plugins. In
|
||||
// order to mirror that, we just the value as computed by PJRT itself, through
|
||||
// comptime reflection. We could make the argument to remove that one day since
|
||||
// [1] has been fixed. The problem is that this problem could happen again in
|
||||
// as the way PJRT does it is not very robust.
|
||||
//
|
||||
// 1. https://github.com/openxla/xla/issues/10032
|
||||
fn pjrtStructSize(comptime T: type) usize {
|
||||
// unsafe on purpose, we want this to fail if that ever changes
|
||||
const typedef_name = comptime blk: {
|
||||
const needle = ".struct_";
|
||||
const idx = std.mem.indexOf(u8, @typeName(T), needle).?;
|
||||
break :blk @typeName(T)[idx + needle.len ..];
|
||||
};
|
||||
return @field(c, typedef_name ++ "_STRUCT_SIZE");
|
||||
}
|
||||
|
||||
inline fn pjrtStruct(v: anytype) @TypeOf(v) {
|
||||
var ret = v;
|
||||
ret.struct_size = pjrtStructSize(@TypeOf(v));
|
||||
return ret;
|
||||
}
|
||||
|
||||
pub const ApiError = error{
|
||||
Cancelled,
|
||||
Unknown,
|
||||
InvalidArgument,
|
||||
DeadlineExceeded,
|
||||
NotFound,
|
||||
AlreadyExists,
|
||||
PermissionDenied,
|
||||
ResourceExhausted,
|
||||
FailedPrecondition,
|
||||
Aborted,
|
||||
OutOfRange,
|
||||
Unimplemented,
|
||||
Internal,
|
||||
Unavailable,
|
||||
DataLoss,
|
||||
Unauthenticated,
|
||||
};
|
||||
|
||||
fn InnerMixin(comptime innerT: type) type {
|
||||
return struct {
|
||||
fn inner(self: anytype) *innerT {
|
||||
return @ptrCast(@constCast(@alignCast(self)));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub const Api = struct {
|
||||
pub const Version = struct {
|
||||
major: i64,
|
||||
minor: i64,
|
||||
};
|
||||
|
||||
const Funcs = std.meta.FieldEnum(c.PJRT_Api);
|
||||
|
||||
inner: c.PJRT_Api,
|
||||
|
||||
pub fn loadFrom(library: []const u8) !*const Api {
|
||||
var lib: std.DynLib = switch (builtin.os.tag) {
|
||||
.linux => blk: {
|
||||
const library_c = try std.posix.toPosixPath(library);
|
||||
break :blk .{
|
||||
.inner = .{
|
||||
.handle = c.dlopen(&library_c, c.RTLD_LAZY | c.RTLD_LOCAL | c.RTLD_NODELETE) orelse {
|
||||
return error.FileNotFound;
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
else => try std.DynLib.open(library),
|
||||
};
|
||||
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse {
|
||||
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});
|
||||
};
|
||||
|
||||
const api = DynGetPjrtApi();
|
||||
log.info("Loaded library: {s}", .{library});
|
||||
_ = api.call(.PJRT_Plugin_Initialize, .{}) catch unreachable;
|
||||
|
||||
return api;
|
||||
}
|
||||
|
||||
fn CallFnArgType(comptime func: Funcs) type {
|
||||
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func));
|
||||
const fn_ptr = @typeInfo(fti.Optional.child);
|
||||
const fn_type_info = @typeInfo(fn_ptr.Pointer.child);
|
||||
const arg_array_type_info = @typeInfo(fn_type_info.Fn.params[0].type.?);
|
||||
return arg_array_type_info.Pointer.child;
|
||||
}
|
||||
|
||||
inline fn call(self: *const Api, comptime method: Funcs, arg: CallFnArgType(method)) ApiError!@TypeOf(arg) {
|
||||
var ret = pjrtStruct(arg);
|
||||
const fn_ptr = @field(&self.inner, @tagName(method)).?;
|
||||
const result = fn_ptr(&ret);
|
||||
if (@TypeOf(result) == void) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (result) |pjrt_c_error| {
|
||||
const pjrt_error: *Error = @ptrCast(pjrt_c_error);
|
||||
log.err("[{s}] {s}", .{ @tagName(method), pjrt_error.getMessage(self) });
|
||||
return pjrt_error.getCode(self).toApiError();
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
pub fn lookupExtension(self: *const Api, comptime ExtensionT: type, ext_id: c_int) ?*const ExtensionT {
|
||||
var cur: [*c]const c.PJRT_Extension_Base = @alignCast(@ptrCast(self.inner.extension_start));
|
||||
while (cur != null) : (cur = cur.*.next) {
|
||||
if (cur.*.type == ext_id) {
|
||||
return @alignCast(@ptrCast(cur));
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
pub inline fn version(self: *const Api) Version {
|
||||
return .{
|
||||
.major = @intCast(self.inner.pjrt_api_version.major_version),
|
||||
.minor = @intCast(self.inner.pjrt_api_version.minor_version),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const ErrorCode = enum(c.PJRT_Error_Code) {
|
||||
cancelled = c.PJRT_Error_Code_CANCELLED,
|
||||
unknown = c.PJRT_Error_Code_UNKNOWN,
|
||||
invalid_argument = c.PJRT_Error_Code_INVALID_ARGUMENT,
|
||||
deadline_exceeded = c.PJRT_Error_Code_DEADLINE_EXCEEDED,
|
||||
not_found = c.PJRT_Error_Code_NOT_FOUND,
|
||||
already_exists = c.PJRT_Error_Code_ALREADY_EXISTS,
|
||||
permission_denied = c.PJRT_Error_Code_PERMISSION_DENIED,
|
||||
resource_exhausted = c.PJRT_Error_Code_RESOURCE_EXHAUSTED,
|
||||
failed_precondition = c.PJRT_Error_Code_FAILED_PRECONDITION,
|
||||
aborted = c.PJRT_Error_Code_ABORTED,
|
||||
out_of_range = c.PJRT_Error_Code_OUT_OF_RANGE,
|
||||
unimplemented = c.PJRT_Error_Code_UNIMPLEMENTED,
|
||||
internal = c.PJRT_Error_Code_INTERNAL,
|
||||
unavailable = c.PJRT_Error_Code_UNAVAILABLE,
|
||||
data_loss = c.PJRT_Error_Code_DATA_LOSS,
|
||||
unauthenticated = c.PJRT_Error_Code_UNAUTHENTICATED,
|
||||
|
||||
pub fn toApiError(code: ErrorCode) ApiError {
|
||||
return switch (code) {
|
||||
.cancelled => error.Cancelled,
|
||||
.unknown => error.Unknown,
|
||||
.invalid_argument => error.InvalidArgument,
|
||||
.deadline_exceeded => error.DeadlineExceeded,
|
||||
.not_found => error.NotFound,
|
||||
.already_exists => error.AlreadyExists,
|
||||
.permission_denied => error.PermissionDenied,
|
||||
.resource_exhausted => error.ResourceExhausted,
|
||||
.failed_precondition => error.FailedPrecondition,
|
||||
.aborted => error.Aborted,
|
||||
.out_of_range => error.OutOfRange,
|
||||
.unimplemented => error.Unimplemented,
|
||||
.internal => error.Internal,
|
||||
.unavailable => error.Unavailable,
|
||||
.data_loss => error.DataLoss,
|
||||
.unauthenticated => error.Unauthenticated,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const Error = opaque {
|
||||
pub fn deinit(self: *Error, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Error_Destroy, .{
|
||||
.@"error" = @ptrCast(self),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn getCode(self: *Error, api: *const Api) ErrorCode {
|
||||
const ret = api.call(.PJRT_Error_GetCode, .{
|
||||
.@"error" = @ptrCast(self),
|
||||
}) catch unreachable;
|
||||
return @enumFromInt(ret.code);
|
||||
}
|
||||
|
||||
pub fn getMessage(self: *Error, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_Error_Message, .{
|
||||
.@"error" = @ptrCast(self),
|
||||
}) catch unreachable;
|
||||
return ret.message[0..ret.message_size];
|
||||
}
|
||||
};
|
||||
|
||||
pub const ClientInitError = error{LoadingFailed} || ApiError;
|
||||
|
||||
pub const Client = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Client).inner;
|
||||
|
||||
pub const ProgramFormat = enum {
|
||||
hlo,
|
||||
mlir,
|
||||
};
|
||||
|
||||
pub fn init(api: *const Api, create_options: []const NamedValue) ClientInitError!*Client {
|
||||
// log.info("Loaded PJRT runtime plugin: {s}", .{api.Platform});
|
||||
const ret = try api.call(.PJRT_Client_Create, .{
|
||||
.create_options = @ptrCast(create_options.ptr),
|
||||
.num_options = create_options.len,
|
||||
.kv_get_callback = null,
|
||||
.kv_put_callback = null,
|
||||
.kv_put_user_arg = null,
|
||||
.kv_get_user_arg = null,
|
||||
});
|
||||
return @ptrCast(ret.client.?);
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Client, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Client_Destroy, .{
|
||||
.client = self.inner(),
|
||||
}) catch {};
|
||||
}
|
||||
|
||||
pub fn getPlatformName(self: *const Client, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_Client_PlatformName, .{
|
||||
.client = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.platform_name[0..ret.platform_name_size];
|
||||
}
|
||||
|
||||
pub fn getDevices(self: *const Client, api: *const Api) []const *Device {
|
||||
const ret = api.call(.PJRT_Client_Devices, .{
|
||||
.client = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.devices[0..ret.num_devices]);
|
||||
}
|
||||
|
||||
pub fn getAddressableDevices(self: *const Client, api: *const Api) []const *Device {
|
||||
const ret = api.call(.PJRT_Client_AddressableDevices, .{
|
||||
.client = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.addressable_devices[0..ret.num_addressable_devices]);
|
||||
}
|
||||
|
||||
pub const CompileArgs = struct {
|
||||
bytecode: []const u8,
|
||||
bytecode_format: ProgramFormat,
|
||||
compile_options_pb: []const u8,
|
||||
};
|
||||
|
||||
pub fn compile(self: *const Client, api: *const Api, args: CompileArgs) ApiError!*LoadedExecutable {
|
||||
const bytecode_format_ = @tagName(args.bytecode_format);
|
||||
const ret = try api.call(.PJRT_Client_Compile, .{
|
||||
.program = &pjrtStruct(c.PJRT_Program{
|
||||
.code = @ptrCast(@constCast(args.bytecode.ptr)),
|
||||
.code_size = args.bytecode.len,
|
||||
.format = @ptrCast(@constCast(bytecode_format_.ptr)),
|
||||
.format_size = bytecode_format_.len,
|
||||
}),
|
||||
.compile_options = @ptrCast(@constCast(args.compile_options_pb.ptr)),
|
||||
.compile_options_size = args.compile_options_pb.len,
|
||||
.client = self.inner(),
|
||||
});
|
||||
return @ptrCast(ret.executable.?);
|
||||
}
|
||||
|
||||
pub const BufferFromHostBufferArgs = struct {
|
||||
data: []const u8,
|
||||
buffer_type: BufferType,
|
||||
dims: []const i64,
|
||||
byte_strides: ?[]const i64,
|
||||
device: *const Device,
|
||||
host_buffer_semantics: HostBufferSemantics,
|
||||
};
|
||||
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, *Event } {
|
||||
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
|
||||
.client = self.inner(),
|
||||
.data = @ptrCast(@constCast(args.data.ptr)),
|
||||
.type = @intFromEnum(args.buffer_type),
|
||||
.dims = @ptrCast(@constCast(args.dims.ptr)),
|
||||
.num_dims = args.dims.len,
|
||||
.byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null,
|
||||
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
|
||||
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
|
||||
.device = @ptrCast(@constCast(args.device)),
|
||||
.memory = null, // TODO
|
||||
.device_layout = null, // TODO
|
||||
.done_with_host_buffer = null,
|
||||
.buffer = null,
|
||||
});
|
||||
return .{
|
||||
@ptrCast(ret.buffer.?),
|
||||
@ptrCast(ret.done_with_host_buffer.?),
|
||||
};
|
||||
}
|
||||
|
||||
/// Returns the Profiler for this API.
|
||||
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
||||
/// Platforms with known profiler extensions: cuda, xpu
|
||||
pub fn getProfiler(self: *const Client, api: *const Api, options: Profiler.Options) Profiler {
|
||||
if (api.version().minor >= 45) {
|
||||
if (api.lookupExtension(c.PJRT_Profiler_Extension, c.PJRT_Extension_Type_Profiler)) |ext| {
|
||||
return Profiler.init(ext.profiler_api.*, options);
|
||||
}
|
||||
}
|
||||
log.warn("No profiler found for platform: {}", .{self});
|
||||
return Profiler.init(null, options);
|
||||
}
|
||||
|
||||
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
|
||||
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
||||
// return .{ .custom_call_register = ext.custom_call.? };
|
||||
// }
|
||||
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
|
||||
// return null;
|
||||
// }
|
||||
|
||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||
const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{
|
||||
.client = self.inner(),
|
||||
.serialized_executable = bytes.ptr,
|
||||
.serialized_executable_size = bytes.len,
|
||||
});
|
||||
return @ptrCast(ret.loaded_executable.?);
|
||||
}
|
||||
|
||||
pub const CreateViewOfDeviceBufferArgs = struct {
|
||||
data: []const u8,
|
||||
dims: []const i64,
|
||||
element_type: BufferType,
|
||||
layout: MemoryLayout,
|
||||
device: *const Device,
|
||||
on_delete_callback: ?*const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = null,
|
||||
on_delete_callback_arg: ?*anyopaque = null,
|
||||
stream: ?isize = null,
|
||||
};
|
||||
|
||||
pub fn createViewOfDeviceBuffer(self: *const Client, api: *const Api, args: CreateViewOfDeviceBufferArgs) ApiError!*Buffer {
|
||||
const layout = args.layout.toCStruct();
|
||||
const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{
|
||||
.client = self.inner(),
|
||||
.device_buffer_ptr = @ptrCast(@constCast(args.data.ptr)),
|
||||
.dims = args.dims.ptr,
|
||||
.num_dims = args.dims.len,
|
||||
.element_type = @intFromEnum(args.element_type),
|
||||
.layout = @ptrCast(@constCast(&layout)),
|
||||
.device = @ptrCast(@constCast(args.device)),
|
||||
.on_delete_callback = args.on_delete_callback,
|
||||
.on_delete_callback_arg = args.on_delete_callback_arg,
|
||||
.stream = if (args.stream) |stream| stream else 0,
|
||||
});
|
||||
return @ptrCast(ret.buffer.?);
|
||||
}
|
||||
};
|
||||
|
||||
// // pub const CustomCallSignature = *const fn (*anyopaque, **anyopaque, [*c]const u8, usize) callconv(.C) void;
|
||||
|
||||
// // pub const GpuCustomCallRegistry = struct {
|
||||
// // custom_call_register: *const c.PJRT_Gpu_Register_Custom_Call,
|
||||
|
||||
// // pub fn registerCustomCall(self: GpuCustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: CustomCallSignature) ApiError!void {
|
||||
// // var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
|
||||
// // .function_name = name.ptr,
|
||||
// // .function_name_size = name.len,
|
||||
// // .api_version = @intCast(api_version),
|
||||
// // .custom_call_function = @ptrCast(@constCast(func)),
|
||||
// // });
|
||||
// // const result = self.custom_call_register(&ret);
|
||||
// // if (result) |pjrt_c_error| {
|
||||
// // const pjrt_error = .{ .inner = pjrt_c_error };
|
||||
// // log.err("{s}", .{pjrt_error.getMessage(api)});
|
||||
// // return pjrt_error.getCode().toApiError();
|
||||
// // }
|
||||
// // }
|
||||
// // };
|
||||
|
||||
// // const OldPjrtExtension = extern struct {
|
||||
// // type: c.PJRT_Extension_Type,
|
||||
// // next: [*]OldPjrtExtension,
|
||||
// // };
|
||||
|
||||
pub const Device = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Device).inner;
|
||||
|
||||
pub fn getDescription(self: *const Device, api: *const Api) *const DeviceDescription {
|
||||
const ret = api.call(.PJRT_Device_GetDescription, .{
|
||||
.device = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.device_description.?);
|
||||
}
|
||||
|
||||
pub fn isAddressable(self: *const Device, api: *const Api) bool {
|
||||
const ret = api.call(.PJRT_Device_IsAddressable, .{
|
||||
.device = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.is_addressable;
|
||||
}
|
||||
|
||||
pub fn getLocalHardwareId(self: *const Device, api: *const Api) usize {
|
||||
const ret = api.call(.PJRT_Device_LocalHardwareId, .{
|
||||
.device = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @intCast(ret.local_hardware_id);
|
||||
}
|
||||
};
|
||||
|
||||
pub const DeviceDescription = opaque {
|
||||
const inner = InnerMixin(c.PJRT_DeviceDescription).inner;
|
||||
|
||||
pub fn getId(self: *const DeviceDescription, api: *const Api) usize {
|
||||
const ret = api.call(.PJRT_DeviceDescription_Id, .{
|
||||
.device_description = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @intCast(ret.id);
|
||||
}
|
||||
|
||||
pub fn getProcessIndex(self: *const DeviceDescription, api: *const Api) usize {
|
||||
const ret = api.call(.PJRT_DeviceDescription_ProcessIndex, .{
|
||||
.device_description = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @intCast(ret.process_index);
|
||||
}
|
||||
|
||||
pub fn getKind(self: *const DeviceDescription, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_DeviceDescription_Kind, .{
|
||||
.device_description = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.device_kind[0..ret.device_kind_size];
|
||||
}
|
||||
|
||||
pub fn debugString(self: *const DeviceDescription, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_DeviceDescription_DebugString, .{
|
||||
.device_description = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.debug_string[0..ret.debug_string_size];
|
||||
}
|
||||
|
||||
pub fn toString(self: *const DeviceDescription, api: *const Api) []const u8 {
|
||||
const ret = api.call(.PJRT_DeviceDescription_ToString, .{
|
||||
.device_description = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.to_string[0..ret.to_string_size];
|
||||
}
|
||||
};
|
||||
|
||||
pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError;
|
||||
|
||||
pub const SerializeResult = struct {
|
||||
bytes: []const u8,
|
||||
handle: *anyopaque,
|
||||
deleter: *const fn (?*anyopaque) callconv(.C) void,
|
||||
|
||||
pub fn deinit(self: *SerializeResult) void {
|
||||
self.deleter(self.handle);
|
||||
self.bytes = &.{};
|
||||
self.* = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
pub const Executable = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Executable).inner;
|
||||
|
||||
pub fn deinit(self: *Executable, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Executable_Destroy, .{
|
||||
.executable = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn getCostAnalysis(self: *const Executable, api: *const Api) GetCostAnalysisError![]*const NamedValue {
|
||||
const ret = try api.call(.PJRT_Executable_GetCostAnalysis, .{
|
||||
.executable = self.inner(),
|
||||
});
|
||||
const values: [*]*const NamedValue = @ptrCast(ret.properties);
|
||||
return values[0..ret.num_properties];
|
||||
}
|
||||
|
||||
pub fn serialize(self: *const Executable, api: *const Api) ApiError!SerializeResult {
|
||||
const ret = try api.call(.PJRT_Executable_Serialize, .{
|
||||
.executable = self.inner(),
|
||||
});
|
||||
|
||||
return .{
|
||||
.bytes = ret.serialized_bytes[0..ret.serialized_bytes_size],
|
||||
.handle = ret.serialized_executable.?,
|
||||
.deleter = @ptrCast(ret.serialized_executable_deleter.?),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const LoadedExecutable = opaque {
|
||||
const inner = InnerMixin(c.PJRT_LoadedExecutable).inner;
|
||||
|
||||
pub fn deinit(self: *LoadedExecutable, api: *const Api) void {
|
||||
_ = api.call(.PJRT_LoadedExecutable_Destroy, .{
|
||||
.executable = self.inner(),
|
||||
}) catch {};
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
||||
_ = api.call(.PJRT_LoadedExecutable_Delete, .{
|
||||
.executable = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn isDeleted(self: *const LoadedExecutable, api: *const Api) bool {
|
||||
const ret = api.call(.PJRT_LoadedExecutable_IsDeleted, .{
|
||||
.executable = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.is_deleted;
|
||||
}
|
||||
|
||||
pub fn getAddressableDevices(self: *const LoadedExecutable, api: *const Api) []Device {
|
||||
const ret = api.call(.PJRT_LoadedExecutable_AddressableDevices, .{
|
||||
.executable = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.addressable_devices);
|
||||
}
|
||||
|
||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: struct {
|
||||
num_args: usize,
|
||||
arguments: []const [*]const *const Buffer,
|
||||
results: []const [*]*Buffer,
|
||||
events: []*Event,
|
||||
non_donatable_input_indices: []const i64 = &.{},
|
||||
}) ApiError!void {
|
||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||
.send_callbacks = null,
|
||||
.recv_callbacks = null,
|
||||
.num_send_ops = 0,
|
||||
.num_recv_ops = 0,
|
||||
.launch_id = 0,
|
||||
.non_donatable_input_indices = @ptrCast(args.non_donatable_input_indices.ptr),
|
||||
.num_non_donatable_input_indices = args.non_donatable_input_indices.len,
|
||||
});
|
||||
_ = try api.call(.PJRT_LoadedExecutable_Execute, .{
|
||||
.executable = self.inner(),
|
||||
.options = @ptrCast(&options),
|
||||
.argument_lists = @ptrCast(args.arguments.ptr),
|
||||
.num_devices = @intCast(args.arguments.len),
|
||||
.num_args = args.num_args,
|
||||
.output_lists = @ptrCast(args.results.ptr),
|
||||
.device_complete_events = @ptrCast(args.events.ptr),
|
||||
.execute_device = null,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
||||
const ret = try api.call(.PJRT_LoadedExecutable_GetExecutable, .{
|
||||
.loaded_executable = self.inner(),
|
||||
});
|
||||
return @ptrCast(ret.executable.?);
|
||||
}
|
||||
};
|
||||
|
||||
pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
||||
INVALID = c.PJRT_Buffer_Type_INVALID,
|
||||
PRED = c.PJRT_Buffer_Type_PRED,
|
||||
S8 = c.PJRT_Buffer_Type_S8,
|
||||
S16 = c.PJRT_Buffer_Type_S16,
|
||||
S32 = c.PJRT_Buffer_Type_S32,
|
||||
S64 = c.PJRT_Buffer_Type_S64,
|
||||
U8 = c.PJRT_Buffer_Type_U8,
|
||||
U16 = c.PJRT_Buffer_Type_U16,
|
||||
U32 = c.PJRT_Buffer_Type_U32,
|
||||
U64 = c.PJRT_Buffer_Type_U64,
|
||||
F16 = c.PJRT_Buffer_Type_F16,
|
||||
F32 = c.PJRT_Buffer_Type_F32,
|
||||
F64 = c.PJRT_Buffer_Type_F64,
|
||||
BF16 = c.PJRT_Buffer_Type_BF16,
|
||||
C64 = c.PJRT_Buffer_Type_C64,
|
||||
C128 = c.PJRT_Buffer_Type_C128,
|
||||
F8E5M2 = c.PJRT_Buffer_Type_F8E5M2,
|
||||
F8E4M3FN = c.PJRT_Buffer_Type_F8E4M3FN,
|
||||
F8E4M3B11FNUZ = c.PJRT_Buffer_Type_F8E4M3B11FNUZ,
|
||||
F8E5M2FNUZ = c.PJRT_Buffer_Type_F8E5M2FNUZ,
|
||||
F8E4M3FNUZ = c.PJRT_Buffer_Type_F8E4M3FNUZ,
|
||||
S4 = c.PJRT_Buffer_Type_S4,
|
||||
U4 = c.PJRT_Buffer_Type_U4,
|
||||
};
|
||||
|
||||
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
||||
Tiled = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
||||
Strides = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
||||
};
|
||||
|
||||
pub const MemoryLayout = union(MemoryLayoutType) {
|
||||
pub const Type = MemoryLayoutType;
|
||||
|
||||
pub const Tiled = struct {
|
||||
minor_to_major: []const i64,
|
||||
tile_dims: []const i64,
|
||||
tile_dims_sizes: []const usize,
|
||||
};
|
||||
|
||||
pub const Strides = struct {
|
||||
byte_strides: []const i64,
|
||||
};
|
||||
|
||||
Tiled: Tiled,
|
||||
Strides: Strides,
|
||||
|
||||
fn toCStruct(self: MemoryLayout) c.PJRT_Buffer_MemoryLayout {
|
||||
return pjrtStruct(switch (self) {
|
||||
.Tiled => |v| c.PJRT_Buffer_MemoryLayout{
|
||||
.type = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
||||
.unnamed_0 = .{
|
||||
.tiled = c.PJRT_Buffer_MemoryLayout_Tiled{
|
||||
.minor_to_major = v.minor_to_major.ptr,
|
||||
.minor_to_major_size = v.minor_to_major.len,
|
||||
.tile_dims = v.tile_dims.ptr,
|
||||
.tile_dim_sizes = v.tile_dims_sizes.ptr,
|
||||
.num_tiles = v.tile_dims_sizes.len,
|
||||
},
|
||||
},
|
||||
},
|
||||
.Strides => |v| c.PJRT_Buffer_MemoryLayout{
|
||||
.type = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
||||
.unnamed_0 = .{
|
||||
.strides = c.PJRT_Buffer_MemoryLayout_Strides{
|
||||
.byte_strides = v.byte_strides.ptr,
|
||||
.num_byte_strides = v.byte_strides.len,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
pub const HostBufferSemantics = enum(c.PJRT_HostBufferSemantics) {
|
||||
ImmutableOnlyDuringCall = c.PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
|
||||
ImmutableUntilTransferCompletes = c.PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
|
||||
ImmutableZeroCopy = c.PJRT_HostBufferSemantics_kImmutableZeroCopy,
|
||||
MutableZeroCopy = c.PJRT_HostBufferSemantics_kMutableZeroCopy,
|
||||
};
|
||||
|
||||
pub const Buffer = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Buffer).inner;
|
||||
|
||||
pub fn deinit(self: *Buffer, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Buffer_Destroy, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn getDevice(self: *const Buffer, api: *const Api) ApiError!*Device {
|
||||
const ret = try api.call(.PJRT_Buffer_Device, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
return @ptrCast(ret.device.?);
|
||||
}
|
||||
|
||||
pub fn delete(self: *Buffer, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Buffer_Delete, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn isDeleted(self: *const Buffer, api: *const Api) bool {
|
||||
const ret = api.call(.PJRT_Buffer_IsDeleted, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.is_deleted;
|
||||
}
|
||||
|
||||
pub fn isOnCpu(self: *const Buffer, api: *const Api) bool {
|
||||
const ret = api.call(.PJRT_Buffer_IsOnCpu, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.is_on_cpu;
|
||||
}
|
||||
|
||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!*Event {
|
||||
const ret = try api.call(.PJRT_Buffer_ToHostBuffer, .{
|
||||
.src = self.inner(),
|
||||
.dst = @ptrCast(dst.ptr),
|
||||
.dst_size = dst.len,
|
||||
});
|
||||
return @ptrCast(ret.event.?);
|
||||
}
|
||||
|
||||
pub fn getElementType(self: *const Buffer, api: *const Api) BufferType {
|
||||
const ret = api.call(.PJRT_Buffer_ElementType, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @enumFromInt(ret.type);
|
||||
}
|
||||
|
||||
pub fn getDimensions(self: *const Buffer, api: *const Api) []const i64 {
|
||||
const ret = api.call(.PJRT_Buffer_Dimensions, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.dims[0..ret.num_dims];
|
||||
}
|
||||
|
||||
pub fn getUnpaddedDimensions(self: *const Buffer, api: *const Api) ApiError![]const i64 {
|
||||
const ret = try api.call(.PJRT_Buffer_UnpaddedDimensions, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
return ret.dims[0..ret.num_dims];
|
||||
}
|
||||
|
||||
pub fn getOnDeviceSizeInBytes(self: *const Buffer, api: *const Api) ApiError!usize {
|
||||
const ret = try api.call(.PJRT_Buffer_OnDeviceSizeInBytes, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
return ret.on_device_size_in_bytes;
|
||||
}
|
||||
|
||||
pub fn copyToDevice(self: *const Buffer, api: *const Api, device: Device) ApiError!Buffer {
|
||||
const ret = try api.call(.PJRT_Buffer_CopyToDevice, .{
|
||||
.buffer = self.inner(),
|
||||
.dst_device = device.inner,
|
||||
});
|
||||
return @ptrCast(ret.dst_buffer.?);
|
||||
}
|
||||
|
||||
pub fn getReadyEvent(self: *const Buffer, api: *const Api) *Event {
|
||||
const ret = api.call(.PJRT_Buffer_ReadyEvent, .{
|
||||
.buffer = self.inner(),
|
||||
}) catch unreachable;
|
||||
return @ptrCast(ret.event.?);
|
||||
}
|
||||
|
||||
pub fn getOpaqueDeviceMemoryDataPointer(self: *const Buffer, api: *const Api) ApiError!*anyopaque {
|
||||
const ret = try api.call(.PJRT_Buffer_OpaqueDeviceMemoryDataPointer, .{
|
||||
.buffer = self.inner(),
|
||||
});
|
||||
return ret.device_memory_ptr.?;
|
||||
}
|
||||
};
|
||||
|
||||
pub const Event = opaque {
|
||||
const inner = InnerMixin(c.PJRT_Event).inner;
|
||||
|
||||
pub fn deinit(self: *Event, api: *const Api) void {
|
||||
_ = api.call(.PJRT_Event_Destroy, .{
|
||||
.event = self.inner(),
|
||||
}) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn isReady(self: *const Event, api: *const Api) bool {
|
||||
const ret = api.call(.PJRT_Event_IsReady, .{
|
||||
.event = self.inner(),
|
||||
}) catch unreachable;
|
||||
return ret.is_ready;
|
||||
}
|
||||
|
||||
pub fn getEventError(self: *const Event, api: *const Api) ApiError!?*Error {
|
||||
const ret = try api.call(.PJRT_Event_Error, .{
|
||||
.event = self.inner(),
|
||||
});
|
||||
return @ptrCast(ret);
|
||||
}
|
||||
|
||||
pub fn await_(self: *const Event, api: *const Api) ApiError!void {
|
||||
_ = try api.call(.PJRT_Event_Await, .{
|
||||
.event = self.inner(),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.C) void, user_arg: ?*anyopaque) ApiError!void {
|
||||
_ = try api.call(.PJRT_Event_OnReady, .{
|
||||
.event = self.inner(),
|
||||
.callback = @ptrCast(func),
|
||||
.user_arg = user_arg,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
pub const NamedValue = extern struct {
|
||||
comptime {
|
||||
std.debug.assert(@sizeOf(NamedValue) == @sizeOf(c.PJRT_NamedValue));
|
||||
}
|
||||
|
||||
inner: c.PJRT_NamedValue,
|
||||
|
||||
pub const Kind = enum(c.PJRT_NamedValue_Type) {
|
||||
string = c.PJRT_NamedValue_kString,
|
||||
int64 = c.PJRT_NamedValue_kInt64,
|
||||
int64list = c.PJRT_NamedValue_kInt64List,
|
||||
float = c.PJRT_NamedValue_kFloat,
|
||||
bool = c.PJRT_NamedValue_kBool,
|
||||
};
|
||||
|
||||
pub fn kind(self: NamedValue) Kind {
|
||||
return @enumFromInt(self.inner.type);
|
||||
}
|
||||
|
||||
pub fn name(self: NamedValue) []const u8 {
|
||||
return self.inner.name[0..self.inner.name_size];
|
||||
}
|
||||
|
||||
pub fn from(name_: []const u8, value: anytype) NamedValue {
|
||||
return switch (@TypeOf(value)) {
|
||||
[]u8, []const u8 => fromString(name_, value),
|
||||
i64 => fromInt64(name_, value),
|
||||
[]i64, []const i64 => fromInt64List(name_, value),
|
||||
f32 => fromFloat(name_, value),
|
||||
bool => fromBool(name_, value),
|
||||
else => unreachable,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn fromString(name_: []const u8, value: []const u8) NamedValue {
|
||||
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
|
||||
.name = @ptrCast(@constCast(name_.ptr)),
|
||||
.name_size = name_.len,
|
||||
.type = c.PJRT_NamedValue_kString,
|
||||
.unnamed_0 = .{ .string_value = @ptrCast(@constCast(value.ptr)) },
|
||||
.value_size = value.len,
|
||||
}) };
|
||||
}
|
||||
|
||||
pub fn fromInt64(name_: []const u8, value: i64) NamedValue {
|
||||
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
|
||||
.name = @ptrCast(@constCast(name_.ptr)),
|
||||
.name_size = name_.len,
|
||||
.type = c.PJRT_NamedValue_kInt64,
|
||||
.unnamed_0 = .{ .int64_value = value },
|
||||
.value_size = 1,
|
||||
}) };
|
||||
}
|
||||
|
||||
pub fn fromInt64List(name_: []const u8, value: []const i64) NamedValue {
|
||||
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
|
||||
.name = @ptrCast(@constCast(name_.ptr)),
|
||||
.name_size = name_.len,
|
||||
.type = c.PJRT_NamedValue_kInt64List,
|
||||
.unnamed_0 = .{ .int64_array_value = @ptrCast(@constCast(value.ptr)) },
|
||||
.value_size = value.len,
|
||||
}) };
|
||||
}
|
||||
|
||||
pub fn fromFloat(name_: []const u8, value: f32) NamedValue {
|
||||
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
|
||||
.name = @ptrCast(@constCast(name_.ptr)),
|
||||
.name_size = name_.len,
|
||||
.type = c.PJRT_NamedValue_kFloat,
|
||||
.unnamed_0 = .{ .float_value = value },
|
||||
.value_size = 1,
|
||||
}) };
|
||||
}
|
||||
|
||||
pub fn fromBool(name_: []const u8, value: bool) NamedValue {
|
||||
return .{ .inner = pjrtStruct(c.PJRT_NamedValue{
|
||||
.name = @ptrCast(@constCast(name_.ptr)),
|
||||
.name_size = name_.len,
|
||||
.type = c.PJRT_NamedValue_kBool,
|
||||
.unnamed_0 = .{ .bool_value = value },
|
||||
.value_size = 1,
|
||||
}) };
|
||||
}
|
||||
|
||||
pub fn format(
|
||||
self: NamedValue,
|
||||
comptime fmt: []const u8,
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] });
|
||||
const u = self.inner.unnamed_0;
|
||||
switch (self.kind()) {
|
||||
.string => try writer.print(" .string = {s} ", .{u.string_value[0..self.inner.value_size]}),
|
||||
.int64 => try writer.print(" .int64 = {d} ", .{u.int64_value}),
|
||||
.int64list => try writer.print(" .int64list = {d} ", .{u.int64_array_value[0..self.inner.value_size]}),
|
||||
.float => try writer.print(" .float = {d} ", .{u.float_value}),
|
||||
.bool => try writer.print(" .bool = {} ", .{u.bool_value}),
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
}
|
||||
};
|
||||
205
pjrt/profiler.zig
Normal file
205
pjrt/profiler.zig
Normal file
@ -0,0 +1,205 @@
|
||||
const std = @import("std");
|
||||
const c = @import("c");
|
||||
const tsl_proto = @import("//tsl:profiler_options_proto");
|
||||
|
||||
const log = std.log.scoped(.zml_profiler);
|
||||
|
||||
/// Pjrt Profiler extension
|
||||
pub const Profiler = struct {
|
||||
api: ?c.PLUGIN_Profiler_Api,
|
||||
inner: *c.PLUGIN_Profiler,
|
||||
last_error: ?*Error = null,
|
||||
status: Status = .ready,
|
||||
|
||||
pub const Status = enum { ready, started, stopped, done };
|
||||
pub const Error = c.PLUGIN_Profiler_Error;
|
||||
pub const Options = tsl_proto.ProfileOptions;
|
||||
|
||||
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: Options) Profiler {
|
||||
if (api == null) {
|
||||
return .{ .api = null, .inner = undefined };
|
||||
}
|
||||
|
||||
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
|
||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
||||
const byte_options = options.encode(fba.allocator()) catch unreachable;
|
||||
var res: Profiler = .{ .api = api, .inner = undefined };
|
||||
var args: c.PLUGIN_Profiler_Create_Args = .{
|
||||
.options = byte_options.ptr,
|
||||
.options_size = byte_options.len,
|
||||
.profiler = undefined, // out
|
||||
};
|
||||
res.check(api.?.create.?(&args)) catch unreachable;
|
||||
|
||||
res.inner = args.profiler.?;
|
||||
return res;
|
||||
}
|
||||
|
||||
fn transition(self: *Profiler, fn_name: []const u8, expected: Status, next: Status) void {
|
||||
if (self.status == expected) {
|
||||
self.status = next;
|
||||
return;
|
||||
}
|
||||
std.debug.panic("Profiler can't `{s}()`. Current status: {}, expected: {}", .{ fn_name, self.status, expected });
|
||||
}
|
||||
|
||||
pub fn start(self: *Profiler) void {
|
||||
self.transition("start", .ready, .started);
|
||||
if (self.api == null) return;
|
||||
var args: c.PLUGIN_Profiler_Start_Args = .{ .profiler = self.inner };
|
||||
self.check(self.api.?.start.?(&args)) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn stop(self: *Profiler) void {
|
||||
self.transition("stop", .started, .stopped);
|
||||
if (self.api == null) return;
|
||||
|
||||
var args: c.PLUGIN_Profiler_Stop_Args = .{ .profiler = self.inner };
|
||||
self.check(self.api.?.stop.?(&args)) catch unreachable;
|
||||
}
|
||||
|
||||
pub fn collectData(self: *Profiler, allocator: std.mem.Allocator) !ProfilingData {
|
||||
self.transition("collect_data", .stopped, .done);
|
||||
if (self.api == null) return .{ .external = &.{} };
|
||||
|
||||
var args: c.PLUGIN_Profiler_CollectData_Args = .{
|
||||
.struct_size = c.PLUGIN_Profiler_CollectData_Args_STRUCT_SIZE,
|
||||
.profiler = self.inner,
|
||||
.buffer = null,
|
||||
.buffer_size_in_bytes = 0,
|
||||
};
|
||||
try self.check(self.api.?.collect_data.?(&args));
|
||||
std.debug.assert(args.buffer_size_in_bytes > 0);
|
||||
const buffer: ProfilingData = if (args.buffer == null) blk: {
|
||||
std.log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
|
||||
// The plugin want us to allocate memory for it:
|
||||
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
|
||||
args.buffer = buffer.ptr;
|
||||
try self.check(self.api.?.collect_data.?(&args));
|
||||
break :blk .{ .owned = buffer };
|
||||
} else blk: {
|
||||
std.log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
|
||||
// Drop sentinel. The profiler plugin returns a null terminated string.
|
||||
// But this is creating issues if we save the sentinel on disk,
|
||||
// because it will trip up protobuf readers.
|
||||
var data = args.buffer[0..args.buffer_size_in_bytes];
|
||||
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
|
||||
break :blk .{ .external = data };
|
||||
};
|
||||
|
||||
// printDataAsXSpace(allocator, buffer.items());
|
||||
return buffer;
|
||||
}
|
||||
|
||||
pub fn dumpDataTo(
|
||||
self: *Profiler,
|
||||
allocator: std.mem.Allocator,
|
||||
dir: std.fs.Dir,
|
||||
file_name: []const u8,
|
||||
) !void {
|
||||
const profile_data = try self.collectData(allocator);
|
||||
defer profile_data.free(allocator);
|
||||
|
||||
if (profile_data.items().len == 0) return;
|
||||
|
||||
const file = try dir.createFile(file_name, .{ .truncate = true });
|
||||
defer file.close();
|
||||
log.info("Writing profiling data to {s} ({} bytes)", .{ file_name, profile_data.items().len });
|
||||
return try file.writeAll(profile_data.items());
|
||||
}
|
||||
|
||||
fn check(self: *Profiler, c_error: ?*Error) !void {
|
||||
if (c_error) |err| {
|
||||
self.last_error = err;
|
||||
return error.PjrtProfilerError;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deinit(self: Profiler) void {
|
||||
switch (self.status) {
|
||||
.started => log.warn("Profiler was never stopped", .{}),
|
||||
.stopped => log.warn("Profiler data was never collected", .{}),
|
||||
else => {},
|
||||
}
|
||||
if (self.api == null) return;
|
||||
|
||||
var args: c.PLUGIN_Profiler_Destroy_Args = .{ .profiler = self.inner };
|
||||
_ = self.api.?.destroy.?(&args);
|
||||
}
|
||||
};
|
||||
|
||||
// If this was working it would be a good alternative to xspace_to_json.cc
|
||||
// const xspace = @import("xspace.pb.zig");
|
||||
// pub fn printDataAsXSpace(allocator: std.mem.Allocator, data: []const u8) void {
|
||||
// var arena = std.heap.ArenaAllocator.init(allocator);
|
||||
// defer arena.deinit();
|
||||
//
|
||||
// const space = xspace.XSpace.decode(data, arena.allocator()) catch |e| {
|
||||
// std.log.err("Couldn't load profiling data: {}", .{e});
|
||||
// return;
|
||||
// };
|
||||
//
|
||||
// for (space.errors.items) |err| {
|
||||
// std.log.err("{s}", .{err.getSlice()});
|
||||
// }
|
||||
// for (space.warnings.items) |warning| {
|
||||
// std.log.warn("{s}", .{warning.getSlice()});
|
||||
// }
|
||||
// for (space.hostnames.items) |host| {
|
||||
// std.log.info("Profiled host {s}", .{host.getSlice()});
|
||||
// }
|
||||
// for (space.planes.items) |plane| {
|
||||
// var event_metadata = std.hash_map.AutoHashMap(i64, xspace.XEventMetadata).init(arena.allocator());
|
||||
// event_metadata.ensureTotalCapacity(@intCast(plane.event_metadata.items.len)) catch return;
|
||||
// defer event_metadata.deinit();
|
||||
// for (plane.event_metadata.items) |event_meta_entry| {
|
||||
// if (event_meta_entry.value) |event_meta| {
|
||||
// event_metadata.putAssumeCapacity(event_meta.id, event_meta);
|
||||
// }
|
||||
// }
|
||||
// std.log.info("Profiled device {s}", .{plane.name.getSlice()});
|
||||
|
||||
// for (plane.lines.items) |line| {
|
||||
// std.log.info(
|
||||
// "{d} -> {d} xline {s} ({d} events)",
|
||||
// .{ line.timestamp_ns, line.duration_ps, line.name.getSlice(), line.events.items.len },
|
||||
// );
|
||||
// const ps_per_ns: i64 = 1000;
|
||||
// var duration_ns: i64 = 0;
|
||||
// var last_metadata_id: i64 = 0;
|
||||
// for (line.events.items) |event| {
|
||||
// if (event.metadata_id != last_metadata_id and duration_ns != 0) {
|
||||
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
|
||||
// const meta = event_metadata.get(event.metadata_id).?;
|
||||
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
|
||||
|
||||
// last_metadata_id = event.metadata_id;
|
||||
// duration_ns = 0;
|
||||
// }
|
||||
// duration_ns += @divFloor(event.duration_ps, ps_per_ns);
|
||||
|
||||
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
|
||||
// const meta = event_metadata.get(event.metadata_id).?;
|
||||
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
const ProfilingData = union(enum) {
|
||||
owned: []const u8,
|
||||
external: []const u8,
|
||||
|
||||
pub fn items(self: ProfilingData) []const u8 {
|
||||
return switch (self) {
|
||||
inline else => |x| x,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn free(self: ProfilingData, allocator: std.mem.Allocator) void {
|
||||
switch (self) {
|
||||
.owned => |data| allocator.free(data),
|
||||
.external => {},
|
||||
}
|
||||
}
|
||||
};
|
||||
21
platform_mappings
Normal file
21
platform_mappings
Normal file
@ -0,0 +1,21 @@
|
||||
platforms:
|
||||
@zml//platforms:linux_amd64
|
||||
--cpu=k8
|
||||
|
||||
@zml//platforms:linux_arm64
|
||||
--cpu=aarch64
|
||||
|
||||
@zml//platforms:macos_arm64
|
||||
--cpu=darwin_arm64
|
||||
--apple_platform_type=macos
|
||||
|
||||
flags:
|
||||
--cpu=darwin_arm64
|
||||
--apple_platform_type=macos
|
||||
@zml//platforms:macos_arm64
|
||||
|
||||
--cpu=k8
|
||||
@zml//platforms:linux_amd64
|
||||
|
||||
--cpu=aarch64
|
||||
@zml//platforms:linux_arm64
|
||||
26
platforms/BUILD.bazel
Normal file
26
platforms/BUILD.bazel
Normal file
@ -0,0 +1,26 @@
|
||||
platform(
|
||||
name = "linux_amd64",
|
||||
constraint_values = [
|
||||
"@platforms//cpu:x86_64",
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
platform(
|
||||
name = "linux_arm64",
|
||||
constraint_values = [
|
||||
"@platforms//cpu:aarch64",
|
||||
"@platforms//os:linux",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
platform(
|
||||
name = "macos_arm64",
|
||||
constraint_values = [
|
||||
"@platforms//cpu:aarch64",
|
||||
"@platforms//os:macos",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
42
runtimes/BUILD.bazel
Normal file
42
runtimes/BUILD.bazel
Normal file
@ -0,0 +1,42 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
|
||||
RUNTIMES = {
|
||||
"cpu": True,
|
||||
"cuda": False,
|
||||
"rocm": False,
|
||||
"tpu": False,
|
||||
}
|
||||
|
||||
[
|
||||
bool_flag(
|
||||
name = runtime,
|
||||
build_setting_default = default,
|
||||
)
|
||||
for runtime, default in RUNTIMES.items()
|
||||
]
|
||||
|
||||
[
|
||||
config_setting(
|
||||
name = "_{}".format(runtime),
|
||||
flag_values = {":{}".format(runtime): "True"},
|
||||
)
|
||||
for runtime in RUNTIMES.keys()
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "runtimes",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":_cpu": ["//runtimes/cpu"],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
":_cuda": ["//runtimes/cuda"],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
":_rocm": ["//runtimes/rocm"],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
":_tpu": ["//runtimes/tpu"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
8
runtimes/cpu/BUILD.bazel
Normal file
8
runtimes/cpu/BUILD.bazel
Normal file
@ -0,0 +1,8 @@
|
||||
alias(
|
||||
name = "cpu",
|
||||
actual = select({
|
||||
"@platforms//os:macos": "@libpjrt_cpu_darwin_arm64//:libpjrt_cpu",
|
||||
"@platforms//os:linux": "@libpjrt_cpu_linux_amd64//:libpjrt_cpu",
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
34
runtimes/cpu/cpu.bzl
Normal file
34
runtimes/cpu/cpu.bzl
Normal file
@ -0,0 +1,34 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
_BUILD = """\
|
||||
cc_import(
|
||||
name = "libpjrt_cpu",
|
||||
shared_library = "libpjrt_cpu.{ext}",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
"""
|
||||
|
||||
def _cpu_pjrt_plugin_impl(mctx):
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_linux_amd64",
|
||||
build_file_content = _BUILD.format(ext = "so"),
|
||||
sha256 = "14317143acd6a38656e97280e8010c0b8d8c0863dff2ae82834b6f2fe747427b",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_linux-amd64.tar.gz",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_cpu_darwin_arm64",
|
||||
build_file_content = _BUILD.format(ext = "dylib"),
|
||||
sha256 = "3a26e1372f68fc11028c4ec22a0c72693f08e7690ba8c5f28b17f5baa9c9dc77",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cpu_darwin-arm64.tar.gz",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
cpu_pjrt_plugin = module_extension(
|
||||
implementation = _cpu_pjrt_plugin_impl,
|
||||
)
|
||||
5
runtimes/cuda/BUILD.bazel
Normal file
5
runtimes/cuda/BUILD.bazel
Normal file
@ -0,0 +1,5 @@
|
||||
alias(
|
||||
name = "cuda",
|
||||
actual = "@libpjrt_cuda",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
197
runtimes/cuda/cuda.bzl
Normal file
197
runtimes/cuda/cuda.bzl
Normal file
@ -0,0 +1,197 @@
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
|
||||
|
||||
ARCH = "linux-x86_64"
|
||||
|
||||
CUDA_VERSION = "12.6.1"
|
||||
CUDNN_VERSION = "9.3.0"
|
||||
|
||||
_CC_IMPORT_TPL = """\
|
||||
cc_import(
|
||||
name = "{name}",
|
||||
shared_library = "lib/{shared_library}",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
"""
|
||||
|
||||
CUDA_PACKAGES = {
|
||||
"cuda_cudart": _CC_IMPORT_TPL.format(name = "cudart", shared_library = "libcudart.so.12"),
|
||||
"cuda_cupti": _CC_IMPORT_TPL.format(name = "cupti", shared_library = "libcupti.so.12"),
|
||||
"libcufft": _CC_IMPORT_TPL.format(name = "cufft", shared_library = "libcufft.so.11"),
|
||||
"libcusolver": _CC_IMPORT_TPL.format(name = "cusolver", shared_library = "libcusolver.so.11"),
|
||||
"libcusparse": _CC_IMPORT_TPL.format(name = "cusparse", shared_library = "libcusparse.so.12"),
|
||||
"libnvjitlink": _CC_IMPORT_TPL.format(name = "nvjitlink", shared_library = "libnvJitLink.so.12"),
|
||||
"cuda_nvcc": """\
|
||||
filegroup(
|
||||
name = "ptxas",
|
||||
srcs = ["bin/ptxas"],
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libdevice",
|
||||
srcs = ["nvvm/libdevice/libdevice.10.bc"],
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "nvvm",
|
||||
shared_library = "nvvm/lib64/libnvvm.so.4",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
"cuda_nvrtc": """\
|
||||
cc_import(
|
||||
name = "nvrtc",
|
||||
shared_library = "lib/libnvrtc.so.12",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
deps = [":nvrtc_builtins"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "nvrtc_builtins",
|
||||
shared_library = "lib/libnvrtc-builtins.so.12.6",
|
||||
)
|
||||
""",
|
||||
"libcublas": """\
|
||||
cc_import(
|
||||
name = "cublasLt",
|
||||
shared_library = "lib/libcublasLt.so.12",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cublas",
|
||||
shared_library = "lib/libcublas.so.12",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
deps = [":cublasLt"],
|
||||
)
|
||||
""",
|
||||
}
|
||||
|
||||
CUDNN_PACKAGES = {
|
||||
"cudnn": """\
|
||||
cc_import(
|
||||
name = "cudnn",
|
||||
shared_library = "lib/libcudnn.so.9",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
deps = [
|
||||
":cudnn_adv",
|
||||
":cudnn_ops",
|
||||
":cudnn_cnn",
|
||||
":cudnn_graph",
|
||||
":cudnn_engines_precompiled",
|
||||
":cudnn_engines_runtime_compiled",
|
||||
":cudnn_heuristic",
|
||||
],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_adv",
|
||||
shared_library = "lib/libcudnn_adv.so.9",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_ops",
|
||||
shared_library = "lib/libcudnn_ops.so.9",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_cnn",
|
||||
shared_library = "lib/libcudnn_cnn.so.9",
|
||||
deps = [":cudnn_ops"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_graph",
|
||||
shared_library = "lib/libcudnn_graph.so.9",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_engines_precompiled",
|
||||
shared_library = "lib/libcudnn_engines_precompiled.so.9",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_engines_runtime_compiled",
|
||||
shared_library = "lib/libcudnn_engines_runtime_compiled.so.9",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cudnn_heuristic",
|
||||
shared_library = "lib/libcudnn_heuristic.so.9",
|
||||
)
|
||||
""",
|
||||
}
|
||||
|
||||
def _cuda_impl(mctx):
|
||||
CUDA_REDIST = json.decode(mctx.read(Label("@zml//runtimes/cuda:cuda.redistrib_{}.json".format(CUDA_VERSION))))
|
||||
CUDNN_REDIST = json.decode(mctx.read(Label("@zml//runtimes/cuda:cudnn.redistrib_{}.json".format(CUDNN_VERSION))))
|
||||
|
||||
for pkg, build_file_content in CUDA_PACKAGES.items():
|
||||
pkg_data = CUDA_REDIST[pkg]
|
||||
arch_data = pkg_data.get(ARCH)
|
||||
if not arch_data:
|
||||
continue
|
||||
http_archive(
|
||||
name = pkg,
|
||||
build_file_content = build_file_content,
|
||||
url = "https://developer.download.nvidia.com/compute/cuda/redist/" + arch_data["relative_path"],
|
||||
sha256 = arch_data["sha256"],
|
||||
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
||||
)
|
||||
|
||||
for pkg, build_file_content in CUDNN_PACKAGES.items():
|
||||
pkg_data = CUDNN_REDIST[pkg]
|
||||
arch_data = pkg_data.get(ARCH)
|
||||
if not arch_data:
|
||||
continue
|
||||
arch_data = arch_data.get("cuda12", arch_data)
|
||||
http_archive(
|
||||
name = pkg,
|
||||
build_file_content = build_file_content,
|
||||
url = "https://developer.download.nvidia.com/compute/cudnn/redist/" + arch_data["relative_path"],
|
||||
sha256 = arch_data["sha256"],
|
||||
strip_prefix = paths.basename(arch_data["relative_path"]).replace(".tar.xz", ""),
|
||||
)
|
||||
|
||||
http_deb_archive(
|
||||
name = "libnccl",
|
||||
urls = ["https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libnccl2_2.22.3-1+cuda12.6_amd64.deb"],
|
||||
sha256 = "2f64685bcd503150ab45d00503236a56da58a15eac5fd36508045a74f4e10678",
|
||||
build_file_content = """\
|
||||
cc_import(
|
||||
name = "nccl",
|
||||
shared_library = "usr/lib/x86_64-linux-gnu/libnccl.so.2",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
)
|
||||
http_deb_archive(
|
||||
name = "zlib",
|
||||
urls = ["http://archive.ubuntu.com/ubuntu/pool/main/z/zlib/zlib1g_1.3.dfsg-3.1ubuntu2.1_amd64.deb"],
|
||||
sha256 = "7074b6a2f6367a10d280c00a1cb02e74277709180bab4f2491a2f355ab2d6c20",
|
||||
build_file_content = """\
|
||||
cc_import(
|
||||
name = "zlib",
|
||||
shared_library = "usr/lib/x86_64-linux-gnu/libz.so.1",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
)
|
||||
http_archive(
|
||||
name = "libpjrt_cuda",
|
||||
build_file = "libpjrt_cuda.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-cuda_linux-amd64.tar.gz",
|
||||
sha256 = "b705f761e24d85ecd750df992a88715d9c461b7561c31722b9f878eeab32f39e",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = ["libpjrt_cuda"],
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
cuda_packages = module_extension(
|
||||
implementation = _cuda_impl,
|
||||
)
|
||||
1037
runtimes/cuda/cuda.redistrib_12.6.1.json
Normal file
1037
runtimes/cuda/cuda.redistrib_12.6.1.json
Normal file
File diff suppressed because it is too large
Load Diff
77
runtimes/cuda/cudnn.redistrib_9.3.0.json
Normal file
77
runtimes/cuda/cudnn.redistrib_9.3.0.json
Normal file
@ -0,0 +1,77 @@
|
||||
{
|
||||
"release_date": "2024-08-01",
|
||||
"release_label": "9.3.0",
|
||||
"release_product": "cudnn",
|
||||
"cudnn": {
|
||||
"name": "NVIDIA CUDA Deep Neural Network library",
|
||||
"license": "cudnn",
|
||||
"license_path": "cudnn/LICENSE.txt",
|
||||
"version": "9.3.0.75",
|
||||
"linux-x86_64": {
|
||||
"cuda11": {
|
||||
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.3.0.75_cuda11-archive.tar.xz",
|
||||
"sha256": "069da084cd368f39fb830d4c4e931803064fb65e766f4cf7df2da4a346a2ba9f",
|
||||
"md5": "4ab5cfb90cfe16c02cbdb88788165de5",
|
||||
"size": "748177916"
|
||||
},
|
||||
"cuda12": {
|
||||
"relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.3.0.75_cuda12-archive.tar.xz",
|
||||
"sha256": "3d6ef10aa06dc9339a477e2b057e085ff8500bbdee79e42c7e13655c9eff2c26",
|
||||
"md5": "2fa73268de8bbdab5560f4aa1a5a73ab",
|
||||
"size": "756509380"
|
||||
}
|
||||
},
|
||||
"cuda_variant": [
|
||||
"11",
|
||||
"12"
|
||||
],
|
||||
"linux-sbsa": {
|
||||
"cuda11": {
|
||||
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.3.0.75_cuda11-archive.tar.xz",
|
||||
"sha256": "04b56fbf7bee15c24e339c2ba94d17aa88b9e334d0cd19e75853dc5452794bf7",
|
||||
"md5": "eb3d809ff9c853721d342aa68564ed77",
|
||||
"size": "746866932"
|
||||
},
|
||||
"cuda12": {
|
||||
"relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.3.0.75_cuda12-archive.tar.xz",
|
||||
"sha256": "1226dd9b989c898638552963d000a3cfbb8a0a0cb9baf8cb09e6abd77ed7d639",
|
||||
"md5": "bdf5c7ba6ae34cc0dbcdfb98f3da746d",
|
||||
"size": "754450192"
|
||||
}
|
||||
},
|
||||
"windows-x86_64": {
|
||||
"cuda11": {
|
||||
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.3.0.75_cuda11-archive.zip",
|
||||
"sha256": "0e6f9d39343b88208b1daea280d62d6f7a90355395999806de6624c8361c36fc",
|
||||
"md5": "d7414555bd6c9e8a91139f00864fa2fa",
|
||||
"size": "562755549"
|
||||
},
|
||||
"cuda12": {
|
||||
"relative_path": "cudnn/windows-x86_64/cudnn-windows-x86_64-9.3.0.75_cuda12-archive.zip",
|
||||
"sha256": "864a85dc67c7f92b9a8639f323acb4af63ad65de2ca82dccdf2c0b6a701c27c0",
|
||||
"md5": "5f82d9233dd22a6664abff13cd22a224",
|
||||
"size": "566118754"
|
||||
}
|
||||
},
|
||||
"linux-aarch64": {
|
||||
"cuda12": {
|
||||
"relative_path": "cudnn/linux-aarch64/cudnn-linux-aarch64-9.3.0.75_cuda12-archive.tar.xz",
|
||||
"sha256": "1aae4bfced63f930b4677f9e928e197878dece2edd51ca8fa7ab8363d0e6ed60",
|
||||
"md5": "e5a94b1b0bbd313c0c11235d969ca028",
|
||||
"size": "796357916"
|
||||
}
|
||||
}
|
||||
},
|
||||
"cudnn_samples": {
|
||||
"name": "NVIDIA cuDNN samples",
|
||||
"license": "cudnn",
|
||||
"license_path": "cudnn_samples/LICENSE.txt",
|
||||
"version": "9.3.0.75",
|
||||
"source": {
|
||||
"relative_path": "cudnn_samples/source/cudnn_samples-source-9.3.0.75-archive.tar.xz",
|
||||
"sha256": "5da5536bba749158b245ddf69b4a0e2f4bf122e9b3c0e7406d8c2ddd5dd97518",
|
||||
"md5": "d855a9694631213fb6e78da90ae60fbe",
|
||||
"size": "1666860"
|
||||
}
|
||||
}
|
||||
}
|
||||
32
runtimes/cuda/libpjrt_cuda.BUILD.bazel
Normal file
32
runtimes/cuda/libpjrt_cuda.BUILD.bazel
Normal file
@ -0,0 +1,32 @@
|
||||
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
|
||||
copy_to_directory(
|
||||
name = "sandbox",
|
||||
srcs = [
|
||||
"@cuda_nvcc//:libdevice",
|
||||
"@cuda_nvcc//:ptxas",
|
||||
],
|
||||
include_external_repositories = ["**"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "libpjrt_cuda",
|
||||
data = [":sandbox"],
|
||||
shared_library = "libpjrt_cuda.so",
|
||||
visibility = ["@zml//runtimes/cuda:__subpackages__"],
|
||||
deps = [
|
||||
"@cuda_cudart//:cudart",
|
||||
"@cuda_cupti//:cupti",
|
||||
"@cuda_nvcc//:nvvm",
|
||||
"@cuda_nvrtc//:nvrtc",
|
||||
"@cudnn//:cudnn",
|
||||
"@libcublas//:cublas",
|
||||
"@libcufft//:cufft",
|
||||
"@libcusolver//:cusolver",
|
||||
"@libcusparse//:cusparse",
|
||||
"@libnccl//:nccl",
|
||||
"@libnvjitlink//:nvjitlink",
|
||||
"@zlib",
|
||||
],
|
||||
)
|
||||
21
runtimes/rocm/BUILD.bazel
Normal file
21
runtimes/rocm/BUILD.bazel
Normal file
@ -0,0 +1,21 @@
|
||||
filegroup(
|
||||
name = "zmlrocmhooks_srcs",
|
||||
srcs = ["zmlrocmhooks.cc"],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "hipblaslt",
|
||||
actual = "@libpjrt_rocm//:hipblaslt",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "gfx",
|
||||
actual = "@libpjrt_rocm//:gfx",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "rocm",
|
||||
actual = "@libpjrt_rocm",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
49
runtimes/rocm/gfx.bzl
Normal file
49
runtimes/rocm/gfx.bzl
Normal file
@ -0,0 +1,49 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
|
||||
|
||||
_ALL_GFX = ["gfx900", "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1010", "gfx1012", "gfx1030", "gfx1100", "gfx1101", "gfx1102"]
|
||||
|
||||
def _compute_enabled_gfx(values):
|
||||
ret = {}
|
||||
for v in values:
|
||||
if (v == "all"):
|
||||
ret = {gfx: True for gfx in _ALL_GFX}
|
||||
elif (v == "none"):
|
||||
ret = {}
|
||||
else:
|
||||
ret[v] = True
|
||||
return ret
|
||||
|
||||
def _gfx_from_file(file):
|
||||
return file.basename[:-len(file.extension) - 1].rpartition("_")[-1].partition("-")[0]
|
||||
|
||||
def _is_file_enabled(file, enabled_gfx):
|
||||
gfx = _gfx_from_file(file)
|
||||
return gfx in enabled_gfx or gfx == "fallback"
|
||||
|
||||
def _bytecode_select_impl(ctx):
|
||||
enabled_gfx = _compute_enabled_gfx(ctx.attr.enabled_gfx[BuildSettingInfo].value)
|
||||
return [
|
||||
DefaultInfo(
|
||||
files = depset([
|
||||
file
|
||||
for file in ctx.files.bytecodes
|
||||
if _is_file_enabled(file, enabled_gfx)
|
||||
]),
|
||||
),
|
||||
]
|
||||
|
||||
bytecode_select = rule(
|
||||
implementation = _bytecode_select_impl,
|
||||
attrs = {
|
||||
"bytecodes": attr.label_list(allow_files = True),
|
||||
"enabled_gfx": attr.label(mandatory = True),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def if_gfx(gfx, value):
|
||||
return select({
|
||||
"@zml//runtimes/rocm:_{}".format(gfx): value,
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
88
runtimes/rocm/libpjrt_rocm.BUILD.bazel
Normal file
88
runtimes/rocm/libpjrt_rocm.BUILD.bazel
Normal file
@ -0,0 +1,88 @@
|
||||
load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_list_flag")
|
||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
|
||||
string_list_flag(
|
||||
name = "gfx",
|
||||
build_setting_default = ["all"],
|
||||
visibility = [
|
||||
"@rocblas//:__subpackages__",
|
||||
"@hipblaslt-dev//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
bool_flag(
|
||||
name = "hipblaslt",
|
||||
build_setting_default = True,
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "_hipblaslt",
|
||||
flag_values = {":hipblaslt": "True"},
|
||||
)
|
||||
|
||||
copy_to_directory(
|
||||
name = "sandbox",
|
||||
srcs = [
|
||||
"@rocm-device-libs//:runfiles",
|
||||
"@rocm-llvm//:lld",
|
||||
],
|
||||
include_external_repositories = ["*"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "zmlrocmhooks_lib",
|
||||
data = ["@rocblas//:runfiles"],
|
||||
srcs = ["@zml//runtimes/rocm:zmlrocmhooks_srcs"],
|
||||
linkopts = [
|
||||
"-lc",
|
||||
"-ldl",
|
||||
],
|
||||
deps = ["@bazel_tools//tools/cpp/runfiles"],
|
||||
)
|
||||
|
||||
cc_shared_library(
|
||||
name = "zmlrocmhooks_so",
|
||||
shared_lib_name = "libzmlrocmhooks.so.0",
|
||||
deps = [":zmlrocmhooks_lib"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "zmlrocmhooks",
|
||||
shared_library = ":zmlrocmhooks_so",
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "libpjrt_rocm",
|
||||
data = [
|
||||
":sandbox",
|
||||
"@rocblas//:runfiles",
|
||||
] + select({
|
||||
":_hipblaslt": ["@hipblaslt-dev//:runfiles"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
shared_library = "libpjrt_rocm.so",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":zmlrocmhooks",
|
||||
"@comgr//:amd_comgr",
|
||||
"@hip-runtime-amd//:amdhip",
|
||||
"@hipblaslt",
|
||||
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
|
||||
"@hsa-rocr//:hsa-runtime",
|
||||
"@miopen-hip//:MIOpen",
|
||||
"@rccl",
|
||||
"@rocblas",
|
||||
"@rocm-core",
|
||||
"@rocm-smi-lib//:rocm_smi",
|
||||
"@rocprofiler-register",
|
||||
"@roctracer",
|
||||
"@libelf",
|
||||
"@libdrm",
|
||||
"@libnuma",
|
||||
"@libzstd",
|
||||
"@libdrm-amdgpu",
|
||||
"@libtinfo",
|
||||
"@zlib1g",
|
||||
],
|
||||
)
|
||||
11376
runtimes/rocm/packages.amd64.txt
Normal file
11376
runtimes/rocm/packages.amd64.txt
Normal file
File diff suppressed because it is too large
Load Diff
242
runtimes/rocm/rocm.bzl
Normal file
242
runtimes/rocm/rocm.bzl
Normal file
@ -0,0 +1,242 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("//bazel:http_deb_archive.bzl", "http_deb_archive")
|
||||
|
||||
ROCM_VERSION = "6.2"
|
||||
BASE_URL = "https://repo.radeon.com/rocm/apt/{}".format(ROCM_VERSION)
|
||||
STRIP_PREFIX = "opt/rocm-6.2.0"
|
||||
|
||||
def pkg_kwargs(pkg, packages):
|
||||
return {
|
||||
"name": pkg,
|
||||
"urls": [BASE_URL + "/" + packages[pkg]["Filename"]],
|
||||
"sha256": packages[pkg]["SHA256"],
|
||||
"strip_prefix": STRIP_PREFIX,
|
||||
}
|
||||
|
||||
def _ubuntu_package(path, deb_path, sha256, name, shared_library):
|
||||
return {
|
||||
"urls": ["http://archive.ubuntu.com/ubuntu/pool/main/{}".format(path)],
|
||||
"sha256": sha256,
|
||||
"build_file_content": """\
|
||||
cc_import(
|
||||
name = {name},
|
||||
shared_library = "{deb_path}{shared_library}",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
""".format(name = repr(name), shared_library = shared_library, deb_path = deb_path),
|
||||
}
|
||||
|
||||
_UBUNTU_PACKAGES = {
|
||||
"libdrm": _ubuntu_package(
|
||||
path = "libd/libdrm/libdrm2_2.4.107-8ubuntu1~20.04.2_amd64.deb",
|
||||
deb_path = "usr/lib/x86_64-linux-gnu/",
|
||||
sha256 = "9b01d73313841abe8e3f24c2715edced675fbe329bbd10be912a5b135cd51fb6",
|
||||
name = "libdrm",
|
||||
shared_library = "libdrm.so.2",
|
||||
),
|
||||
"libelf": _ubuntu_package(
|
||||
path = "e/elfutils/libelf1_0.176-1.1build1_amd64.deb",
|
||||
deb_path = "usr/lib/x86_64-linux-gnu/",
|
||||
sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3",
|
||||
name = "libelf",
|
||||
shared_library = "libelf.so.1",
|
||||
),
|
||||
"libnuma": _ubuntu_package(
|
||||
path = "n/numactl/libnuma1_2.0.12-1_amd64.deb",
|
||||
deb_path = "usr/lib/x86_64-linux-gnu/",
|
||||
sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637",
|
||||
name = "libnuma",
|
||||
shared_library = "libnuma.so.1",
|
||||
),
|
||||
"libzstd": _ubuntu_package(
|
||||
path = "libz/libzstd/libzstd1_1.4.4+dfsg-3ubuntu0.1_amd64.deb",
|
||||
deb_path = "usr/lib/x86_64-linux-gnu/",
|
||||
sha256 = "7a4422dadb90510dc90765c308d65e61a3e244ceb3886394335e48cff7559e69",
|
||||
name = "libzstd",
|
||||
shared_library = "libzstd.so.1",
|
||||
),
|
||||
"libdrm-amdgpu": _ubuntu_package(
|
||||
path = "libd/libdrm/libdrm-amdgpu1_2.4.107-8ubuntu1~20.04.2_amd64.deb",
|
||||
deb_path = "usr/lib/x86_64-linux-gnu/",
|
||||
sha256 = "0d95779b581f344e3d658e0f21f6e4b57da6eb3606c0bcb8cb874c12f5754bf2",
|
||||
name = "libdrm-amdgpu",
|
||||
shared_library = "libdrm_amdgpu.so.1",
|
||||
),
|
||||
"libtinfo": _ubuntu_package(
|
||||
path = "n/ncurses/libtinfo6_6.2-0ubuntu2.1_amd64.deb",
|
||||
deb_path = "lib/x86_64-linux-gnu/",
|
||||
sha256 = "711a3a901c3a71561565558865699efa9c07a99fdc810ffe086a5636f89c6431",
|
||||
name = "libtinfo",
|
||||
shared_library = "libtinfo.so.6",
|
||||
),
|
||||
"zlib1g": _ubuntu_package(
|
||||
path = "z/zlib/zlib1g_1.2.11.dfsg-2ubuntu1.5_amd64.deb",
|
||||
deb_path = "lib/x86_64-linux-gnu/",
|
||||
sha256 = "bf67018f5303466eb468680b637a5d3f3bb17b9d44decf3d82d40b35babcd3e0",
|
||||
name = "zlib1g",
|
||||
shared_library = "libz.so.1",
|
||||
),
|
||||
}
|
||||
|
||||
_CC_IMPORT_TPL = """\
|
||||
cc_import(
|
||||
name = "{name}",
|
||||
shared_library = "lib/{shared_library}",
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
"""
|
||||
|
||||
_RUNFILES_TPL = """\
|
||||
filegroup(
|
||||
name = "{name}",
|
||||
srcs = glob({glob}),
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
"""
|
||||
|
||||
_PACKAGES = {
|
||||
"rocm-core": _CC_IMPORT_TPL.format(name = "rocm-core", shared_library = "librocm-core.so.1"),
|
||||
"rocm-smi-lib": _CC_IMPORT_TPL.format(name = "rocm_smi", shared_library = "librocm_smi64.so.7"),
|
||||
"hsa-rocr": _CC_IMPORT_TPL.format(name = "hsa-runtime", shared_library = "libhsa-runtime64.so.1"),
|
||||
"hsa-amd-aqlprofile": _CC_IMPORT_TPL.format(name = "hsa-amd-aqlprofile", shared_library = "libhsa-amd-aqlprofile64.so.1"),
|
||||
"comgr": _CC_IMPORT_TPL.format(name = "amd_comgr", shared_library = "libamd_comgr.so.2"),
|
||||
"rocprofiler-register": _CC_IMPORT_TPL.format(name = "rocprofiler-register", shared_library = "librocprofiler-register.so.0"),
|
||||
"miopen-hip": "".join([
|
||||
_CC_IMPORT_TPL.format(name = "MIOpen", shared_library = "libMIOpen.so.1"),
|
||||
_RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/miopen/**"])),
|
||||
]),
|
||||
"rccl": "".join([
|
||||
_CC_IMPORT_TPL.format(name = "rccl", shared_library = "librccl.so.1"),
|
||||
_RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/rccl/msccl-algorithms/**"])),
|
||||
]),
|
||||
"rocm-device-libs": _RUNFILES_TPL.format(name = "runfiles", glob = repr(["amdgcn/**"])),
|
||||
"hip-dev": _RUNFILES_TPL.format(name = "runfiles", glob = repr(["share/**"])),
|
||||
"rocblas": """\
|
||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
|
||||
|
||||
cc_import(
|
||||
name = "rocblas",
|
||||
shared_library = "lib/librocblas.so.4",
|
||||
add_needed = ["libzmlrocmhooks.so.0"],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
|
||||
bytecode_select(
|
||||
name = "runfiles",
|
||||
bytecodes = glob(["lib/rocblas/library/*"]),
|
||||
enabled_gfx = "@libpjrt_rocm//:gfx",
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
"roctracer": """\
|
||||
cc_import(
|
||||
name = "roctracer",
|
||||
shared_library = "lib/libroctracer64.so.4",
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
deps = [":roctx"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "roctx",
|
||||
shared_library = "lib/libroctx64.so.4",
|
||||
)
|
||||
""",
|
||||
"hipblaslt": """\
|
||||
load("@zml//bazel:cc_import.bzl", "cc_import")
|
||||
cc_import(
|
||||
name = "hipblaslt",
|
||||
shared_library = "lib/libhipblaslt.so.0",
|
||||
add_needed = ["libzmlrocmhooks.so.0"],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
"hipblaslt-dev": """\
|
||||
load("@zml//runtimes/rocm:gfx.bzl", "bytecode_select")
|
||||
|
||||
bytecode_select(
|
||||
name = "bytecodes",
|
||||
bytecodes = glob(
|
||||
include = ["lib/hipblaslt/library/*"],
|
||||
exclude = ["lib/hipblaslt/library/hipblasltExtOpLibrary.dat"],
|
||||
),
|
||||
enabled_gfx = "@libpjrt_rocm//:gfx",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "runfiles",
|
||||
srcs = [
|
||||
"lib/hipblaslt/library/hipblasltExtOpLibrary.dat",
|
||||
":bytecodes",
|
||||
],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
"hip-runtime-amd": """\
|
||||
cc_import(
|
||||
name = "amdhip",
|
||||
shared_library = "lib/libamdhip64.so.6",
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
deps = [":hiprtc"],
|
||||
)
|
||||
cc_import(
|
||||
name = "hiprtc",
|
||||
shared_library = "lib/libhiprtc.so.6",
|
||||
)
|
||||
""",
|
||||
"rocm-llvm": """\
|
||||
filegroup(
|
||||
name = "lld",
|
||||
srcs = ["llvm/bin/ld.lld"],
|
||||
visibility = ["@libpjrt_rocm//:__subpackages__"],
|
||||
)
|
||||
""",
|
||||
}
|
||||
|
||||
def _packages_to_dict(txt):
|
||||
packages = {}
|
||||
current_pkg = {}
|
||||
for line in txt.splitlines():
|
||||
if line == "":
|
||||
if current_pkg:
|
||||
packages[current_pkg["Package"]] = current_pkg
|
||||
current_pkg = {}
|
||||
continue
|
||||
if line.startswith(" "):
|
||||
current_pkg[key] += line
|
||||
continue
|
||||
split = line.split(": ", 1)
|
||||
key = split[0]
|
||||
value = len(split) > 1 and split[1] or ""
|
||||
current_pkg[key] = value
|
||||
return packages
|
||||
|
||||
def _rocm_impl(mctx):
|
||||
data = mctx.read(Label("@zml//runtimes/rocm:packages.amd64.txt"))
|
||||
PACKAGES = _packages_to_dict(data)
|
||||
|
||||
for pkg, build_file_content in _PACKAGES.items():
|
||||
http_deb_archive(
|
||||
build_file_content = build_file_content,
|
||||
**pkg_kwargs(pkg, PACKAGES)
|
||||
)
|
||||
|
||||
for repository, kwargs in _UBUNTU_PACKAGES.items():
|
||||
http_deb_archive(name = repository, **kwargs)
|
||||
|
||||
http_archive(
|
||||
name = "libpjrt_rocm",
|
||||
build_file = "libpjrt_rocm.BUILD.bazel",
|
||||
url = "https://github.com/zml/pjrt-artifacts/releases/download/v0.1.13/pjrt-rocm_linux-amd64.tar.gz",
|
||||
sha256 = "5900cec41274e80ab799bc13f31cdc87202f8e168d7e753b1c10796912f5ebef",
|
||||
)
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = ["libpjrt_rocm"],
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
rocm_packages = module_extension(
|
||||
implementation = _rocm_impl,
|
||||
)
|
||||
103
runtimes/rocm/zmlrocmhooks.cc
Normal file
103
runtimes/rocm/zmlrocmhooks.cc
Normal file
@ -0,0 +1,103 @@
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <dlfcn.h>
|
||||
#include <errno.h>
|
||||
#include <fstream>
|
||||
#include <stdlib.h>
|
||||
#include "tools/cpp/runfiles/runfiles.h"
|
||||
|
||||
namespace zml
|
||||
{
|
||||
using bazel::tools::cpp::runfiles::Runfiles;
|
||||
|
||||
std::unique_ptr<Runfiles> runfiles;
|
||||
std::string ROCBLAS_TENSILE_LIBPATH;
|
||||
std::string HIPBLASLT_TENSILE_LIBPATH;
|
||||
std::string HIPBLASLT_EXT_OP_LIBRARY_PATH;
|
||||
std::string ROCM_PATH;
|
||||
|
||||
typedef void *(*dlopen_func)(const char *filename, int flags);
|
||||
dlopen_func dlopen_orig = nullptr;
|
||||
|
||||
__attribute__((constructor)) static void setup(int argc, char **argv)
|
||||
{
|
||||
runfiles = std::unique_ptr<Runfiles>(Runfiles::Create(argv[0], BAZEL_CURRENT_REPOSITORY));
|
||||
|
||||
HIPBLASLT_EXT_OP_LIBRARY_PATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library/hipblasltExtOpLibrary.dat");
|
||||
if (HIPBLASLT_EXT_OP_LIBRARY_PATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_EXT_OP_LIBRARY_PATH", HIPBLASLT_EXT_OP_LIBRARY_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
HIPBLASLT_TENSILE_LIBPATH = runfiles->Rlocation("hipblaslt-dev/lib/hipblaslt/library");
|
||||
if (HIPBLASLT_TENSILE_LIBPATH != "")
|
||||
{
|
||||
setenv("HIPBLASLT_TENSILE_LIBPATH", HIPBLASLT_TENSILE_LIBPATH.c_str(), 1);
|
||||
}
|
||||
|
||||
ROCBLAS_TENSILE_LIBPATH = runfiles->Rlocation("rocblas/lib/rocblas/library");
|
||||
setenv("ROCBLAS_TENSILE_LIBPATH", ROCBLAS_TENSILE_LIBPATH.c_str(), 1);
|
||||
|
||||
ROCM_PATH = runfiles->Rlocation("libpjrt_rocm/sandbox");
|
||||
setenv("ROCM_PATH", ROCM_PATH.c_str(), 1);
|
||||
}
|
||||
|
||||
static void *rocm_dlopen(const char *filename, int flags)
|
||||
{
|
||||
if (filename != NULL)
|
||||
{
|
||||
char *replacements[] = {
|
||||
"librocm-core.so",
|
||||
"librocm-core.so.1",
|
||||
"librocm_smi64.so",
|
||||
"librocm_smi64.so.7",
|
||||
"libhsa-runtime64.so",
|
||||
"libhsa-runtime64.so.1",
|
||||
"libhsa-amd-aqlprofile64.so",
|
||||
"libhsa-amd-aqlprofile64.so.1",
|
||||
"libamd_comgr.so",
|
||||
"libamd_comgr.so.2",
|
||||
"librocprofiler-register.so",
|
||||
"librocprofiler-register.so.0",
|
||||
"libMIOpen.so",
|
||||
"libMIOpen.so.1",
|
||||
"librccl.so",
|
||||
"librccl.so.1",
|
||||
"librocblas.so",
|
||||
"librocblas.so.4",
|
||||
"libroctracer64.so",
|
||||
"libroctracer64.so.4",
|
||||
"libroctx64.so",
|
||||
"libroctx64.so.4",
|
||||
"libhipblaslt.so",
|
||||
"libhipblaslt.so.0",
|
||||
"libamdhip64.so",
|
||||
"libamdhip64.so.6",
|
||||
"libhiprtc.so",
|
||||
"libhiprtc.so.6",
|
||||
NULL,
|
||||
NULL,
|
||||
};
|
||||
for (int i = 0; replacements[i] != NULL; i += 2)
|
||||
{
|
||||
if (strcmp(filename, replacements[i]) == 0)
|
||||
{
|
||||
filename = replacements[i + 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dlopen_orig(filename, flags);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C"
|
||||
{
|
||||
zml::dlopen_func _zml_rocm_resolve_dlopen()
|
||||
{
|
||||
zml::dlopen_orig = (zml::dlopen_func)dlsym(RTLD_NEXT, "dlopen");
|
||||
return zml::rocm_dlopen;
|
||||
}
|
||||
|
||||
extern void *dlopen(const char *filename, int flags) __attribute__((ifunc("_zml_rocm_resolve_dlopen")));
|
||||
}
|
||||
5
runtimes/tpu/BUILD.bazel
Normal file
5
runtimes/tpu/BUILD.bazel
Normal file
@ -0,0 +1,5 @@
|
||||
alias(
|
||||
name = "tpu",
|
||||
actual = "@libpjrt_tpu",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
14
runtimes/tpu/libpjrt_tpu.BUILD.bazel
Normal file
14
runtimes/tpu/libpjrt_tpu.BUILD.bazel
Normal file
@ -0,0 +1,14 @@
|
||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
||||
|
||||
copy_file(
|
||||
name = "libpjrt_tpu_so",
|
||||
src = "libtpu/libtpu.so",
|
||||
out = "libpjrt_tpu.so",
|
||||
allow_symlink = True,
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "libpjrt_tpu",
|
||||
shared_library = ":libpjrt_tpu_so",
|
||||
visibility = ["@zml//runtimes/tpu:__subpackages__"],
|
||||
)
|
||||
20
runtimes/tpu/tpu.bzl
Normal file
20
runtimes/tpu/tpu.bzl
Normal file
@ -0,0 +1,20 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
def _tpu_impl(mctx):
|
||||
# https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
http_archive(
|
||||
name = "libpjrt_tpu",
|
||||
url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20240915+nightly-py3-none-any.whl",
|
||||
type = "zip",
|
||||
sha256 = "4eee6c7dc92e7e60334c5b0261308a0a07c2f5b6235c7c60f465263576c602bf",
|
||||
build_file = "libpjrt_tpu.BUILD.bazel",
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = ["libpjrt_tpu"],
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
tpu_packages = module_extension(
|
||||
implementation = _tpu_impl,
|
||||
)
|
||||
6
third_party/bazel_registry.json
vendored
Normal file
6
third_party/bazel_registry.json
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"mirrors": [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/"
|
||||
],
|
||||
"module_base_path": "."
|
||||
}
|
||||
78
third_party/modules/aspect_bazel_lib/2.8.1.1/MODULE.bazel
vendored
Normal file
78
third_party/modules/aspect_bazel_lib/2.8.1.1/MODULE.bazel
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
"aspect-build/bazel-lib"
|
||||
|
||||
module(
|
||||
name = "aspect_bazel_lib",
|
||||
version = "2.8.1.1",
|
||||
bazel_compatibility = [">=6.0.0"],
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
# Lower-bounds (minimum) versions for direct runtime dependencies
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
|
||||
# 0.5.4 is the first version with bzlmod support
|
||||
bazel_dep(name = "stardoc", version = "0.6.2", repo_name = "io_bazel_stardoc")
|
||||
|
||||
bazel_lib_toolchains = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "toolchains")
|
||||
bazel_lib_toolchains.copy_directory()
|
||||
bazel_lib_toolchains.copy_to_directory()
|
||||
bazel_lib_toolchains.jq()
|
||||
bazel_lib_toolchains.yq()
|
||||
bazel_lib_toolchains.coreutils()
|
||||
bazel_lib_toolchains.tar()
|
||||
bazel_lib_toolchains.zstd()
|
||||
bazel_lib_toolchains.expand_template()
|
||||
bazel_lib_toolchains.bats()
|
||||
use_repo(bazel_lib_toolchains, "bats_toolchains", "bsd_tar_toolchains", "copy_directory_toolchains", "copy_to_directory_toolchains", "coreutils_toolchains", "expand_template_toolchains", "jq_toolchains", "yq_toolchains", "zstd_toolchains")
|
||||
|
||||
register_toolchains(
|
||||
"@copy_directory_toolchains//:all",
|
||||
"@copy_to_directory_toolchains//:all",
|
||||
"@jq_toolchains//:all",
|
||||
"@yq_toolchains//:all",
|
||||
"@coreutils_toolchains//:all",
|
||||
"@expand_template_toolchains//:all",
|
||||
"@bats_toolchains//:all",
|
||||
"@bsd_tar_toolchains//:all",
|
||||
"@zstd_toolchains//:all",
|
||||
)
|
||||
|
||||
####### Dev dependencies ########
|
||||
|
||||
# To allow /tools to be built from source
|
||||
# NOTE: when publishing to BCR, we patch this to be dev_dependency, as we publish pre-built binaries
|
||||
# along with our releases.
|
||||
|
||||
bazel_dep(
|
||||
name = "gazelle",
|
||||
version = "0.36.0",
|
||||
dev_dependency = True,
|
||||
)
|
||||
bazel_dep(
|
||||
name = "rules_go",
|
||||
version = "0.46.0",
|
||||
repo_name = "io_bazel_rules_go",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
go_deps = use_extension(
|
||||
"@gazelle//:extensions.bzl",
|
||||
"go_deps",
|
||||
dev_dependency = True,
|
||||
)
|
||||
go_deps.from_file(go_mod = "//:go.mod")
|
||||
use_repo(
|
||||
go_deps,
|
||||
"com_github_bmatcuk_doublestar_v4",
|
||||
"org_golang_x_exp",
|
||||
"org_golang_x_sys",
|
||||
)
|
||||
|
||||
host = use_extension("@aspect_bazel_lib//lib:extensions.bzl", "host", dev_dependency = True)
|
||||
host.host()
|
||||
use_repo(host, "aspect_bazel_lib_host")
|
||||
|
||||
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.5.0", dev_dependency = True)
|
||||
bazel_dep(name = "buildifier_prebuilt", version = "6.4.0", dev_dependency = True)
|
||||
bazel_dep(name = "bazel_features", version = "0.2.0", dev_dependency = True)
|
||||
@ -0,0 +1,50 @@
|
||||
From 5dbe6a2604e440b684add7b44531898acc8631b4 Mon Sep 17 00:00:00 2001
|
||||
From: Steeve Morin <steeve.morin@gmail.com>
|
||||
Date: Sun, 8 Sep 2024 21:17:29 +0200
|
||||
Subject: [PATCH] fix: add repo mapping to tar archive
|
||||
|
||||
When using bzlmod, runfiles lookup will fail without it.
|
||||
---
|
||||
lib/private/tar.bzl | 17 +++++++++++++----
|
||||
1 file changed, 13 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/lib/private/tar.bzl b/lib/private/tar.bzl
|
||||
index 733ff60..29434f6 100644
|
||||
--- a/lib/private/tar.bzl
|
||||
+++ b/lib/private/tar.bzl
|
||||
@@ -147,12 +147,19 @@ def _tar_impl(ctx):
|
||||
args.add(ctx.file.mtree, format = "@%s")
|
||||
inputs.append(ctx.file.mtree)
|
||||
|
||||
+ src_runfiles = []
|
||||
+ for src in ctx.attr.srcs:
|
||||
+ src_di = src[DefaultInfo]
|
||||
+ if getattr(src_di.files_to_run, "repo_mapping_manifest", None) != None:
|
||||
+ src_runfiles.append(depset(
|
||||
+ direct = [src_di.files_to_run.repo_mapping_manifest],
|
||||
+ transitive = [src_di.default_runfiles.files],
|
||||
+ ))
|
||||
+ else:
|
||||
+ src_runfiles.append(src_di.default_runfiles.files)
|
||||
ctx.actions.run(
|
||||
executable = bsdtar.tarinfo.binary,
|
||||
- inputs = depset(direct = inputs, transitive = [bsdtar.default.files] + [
|
||||
- src[DefaultInfo].default_runfiles.files
|
||||
- for src in ctx.attr.srcs
|
||||
- ]),
|
||||
+ inputs = depset(direct = inputs, transitive = [bsdtar.default.files] + src_runfiles),
|
||||
outputs = [out],
|
||||
arguments = [args],
|
||||
mnemonic = "Tar",
|
||||
@@ -234,6 +241,8 @@ def _mtree_impl(ctx):
|
||||
workspace_name = str(ctx.workspace_name)
|
||||
|
||||
content.add(_mtree_line(runfiles_dir, type = "dir"))
|
||||
+ if getattr(default_info.files_to_run, "repo_mapping_manifest", None) != None:
|
||||
+ content.add(_mtree_line("{}/_repo_mapping".format(runfiles_dir), type = "file", content = default_info.files_to_run.repo_mapping_manifest.path))
|
||||
content.add_all(
|
||||
s.default_runfiles.files,
|
||||
expand_directories = True,
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
|
||||
27
third_party/modules/aspect_bazel_lib/2.8.1.1/patches/go_dev_dep.patch
vendored
Normal file
27
third_party/modules/aspect_bazel_lib/2.8.1.1/patches/go_dev_dep.patch
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
diff --git a/MODULE.bazel b/MODULE.bazel
|
||||
index e63fa5b..9d78a88 100644
|
||||
--- a/MODULE.bazel
|
||||
+++ b/MODULE.bazel
|
||||
@@ -50,19 +50,19 @@ use_repo(host, "aspect_bazel_lib_host")
|
||||
bazel_dep(
|
||||
name = "gazelle",
|
||||
version = "0.36.0",
|
||||
- # In released versions: dev_dependency = True
|
||||
+ dev_dependency = True,
|
||||
)
|
||||
bazel_dep(
|
||||
name = "rules_go",
|
||||
version = "0.46.0",
|
||||
repo_name = "io_bazel_rules_go",
|
||||
- # In released versions: dev_dependency = True
|
||||
+ dev_dependency = True,
|
||||
)
|
||||
|
||||
go_deps = use_extension(
|
||||
"@gazelle//:extensions.bzl",
|
||||
"go_deps",
|
||||
- # In released versions: dev_dependency = True
|
||||
+ dev_dependency = True,
|
||||
)
|
||||
go_deps.from_file(go_mod = "//:go.mod")
|
||||
use_repo(
|
||||
14
third_party/modules/aspect_bazel_lib/2.8.1.1/patches/module_dot_bazel_version.patch
vendored
Normal file
14
third_party/modules/aspect_bazel_lib/2.8.1.1/patches/module_dot_bazel_version.patch
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
===================================================================
|
||||
--- a/MODULE.bazel
|
||||
+++ b/MODULE.bazel
|
||||
@@ -1,9 +1,9 @@
|
||||
"aspect-build/bazel-lib"
|
||||
|
||||
module(
|
||||
name = "aspect_bazel_lib",
|
||||
- version = "0.0.0",
|
||||
+ version = "2.8.1",
|
||||
bazel_compatibility = [">=6.0.0"],
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
11
third_party/modules/aspect_bazel_lib/2.8.1.1/source.json
vendored
Normal file
11
third_party/modules/aspect_bazel_lib/2.8.1.1/source.json
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"integrity": "sha256-aINU7mvuunGUJD1z6wmSuaEujt7u7FtlRPS1MaMRIjc=",
|
||||
"strip_prefix": "bazel-lib-2.8.1",
|
||||
"url": "https://github.com/aspect-build/bazel-lib/releases/download/v2.8.1/bazel-lib-v2.8.1.tar.gz",
|
||||
"patches": {
|
||||
"go_dev_dep.patch": "sha256-DTc/hk+etl4D50M0BLRik2vHbrgDb6rds+Dj4xphWb4=",
|
||||
"module_dot_bazel_version.patch": "sha256-cEMv6bY7Sc5dERv8YG0lq45zuZCsVPNn4oAN9aOkf40=",
|
||||
"0001-fix-add-repo-mapping-to-tar-archive.patch": ""
|
||||
},
|
||||
"patch_strip": 1
|
||||
}
|
||||
25
third_party/modules/aspect_bazel_lib/metadata.json
vendored
Normal file
25
third_party/modules/aspect_bazel_lib/metadata.json
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"homepage": "https://docs.aspect.build/rules/aspect_bazel_lib",
|
||||
"maintainers": [
|
||||
{
|
||||
"name": "Alex Eagle",
|
||||
"email": "alex@aspect.dev",
|
||||
"github": "alexeagle"
|
||||
},
|
||||
{
|
||||
"name": "Derek Cormier",
|
||||
"email": "derek@aspect.dev",
|
||||
"github": "kormide"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:aspect-build/bazel-lib"
|
||||
],
|
||||
"versions": [
|
||||
"2.8.1.1"
|
||||
],
|
||||
"yanked_versions": {
|
||||
"1.31.0": "1.31.0 has a breaking change to the default yq version",
|
||||
"1.34.2": "The version field has a leading 'v' due to a release automation bug"
|
||||
}
|
||||
}
|
||||
7
third_party/modules/libxev/20240825.0-dbe2291/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20240825.0-dbe2291/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "libxev",
|
||||
version = "20240825.0-dbe2291",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||
13
third_party/modules/libxev/20240825.0-dbe2291/overlay/BUILD.bazel
vendored
Normal file
13
third_party/modules/libxev/20240825.0-dbe2291/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "xev",
|
||||
srcs = glob([
|
||||
"src/*.zig",
|
||||
"src/backend/*.zig",
|
||||
"src/linux/*.zig",
|
||||
"src/watcher/*.zig",
|
||||
]),
|
||||
main = "src/main.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
7
third_party/modules/libxev/20240825.0-dbe2291/overlay/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20240825.0-dbe2291/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "libxev",
|
||||
version = "20240825.0-dbe2291",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||
9
third_party/modules/libxev/20240825.0-dbe2291/source.json
vendored
Normal file
9
third_party/modules/libxev/20240825.0-dbe2291/source.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"strip_prefix": "libxev-dbe22910a43e9e8ec9948d3cbd73d8488a074967",
|
||||
"url": "https://github.com/mitchellh/libxev/archive/dbe22910a43e9e8ec9948d3cbd73d8488a074967.tar.gz",
|
||||
"integrity": "sha256-JM70bkRUyuAQXKV1piJ6I0IClzCrXV43sdLfIKVC9Lo=",
|
||||
"overlay": {
|
||||
"MODULE.bazel": "",
|
||||
"BUILD.bazel": ""
|
||||
}
|
||||
}
|
||||
7
third_party/modules/libxev/20240910.0-a2d9b31/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20240910.0-a2d9b31/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "libxev",
|
||||
version = "20240910.0-a2d9b31",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||
13
third_party/modules/libxev/20240910.0-a2d9b31/overlay/BUILD.bazel
vendored
Normal file
13
third_party/modules/libxev/20240910.0-a2d9b31/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "xev",
|
||||
srcs = glob([
|
||||
"src/*.zig",
|
||||
"src/backend/*.zig",
|
||||
"src/linux/*.zig",
|
||||
"src/watcher/*.zig",
|
||||
]),
|
||||
main = "src/main.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
7
third_party/modules/libxev/20240910.0-a2d9b31/overlay/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20240910.0-a2d9b31/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "libxev",
|
||||
version = "20240910.0-a2d9b31",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||
9
third_party/modules/libxev/20240910.0-a2d9b31/source.json
vendored
Normal file
9
third_party/modules/libxev/20240910.0-a2d9b31/source.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"strip_prefix": "libxev-a2d9b3154f1e5ed463c9f2fb98a0d739057796ce",
|
||||
"url": "https://github.com/zml/libxev/archive/a2d9b3154f1e5ed463c9f2fb98a0d739057796ce.tar.gz",
|
||||
"integrity": "sha256-xdMjrGSB3telt52iRqHgem0WSRPjfGlTKBRyDMXLlaQ=",
|
||||
"overlay": {
|
||||
"MODULE.bazel": "",
|
||||
"BUILD.bazel": ""
|
||||
}
|
||||
}
|
||||
18
third_party/modules/libxev/metadata.json
vendored
Normal file
18
third_party/modules/libxev/metadata.json
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"homepage": "https://github.com/mitchellh/libxev",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:mitchellh/libxev"
|
||||
],
|
||||
"versions": [
|
||||
"20240825.0-dbe2291",
|
||||
"20240910.0-a2d9b31"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
10
third_party/modules/llvm-raw/20240823.0-f142f8a/MODULE.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240823.0-f142f8a/MODULE.bazel
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
module(
|
||||
name = "llvm-raw",
|
||||
version = "20240823.0-f142f8a",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||
0
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/BUILD.bazel
vendored
Normal file
0
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/BUILD.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/MODULE.bazel
vendored
Normal file
10
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
module(
|
||||
name = "llvm-raw",
|
||||
version = "20240823.0-f142f8a",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "zstd", version = "1.5.6", repo_name = "llvm_zstd")
|
||||
bazel_dep(name = "zlib", version = "1.3.1.bcr.3", repo_name = "llvm_zlib")
|
||||
28
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/utils/bazel/extension.bzl
vendored
Normal file
28
third_party/modules/llvm-raw/20240823.0-f142f8a/overlay/utils/bazel/extension.bzl
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
load("//utils/bazel:configure.bzl", _llvm_configure = "llvm_configure")
|
||||
|
||||
def _llvm_impl(mctx):
|
||||
_targets = {}
|
||||
for mod in mctx.modules:
|
||||
for conf in mod.tags.configure:
|
||||
for target in conf.targets:
|
||||
_targets[target] = True
|
||||
_llvm_configure(
|
||||
name = "llvm-project",
|
||||
targets = _targets.keys(),
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
llvm = module_extension(
|
||||
implementation = _llvm_impl,
|
||||
tag_classes = {
|
||||
"configure": tag_class(
|
||||
attrs = {
|
||||
"targets": attr.string_list(mandatory = True),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
10
third_party/modules/llvm-raw/20240823.0-f142f8a/source.json
vendored
Normal file
10
third_party/modules/llvm-raw/20240823.0-f142f8a/source.json
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"strip_prefix": "llvm-project-f142f8afe21bceb00fb495468aa0b5043e98c419",
|
||||
"url": "https://github.com/llvm/llvm-project/archive/f142f8afe21bceb00fb495468aa0b5043e98c419.tar.gz",
|
||||
"integrity": "sha256-8ftYOnMGq99igt2I1n9SOZpYLq79+F7HiKL/7lOwcHs=",
|
||||
"overlay": {
|
||||
"BUILD.bazel": "",
|
||||
"MODULE.bazel": "",
|
||||
"utils/bazel/extension.bzl": ""
|
||||
}
|
||||
}
|
||||
17
third_party/modules/llvm-raw/metadata.json
vendored
Normal file
17
third_party/modules/llvm-raw/metadata.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"homepage": "https://github.com/llvm/llvm-project",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:llvm/llvm-project"
|
||||
],
|
||||
"versions": [
|
||||
"20240823.0-f142f8a"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
68
third_party/modules/rules_zig/20240904.0-010da15/MODULE.bazel
vendored
Normal file
68
third_party/modules/rules_zig/20240904.0-010da15/MODULE.bazel
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
module(
|
||||
name = "rules_zig",
|
||||
version = "20240904.0-010da15",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
|
||||
zig = use_extension("//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//zig/private:versions.json")
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
|
||||
zig_dev = use_extension(
|
||||
"//zig:extensions.bzl",
|
||||
"zig",
|
||||
dev_dependency = True,
|
||||
)
|
||||
zig_dev.toolchain(zig_version = "0.13.0")
|
||||
zig_dev.toolchain(zig_version = "0.12.1")
|
||||
zig_dev.toolchain(zig_version = "0.12.0")
|
||||
zig_dev.toolchain(zig_version = "0.11.0")
|
||||
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9", dev_dependency = True)
|
||||
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
|
||||
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
|
||||
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "buildifier_prebuilt",
|
||||
version = "7.3.1",
|
||||
dev_dependency = True,
|
||||
)
|
||||
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
|
||||
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "rules_bazel_integration_test",
|
||||
version = "0.25.0",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
bazel_binaries = use_extension(
|
||||
"@rules_bazel_integration_test//:extensions.bzl",
|
||||
"bazel_binaries",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
# NOTE: Keep in sync with WORKSPACE.
|
||||
bazel_binaries.download(version_file = "//:.bazelversion")
|
||||
bazel_binaries.download(version = "7.0.0")
|
||||
use_repo(
|
||||
bazel_binaries,
|
||||
"bazel_binaries",
|
||||
"bazel_binaries_bazelisk",
|
||||
"build_bazel_bazel_.bazelversion",
|
||||
"build_bazel_bazel_7_0_0",
|
||||
)
|
||||
|
||||
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
|
||||
# However, if we do not include it explicitly, then the runfiles resolution for
|
||||
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
|
||||
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
|
||||
# through the rules_multirun target //util:update.
|
||||
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
|
||||
5
third_party/modules/rules_zig/20240904.0-010da15/source.json
vendored
Normal file
5
third_party/modules/rules_zig/20240904.0-010da15/source.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"type": "git_repository",
|
||||
"remote": "https://github.com/zml/rules_zig.git",
|
||||
"commit": "010da15abb4335479778d6b4fb2ca752a0ab80e3"
|
||||
}
|
||||
68
third_party/modules/rules_zig/20240909.0-37f17ff/MODULE.bazel
vendored
Normal file
68
third_party/modules/rules_zig/20240909.0-37f17ff/MODULE.bazel
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
module(
|
||||
name = "rules_zig",
|
||||
version = "20240909.0-37f17ff",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
|
||||
zig = use_extension("//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//zig/private:versions.json")
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
|
||||
zig_dev = use_extension(
|
||||
"//zig:extensions.bzl",
|
||||
"zig",
|
||||
dev_dependency = True,
|
||||
)
|
||||
zig_dev.toolchain(zig_version = "0.13.0")
|
||||
zig_dev.toolchain(zig_version = "0.12.1")
|
||||
zig_dev.toolchain(zig_version = "0.12.0")
|
||||
zig_dev.toolchain(zig_version = "0.11.0")
|
||||
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9", dev_dependency = True)
|
||||
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
|
||||
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
|
||||
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "buildifier_prebuilt",
|
||||
version = "7.3.1",
|
||||
dev_dependency = True,
|
||||
)
|
||||
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
|
||||
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "rules_bazel_integration_test",
|
||||
version = "0.25.0",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
bazel_binaries = use_extension(
|
||||
"@rules_bazel_integration_test//:extensions.bzl",
|
||||
"bazel_binaries",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
# NOTE: Keep in sync with WORKSPACE.
|
||||
bazel_binaries.download(version_file = "//:.bazelversion")
|
||||
bazel_binaries.download(version = "7.0.0")
|
||||
use_repo(
|
||||
bazel_binaries,
|
||||
"bazel_binaries",
|
||||
"bazel_binaries_bazelisk",
|
||||
"build_bazel_bazel_.bazelversion",
|
||||
"build_bazel_bazel_7_0_0",
|
||||
)
|
||||
|
||||
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
|
||||
# However, if we do not include it explicitly, then the runfiles resolution for
|
||||
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
|
||||
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
|
||||
# through the rules_multirun target //util:update.
|
||||
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
|
||||
5
third_party/modules/rules_zig/20240909.0-37f17ff/source.json
vendored
Normal file
5
third_party/modules/rules_zig/20240909.0-37f17ff/source.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"type": "git_repository",
|
||||
"remote": "https://github.com/zml/rules_zig.git",
|
||||
"commit": "37f17ffd2f6d04fc4bf9b583f9a1b117856a88d1"
|
||||
}
|
||||
68
third_party/modules/rules_zig/20240912.0-41bfe84/MODULE.bazel
vendored
Normal file
68
third_party/modules/rules_zig/20240912.0-41bfe84/MODULE.bazel
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
module(
|
||||
name = "rules_zig",
|
||||
version = "20240912.0-41bfe84",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
|
||||
zig = use_extension("//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//zig/private:versions.json")
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
|
||||
zig_dev = use_extension(
|
||||
"//zig:extensions.bzl",
|
||||
"zig",
|
||||
dev_dependency = True,
|
||||
)
|
||||
zig_dev.toolchain(zig_version = "0.13.0")
|
||||
zig_dev.toolchain(zig_version = "0.12.1")
|
||||
zig_dev.toolchain(zig_version = "0.12.0")
|
||||
zig_dev.toolchain(zig_version = "0.11.0")
|
||||
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
|
||||
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
|
||||
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "buildifier_prebuilt",
|
||||
version = "7.3.1",
|
||||
dev_dependency = True,
|
||||
)
|
||||
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
|
||||
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "rules_bazel_integration_test",
|
||||
version = "0.25.0",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
bazel_binaries = use_extension(
|
||||
"@rules_bazel_integration_test//:extensions.bzl",
|
||||
"bazel_binaries",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
# NOTE: Keep in sync with WORKSPACE.
|
||||
bazel_binaries.download(version_file = "//:.bazelversion")
|
||||
bazel_binaries.download(version = "7.0.0")
|
||||
use_repo(
|
||||
bazel_binaries,
|
||||
"bazel_binaries",
|
||||
"bazel_binaries_bazelisk",
|
||||
"build_bazel_bazel_.bazelversion",
|
||||
"build_bazel_bazel_7_0_0",
|
||||
)
|
||||
|
||||
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
|
||||
# However, if we do not include it explicitly, then the runfiles resolution for
|
||||
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
|
||||
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
|
||||
# through the rules_multirun target //util:update.
|
||||
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
|
||||
5
third_party/modules/rules_zig/20240912.0-41bfe84/source.json
vendored
Normal file
5
third_party/modules/rules_zig/20240912.0-41bfe84/source.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"type": "git_repository",
|
||||
"remote": "https://github.com/zml/rules_zig.git",
|
||||
"commit": "41bfe84e4d9a43cbe55281dddc80b683cc6fc6eb"
|
||||
}
|
||||
68
third_party/modules/rules_zig/20240913.0-1957d05/MODULE.bazel
vendored
Normal file
68
third_party/modules/rules_zig/20240913.0-1957d05/MODULE.bazel
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
module(
|
||||
name = "rules_zig",
|
||||
version = "20240913.0-1957d05",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
|
||||
zig = use_extension("//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//zig/private:versions.json")
|
||||
use_repo(zig, "zig_toolchains")
|
||||
|
||||
register_toolchains("@rules_zig//zig/target:all")
|
||||
|
||||
register_toolchains("@zig_toolchains//:all")
|
||||
|
||||
zig_dev = use_extension(
|
||||
"//zig:extensions.bzl",
|
||||
"zig",
|
||||
dev_dependency = True,
|
||||
)
|
||||
zig_dev.toolchain(zig_version = "0.13.0")
|
||||
zig_dev.toolchain(zig_version = "0.12.1")
|
||||
zig_dev.toolchain(zig_version = "0.12.0")
|
||||
zig_dev.toolchain(zig_version = "0.11.0")
|
||||
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
|
||||
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
|
||||
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "buildifier_prebuilt",
|
||||
version = "7.3.1",
|
||||
dev_dependency = True,
|
||||
)
|
||||
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
|
||||
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
|
||||
bazel_dep(
|
||||
name = "rules_bazel_integration_test",
|
||||
version = "0.25.0",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
bazel_binaries = use_extension(
|
||||
"@rules_bazel_integration_test//:extensions.bzl",
|
||||
"bazel_binaries",
|
||||
dev_dependency = True,
|
||||
)
|
||||
|
||||
# NOTE: Keep in sync with WORKSPACE.
|
||||
bazel_binaries.download(version_file = "//:.bazelversion")
|
||||
bazel_binaries.download(version = "7.0.0")
|
||||
use_repo(
|
||||
bazel_binaries,
|
||||
"bazel_binaries",
|
||||
"bazel_binaries_bazelisk",
|
||||
"build_bazel_bazel_.bazelversion",
|
||||
"build_bazel_bazel_7_0_0",
|
||||
)
|
||||
|
||||
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
|
||||
# However, if we do not include it explicitly, then the runfiles resolution for
|
||||
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
|
||||
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
|
||||
# through the rules_multirun target //util:update.
|
||||
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
|
||||
5
third_party/modules/rules_zig/20240913.0-1957d05/source.json
vendored
Normal file
5
third_party/modules/rules_zig/20240913.0-1957d05/source.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"type": "git_repository",
|
||||
"remote": "https://github.com/zml/rules_zig.git",
|
||||
"commit": "1957d0572193fb859e721a0fab8bd8f0fb57f3ff"
|
||||
}
|
||||
20
third_party/modules/rules_zig/metadata.json
vendored
Normal file
20
third_party/modules/rules_zig/metadata.json
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"homepage": "https://github.com/zml/rules_zig",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:zml/rules_zig"
|
||||
],
|
||||
"versions": [
|
||||
"20240904.0-010da15",
|
||||
"20240909.0-37f17ff",
|
||||
"20240912.0-41bfe84",
|
||||
"20240913.0-1957d05"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/MODULE.bazel
vendored
Normal file
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "sentencepiece",
|
||||
version = "20240618.0-d7ace0a",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/overlay/BUILD.bazel
vendored
Normal file
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "sentencepiece_model_proto",
|
||||
srcs = ["src/sentencepiece_model.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/overlay/MODULE.bazel
vendored
Normal file
7
third_party/modules/sentencepiece/20240618.0-d7ace0a/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
module(
|
||||
name = "sentencepiece",
|
||||
version = "20240618.0-d7ace0a",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
9
third_party/modules/sentencepiece/20240618.0-d7ace0a/source.json
vendored
Normal file
9
third_party/modules/sentencepiece/20240618.0-d7ace0a/source.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"strip_prefix": "sentencepiece-d7ace0a4df17af17d73dc8fe6ff92cfd8d08e2b8",
|
||||
"url": "https://github.com/google/sentencepiece/archive/d7ace0a4df17af17d73dc8fe6ff92cfd8d08e2b8.tar.gz",
|
||||
"integrity": "sha256-bt8QFNNIpqcJoCNMUrmWfUf0uNobtk8pW+XbHL+ZRLg=",
|
||||
"overlay": {
|
||||
"MODULE.bazel": "",
|
||||
"BUILD.bazel": ""
|
||||
}
|
||||
}
|
||||
17
third_party/modules/sentencepiece/metadata.json
vendored
Normal file
17
third_party/modules/sentencepiece/metadata.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"homepage": "https://github.com/google/sentencepiece",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"google/sentencepiece"
|
||||
],
|
||||
"versions": [
|
||||
"20240618.0-d7ace0a"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
15
third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20240829.0-54aa1a5/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
module(
|
||||
name = "stablehlo",
|
||||
version = "20240829.0-54aa1a5",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = ["AArch64", "X86", "NVPTX"],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
15
third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel
vendored
Normal file
15
third_party/modules/stablehlo/20240829.0-54aa1a5/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
module(
|
||||
name = "stablehlo",
|
||||
version = "20240829.0-54aa1a5",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "llvm-raw", version = "20240823.0-f142f8a")
|
||||
|
||||
llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm")
|
||||
llvm.configure(
|
||||
targets = ["AArch64", "X86", "NVPTX"],
|
||||
)
|
||||
use_repo(llvm, "llvm-project")
|
||||
@ -0,0 +1,83 @@
|
||||
From f11d9bd7639f63cba681caa39cea27af114a4a71 Mon Sep 17 00:00:00 2001
|
||||
From: Steeve Morin <steeve.morin@gmail.com>
|
||||
Date: Mon, 2 Sep 2024 23:02:28 +0200
|
||||
Subject: [PATCH 1/2] Remove duplicated symbols in StablehloApi.h
|
||||
|
||||
This causes C compilers to choke.
|
||||
|
||||
Refs #2494
|
||||
---
|
||||
stablehlo/integrations/c/StablehloApi.cpp | 2 +-
|
||||
stablehlo/integrations/c/StablehloApi.h | 15 +--------------
|
||||
stablehlo/integrations/python/StablehloApi.cpp | 2 +-
|
||||
3 files changed, 3 insertions(+), 16 deletions(-)
|
||||
|
||||
diff --git a/stablehlo/integrations/c/StablehloApi.cpp b/stablehlo/integrations/c/StablehloApi.cpp
|
||||
index 8d922198..9b41ff21 100644
|
||||
--- a/stablehlo/integrations/c/StablehloApi.cpp
|
||||
+++ b/stablehlo/integrations/c/StablehloApi.cpp
|
||||
@@ -98,7 +98,7 @@ MlirLogicalResult stablehloSerializePortableArtifact(
|
||||
return mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
-MlirLogicalResult stablehloDeserializePortableArtifact(
|
||||
+MlirLogicalResult stablehloDeserializePortableArtifactToBytecode(
|
||||
MlirStringRef artifactStr, MlirStringCallback callback, void *userData) {
|
||||
mlir::detail::CallbackOstream stream(callback, userData);
|
||||
if (failed(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr),
|
||||
diff --git a/stablehlo/integrations/c/StablehloApi.h b/stablehlo/integrations/c/StablehloApi.h
|
||||
index 4c542508..f94dca8d 100644
|
||||
--- a/stablehlo/integrations/c/StablehloApi.h
|
||||
+++ b/stablehlo/integrations/c/StablehloApi.h
|
||||
@@ -76,16 +76,6 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
|
||||
MlirStringRef moduleStr, MlirStringRef targetVersion,
|
||||
MlirStringCallback callback, void* userData);
|
||||
|
||||
-// Write a StableHLO program expressed as a string (either prettyprinted MLIR
|
||||
-// module or MLIR bytecode) to a portable artifact.
|
||||
-// Can fail if `moduleStr` cannot be parsed, or if it cannot be expressed in the
|
||||
-// `targetVersion` version of StableHLO, e.g. if it's using new or removed
|
||||
-// features, or if it involves unsupported dialects.
|
||||
-// Returns false on failure.
|
||||
-MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
|
||||
- MlirModule moduleStr, MlirStringRef targetVersion,
|
||||
- MlirStringCallback callback, void* userData);
|
||||
-
|
||||
// Read a StableHLO program from a portable artifact, returning the module as
|
||||
// MLIR bytecode. Note, this bytecode returned is not a portable artifact,
|
||||
// and has the stability of returning textual assembly format. Bytecode is
|
||||
@@ -93,7 +83,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
|
||||
// Can fail if `artifactStr` cannot be expressed in the current version of
|
||||
// StableHLO, e.g. if it's using incompatible features.
|
||||
// Returns false on failure.
|
||||
-MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifact(
|
||||
+MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifactAsBytecode(
|
||||
MlirStringRef artifactStr, MlirStringCallback callback, void* userData);
|
||||
|
||||
// Read a StableHLO program from a portable artifact, returning the module as
|
||||
@@ -109,9 +99,6 @@ MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
|
||||
|
||||
// Call the Interpreter, returns MlirArrayAttr of dense element
|
||||
// MlirAttribute results
|
||||
-MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
|
||||
- MlirStringRef artifactStr, MlirContext ctx);
|
||||
-
|
||||
// Entrypoint for calling the StableHLO reference interpreter.
|
||||
// Returns an array attribute of dense element attributes for results.
|
||||
// Sets error code to non-zero on failure.
|
||||
diff --git a/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/integrations/python/StablehloApi.cpp
|
||||
index 46a640e1..4229ef76 100644
|
||||
--- a/stablehlo/integrations/python/StablehloApi.cpp
|
||||
+++ b/stablehlo/integrations/python/StablehloApi.cpp
|
||||
@@ -213,7 +213,7 @@ void AddPortableApi(py::module &m) {
|
||||
"deserialize_portable_artifact_str",
|
||||
[](std::string_view artifact) -> py::bytes {
|
||||
StringWriterHelper accumulator;
|
||||
- if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifact(
|
||||
+ if (mlirLogicalResultIsFailure(stablehloDeserializePortableArtifactToBytecode(
|
||||
toMlirStringRef(artifact), accumulator.getMlirStringCallback(),
|
||||
accumulator.getUserData()))) {
|
||||
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
|
||||
401
third_party/modules/stablehlo/20240829.0-54aa1a5/patches/temporary.patch
vendored
Normal file
401
third_party/modules/stablehlo/20240829.0-54aa1a5/patches/temporary.patch
vendored
Normal file
@ -0,0 +1,401 @@
|
||||
diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel
|
||||
--- stablehlo/BUILD.bazel
|
||||
+++ stablehlo/BUILD.bazel
|
||||
@@ -340,6 +340,21 @@
|
||||
],
|
||||
)
|
||||
|
||||
+gentbl_cc_library(
|
||||
+ name = "stablehlo_create_compatibility_expander_inc_gen",
|
||||
+ tbl_outs = [
|
||||
+ (
|
||||
+ ["--gen-rewriters"],
|
||||
+ "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc",
|
||||
+ ),
|
||||
+ ],
|
||||
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
+ td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td",
|
||||
+ deps = [
|
||||
+ ":stablehlo_ops_td_files",
|
||||
+ ],
|
||||
+)
|
||||
+
|
||||
cc_library(
|
||||
name = "interpreter_ops",
|
||||
srcs = [
|
||||
@@ -1086,6 +1101,7 @@
|
||||
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
|
||||
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
|
||||
"stablehlo/transforms/StablehloConvertToSignless.cpp",
|
||||
+ "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp",
|
||||
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
|
||||
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
|
||||
"stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp",
|
||||
@@ -1109,6 +1125,7 @@
|
||||
":chlo_ops",
|
||||
":chlo_rewriters_inc_gen",
|
||||
":linalg_passes",
|
||||
+ ":stablehlo_create_compatibility_expander_inc_gen",
|
||||
":stablehlo_legalize_deprecated_ops_inc_gen",
|
||||
":stablehlo_ops",
|
||||
":stablehlo_ops_inc_gen",
|
||||
diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp
|
||||
--- stablehlo/stablehlo/dialect/Version.cpp
|
||||
+++ stablehlo/stablehlo/dialect/Version.cpp
|
||||
@@ -82,7 +82,7 @@
|
||||
case CompatibilityRequirement::WEEK_4:
|
||||
return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024
|
||||
case CompatibilityRequirement::WEEK_12:
|
||||
- return Version(1, 0, 0); // v1.0.0 - May 14, 2024
|
||||
+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024
|
||||
case CompatibilityRequirement::MAX:
|
||||
return Version::getMinimumVersion();
|
||||
}
|
||||
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||||
--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||||
+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir
|
||||
@@ -0,0 +1,43 @@
|
||||
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK
|
||||
+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+// CHECK-LABEL @tan_op_non_complex
|
||||
+// CHECK: %[[sine0:.*]] = stablehlo.sine %arg0 : tensor<4xf64>
|
||||
+// CHECK-NEXT: %[[cosine1:.*]] = stablehlo.cosine %arg0 : tensor<4xf64>
|
||||
+// CHECK-NEXT: %[[div2:.*]] = stablehlo.divide %[[sine0]], %[[cosine1]] : tensor<4xf64>
|
||||
+// CHECK-NEXT: return %[[div2]] : tensor<4xf64>
|
||||
+func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||||
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %arg0 : tensor<4xf64>
|
||||
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
|
||||
+ func.return %1 : tensor<4xf64>
|
||||
+}
|
||||
+
|
||||
+// -----
|
||||
+
|
||||
+// CHECK-LABEL: @tan_op_complex
|
||||
+// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4xf64>
|
||||
+// CHECK: %[[complex:.*]] = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
|
||||
+// CHECK: %[[real:.*]] = stablehlo.real %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+// CHECK: %[[sine:.*]] = stablehlo.sine %[[real]] : tensor<4xf64>
|
||||
+// CHECK: %[[cosine:.*]] = stablehlo.cosine %[[real]] : tensor<4xf64>
|
||||
+// CHECK: %[[divide1:.*]] = stablehlo.divide %[[sine]], %[[cosine]] : tensor<4xf64>
|
||||
+// CHECK: %[[imag:.*]] = stablehlo.imag %[[complex]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+// CHECK: %[[tanh:.*]] = stablehlo.tanh %[[imag]] : tensor<4xf64>
|
||||
+// CHECK: %[[complex2:.*]] = stablehlo.complex %[[divide1]], %[[tanh]] : tensor<4xcomplex<f64>>
|
||||
+// CHECK: %[[multiply:.*]] = stablehlo.multiply %[[divide1]], %[[tanh]] : tensor<4xf64>
|
||||
+// CHECK: %[[negate:.*]] = stablehlo.negate %[[multiply]] : tensor<4xf64>
|
||||
+// CHECK: %[[complex3:.*]] = stablehlo.complex %[[cst]], %[[negate]] : tensor<4xcomplex<f64>>
|
||||
+// CHECK: %[[divide2:.*]] = stablehlo.divide %[[complex2]], %[[complex3]] : tensor<4xcomplex<f64>>
|
||||
+// CHECK: %[[real2:.*]] = stablehlo.real %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+// CHECK: %[[imag2:.*]] = stablehlo.imag %[[divide2]] : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+// CHECK: return %[[real2]], %[[imag2]] : tensor<4xf64>, tensor<4xf64>
|
||||
+func.func @tan_op_complex(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> (tensor<4xf64>, tensor<4xf64>) {
|
||||
+ %0 = stablehlo.complex %arg0, %arg1 : tensor<4xcomplex<f64>>
|
||||
+ // CHECK-NO-DOWNGRADE: stablehlo.tan %0 : tensor<4xcomplex<f64>>
|
||||
+ %1 = stablehlo.tan %0 : tensor<4xcomplex<f64>>
|
||||
+ %2 = stablehlo.real %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+ %3 = stablehlo.imag %1 : (tensor<4xcomplex<f64>>) -> tensor<4xf64>
|
||||
+ func.return %2, %3 : tensor<4xf64>, tensor<4xf64>
|
||||
+}
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt
|
||||
--- stablehlo/stablehlo/transforms/CMakeLists.txt
|
||||
+++ stablehlo/stablehlo/transforms/CMakeLists.txt
|
||||
@@ -20,6 +20,10 @@
|
||||
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
|
||||
add_public_tablegen_target(ChloDecompositionPatternsIncGen)
|
||||
|
||||
+set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td)
|
||||
+mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters)
|
||||
+add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen)
|
||||
+
|
||||
set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td)
|
||||
mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters)
|
||||
add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen)
|
||||
@@ -27,6 +31,7 @@
|
||||
set(LLVM_TARGET_DEFINITIONS VhloToVersionPatterns.td)
|
||||
mlir_tablegen(VhloToVersionPatterns.h.inc --gen-rewriters)
|
||||
add_public_tablegen_target(VhloToVersionPatterns)
|
||||
+
|
||||
|
||||
add_mlir_dialect_library(StablehloPasses
|
||||
PARTIAL_SOURCES_INTENDED
|
||||
@@ -37,6 +42,7 @@
|
||||
StablehloAggressiveSimplification.cpp
|
||||
StablehloCanonicalizeDynamism.cpp
|
||||
StablehloConvertToSignless.cpp
|
||||
+ StablehloCreateCompatibilityExpander.cpp
|
||||
StablehloLegalizeCompositeToCall.cpp
|
||||
StablehloLegalizeDeprecatedOps.cpp
|
||||
StablehloLegalizeQuantToMath.cpp
|
||||
@@ -53,6 +59,7 @@
|
||||
StablehloLegalizeDeprecatedOpsPatternsIncGen
|
||||
PassesIncGen
|
||||
VhloToVersionPatterns
|
||||
+ StablehloCreateCompatibilityExpanderPatternsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
ChloOps
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
|
||||
--- stablehlo/stablehlo/transforms/Passes.h
|
||||
+++ stablehlo/stablehlo/transforms/Passes.h
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
+#include "stablehlo/dialect/Version.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace stablehlo {
|
||||
@@ -96,6 +97,12 @@
|
||||
void populateShapeToStablehloPatterns(MLIRContext *context,
|
||||
RewritePatternSet *patterns);
|
||||
|
||||
+/// Collection of patterns to create compatibility expander for StableHLO
|
||||
+/// operations.
|
||||
+void populateStablehloCreateCompatibilityExpanderPatterns(
|
||||
+ RewritePatternSet *patterns, MLIRContext *context,
|
||||
+ vhlo::Version targetVersion);
|
||||
+
|
||||
//// Additional pass constructors ////
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
|
||||
--- stablehlo/stablehlo/transforms/Passes.td
|
||||
+++ stablehlo/stablehlo/transforms/Passes.td
|
||||
@@ -292,3 +292,51 @@
|
||||
"mlir::stablehlo::StablehloDialect",
|
||||
];
|
||||
}
|
||||
+
|
||||
+def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> {
|
||||
+ let summary = "Create compatibility expander for StableHLO operations.";
|
||||
+
|
||||
+ let description = [{
|
||||
+ StableHLO ops gets updates or new op is introduced in the latest versions.
|
||||
+ This opt-in pass expands backward compatibility with older StableHLO
|
||||
+ versions by decomposing newer StableHLO operations into equivalent
|
||||
+ operations supported by those older versions.
|
||||
+
|
||||
+ Why is this an opt-in pass?
|
||||
+
|
||||
+ Occasionally, StableHLO op enhancements are used to greatly simplify the
|
||||
+ handling of certain common patterns in the OpenXLA ecosystem. This
|
||||
+ includes things like TanOp, which has high framework and compiler support,
|
||||
+ as well as gather/scatter batching dimensions, which can be represented
|
||||
+ using slices, but makes sharding much more difficult. For this category of
|
||||
+ new features, we do not offer automatic downgrade, since it may throw away
|
||||
+ important information used in subsequent optimizations. This pass can be
|
||||
+ used to expand these ops based on a target version to maximize compatibility
|
||||
+ at the expense of potentially less optimal compilation.
|
||||
+
|
||||
+ ```mlir
|
||||
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||||
+ %1 = stablehlo.tan %arg0 : tensor<4xf64>
|
||||
+ func.return %1 : tensor<4xf64>
|
||||
+ }
|
||||
+ ```
|
||||
+
|
||||
+ will become:
|
||||
+
|
||||
+ ```mlir
|
||||
+ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
|
||||
+ %0 = stablehlo.sine %arg0 : tensor<4xf64>
|
||||
+ %1 = stablehlo.cosine %arg0 : tensor<4xf64>
|
||||
+ %2 = stablehlo.divide %0, %1 : tensor<4xf64>
|
||||
+ return %2 : tensor<4xf64>
|
||||
+ }
|
||||
+ ```
|
||||
+ }];
|
||||
+ let options = [
|
||||
+ Option<"targetVersionOption", "target", "std::string", "",
|
||||
+ "The target version. Must be a version of the form #.#.#.">,
|
||||
+ ];
|
||||
+ let dependentDialects = [
|
||||
+ "mlir::stablehlo::StablehloDialect",
|
||||
+ ];
|
||||
+}
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||||
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||||
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp
|
||||
@@ -0,0 +1,128 @@
|
||||
+/* Copyright 2024 The StableHLO Authors. All Rights Reserved.
|
||||
+Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+you may not use this file except in compliance with the License.
|
||||
+You may obtain a copy of the License at
|
||||
+ http://www.apache.org/licenses/LICENSE-2.0
|
||||
+Unless required by applicable law or agreed to in writing, software
|
||||
+distributed under the License is distributed on an "AS IS" BASIS,
|
||||
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
+See the License for the specific language governing permissions and
|
||||
+limitations under the License.
|
||||
+==============================================================================*/
|
||||
+
|
||||
+#include <fcntl.h>
|
||||
+
|
||||
+#include <cassert>
|
||||
+
|
||||
+#include "llvm/ADT/APFloat.h"
|
||||
+#include "llvm/Support/ErrorHandling.h"
|
||||
+#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
+#include "mlir/IR/BuiltinAttributes.h"
|
||||
+#include "mlir/IR/PatternMatch.h"
|
||||
+#include "mlir/Support/LLVM.h"
|
||||
+#include "mlir/Transforms/DialectConversion.h"
|
||||
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
+#include "stablehlo/dialect/StablehloOps.h"
|
||||
+#include "stablehlo/dialect/Version.h"
|
||||
+#include "stablehlo/transforms/Passes.h"
|
||||
+
|
||||
+namespace mlir {
|
||||
+namespace stablehlo {
|
||||
+#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS
|
||||
+#include "stablehlo/transforms/Passes.h.inc"
|
||||
+
|
||||
+namespace {
|
||||
+
|
||||
+//===----------------------------------------------------------------------===//
|
||||
+// Helpers.
|
||||
+//===----------------------------------------------------------------------===//
|
||||
+
|
||||
+// Creates a constant with all ones.
|
||||
+static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) {
|
||||
+ auto shapedTy = dyn_cast<mlir::ShapedType>(val.getType());
|
||||
+ if (!shapedTy) llvm_unreachable("Unsupported shaped type.");
|
||||
+
|
||||
+ mlir::DenseElementsAttr elementsAttr =
|
||||
+ mlir::DenseElementsAttr::get(shapedTy, 1.0);
|
||||
+
|
||||
+ return b.create<mlir::stablehlo::ConstantOp>(loc, val.getType(),
|
||||
+ elementsAttr);
|
||||
+}
|
||||
+
|
||||
+// Check user-specified target version.
|
||||
+vhlo::Version validateTargetVersion(llvm::StringRef versionRef) {
|
||||
+ auto failOrVersion = vhlo::Version::fromString(versionRef);
|
||||
+ if (failed(failOrVersion)) {
|
||||
+ assert(!versionRef.empty() &&
|
||||
+ "No target version specified. Target version must be of the form "
|
||||
+ "`#.#.#`.");
|
||||
+ assert(versionRef.empty() &&
|
||||
+ "Invalid target version argument. Target version must be of the "
|
||||
+ "form `#.#.#`.");
|
||||
+ }
|
||||
+ vhlo::Version targetVersion = *failOrVersion;
|
||||
+ assert((vhlo::Version::getMinimumVersion() <= targetVersion) &&
|
||||
+ "target version is less than minimum supported.");
|
||||
+ assert((targetVersion <= vhlo::Version::getCurrentVersion()) &&
|
||||
+ "target version is greater than current version.");
|
||||
+ return targetVersion;
|
||||
+}
|
||||
+
|
||||
+//===----------------------------------------------------------------------===//
|
||||
+// Pass
|
||||
+//===----------------------------------------------------------------------===//
|
||||
+
|
||||
+struct StablehloCreateCompatibilityExpanderPass
|
||||
+ : public impl::StablehloCreateCompatibilityExpanderPassBase<
|
||||
+ StablehloCreateCompatibilityExpanderPass> {
|
||||
+ StablehloCreateCompatibilityExpanderPass()
|
||||
+ : StablehloCreateCompatibilityExpanderPassBase<
|
||||
+ StablehloCreateCompatibilityExpanderPass>() {}
|
||||
+ StablehloCreateCompatibilityExpanderPass(
|
||||
+ const StablehloCreateCompatibilityExpanderPassOptions &opts)
|
||||
+ : StablehloCreateCompatibilityExpanderPassBase<
|
||||
+ StablehloCreateCompatibilityExpanderPass>(opts) {}
|
||||
+
|
||||
+ public:
|
||||
+ LogicalResult initialize(MLIRContext *context) override {
|
||||
+ auto targetVersion = validateTargetVersion(targetVersionOption);
|
||||
+
|
||||
+ config.useTopDownTraversal = true;
|
||||
+ RewritePatternSet patterns_(context);
|
||||
+ populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context,
|
||||
+ targetVersion);
|
||||
+ patterns = std::move(patterns_);
|
||||
+ return success();
|
||||
+ }
|
||||
+
|
||||
+ void runOnOperation() override {
|
||||
+ auto func = getOperation();
|
||||
+ if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
|
||||
+ func.emitError(
|
||||
+ "Failed to converge StableHLOCreateCompatibilityExpanderPass in ")
|
||||
+ << config.maxIterations << " iterations";
|
||||
+ signalPassFailure();
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ private:
|
||||
+ FrozenRewritePatternSet patterns;
|
||||
+ GreedyRewriteConfig config;
|
||||
+};
|
||||
+
|
||||
+#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc"
|
||||
+
|
||||
+} // namespace
|
||||
+
|
||||
+void populateStablehloCreateCompatibilityExpanderPatterns(
|
||||
+ RewritePatternSet *patterns, MLIRContext *context,
|
||||
+ vhlo::Version targetVersion) {
|
||||
+ // StableHLO TanOp is introduced in v1.4.0.
|
||||
+ if (targetVersion < vhlo::Version(1, 4, 0)) {
|
||||
+ patterns->add<TanOp_ComplexElementType_CompatiblityExpander>(context);
|
||||
+ patterns->add<TanOp_CompatiblityExpander>(context);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+} // namespace stablehlo
|
||||
+} // namespace mlir
|
||||
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||||
--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||||
+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td
|
||||
@@ -0,0 +1,47 @@
|
||||
+/* Copyright 2022 The StableHLO Authors.
|
||||
+
|
||||
+Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+you may not use this file except in compliance with the License.
|
||||
+You may obtain a copy of the License at
|
||||
+
|
||||
+ http://www.apache.org/licenses/LICENSE-2.0
|
||||
+
|
||||
+Unless required by applicable law or agreed to in writing, software
|
||||
+distributed under the License is distributed on an "AS IS" BASIS,
|
||||
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
+See the License for the specific language governing permissions and
|
||||
+limitations under the License.
|
||||
+==============================================================================*/
|
||||
+
|
||||
+include "mlir/IR/OpBase.td"
|
||||
+include "stablehlo/dialect/StablehloOps.td"
|
||||
+
|
||||
+def ComplexElementType : Type<
|
||||
+ CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
|
||||
+ "Complex element type">;
|
||||
+
|
||||
+def NonComplexElementType : Type<
|
||||
+ CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
|
||||
+ "Non-complex element type">;
|
||||
+
|
||||
+def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">;
|
||||
+
|
||||
+// Express `tan` as
|
||||
+// sine(x) / cosine(x)
|
||||
+def TanOp_CompatiblityExpander : Pat<(StableHLO_TanOp NonComplexElementType:$input),
|
||||
+ (StableHLO_DivOp
|
||||
+ (StableHLO_SineOp $input),
|
||||
+ (StableHLO_CosineOp $input)
|
||||
+ )>;
|
||||
+
|
||||
+// Express `tan(a + bi)` as
|
||||
+// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
|
||||
+def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp ComplexElementType:$input),
|
||||
+ (StableHLO_DivOp
|
||||
+ (StableHLO_ComplexOp
|
||||
+ (StableHLO_TanOp:$tan (StableHLO_RealOp $input)),
|
||||
+ (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))),
|
||||
+ (StableHLO_ComplexOp
|
||||
+ (createConstantWithAllOnes $tan),
|
||||
+ (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
|
||||
+ )>;
|
||||
|
||||
12
third_party/modules/stablehlo/20240829.0-54aa1a5/source.json
vendored
Normal file
12
third_party/modules/stablehlo/20240829.0-54aa1a5/source.json
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"strip_prefix": "stablehlo-54aa1a57178251981da616b877dda1a88d840d11",
|
||||
"url": "https://github.com/openxla/stablehlo/archive/54aa1a57178251981da616b877dda1a88d840d11.tar.gz",
|
||||
"integrity": "sha256-PFXXs3vQ+WlMiCC8twbGEFbsBEF/67Z4qtmq8CtUSUU=",
|
||||
"overlay": {
|
||||
"MODULE.bazel": ""
|
||||
},
|
||||
"patch_strip": 1,
|
||||
"patches": {
|
||||
"0001-Remove-duplicated-symbols-in-StablehloApi.h.patch": ""
|
||||
}
|
||||
}
|
||||
17
third_party/modules/stablehlo/metadata.json
vendored
Normal file
17
third_party/modules/stablehlo/metadata.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"homepage": "https://github.com/openxla/stablehlo",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:openxla/stablehlo"
|
||||
],
|
||||
"versions": [
|
||||
"20240829.0-54aa1a5"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
34
third_party/modules/xla/20240902.0-d18cd64/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20240902.0-d18cd64/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20240902.0-d18cd64",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
|
||||
|
||||
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||
use_repo(
|
||||
xla_workspace,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
)
|
||||
34
third_party/modules/xla/20240902.0-d18cd64/overlay/MODULE.bazel
vendored
Normal file
34
third_party/modules/xla/20240902.0-d18cd64/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
module(
|
||||
name = "xla",
|
||||
version = "20240902.0-d18cd64",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "platforms", version = "0.0.8")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_apple", version = "3.2.1", repo_name = "build_bazel_rules_apple")
|
||||
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||
bazel_dep(name = "zlib", version = "1.2.13")
|
||||
bazel_dep(name = "re2", version = "2024-02-01", repo_name = "com_googlesource_code_re2")
|
||||
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||
|
||||
bazel_dep(name = "stablehlo", version = "20240829.0-54aa1a5")
|
||||
|
||||
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||
use_repo(tsl, "tsl")
|
||||
|
||||
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||
use_repo(
|
||||
xla_workspace,
|
||||
"com_github_grpc_grpc",
|
||||
"com_google_protobuf",
|
||||
"local_config_cuda",
|
||||
"local_config_remote_execution",
|
||||
"local_config_rocm",
|
||||
"local_config_tensorrt",
|
||||
)
|
||||
13
third_party/modules/xla/20240902.0-d18cd64/overlay/tsl.bzl
vendored
Normal file
13
third_party/modules/xla/20240902.0-d18cd64/overlay/tsl.bzl
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
load("//third_party:repo.bzl", "tf_vendored")
|
||||
|
||||
def _tsl_impl(mctx):
|
||||
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
tsl = module_extension(
|
||||
implementation = _tsl_impl,
|
||||
)
|
||||
52
third_party/modules/xla/20240902.0-d18cd64/overlay/workspace.bzl
vendored
Normal file
52
third_party/modules/xla/20240902.0-d18cd64/overlay/workspace.bzl
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
load("@tsl//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("@tsl//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||
load("@tsl//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||
load("@tsl//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||
|
||||
def _xla_workspace_impl(mctx):
|
||||
cuda_configure(name = "local_config_cuda")
|
||||
remote_execution_configure(name = "local_config_remote_execution")
|
||||
rocm_configure(name = "local_config_rocm")
|
||||
tensorrt_configure(name = "local_config_tensorrt")
|
||||
tf_http_archive(
|
||||
name = "com_github_grpc_grpc",
|
||||
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||
system_build_file = "@tsl//third_party/systemlibs:grpc.BUILD",
|
||||
patch_file = [
|
||||
"@tsl//third_party/grpc:generate_cc_env_fix.patch",
|
||||
"@tsl//third_party/grpc:register_go_toolchain.patch",
|
||||
],
|
||||
system_link_files = {
|
||||
"@tsl//third_party/systemlibs:BUILD": "bazel/BUILD",
|
||||
"@tsl//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||
"@tsl//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||
)
|
||||
tf_http_archive(
|
||||
name = "com_google_protobuf",
|
||||
patch_file = ["@tsl//third_party/protobuf:protobuf.patch"],
|
||||
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||
strip_prefix = "protobuf-3.21.9",
|
||||
system_build_file = "@tsl//third_party/systemlibs:protobuf.BUILD",
|
||||
system_link_files = {
|
||||
"@tsl//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||
"@tsl//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||
},
|
||||
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||
)
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
root_module_direct_deps = "all",
|
||||
root_module_direct_dev_deps = [],
|
||||
)
|
||||
|
||||
xla_workspace = module_extension(
|
||||
implementation = _xla_workspace_impl,
|
||||
)
|
||||
@ -0,0 +1,27 @@
|
||||
From 4db5de34f70d991fedbe28915c8239b97ba7a064 Mon Sep 17 00:00:00 2001
|
||||
From: Steeve Morin <steeve.morin@gmail.com>
|
||||
Date: Mon, 18 Mar 2024 17:17:34 +0100
|
||||
Subject: [PATCH 3/3] [PJRT C API] Ensure C compliance for Profiler Extension
|
||||
|
||||
---
|
||||
xla/pjrt/c/pjrt_c_api_profiler_extension.h | 2 ++
|
||||
1 file changed, 2 insertions(+)
|
||||
|
||||
diff --git a/xla/pjrt/c/pjrt_c_api_profiler_extension.h b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
index c821916ad..89a596123 100644
|
||||
--- a/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
+++ b/xla/pjrt/c/pjrt_c_api_profiler_extension.h
|
||||
@@ -16,8 +16,10 @@ limitations under the License.
|
||||
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
|
||||
|
||||
+#ifdef __cplusplus
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
+#endif
|
||||
|
||||
#include "xla/backends/profiler/plugin/profiler_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
|
||||
14
third_party/modules/xla/20240902.0-d18cd64/source.json
vendored
Normal file
14
third_party/modules/xla/20240902.0-d18cd64/source.json
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"strip_prefix": "xla-d18cd64b7cd61a2ade10089665ac104f639101b1",
|
||||
"url": "https://github.com/openxla/xla/archive/d18cd64b7cd61a2ade10089665ac104f639101b1.tar.gz",
|
||||
"integrity": "sha256-EtKhjU91STBceUmg0TUE6cPeRkeSz3LI2S1i3EFMj/E=",
|
||||
"overlay": {
|
||||
"tsl.bzl": "",
|
||||
"workspace.bzl": "",
|
||||
"MODULE.bazel": ""
|
||||
},
|
||||
"patch_strip": 1,
|
||||
"patches": {
|
||||
"0003-PJRT-C-API-Ensure-C-compliance-for-Profiler-Extensio.patch": ""
|
||||
}
|
||||
}
|
||||
17
third_party/modules/xla/metadata.json
vendored
Normal file
17
third_party/modules/xla/metadata.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"homepage": "https://github.com/openxla/xla",
|
||||
"maintainers": [
|
||||
{
|
||||
"email": "bzlmod@zml.ai",
|
||||
"github": "zml",
|
||||
"name": "ZML Engineering Team"
|
||||
}
|
||||
],
|
||||
"repository": [
|
||||
"github:openxla/xla"
|
||||
],
|
||||
"versions": [
|
||||
"20240902.0-d18cd64"
|
||||
],
|
||||
"yanked_versions": {}
|
||||
}
|
||||
8
third_party/modules/zig-protobuf/20240722.0-c644d11/MODULE.bazel
vendored
Normal file
8
third_party/modules/zig-protobuf/20240722.0-c644d11/MODULE.bazel
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
module(
|
||||
name = "zig-protobuf",
|
||||
version = "20240722.0-c644d11",
|
||||
compatibility_level = 1,
|
||||
)
|
||||
|
||||
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||
32
third_party/modules/zig-protobuf/20240722.0-c644d11/overlay/BUILD.bazel
vendored
Normal file
32
third_party/modules/zig-protobuf/20240722.0-c644d11/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
load("@rules_proto//proto:defs.bzl", "proto_lang_toolchain")
|
||||
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "protobuf",
|
||||
import_name = "protobuf",
|
||||
main = "src/protobuf.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
zig_binary(
|
||||
name = "generator",
|
||||
srcs = [
|
||||
"bootstrapped-generator/FullName.zig",
|
||||
"bootstrapped-generator/google/protobuf/compiler/plugin.pb.zig",
|
||||
"bootstrapped-generator/google/protobuf/descriptor.pb.zig",
|
||||
],
|
||||
kind = BINARY_KIND.exe,
|
||||
main = "bootstrapped-generator/main.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":protobuf"],
|
||||
)
|
||||
|
||||
proto_lang_toolchain(
|
||||
name = "zig_toolchain",
|
||||
command_line = "--zig_out=$(OUT)",
|
||||
output_files = "multiple",
|
||||
plugin = ":generator",
|
||||
plugin_format_flag = "--plugin=protoc-gen-zig=%s",
|
||||
runtime = ":protobuf",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user