Add Zig 0.15 compatibility: update BUILD files, async primitives, stdx utilities, MLIR dialects, and PJRT FFI.
This commit is contained in:
parent
e3b7705e3d
commit
488a844a0f
@ -11,6 +11,7 @@ zls_completion(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//async",
|
"//async",
|
||||||
|
"//examples/llama",
|
||||||
"//stdx",
|
"//stdx",
|
||||||
"//zml",
|
"//zml",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -18,7 +18,7 @@ bazel_dep(name = "rules_oci", version = "2.2.6")
|
|||||||
bazel_dep(name = "rules_proto", version = "7.1.0")
|
bazel_dep(name = "rules_proto", version = "7.1.0")
|
||||||
bazel_dep(name = "rules_python", version = "1.5.3")
|
bazel_dep(name = "rules_python", version = "1.5.3")
|
||||||
bazel_dep(name = "rules_rust", version = "0.63.0")
|
bazel_dep(name = "rules_rust", version = "0.63.0")
|
||||||
bazel_dep(name = "rules_zig", version = "20250821.0-be53625")
|
bazel_dep(name = "rules_zig", version = "20250827.0-35b6d57")
|
||||||
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
|
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
|
||||||
bazel_dep(name = "with_cfg.bzl", version = "0.11.0")
|
bazel_dep(name = "with_cfg.bzl", version = "0.11.0")
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = Tr
|
|||||||
|
|
||||||
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
||||||
zig.index(file = "//bazel:zig_index.json")
|
zig.index(file = "//bazel:zig_index.json")
|
||||||
zig.toolchain(zig_version = "0.14.1")
|
zig.toolchain(zig_version = "0.15.1")
|
||||||
zig.mirrors(urls = [
|
zig.mirrors(urls = [
|
||||||
"https://mirror.zml.ai/zig",
|
"https://mirror.zml.ai/zig",
|
||||||
"https://ziglang.org/builds/",
|
"https://ziglang.org/builds/",
|
||||||
|
|||||||
208
async/async.zig
208
async/async.zig
@ -228,36 +228,88 @@ pub const AsyncThread = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn getStdIn() !File {
|
pub fn getStdIn() !File {
|
||||||
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
|
return File.initStreaming(std.fs.File.stdin()) catch @panic("Unable to open stdin");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getStdOut() File {
|
pub fn getStdOut() File {
|
||||||
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
|
return File.initStreaming(std.fs.File.stdout()) catch @panic("Unable to open stdout");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getStdErr() File {
|
pub fn getStdErr() File {
|
||||||
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
|
return File.initStreaming(std.fs.File.stderr()) catch @panic("Unable to open stderr");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const File = struct {
|
pub const File = struct {
|
||||||
|
pub const Reader = struct {
|
||||||
|
interface: std.Io.Reader,
|
||||||
|
file: File,
|
||||||
|
mode: std.fs.File.Reader.Mode = .positional,
|
||||||
|
pos: u64 = 0,
|
||||||
|
|
||||||
|
fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
|
||||||
|
const self: *Reader = @alignCast(@fieldParentPtr("interface", r));
|
||||||
|
const dest = limit.slice(try w.writableSliceGreedy(1));
|
||||||
|
const n = switch (self.mode) {
|
||||||
|
.streaming => self.file.read(dest),
|
||||||
|
.positional => self.file.pread(dest, self.pos),
|
||||||
|
else => @panic("UNSUPPORTED"),
|
||||||
|
} catch {
|
||||||
|
return std.Io.Reader.StreamError.ReadFailed;
|
||||||
|
};
|
||||||
|
if (n == 0) {
|
||||||
|
return std.Io.Reader.StreamError.EndOfStream;
|
||||||
|
}
|
||||||
|
self.pos += n;
|
||||||
|
w.advance(n);
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Writer = struct {
|
||||||
|
interface: std.Io.Writer,
|
||||||
|
file: File,
|
||||||
|
mode: std.fs.File.Writer.Mode = .positional,
|
||||||
|
pos: u64 = 0,
|
||||||
|
|
||||||
|
fn write(self: *Writer, buf: []const u8) !usize {
|
||||||
|
const n = switch (self.mode) {
|
||||||
|
.streaming => self.file.write(buf),
|
||||||
|
.positional => self.file.pwrite(buf, self.pos),
|
||||||
|
else => unreachable,
|
||||||
|
} catch {
|
||||||
|
return std.Io.Writer.Error.WriteFailed;
|
||||||
|
};
|
||||||
|
self.pos += n;
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
|
||||||
|
// TODO: implement splat
|
||||||
|
_ = splat;
|
||||||
|
const self: *Writer = @alignCast(@fieldParentPtr("interface", w));
|
||||||
|
var total: usize = 0;
|
||||||
|
if (w.buffered().len > 0) {
|
||||||
|
total += w.consume(try self.write(w.buffered()));
|
||||||
|
}
|
||||||
|
for (data) |d| {
|
||||||
|
const n = try self.write(d);
|
||||||
|
total += n;
|
||||||
|
if (n < d.len) {
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?;
|
pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?;
|
||||||
pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?;
|
pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?;
|
||||||
pub const Reader = std.io.GenericReader(File, stdx.meta.FnSignature(File.read, null).ReturnErrorSet.?, File.read);
|
|
||||||
pub const Writer = std.io.GenericWriter(File, stdx.meta.FnSignature(File.write, null).ReturnErrorSet.?, File.write);
|
|
||||||
pub const SeekableStream = std.io.SeekableStream(
|
|
||||||
File,
|
|
||||||
SeekError,
|
|
||||||
GetSeekPosError,
|
|
||||||
seekTo,
|
|
||||||
seekBy,
|
|
||||||
getPos,
|
|
||||||
getEndPos,
|
|
||||||
);
|
|
||||||
|
|
||||||
_handle: std.fs.File.Handle,
|
_handle: std.fs.File.Handle,
|
||||||
inner: aio.File,
|
inner: aio.File,
|
||||||
|
is_streaming: bool = false,
|
||||||
|
|
||||||
fn asFile(self: File) std.fs.File {
|
pub fn asFile(self: File) std.fs.File {
|
||||||
return .{ .handle = self._handle };
|
return .{ .handle = self._handle };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,6 +324,14 @@ pub const File = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn initStreaming(file_: std.fs.File) !File {
|
||||||
|
return .{
|
||||||
|
._handle = file_.handle,
|
||||||
|
.inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)),
|
||||||
|
.is_streaming = true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
|
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
|
||||||
return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags }));
|
return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags }));
|
||||||
}
|
}
|
||||||
@ -280,6 +340,21 @@ pub const File = struct {
|
|||||||
return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags });
|
return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reader(self: File, buffer: []u8) Reader {
|
||||||
|
return .{
|
||||||
|
.file = self,
|
||||||
|
.interface = .{
|
||||||
|
.vtable = &.{
|
||||||
|
.stream = Reader.stream,
|
||||||
|
},
|
||||||
|
.buffer = buffer,
|
||||||
|
.seek = 0,
|
||||||
|
.end = 0,
|
||||||
|
},
|
||||||
|
.mode = if (self.is_streaming) .streaming else .positional,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn read(self: File, buf: []u8) !usize {
|
pub fn read(self: File, buf: []u8) !usize {
|
||||||
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
|
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
|
||||||
if (buf.len == 0) {
|
if (buf.len == 0) {
|
||||||
@ -310,6 +385,19 @@ pub const File = struct {
|
|||||||
return self.inner.write(.{ .slice = buf });
|
return self.inner.write(.{ .slice = buf });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn writer(self: File, buffer: []u8) Writer {
|
||||||
|
return .{
|
||||||
|
.file = self,
|
||||||
|
.interface = .{
|
||||||
|
.vtable = &.{
|
||||||
|
.drain = Writer.drain,
|
||||||
|
},
|
||||||
|
.buffer = buffer,
|
||||||
|
},
|
||||||
|
.mode = if (self.is_streaming) .streaming else .positional,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
|
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
|
||||||
return self.inner.pwrite(.{ .slice = buf }, offset);
|
return self.inner.pwrite(.{ .slice = buf }, offset);
|
||||||
}
|
}
|
||||||
@ -318,18 +406,6 @@ pub const File = struct {
|
|||||||
return self.inner.close();
|
return self.inner.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reader(self: File) Reader {
|
|
||||||
return .{ .context = self };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn seekableStream(file: File) SeekableStream {
|
|
||||||
return .{ .context = file };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn writer(self: File) Writer {
|
|
||||||
return .{ .context = self };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stat(self: File) !std.fs.File.Stat {
|
pub fn stat(self: File) !std.fs.File.Stat {
|
||||||
return try callBlocking(std.fs.File.stat, .{self.asFile()});
|
return try callBlocking(std.fs.File.stat, .{self.asFile()});
|
||||||
}
|
}
|
||||||
@ -375,8 +451,47 @@ pub const Socket = struct {
|
|||||||
pub const TCP = struct {
|
pub const TCP = struct {
|
||||||
const Inner = aio.TCP;
|
const Inner = aio.TCP;
|
||||||
|
|
||||||
pub const Reader = std.io.GenericReader(TCP, stdx.meta.FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read);
|
pub const Reader = struct {
|
||||||
pub const Writer = std.io.GenericWriter(TCP, stdx.meta.FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write);
|
interface: std.Io.Reader,
|
||||||
|
socket: TCP,
|
||||||
|
|
||||||
|
fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
|
||||||
|
const self: *Reader = @alignCast(@fieldParentPtr("interface", r));
|
||||||
|
const dest = limit.slice(try w.writableSliceGreedy(1));
|
||||||
|
const n = self.socket.read(dest) catch {
|
||||||
|
return std.Io.Reader.StreamError.ReadFailed;
|
||||||
|
};
|
||||||
|
w.advance(n);
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Writer = struct {
|
||||||
|
interface: std.Io.Writer,
|
||||||
|
socket: TCP,
|
||||||
|
|
||||||
|
fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
|
||||||
|
// TODO: implement splat
|
||||||
|
_ = splat;
|
||||||
|
const self: *Writer = @alignCast(@fieldParentPtr("interface", w));
|
||||||
|
var total: usize = 0;
|
||||||
|
if (w.buffered().len >= 0) {
|
||||||
|
total += w.consume(self.socket.write(w.buffered()) catch {
|
||||||
|
return std.Io.Writer.Error.WriteFailed;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (data) |d| {
|
||||||
|
const n = self.socket.write(d) catch {
|
||||||
|
return std.Io.Writer.Error.WriteFailed;
|
||||||
|
};
|
||||||
|
total += n;
|
||||||
|
if (n < d.len) {
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
inner: aio.TCP,
|
inner: aio.TCP,
|
||||||
|
|
||||||
@ -418,12 +533,30 @@ pub const Socket = struct {
|
|||||||
return self.inner.close();
|
return self.inner.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reader(self: TCP) Reader {
|
pub fn reader(self: TCP, buffer: []u8) Reader {
|
||||||
return .{ .context = self };
|
return .{
|
||||||
|
.socket = self,
|
||||||
|
.interface = .{
|
||||||
|
.vtable = &.{
|
||||||
|
.stream = Reader.stream,
|
||||||
|
},
|
||||||
|
.buffer = buffer,
|
||||||
|
.seek = 0,
|
||||||
|
.end = 0,
|
||||||
|
},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writer(self: TCP) Writer {
|
pub fn writer(self: TCP, buffer: []u8) Writer {
|
||||||
return .{ .context = self };
|
return .{
|
||||||
|
.socket = self,
|
||||||
|
.interface = .{
|
||||||
|
.vtable = &.{
|
||||||
|
.drain = Writer.drain,
|
||||||
|
},
|
||||||
|
.buffer = buffer,
|
||||||
|
},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -564,9 +697,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn {
|
|||||||
|
|
||||||
const level_txt = comptime message_level.asText();
|
const level_txt = comptime message_level.asText();
|
||||||
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
|
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
|
||||||
const stderr = getStdErr().writer();
|
var buffer: [1024]u8 = undefined;
|
||||||
var bw = std.io.bufferedWriter(stderr);
|
var stderr = getStdErr().writer(&buffer);
|
||||||
const writer = bw.writer();
|
|
||||||
|
|
||||||
var mutex = Self.mu orelse blk: {
|
var mutex = Self.mu orelse blk: {
|
||||||
Self.mu = Mutex.init();
|
Self.mu = Mutex.init();
|
||||||
@ -575,8 +707,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn {
|
|||||||
mutex.lock();
|
mutex.lock();
|
||||||
defer mutex.unlock();
|
defer mutex.unlock();
|
||||||
nosuspend {
|
nosuspend {
|
||||||
writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return;
|
stderr.interface.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch unreachable;
|
||||||
bw.flush() catch return;
|
stderr.interface.flush() catch unreachable;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}.call;
|
}.call;
|
||||||
|
|||||||
@ -156,7 +156,7 @@ const Coro = struct {
|
|||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.C) noreturn {
|
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn {
|
||||||
const from_coro: *Coro = @fieldParentPtr("impl", from);
|
const from_coro: *Coro = @fieldParentPtr("impl", from);
|
||||||
const this_coro: *Coro = @fieldParentPtr("impl", this);
|
const this_coro: *Coro = @fieldParentPtr("impl", this);
|
||||||
log(.debug, "coro start {any}", .{this_coro.id});
|
log(.debug, "coro start {any}", .{this_coro.id});
|
||||||
|
|||||||
@ -49,7 +49,7 @@ pub const Coro = packed struct {
|
|||||||
const Func = *const fn (
|
const Func = *const fn (
|
||||||
from: *Coro,
|
from: *Coro,
|
||||||
self: *Coro,
|
self: *Coro,
|
||||||
) callconv(.C) noreturn;
|
) callconv(.c) noreturn;
|
||||||
|
|
||||||
pub fn init(func: Func, stack: []align(stack_alignment) u8) !Self {
|
pub fn init(func: Func, stack: []align(stack_alignment) u8) !Self {
|
||||||
stdx.debug.assertComptime(@sizeOf(usize) == 8, "usize expected to take 8 bytes", .{});
|
stdx.debug.assertComptime(@sizeOf(usize) == 8, "usize expected to take 8 bytes", .{});
|
||||||
|
|||||||
@ -1,83 +1,264 @@
|
|||||||
{
|
{
|
||||||
"master": {
|
"master": {
|
||||||
"version": "0.15.0-dev.650+4f3b59f70",
|
"version": "0.16.0-dev.27+83f773fc6",
|
||||||
"date": "2025-05-29",
|
"date": "2025-08-24",
|
||||||
"docs": "https://ziglang.org/documentation/master/",
|
"docs": "https://ziglang.org/documentation/master/",
|
||||||
"stdDocs": "https://ziglang.org/documentation/master/std/",
|
"stdDocs": "https://ziglang.org/documentation/master/std/",
|
||||||
"src": {
|
"src": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "c14764ee9fd16f4437f2e2e7092cc7ec7ff76469f719be04155c29fa3bcc52cd",
|
"shasum": "afddafbede9becaa6d94a544d3115893cc6aa6492a58545b4f638c58990586dc",
|
||||||
"size": "21279148"
|
"size": "21370308"
|
||||||
},
|
},
|
||||||
"bootstrap": {
|
"bootstrap": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-bootstrap-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-bootstrap-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "142343992733282138b88b2205f29173ce3febcf5d73de4a19111fb36630f5a6",
|
"shasum": "109cdda833baf951ce8116d3fdfd50b3770b02cff5c057f777601ca694f44b7c",
|
||||||
"size": "52649088"
|
"size": "52732704"
|
||||||
},
|
},
|
||||||
"x86_64-macos": {
|
"x86_64-macos": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "70117a96313ddfe57bd0b3fbbbdb2beffc27fb81f5993ba9b2825628bb9f5aaa",
|
"shasum": "1ac88b947a6001eb30d41c15c404847209dfe945a27dbce6d1b7d401e4aef325",
|
||||||
"size": "55795264"
|
"size": "55817716"
|
||||||
},
|
},
|
||||||
"aarch64-macos": {
|
"aarch64-macos": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "991ef1b1871852f87e1d53f65c3dfb3b3f7187a0dc1f8fe6e7be279227b868d3",
|
"shasum": "75243ad6d2e9fcd634862dda9003883484024fd2e067fc7e4a0f6c524cb6c86e",
|
||||||
"size": "50625312"
|
"size": "50659200"
|
||||||
},
|
},
|
||||||
"x86_64-linux": {
|
"x86_64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "2c2f65db1ad72d415b5a5bfb0ccd437bfc16e91b18a58035d28ea3177d07b1e2",
|
"shasum": "c260aefaf5bf10bbd48101217874ccc9c0a37513729dabbbfe09bcf18fff1b5b",
|
||||||
"size": "53712376"
|
"size": "53759360"
|
||||||
},
|
},
|
||||||
"aarch64-linux": {
|
"aarch64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "98bfd9c33b737aa66d5cc19886b772cdc99801e66d8c891d400b0370fc626455",
|
"shasum": "658241c15ad827a89878b2e35e13261cd1c2aaefe9a8bb94ccb907725cf80564",
|
||||||
"size": "49500380"
|
"size": "49485912"
|
||||||
},
|
},
|
||||||
"armv7a-linux": {
|
"arm-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-armv7a-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-arm-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "c204b60e27930702a64d55768fbbf2c819109dd0400168e0daf2994f4e57ed8b",
|
"shasum": "6039c11f41bd032ec3cbda2d92e38b409aa9d9e2935b21266033b820ecb1a668",
|
||||||
"size": "50413108"
|
"size": "50475500"
|
||||||
},
|
},
|
||||||
"riscv64-linux": {
|
"riscv64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "2212613f09d114296f542dc3ca4583d3e1cfa72642e0d192a450ef1704722295",
|
"shasum": "f56712849af75bb1bfda502c8cdfc52dcc41c18d3d7eddd329d94be85538d258",
|
||||||
"size": "53645244"
|
"size": "53610616"
|
||||||
},
|
},
|
||||||
"powerpc64le-linux": {
|
"powerpc64le-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "6a68c45d8bb5a3c1c3a259ad5ef965011c02feb0d17e8c96abd76aa48af03fe9",
|
"shasum": "0562fdd578b5ae65a17e614454e67466314790a1d373279a4b054398f6a7b364",
|
||||||
"size": "53559816"
|
"size": "53585484"
|
||||||
},
|
},
|
||||||
"x86-linux": {
|
"x86-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-x86-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-x86-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "802e860891aa12979c3279e6d0571d08c0ed0398f362ac3f2a2fd945ef59bd1c",
|
"shasum": "6eafb1b1d81118066dc03b17bed6612d7a9da057794f1d0f2451c249cf8a231d",
|
||||||
"size": "56328204"
|
"size": "56334072"
|
||||||
},
|
},
|
||||||
"loongarch64-linux": {
|
"loongarch64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "7fb8e715cdf44a2158afced5eecb3d2ae0f44d0a161f43ea8dfb04e95b22c89f",
|
"shasum": "4fddf62a6698dbece51f504b4de8360a55d24efb0c7a56117eef64c8b717d13f",
|
||||||
"size": "50835068"
|
"size": "50823012"
|
||||||
},
|
},
|
||||||
"s390x-linux": {
|
"s390x-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-s390x-linux-0.15.0-dev.650+4f3b59f70.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-s390x-linux-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
"shasum": "cbbcbf324db5158232d1fd5a1efcc2d694e058ad66772495e291828fa5227450",
|
"shasum": "12ed086066c373e4098dc9f132e42bbc478a3be6bd5fc6967f39396c43543f89",
|
||||||
"size": "53386612"
|
"size": "53525496"
|
||||||
},
|
},
|
||||||
"x86_64-windows": {
|
"x86_64-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.15.0-dev.650+4f3b59f70.zip",
|
"tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.16.0-dev.27+83f773fc6.zip",
|
||||||
"shasum": "0d4e148a60e859f0fa5ce6d293dcec7ba79a7e1da25df151b34e3ce71d0fa84f",
|
"shasum": "ae987b8d93eec8a923bba26a74d20b5d15eac7ee9694c44db82c232247810a51",
|
||||||
"size": "94297867"
|
"size": "93312425"
|
||||||
},
|
},
|
||||||
"aarch64-windows": {
|
"aarch64-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.15.0-dev.650+4f3b59f70.zip",
|
"tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.16.0-dev.27+83f773fc6.zip",
|
||||||
"shasum": "6b1a094cb4c666d0078531ce3d4f970dcf01f63e5b885d9af0f498726bdf5e43",
|
"shasum": "2e3c95d044dd36d302a301fb029ea080b2bda4ef8e42bbb1b2a0246973952c53",
|
||||||
"size": "90200802"
|
"size": "89157759"
|
||||||
},
|
},
|
||||||
"x86-windows": {
|
"x86-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-x86-windows-0.15.0-dev.650+4f3b59f70.zip",
|
"tarball": "https://ziglang.org/builds/zig-x86-windows-0.16.0-dev.27+83f773fc6.zip",
|
||||||
"shasum": "f18b6f9472a27abea8edbfb43687a75b61298e6d4da42247218dc3fd23560aa7",
|
"shasum": "b696f07f7104da430c7b37750bc6e36571c3ffcd5bc49240d8a19aa4b431049f",
|
||||||
"size": "96232989"
|
"size": "95215201"
|
||||||
|
},
|
||||||
|
"aarch64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-aarch64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "0295a3f6bebb6d25fbc916cd4e3cff1285acab75b9591695069ad9a878301203",
|
||||||
|
"size": "49384624"
|
||||||
|
},
|
||||||
|
"arm-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-arm-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "a293670c7ee89ce6cf11f03dee297de2a31ffdc9d26f2c0e6e78fd66b11b2327",
|
||||||
|
"size": "50925296"
|
||||||
|
},
|
||||||
|
"powerpc64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-powerpc64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "503922dd11374ddc44013ae3d7b07a9a17a1404da296bc30bf8fdc6d7273fd82",
|
||||||
|
"size": "52148468"
|
||||||
|
},
|
||||||
|
"powerpc64le-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-powerpc64le-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "1cf44603dec137a35551916ab5e542c36e6c1ec54df4efd7783f8ae66c80acc4",
|
||||||
|
"size": "53516496"
|
||||||
|
},
|
||||||
|
"riscv64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-riscv64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "c4b1c691328abacf7f7b07f1900dd44c2f2852a0b2611378a2531975abb6f260",
|
||||||
|
"size": "53716116"
|
||||||
|
},
|
||||||
|
"x86_64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-x86_64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "989caf4343e1dfb4c9a68f3d3146cac96729e6f6e58276001654307f56c43981",
|
||||||
|
"size": "53808720"
|
||||||
|
},
|
||||||
|
"aarch64-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-aarch64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "bbbe2aa2163d3c25d7414795537088245f9ec3cb53f10518b5824e1a31c35ff9",
|
||||||
|
"size": "49387264"
|
||||||
|
},
|
||||||
|
"arm-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-arm-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "264346f44664055d61c91087d0c55fdd04e3bdf1f2f0f969cd5efcf47d54b247",
|
||||||
|
"size": "52034816"
|
||||||
|
},
|
||||||
|
"x86-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-x86-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "f5e3cdab5d7e24552eee2200b1e4ac45cd84f13d5d64ae0f558ad0210d6883d5",
|
||||||
|
"size": "56915916"
|
||||||
|
},
|
||||||
|
"x86_64-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-x86_64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
|
||||||
|
"shasum": "b6418c9d5036d2328271078687dbf2d772279e65787546974943855423fc6545",
|
||||||
|
"size": "53809524"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"0.15.1": {
|
||||||
|
"date": "2025-08-19",
|
||||||
|
"docs": "https://ziglang.org/documentation/master/",
|
||||||
|
"stdDocs": "https://ziglang.org/documentation/master/std/",
|
||||||
|
"notes": "https://ziglang.org/download/0.15.1/release-notes.html",
|
||||||
|
"src": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-0.15.1.tar.xz",
|
||||||
|
"shasum": "816c0303ab313f59766ce2097658c9fff7fafd1504f61f80f9507cd11652865f",
|
||||||
|
"size": "21359884"
|
||||||
|
},
|
||||||
|
"bootstrap": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-bootstrap-0.15.1.tar.xz",
|
||||||
|
"shasum": "4c0cfbcf12da144955761ca43f89e3c74956bce978694fc1d0a63555f5c0a199",
|
||||||
|
"size": "52711548"
|
||||||
|
},
|
||||||
|
"x86_64-macos": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-macos-0.15.1.tar.xz",
|
||||||
|
"shasum": "9919392e0287cccc106dfbcbb46c7c1c3fa05d919567bb58d7eb16bca4116184",
|
||||||
|
"size": "55791880"
|
||||||
|
},
|
||||||
|
"aarch64-macos": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-macos-0.15.1.tar.xz",
|
||||||
|
"shasum": "c4bd624d901c1268f2deb9d8eb2d86a2f8b97bafa3f118025344242da2c54d7b",
|
||||||
|
"size": "50644996"
|
||||||
|
},
|
||||||
|
"x86_64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "c61c5da6edeea14ca51ecd5e4520c6f4189ef5250383db33d01848293bfafe05",
|
||||||
|
"size": "53734456"
|
||||||
|
},
|
||||||
|
"aarch64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "bb4a8d2ad735e7fba764c497ddf4243cb129fece4148da3222a7046d3f1f19fe",
|
||||||
|
"size": "49493872"
|
||||||
|
},
|
||||||
|
"arm-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "3f4bf3b06b67d14e3f38be30798488c1abe3cf5b33de570cd0e87bbf09b978ad",
|
||||||
|
"size": "50477464"
|
||||||
|
},
|
||||||
|
"riscv64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "7ca7a3e621436fb31d66a253132fc39574a13d2a1b4d8458af4f2e7c6e4374fe",
|
||||||
|
"size": "53597792"
|
||||||
|
},
|
||||||
|
"powerpc64le-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "339e2106496be70b614e32d444298216e676c36d08f5cd7bee3dd1dbd4567fd7",
|
||||||
|
"size": "53566944"
|
||||||
|
},
|
||||||
|
"x86-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "dff166f25fdd06e8341d831a71211b5ba7411463a6b264bdefa8868438690b6a",
|
||||||
|
"size": "56311228"
|
||||||
|
},
|
||||||
|
"loongarch64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-loongarch64-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "0af18a012f3c4cffbef29ab5f42021484d92517f921a1380faacc89c50c1f89d",
|
||||||
|
"size": "50794376"
|
||||||
|
},
|
||||||
|
"s390x-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-s390x-linux-0.15.1.tar.xz",
|
||||||
|
"shasum": "bcd13e5c88cf2d0da7100a48572195560069a94402c9bec3b316398204aa27e2",
|
||||||
|
"size": "53508068"
|
||||||
|
},
|
||||||
|
"x86_64-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-windows-0.15.1.zip",
|
||||||
|
"shasum": "91e69e887ca8c943ce9a515df3af013d95a66a190a3df3f89221277ebad29e34",
|
||||||
|
"size": "92612958"
|
||||||
|
},
|
||||||
|
"aarch64-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-windows-0.15.1.zip",
|
||||||
|
"shasum": "1f1bf16228b0ffcc882b713dc5e11a6db4219cb30997e13c72e8e723c2104ec6",
|
||||||
|
"size": "88458549"
|
||||||
|
},
|
||||||
|
"x86-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-windows-0.15.1.zip",
|
||||||
|
"shasum": "fb1c07cffbb43615d3158ab8b8f5db5da1d48875eca99e1d7a8a0064ff63fc5b",
|
||||||
|
"size": "94516052"
|
||||||
|
},
|
||||||
|
"aarch64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "4d9d25c775828d49ea037b2284310c295d951793da8ebe94827a54fed4cca3ce",
|
||||||
|
"size": "49358464"
|
||||||
|
},
|
||||||
|
"arm-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "9707f3a5f7e1a3d99c40db9a74de1acc61016a197ad289c2ad964f93cb213a18",
|
||||||
|
"size": "50904124"
|
||||||
|
},
|
||||||
|
"powerpc64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "79448884372db04e62f77a46b92245fce805063e69534729537f75cb4681e7e3",
|
||||||
|
"size": "52099020"
|
||||||
|
},
|
||||||
|
"powerpc64le-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "f18ee12ba9c98a20b8d2ad0410c679e7aa5591bc7917f169fd6b377833d2c7ad",
|
||||||
|
"size": "53480976"
|
||||||
|
},
|
||||||
|
"riscv64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "ee9f864a6fd8b57c1f4fdbb11daa06578746a6f8253afe3f5ddb5a76f2eddd2d",
|
||||||
|
"size": "53677800"
|
||||||
|
},
|
||||||
|
"x86_64-freebsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-freebsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "9714f8ac3d3dc908b1599837c6167f857c1efaa930f0cfa840699458de7c3cd0",
|
||||||
|
"size": "53782112"
|
||||||
|
},
|
||||||
|
"aarch64-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-netbsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "b2a528399777583b85b89c54ccd45488af7709d6dd29a27323ec2a229db40910",
|
||||||
|
"size": "49368000"
|
||||||
|
},
|
||||||
|
"arm-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-netbsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "93dc70109cbf5d2e022d20dfb56211978c4ea3c0b1e67aaabff947d8d1583aab",
|
||||||
|
"size": "52028184"
|
||||||
|
},
|
||||||
|
"x86-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-netbsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "a91b26051822ff17f3143f859b87dce5b4a13e90928bd6daa6f07a895d3410f0",
|
||||||
|
"size": "56881864"
|
||||||
|
},
|
||||||
|
"x86_64-netbsd": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-netbsd-0.15.1.tar.xz",
|
||||||
|
"shasum": "6d7ba6eca5b4434351ebdb971b7303c9934514f9bb8481852251dbd5b52b03d6",
|
||||||
|
"size": "53794836"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"0.14.1": {
|
"0.14.1": {
|
||||||
|
|||||||
@ -10,13 +10,13 @@ pub const ZigAllocator = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) ?*anyopaque {
|
pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) ?*anyopaque {
|
||||||
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
||||||
const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null;
|
const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null;
|
||||||
return @ptrCast(ret);
|
return @ptrCast(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) void {
|
pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) void {
|
||||||
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
|
||||||
const memory: [*c]u8 = @ptrCast(ptr);
|
const memory: [*c]u8 = @ptrCast(ptr);
|
||||||
const size = elem * nelems;
|
const size = elem * nelems;
|
||||||
|
|||||||
@ -15,6 +15,7 @@ zig_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//mlir",
|
"//mlir",
|
||||||
"//mlir/dialects/stablehlo",
|
"//mlir/dialects/stablehlo",
|
||||||
|
"//stdx",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub fn func(
|
pub fn func(
|
||||||
ctx: mlir.Context,
|
ctx: mlir.Context,
|
||||||
@ -14,7 +15,7 @@ pub fn func(
|
|||||||
location: mlir.Location,
|
location: mlir.Location,
|
||||||
},
|
},
|
||||||
) mlir.Operation {
|
) mlir.Operation {
|
||||||
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
|
var attrs_tuple_buffer = stdx.BoundedArray(mlir.AttrTuple, 4){};
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) });
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) });
|
||||||
if (args.arg_attrs.len > 0) {
|
if (args.arg_attrs.len > 0) {
|
||||||
|
|||||||
@ -761,7 +761,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
var attrs: stdx.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||||
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
|
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
|
||||||
.{ "call_target_name", .string(ctx, opts.call_target_name) },
|
.{ "call_target_name", .string(ctx, opts.call_target_name) },
|
||||||
@ -770,7 +770,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
});
|
});
|
||||||
|
|
||||||
{
|
{
|
||||||
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
var output_operand_aliases: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
for (opts.output_operand_aliases) |alias| {
|
for (opts.output_operand_aliases) |alias| {
|
||||||
output_operand_aliases.appendAssumeCapacity(
|
output_operand_aliases.appendAssumeCapacity(
|
||||||
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
|
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
|
||||||
@ -789,14 +789,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (opts.operand_layouts) |layouts| {
|
if (opts.operand_layouts) |layouts| {
|
||||||
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
var operand_layouts: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
for (layouts) |ol| {
|
for (layouts) |ol| {
|
||||||
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||||
}
|
}
|
||||||
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
||||||
} else {
|
} else {
|
||||||
const operand_layouts = blk: {
|
const operand_layouts = blk: {
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
var ret: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
for (inputs) |input| {
|
for (inputs) |input| {
|
||||||
const ranked_type = input.getType().as(mlir.RankedTensorType).?;
|
const ranked_type = input.getType().as(mlir.RankedTensorType).?;
|
||||||
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
|
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
|
||||||
@ -808,14 +808,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (opts.result_layouts) |layouts| {
|
if (opts.result_layouts) |layouts| {
|
||||||
var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
var result_layouts: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
for (layouts) |rl| {
|
for (layouts) |rl| {
|
||||||
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||||
}
|
}
|
||||||
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
||||||
} else {
|
} else {
|
||||||
const result_layouts = blk: {
|
const result_layouts = blk: {
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
var ret: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
for (res_types) |t| {
|
for (res_types) |t| {
|
||||||
const ranked_t = t.as(mlir.RankedTensorType).?;
|
const ranked_t = t.as(mlir.RankedTensorType).?;
|
||||||
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
|
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
|
||||||
@ -1271,7 +1271,7 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
|
|||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
c.stablehloVersionFromCompatibilityRequirement(req, (struct {
|
c.stablehloVersionFromCompatibilityRequirement(req, (struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -1292,7 +1292,7 @@ pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []
|
|||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
|
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -1313,7 +1313,7 @@ pub fn getCurrentVersion() []const u8 {
|
|||||||
const ContextWriter = @TypeOf(writer_);
|
const ContextWriter = @TypeOf(writer_);
|
||||||
|
|
||||||
c.stablehloGetCurrentVersion((struct {
|
c.stablehloGetCurrentVersion((struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
||||||
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
|
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
|
||||||
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -1339,7 +1339,7 @@ pub fn getMinimumVersion() []const u8 {
|
|||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
c.stablehloGetMinimumVersion((struct {
|
c.stablehloGetMinimumVersion((struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -1358,7 +1358,7 @@ pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u
|
|||||||
const WriterContext = @TypeOf(context);
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct {
|
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
}
|
}
|
||||||
|
|||||||
120
mlir/mlir.zig
120
mlir/mlir.zig
@ -40,7 +40,7 @@ pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Alternative to MlirWrapperType
|
/// Alternative to MlirWrapperType
|
||||||
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
|
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
|
||||||
|
|
||||||
pub const Registry = struct {
|
pub const Registry = struct {
|
||||||
_inner: c.MlirDialectRegistry,
|
_inner: c.MlirDialectRegistry,
|
||||||
@ -171,7 +171,7 @@ pub const PassManager = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.C) void {
|
fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.c) void {
|
||||||
_ = ctx;
|
_ = ctx;
|
||||||
std.debug.print(">>ERROR: {s}\n", .{err.data});
|
std.debug.print(">>ERROR: {s}\n", .{err.data});
|
||||||
}
|
}
|
||||||
@ -754,7 +754,7 @@ pub const Operation = struct {
|
|||||||
state.addOperands(operands);
|
state.addOperands(operands);
|
||||||
} else if (args.variadic_operands) |operands_segments| {
|
} else if (args.variadic_operands) |operands_segments| {
|
||||||
const MAX_SEGMENTS = 32;
|
const MAX_SEGMENTS = 32;
|
||||||
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
||||||
|
|
||||||
for (operands_segments) |operands| {
|
for (operands_segments) |operands| {
|
||||||
state.addOperands(operands);
|
state.addOperands(operands);
|
||||||
@ -764,7 +764,7 @@ pub const Operation = struct {
|
|||||||
} else if (args.tt_variadic_operands) |operands_segments| {
|
} else if (args.tt_variadic_operands) |operands_segments| {
|
||||||
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
|
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
|
||||||
const MAX_SEGMENTS = 32;
|
const MAX_SEGMENTS = 32;
|
||||||
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
|
||||||
|
|
||||||
for (operands_segments) |operands| {
|
for (operands_segments) |operands| {
|
||||||
state.addOperands(operands);
|
state.addOperands(operands);
|
||||||
@ -811,7 +811,7 @@ pub const Operation = struct {
|
|||||||
@panic("Failed to create MLIR operation");
|
@panic("Failed to create MLIR operation");
|
||||||
};
|
};
|
||||||
if (args.verify and new_op.verify() == false) {
|
if (args.verify and new_op.verify() == false) {
|
||||||
log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
log.err("Failed to verify MLIR operation:\n{f}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
||||||
@panic("Failed to verify MLIR operation");
|
@panic("Failed to verify MLIR operation");
|
||||||
}
|
}
|
||||||
return new_op;
|
return new_op;
|
||||||
@ -888,7 +888,7 @@ pub const Operation = struct {
|
|||||||
c.mlirOperationWriteBytecode(
|
c.mlirOperationWriteBytecode(
|
||||||
self._inner,
|
self._inner,
|
||||||
(struct {
|
(struct {
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
||||||
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
||||||
}
|
}
|
||||||
@ -916,7 +916,7 @@ pub const Operation = struct {
|
|||||||
self._inner,
|
self._inner,
|
||||||
cfg,
|
cfg,
|
||||||
(struct {
|
(struct {
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
||||||
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
|
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
|
||||||
inner_writer_context.write_error = err;
|
inner_writer_context.write_error = err;
|
||||||
@ -939,29 +939,25 @@ pub const Operation = struct {
|
|||||||
op: Operation,
|
op: Operation,
|
||||||
flags: OpPrintingFlags,
|
flags: OpPrintingFlags,
|
||||||
|
|
||||||
pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: @This(), writer: anytype) !void {
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
self.op.print(writer, self.flags);
|
self.op.print(writer, self.flags);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn print(self: Self, writer: anytype, flags: OpPrintingFlags) void {
|
pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void {
|
||||||
const pflags = flags.create();
|
const pflags = flags.create();
|
||||||
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
||||||
|
|
||||||
var writer_context = .{ .writer = writer };
|
|
||||||
const WriterContext = @TypeOf(writer_context);
|
|
||||||
c.mlirOperationPrintWithFlags(
|
c.mlirOperationPrintWithFlags(
|
||||||
self._inner,
|
self._inner,
|
||||||
pflags,
|
pflags,
|
||||||
(struct {
|
(struct {
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void {
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
||||||
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_));
|
||||||
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
_writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed");
|
||||||
}
|
}
|
||||||
}).callback,
|
}).callback,
|
||||||
&writer_context,
|
writer,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -991,7 +987,7 @@ pub const Operation = struct {
|
|||||||
c.mlirOperationWalk(
|
c.mlirOperationWalk(
|
||||||
self._inner,
|
self._inner,
|
||||||
(struct {
|
(struct {
|
||||||
pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.C) c.MlirWalkResult {
|
pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.c) c.MlirWalkResult {
|
||||||
const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_));
|
const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_));
|
||||||
return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op }));
|
return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op }));
|
||||||
}
|
}
|
||||||
@ -1017,24 +1013,26 @@ pub const Operation = struct {
|
|||||||
return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_));
|
return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Hash the canonicalized IR, without debug information that can change across builds.
|
||||||
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
|
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
|
||||||
const NoError = error{};
|
|
||||||
const write = struct {
|
|
||||||
fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize {
|
|
||||||
hasher_.update(bytes);
|
|
||||||
return bytes.len;
|
|
||||||
}
|
|
||||||
}.write;
|
|
||||||
const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write);
|
|
||||||
const writer: HashWriter = .{ .context = hasher };
|
|
||||||
|
|
||||||
// Hash the canonicalized IR, without debug information that can change across builds.
|
|
||||||
// Note: before we where using op.writeBytecode(writer),
|
// Note: before we where using op.writeBytecode(writer),
|
||||||
// but it crashes on some inputs, notably for unused variables.
|
// but it crashes on some inputs, notably for unused variables.
|
||||||
// So we use the text representation of the mlir.
|
// So we use the text representation of the mlir.
|
||||||
// See https://github.com/zml/zml/issues/97.
|
// See https://github.com/zml/zml/issues/97.
|
||||||
// Writes can't fail because we are writing to a hasher.
|
const flags = OpPrintingFlags.create(.{ .debug_info = false });
|
||||||
op.print(writer, .{ .debug_info = false });
|
defer c.mlirOpPrintingFlagsDestroy(flags);
|
||||||
|
|
||||||
|
c.mlirOperationPrintWithFlags(
|
||||||
|
op._inner,
|
||||||
|
flags,
|
||||||
|
(struct {
|
||||||
|
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
||||||
|
const _hasher: *std.hash.XxHash64 = @ptrCast(@alignCast(ctx_));
|
||||||
|
_hasher.update(str.data[0..str.length]);
|
||||||
|
}
|
||||||
|
}).callback,
|
||||||
|
hasher,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1185,9 +1183,9 @@ pub const BlockArgument = struct {
|
|||||||
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: BlockArgument, writer: anytype) !void {
|
||||||
const value = Value{ ._inner = self._inner };
|
const value = Value{ ._inner = self._inner };
|
||||||
return value.format(fmt, options, writer);
|
return value.format(writer);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1234,8 +1232,8 @@ pub const Type = struct {
|
|||||||
|
|
||||||
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
|
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn format(self: SpecificType, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: SpecificType, writer: anytype) !void {
|
||||||
return try Type.format(self.asType(), fmt, options, writer);
|
return try Type.format(self.asType(), writer);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -1344,7 +1342,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
|
|||||||
pub const eql = Type.eqlAny(Int);
|
pub const eql = Type.eqlAny(Int);
|
||||||
pub const format = helpers.format(Int, c.mlirTypePrint);
|
pub const format = helpers.format(Int, c.mlirTypePrint);
|
||||||
|
|
||||||
fn typeIsAIntegerExact(typ: c.MlirType) callconv(.C) bool {
|
fn typeIsAIntegerExact(typ: c.MlirType) callconv(.c) bool {
|
||||||
const bit_width = Config[0];
|
const bit_width = Config[0];
|
||||||
const is_sign = Config[2];
|
const is_sign = Config[2];
|
||||||
return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ);
|
return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ);
|
||||||
@ -1422,20 +1420,20 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
|
|||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
const Complex = @This();
|
const Complex = @This();
|
||||||
|
|
||||||
fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType {
|
fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
|
||||||
return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx));
|
return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType {
|
fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
|
||||||
return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx));
|
return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mlirTypeIsAC64(typ: c.MlirType) callconv(.C) bool {
|
fn mlirTypeIsAC64(typ: c.MlirType) callconv(.c) bool {
|
||||||
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
||||||
return c.mlirTypeIsAF32(element_type);
|
return c.mlirTypeIsAF32(element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mlirTypeIsAC128(typ: c.MlirType) callconv(.C) bool {
|
fn mlirTypeIsAC128(typ: c.MlirType) callconv(.c) bool {
|
||||||
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
|
||||||
return c.mlirTypeIsAF64(element_type);
|
return c.mlirTypeIsAF64(element_type);
|
||||||
}
|
}
|
||||||
@ -1446,7 +1444,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
|
|||||||
.unknown => .{ c.mlirTypeIsAComplex, null },
|
.unknown => .{ c.mlirTypeIsAComplex, null },
|
||||||
};
|
};
|
||||||
|
|
||||||
fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.C) bool {
|
fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.c) bool {
|
||||||
return c.mlirTypeIsAComplex(typ);
|
return c.mlirTypeIsAComplex(typ);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1685,7 +1683,7 @@ pub const Block = struct {
|
|||||||
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
|
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
|
||||||
.block_argument => |arg| {
|
.block_argument => |arg| {
|
||||||
// Hermetic blocks are not allowed to use arguments from other blocks.
|
// Hermetic blocks are not allowed to use arguments from other blocks.
|
||||||
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
|
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {f} from {*} block to {*} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
|
||||||
},
|
},
|
||||||
.null => @panic("InvalidMlir"),
|
.null => @panic("InvalidMlir"),
|
||||||
}
|
}
|
||||||
@ -1694,7 +1692,7 @@ pub const Block = struct {
|
|||||||
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
|
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
|
||||||
if (op.block()) |prev_block| {
|
if (op.block()) |prev_block| {
|
||||||
// Hermetic blocks are not allowed to reference values from other blocks.
|
// Hermetic blocks are not allowed to reference values from other blocks.
|
||||||
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
|
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {*} block to {*} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (0..op.numOperands()) |i| {
|
for (0..op.numOperands()) |i| {
|
||||||
@ -1705,7 +1703,7 @@ pub const Block = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const helpers = struct {
|
pub const helpers = struct {
|
||||||
pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.C) bool) fn (T, T) bool {
|
pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.c) bool) fn (T, T) bool {
|
||||||
return struct {
|
return struct {
|
||||||
fn eql(a: T, b: T) bool {
|
fn eql(a: T, b: T) bool {
|
||||||
return equal_fn(a._inner, b._inner);
|
return equal_fn(a._inner, b._inner);
|
||||||
@ -1713,7 +1711,7 @@ pub const helpers = struct {
|
|||||||
}.eql;
|
}.eql;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (*T) void {
|
pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (*T) void {
|
||||||
return struct {
|
return struct {
|
||||||
fn deinit(a: *T) void {
|
fn deinit(a: *T) void {
|
||||||
deinit_fn(a._inner);
|
deinit_fn(a._inner);
|
||||||
@ -1722,7 +1720,7 @@ pub const helpers = struct {
|
|||||||
}.deinit;
|
}.deinit;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (T) void {
|
pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (T) void {
|
||||||
return struct {
|
return struct {
|
||||||
fn dump(a: T) void {
|
fn dump(a: T) void {
|
||||||
return dump_fn(a._inner);
|
return dump_fn(a._inner);
|
||||||
@ -1730,7 +1728,7 @@ pub const helpers = struct {
|
|||||||
}.dump;
|
}.dump;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (T) bool {
|
pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (T) bool {
|
||||||
return struct {
|
return struct {
|
||||||
fn isNull(a: T) bool {
|
fn isNull(a: T) bool {
|
||||||
return is_null_fn(a._inner);
|
return is_null_fn(a._inner);
|
||||||
@ -1738,21 +1736,13 @@ pub const helpers = struct {
|
|||||||
}.isNull;
|
}.isNull;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void) type {
|
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn format(
|
pub fn format(self: Any, writer: *std.Io.Writer) !void {
|
||||||
self: Any,
|
const WriterWithErr = struct {
|
||||||
comptime fmt: []const u8,
|
writer: *std.Io.Writer,
|
||||||
options: std.fmt.FormatOptions,
|
err: ?std.Io.Writer.Error = null,
|
||||||
writer: anytype,
|
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
||||||
) !void {
|
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
|
|
||||||
const Writer = struct {
|
|
||||||
writer: @TypeOf(writer),
|
|
||||||
err: ?@TypeOf(writer).Error = null,
|
|
||||||
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void {
|
|
||||||
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
|
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
|
||||||
if (ctx.err) |_| return;
|
if (ctx.err) |_| return;
|
||||||
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
||||||
@ -1762,14 +1752,14 @@ pub const helpers = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
var context: Writer = .{ .writer = writer };
|
var context: WriterWithErr = .{ .writer = writer };
|
||||||
print_fn(self._inner, &Writer.printCallback, &context);
|
print_fn(self._inner, &WriterWithErr.printCallback, &context);
|
||||||
if (context.err) |err| return err;
|
if (context.err) |err| return err;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (@FieldType(T, "_inner")) ?T {
|
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
|
||||||
return struct {
|
return struct {
|
||||||
fn wrapOr(inner: @FieldType(T, "_inner")) ?T {
|
fn wrapOr(inner: @FieldType(T, "_inner")) ?T {
|
||||||
if (is_null_fn(inner)) return null;
|
if (is_null_fn(inner)) return null;
|
||||||
@ -1778,7 +1768,7 @@ pub const helpers = struct {
|
|||||||
}.wrapOr;
|
}.wrapOr;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) ?T {
|
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T {
|
||||||
if (is_null_fn(inner)) return null;
|
if (is_null_fn(inner)) return null;
|
||||||
return .{ ._inner = inner };
|
return .{ ._inner = inner };
|
||||||
}
|
}
|
||||||
|
|||||||
@ -424,7 +424,7 @@ pub const CallFrame = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Handler = fn (*CallFrame) callconv(.C) ?*Error;
|
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
|
||||||
|
|
||||||
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
||||||
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
||||||
|
|||||||
@ -88,7 +88,7 @@ pub const Api = struct {
|
|||||||
return err;
|
return err;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse {
|
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.c) *const Api, "GetPjrtApi") orelse {
|
||||||
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});
|
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ pub const Api = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn CallFnArgType(comptime func: Funcs) type {
|
fn CallFnArgType(comptime func: Funcs) type {
|
||||||
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func));
|
const fti = @typeInfo(@FieldType(c.PJRT_Api, @tagName(func)));
|
||||||
const fn_ptr = @typeInfo(fti.optional.child);
|
const fn_ptr = @typeInfo(fti.optional.child);
|
||||||
const fn_type_info = @typeInfo(fn_ptr.pointer.child);
|
const fn_type_info = @typeInfo(fn_ptr.pointer.child);
|
||||||
const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?);
|
const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?);
|
||||||
@ -403,8 +403,8 @@ pub const Client = opaque {
|
|||||||
element_type: BufferType,
|
element_type: BufferType,
|
||||||
layout: MemoryLayout,
|
layout: MemoryLayout,
|
||||||
device: *const Device,
|
device: *const Device,
|
||||||
on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = &struct {
|
on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.c) void = &struct {
|
||||||
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {}
|
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.c) void {}
|
||||||
}.call,
|
}.call,
|
||||||
on_delete_callback_arg: ?*anyopaque = null,
|
on_delete_callback_arg: ?*anyopaque = null,
|
||||||
stream: ?*const Stream = null,
|
stream: ?*const Stream = null,
|
||||||
@ -637,7 +637,7 @@ pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError;
|
|||||||
pub const SerializeResult = struct {
|
pub const SerializeResult = struct {
|
||||||
bytes: []const u8,
|
bytes: []const u8,
|
||||||
handle: *anyopaque,
|
handle: *anyopaque,
|
||||||
deleter: *const fn (?*anyopaque) callconv(.C) void,
|
deleter: *const fn (?*anyopaque) callconv(.c) void,
|
||||||
|
|
||||||
pub fn deinit(self: *SerializeResult) void {
|
pub fn deinit(self: *SerializeResult) void {
|
||||||
self.deleter(self.handle);
|
self.deleter(self.handle);
|
||||||
@ -1036,7 +1036,7 @@ pub const Event = opaque {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.C) void, user_arg: ?*anyopaque) ApiError!void {
|
pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.c) void, user_arg: ?*anyopaque) ApiError!void {
|
||||||
_ = try api.call(.PJRT_Event_OnReady, .{
|
_ = try api.call(.PJRT_Event_OnReady, .{
|
||||||
.event = self.inner(),
|
.event = self.inner(),
|
||||||
.callback = @ptrCast(func),
|
.callback = @ptrCast(func),
|
||||||
|
|||||||
@ -3,6 +3,7 @@ load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test")
|
|||||||
zig_library(
|
zig_library(
|
||||||
name = "stdx",
|
name = "stdx",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"bounded_array.zig",
|
||||||
"debug.zig",
|
"debug.zig",
|
||||||
"flags.zig",
|
"flags.zig",
|
||||||
"fmt.zig",
|
"fmt.zig",
|
||||||
|
|||||||
412
stdx/bounded_array.zig
Normal file
412
stdx/bounded_array.zig
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const assert = std.debug.assert;
|
||||||
|
const mem = std.mem;
|
||||||
|
const testing = std.testing;
|
||||||
|
const Alignment = std.mem.Alignment;
|
||||||
|
|
||||||
|
/// A structure with an array and a length, that can be used as a slice.
|
||||||
|
///
|
||||||
|
/// Useful to pass around small arrays whose exact size is only known at
|
||||||
|
/// runtime, but whose maximum size is known at comptime, without requiring
|
||||||
|
/// an `Allocator`.
|
||||||
|
///
|
||||||
|
/// ```zig
|
||||||
|
/// var actual_size = 32;
|
||||||
|
/// var a = try BoundedArray(u8, 64).init(actual_size);
|
||||||
|
/// var slice = a.slice(); // a slice of the 64-byte array
|
||||||
|
/// var a_clone = a; // creates a copy - the structure doesn't use any internal pointers
|
||||||
|
/// ```
|
||||||
|
pub fn BoundedArray(comptime T: type, comptime buffer_capacity: usize) type {
|
||||||
|
return BoundedArrayAligned(T, .of(T), buffer_capacity);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A structure with an array, length and alignment, that can be used as a
|
||||||
|
/// slice.
|
||||||
|
///
|
||||||
|
/// Useful to pass around small explicitly-aligned arrays whose exact size is
|
||||||
|
/// only known at runtime, but whose maximum size is known at comptime, without
|
||||||
|
/// requiring an `Allocator`.
|
||||||
|
/// ```zig
|
||||||
|
// var a = try BoundedArrayAligned(u8, 16, 2).init(0);
|
||||||
|
// try a.append(255);
|
||||||
|
// try a.append(255);
|
||||||
|
// const b = @ptrCast(*const [1]u16, a.constSlice().ptr);
|
||||||
|
// try testing.expectEqual(@as(u16, 65535), b[0]);
|
||||||
|
/// ```
|
||||||
|
pub fn BoundedArrayAligned(
|
||||||
|
comptime T: type,
|
||||||
|
comptime alignment: Alignment,
|
||||||
|
comptime buffer_capacity: usize,
|
||||||
|
) type {
|
||||||
|
return struct {
|
||||||
|
const Self = @This();
|
||||||
|
buffer: [buffer_capacity]T align(alignment.toByteUnits()) = undefined,
|
||||||
|
len: usize = 0,
|
||||||
|
|
||||||
|
/// Set the actual length of the slice.
|
||||||
|
/// Returns error.Overflow if it exceeds the length of the backing array.
|
||||||
|
pub fn init(len: usize) error{Overflow}!Self {
|
||||||
|
if (len > buffer_capacity) return error.Overflow;
|
||||||
|
return Self{ .len = len };
|
||||||
|
}
|
||||||
|
|
||||||
|
/// View the internal array as a slice whose size was previously set.
|
||||||
|
pub fn slice(self: anytype) switch (@TypeOf(&self.buffer)) {
|
||||||
|
*align(alignment.toByteUnits()) [buffer_capacity]T => []align(alignment.toByteUnits()) T,
|
||||||
|
*align(alignment.toByteUnits()) const [buffer_capacity]T => []align(alignment.toByteUnits()) const T,
|
||||||
|
else => unreachable,
|
||||||
|
} {
|
||||||
|
return self.buffer[0..self.len];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// View the internal array as a constant slice whose size was previously set.
|
||||||
|
pub fn constSlice(self: *const Self) []align(alignment.toByteUnits()) const T {
|
||||||
|
return self.slice();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adjust the slice's length to `len`.
|
||||||
|
/// Does not initialize added items if any.
|
||||||
|
pub fn resize(self: *Self, len: usize) error{Overflow}!void {
|
||||||
|
if (len > buffer_capacity) return error.Overflow;
|
||||||
|
self.len = len;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove all elements from the slice.
|
||||||
|
pub fn clear(self: *Self) void {
|
||||||
|
self.len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy the content of an existing slice.
|
||||||
|
pub fn fromSlice(m: []const T) error{Overflow}!Self {
|
||||||
|
var list = try init(m.len);
|
||||||
|
@memcpy(list.slice(), m);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the element at index `i` of the slice.
|
||||||
|
pub fn get(self: Self, i: usize) T {
|
||||||
|
return self.constSlice()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the value of the element at index `i` of the slice.
|
||||||
|
pub fn set(self: *Self, i: usize, item: T) void {
|
||||||
|
self.slice()[i] = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the maximum length of a slice.
|
||||||
|
pub fn capacity(self: Self) usize {
|
||||||
|
return self.buffer.len;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check that the slice can hold at least `additional_count` items.
|
||||||
|
pub fn ensureUnusedCapacity(self: Self, additional_count: usize) error{Overflow}!void {
|
||||||
|
if (self.len + additional_count > buffer_capacity) {
|
||||||
|
return error.Overflow;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increase length by 1, returning a pointer to the new item.
|
||||||
|
pub fn addOne(self: *Self) error{Overflow}!*T {
|
||||||
|
try self.ensureUnusedCapacity(1);
|
||||||
|
return self.addOneAssumeCapacity();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increase length by 1, returning pointer to the new item.
|
||||||
|
/// Asserts that there is space for the new item.
|
||||||
|
pub fn addOneAssumeCapacity(self: *Self) *T {
|
||||||
|
assert(self.len < buffer_capacity);
|
||||||
|
self.len += 1;
|
||||||
|
return &self.slice()[self.len - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resize the slice, adding `n` new elements, which have `undefined` values.
|
||||||
|
/// The return value is a pointer to the array of uninitialized elements.
|
||||||
|
pub fn addManyAsArray(self: *Self, comptime n: usize) error{Overflow}!*align(alignment.toByteUnits()) [n]T {
|
||||||
|
const prev_len = self.len;
|
||||||
|
try self.resize(self.len + n);
|
||||||
|
return self.slice()[prev_len..][0..n];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resize the slice, adding `n` new elements, which have `undefined` values.
|
||||||
|
/// The return value is a slice pointing to the uninitialized elements.
|
||||||
|
pub fn addManyAsSlice(self: *Self, n: usize) error{Overflow}![]align(alignment.toByteUnits()) T {
|
||||||
|
const prev_len = self.len;
|
||||||
|
try self.resize(self.len + n);
|
||||||
|
return self.slice()[prev_len..][0..n];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove and return the last element from the slice, or return `null` if the slice is empty.
|
||||||
|
pub fn pop(self: *Self) ?T {
|
||||||
|
if (self.len == 0) return null;
|
||||||
|
const item = self.get(self.len - 1);
|
||||||
|
self.len -= 1;
|
||||||
|
return item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a slice of only the extra capacity after items.
|
||||||
|
/// This can be useful for writing directly into it.
|
||||||
|
/// Note that such an operation must be followed up with a
|
||||||
|
/// call to `resize()`
|
||||||
|
pub fn unusedCapacitySlice(self: *Self) []align(alignment.toByteUnits()) T {
|
||||||
|
return self.buffer[self.len..];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert `item` at index `i` by moving `slice[n .. slice.len]` to make room.
|
||||||
|
/// This operation is O(N).
|
||||||
|
pub fn insert(
|
||||||
|
self: *Self,
|
||||||
|
i: usize,
|
||||||
|
item: T,
|
||||||
|
) error{Overflow}!void {
|
||||||
|
if (i > self.len) {
|
||||||
|
return error.Overflow;
|
||||||
|
}
|
||||||
|
_ = try self.addOne();
|
||||||
|
var s = self.slice();
|
||||||
|
mem.copyBackwards(T, s[i + 1 .. s.len], s[i .. s.len - 1]);
|
||||||
|
self.buffer[i] = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert slice `items` at index `i` by moving `slice[i .. slice.len]` to make room.
|
||||||
|
/// This operation is O(N).
|
||||||
|
pub fn insertSlice(self: *Self, i: usize, items: []const T) error{Overflow}!void {
|
||||||
|
try self.ensureUnusedCapacity(items.len);
|
||||||
|
self.len += items.len;
|
||||||
|
mem.copyBackwards(T, self.slice()[i + items.len .. self.len], self.constSlice()[i .. self.len - items.len]);
|
||||||
|
@memcpy(self.slice()[i..][0..items.len], items);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replace range of elements `slice[start..][0..len]` with `new_items`.
|
||||||
|
/// Grows slice if `len < new_items.len`.
|
||||||
|
/// Shrinks slice if `len > new_items.len`.
|
||||||
|
pub fn replaceRange(
|
||||||
|
self: *Self,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
new_items: []const T,
|
||||||
|
) error{Overflow}!void {
|
||||||
|
const after_range = start + len;
|
||||||
|
var range = self.slice()[start..after_range];
|
||||||
|
|
||||||
|
if (range.len == new_items.len) {
|
||||||
|
@memcpy(range[0..new_items.len], new_items);
|
||||||
|
} else if (range.len < new_items.len) {
|
||||||
|
const first = new_items[0..range.len];
|
||||||
|
const rest = new_items[range.len..];
|
||||||
|
@memcpy(range[0..first.len], first);
|
||||||
|
try self.insertSlice(after_range, rest);
|
||||||
|
} else {
|
||||||
|
@memcpy(range[0..new_items.len], new_items);
|
||||||
|
const after_subrange = start + new_items.len;
|
||||||
|
for (self.constSlice()[after_range..], 0..) |item, i| {
|
||||||
|
self.slice()[after_subrange..][i] = item;
|
||||||
|
}
|
||||||
|
self.len -= len - new_items.len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extend the slice by 1 element.
|
||||||
|
pub fn append(self: *Self, item: T) error{Overflow}!void {
|
||||||
|
const new_item_ptr = try self.addOne();
|
||||||
|
new_item_ptr.* = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extend the slice by 1 element, asserting the capacity is already
|
||||||
|
/// enough to store the new item.
|
||||||
|
pub fn appendAssumeCapacity(self: *Self, item: T) void {
|
||||||
|
const new_item_ptr = self.addOneAssumeCapacity();
|
||||||
|
new_item_ptr.* = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove the element at index `i`, shift elements after index
|
||||||
|
/// `i` forward, and return the removed element.
|
||||||
|
/// Asserts the slice has at least one item.
|
||||||
|
/// This operation is O(N).
|
||||||
|
pub fn orderedRemove(self: *Self, i: usize) T {
|
||||||
|
const newlen = self.len - 1;
|
||||||
|
if (newlen == i) return self.pop().?;
|
||||||
|
const old_item = self.get(i);
|
||||||
|
for (self.slice()[i..newlen], 0..) |*b, j| b.* = self.get(i + 1 + j);
|
||||||
|
self.set(newlen, undefined);
|
||||||
|
self.len = newlen;
|
||||||
|
return old_item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove the element at the specified index and return it.
|
||||||
|
/// The empty slot is filled from the end of the slice.
|
||||||
|
/// This operation is O(1).
|
||||||
|
pub fn swapRemove(self: *Self, i: usize) T {
|
||||||
|
if (self.len - 1 == i) return self.pop().?;
|
||||||
|
const old_item = self.get(i);
|
||||||
|
self.set(i, self.pop().?);
|
||||||
|
return old_item;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append the slice of items to the slice.
|
||||||
|
pub fn appendSlice(self: *Self, items: []const T) error{Overflow}!void {
|
||||||
|
try self.ensureUnusedCapacity(items.len);
|
||||||
|
self.appendSliceAssumeCapacity(items);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append the slice of items to the slice, asserting the capacity is already
|
||||||
|
/// enough to store the new items.
|
||||||
|
pub fn appendSliceAssumeCapacity(self: *Self, items: []const T) void {
|
||||||
|
const old_len = self.len;
|
||||||
|
self.len += items.len;
|
||||||
|
@memcpy(self.slice()[old_len..][0..items.len], items);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append a value to the slice `n` times.
|
||||||
|
/// Allocates more memory as necessary.
|
||||||
|
pub fn appendNTimes(self: *Self, value: T, n: usize) error{Overflow}!void {
|
||||||
|
const old_len = self.len;
|
||||||
|
try self.resize(old_len + n);
|
||||||
|
@memset(self.slice()[old_len..self.len], value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append a value to the slice `n` times.
|
||||||
|
/// Asserts the capacity is enough.
|
||||||
|
pub fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void {
|
||||||
|
const old_len = self.len;
|
||||||
|
self.len += n;
|
||||||
|
assert(self.len <= buffer_capacity);
|
||||||
|
@memset(self.slice()[old_len..self.len], value);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Writer = if (T != u8)
|
||||||
|
@compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++
|
||||||
|
"but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)")
|
||||||
|
else
|
||||||
|
std.io.GenericWriter(*Self, error{Overflow}, appendWrite);
|
||||||
|
|
||||||
|
/// Initializes a writer which will write into the array.
|
||||||
|
pub fn writer(self: *Self) Writer {
|
||||||
|
return .{ .context = self };
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same as `appendSlice` except it returns the number of bytes written, which is always the same
|
||||||
|
/// as `m.len`. The purpose of this function existing is to match `std.io.GenericWriter` API.
|
||||||
|
fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize {
|
||||||
|
try self.appendSlice(m);
|
||||||
|
return m.len;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test BoundedArray {
|
||||||
|
var a = try BoundedArray(u8, 64).init(32);
|
||||||
|
|
||||||
|
try testing.expectEqual(a.capacity(), 64);
|
||||||
|
try testing.expectEqual(a.slice().len, 32);
|
||||||
|
try testing.expectEqual(a.constSlice().len, 32);
|
||||||
|
|
||||||
|
try a.resize(48);
|
||||||
|
try testing.expectEqual(a.len, 48);
|
||||||
|
|
||||||
|
const x = [_]u8{1} ** 10;
|
||||||
|
a = try BoundedArray(u8, 64).fromSlice(&x);
|
||||||
|
try testing.expectEqualSlices(u8, &x, a.constSlice());
|
||||||
|
|
||||||
|
var a2 = a;
|
||||||
|
try testing.expectEqualSlices(u8, a.constSlice(), a2.constSlice());
|
||||||
|
a2.set(0, 0);
|
||||||
|
try testing.expect(a.get(0) != a2.get(0));
|
||||||
|
|
||||||
|
try testing.expectError(error.Overflow, a.resize(100));
|
||||||
|
try testing.expectError(error.Overflow, BoundedArray(u8, x.len - 1).fromSlice(&x));
|
||||||
|
|
||||||
|
try a.resize(0);
|
||||||
|
try a.ensureUnusedCapacity(a.capacity());
|
||||||
|
(try a.addOne()).* = 0;
|
||||||
|
try a.ensureUnusedCapacity(a.capacity() - 1);
|
||||||
|
try testing.expectEqual(a.len, 1);
|
||||||
|
|
||||||
|
const uninitialized = try a.addManyAsArray(4);
|
||||||
|
try testing.expectEqual(uninitialized.len, 4);
|
||||||
|
try testing.expectEqual(a.len, 5);
|
||||||
|
|
||||||
|
try a.append(0xff);
|
||||||
|
try testing.expectEqual(a.len, 6);
|
||||||
|
try testing.expectEqual(a.pop(), 0xff);
|
||||||
|
|
||||||
|
a.appendAssumeCapacity(0xff);
|
||||||
|
try testing.expectEqual(a.len, 6);
|
||||||
|
try testing.expectEqual(a.pop(), 0xff);
|
||||||
|
|
||||||
|
try a.resize(1);
|
||||||
|
try testing.expectEqual(a.pop(), 0);
|
||||||
|
try testing.expectEqual(a.pop(), null);
|
||||||
|
var unused = a.unusedCapacitySlice();
|
||||||
|
@memset(unused[0..8], 2);
|
||||||
|
unused[8] = 3;
|
||||||
|
unused[9] = 4;
|
||||||
|
try testing.expectEqual(unused.len, a.capacity());
|
||||||
|
try a.resize(10);
|
||||||
|
|
||||||
|
try a.insert(5, 0xaa);
|
||||||
|
try testing.expectEqual(a.len, 11);
|
||||||
|
try testing.expectEqual(a.get(5), 0xaa);
|
||||||
|
try testing.expectEqual(a.get(9), 3);
|
||||||
|
try testing.expectEqual(a.get(10), 4);
|
||||||
|
|
||||||
|
try a.insert(11, 0xbb);
|
||||||
|
try testing.expectEqual(a.len, 12);
|
||||||
|
try testing.expectEqual(a.pop(), 0xbb);
|
||||||
|
|
||||||
|
try a.appendSlice(&x);
|
||||||
|
try testing.expectEqual(a.len, 11 + x.len);
|
||||||
|
|
||||||
|
try a.appendNTimes(0xbb, 5);
|
||||||
|
try testing.expectEqual(a.len, 11 + x.len + 5);
|
||||||
|
try testing.expectEqual(a.pop(), 0xbb);
|
||||||
|
|
||||||
|
a.appendNTimesAssumeCapacity(0xcc, 5);
|
||||||
|
try testing.expectEqual(a.len, 11 + x.len + 5 - 1 + 5);
|
||||||
|
try testing.expectEqual(a.pop(), 0xcc);
|
||||||
|
|
||||||
|
try testing.expectEqual(a.len, 29);
|
||||||
|
try a.replaceRange(1, 20, &x);
|
||||||
|
try testing.expectEqual(a.len, 29 + x.len - 20);
|
||||||
|
|
||||||
|
try a.insertSlice(0, &x);
|
||||||
|
try testing.expectEqual(a.len, 29 + x.len - 20 + x.len);
|
||||||
|
|
||||||
|
try a.replaceRange(1, 5, &x);
|
||||||
|
try testing.expectEqual(a.len, 29 + x.len - 20 + x.len + x.len - 5);
|
||||||
|
|
||||||
|
try a.append(10);
|
||||||
|
try testing.expectEqual(a.pop(), 10);
|
||||||
|
|
||||||
|
try a.append(20);
|
||||||
|
const removed = a.orderedRemove(5);
|
||||||
|
try testing.expectEqual(removed, 1);
|
||||||
|
try testing.expectEqual(a.len, 34);
|
||||||
|
|
||||||
|
a.set(0, 0xdd);
|
||||||
|
a.set(a.len - 1, 0xee);
|
||||||
|
const swapped = a.swapRemove(0);
|
||||||
|
try testing.expectEqual(swapped, 0xdd);
|
||||||
|
try testing.expectEqual(a.get(0), 0xee);
|
||||||
|
|
||||||
|
const added_slice = try a.addManyAsSlice(3);
|
||||||
|
try testing.expectEqual(added_slice.len, 3);
|
||||||
|
try testing.expectEqual(a.len, 36);
|
||||||
|
|
||||||
|
while (a.pop()) |_| {}
|
||||||
|
const w = a.writer();
|
||||||
|
const s = "hello, this is a test string";
|
||||||
|
try w.writeAll(s);
|
||||||
|
try testing.expectEqualStrings(s, a.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
|
test "BoundedArrayAligned" {
|
||||||
|
var a = try BoundedArrayAligned(u8, .@"16", 4).init(0);
|
||||||
|
try a.append(0);
|
||||||
|
try a.append(0);
|
||||||
|
try a.append(255);
|
||||||
|
try a.append(255);
|
||||||
|
|
||||||
|
const b = @as(*const [2]u16, @ptrCast(a.constSlice().ptr));
|
||||||
|
try testing.expectEqual(@as(u16, 0), b[0]);
|
||||||
|
try testing.expectEqual(@as(u16, 65535), b[1]);
|
||||||
|
}
|
||||||
@ -47,8 +47,7 @@ const debug = @import("debug.zig");
|
|||||||
|
|
||||||
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
||||||
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
|
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
|
||||||
const stderr = std.io.getStdErr().writer();
|
std.debug.print("error: " ++ fmt_string ++ "\n", args);
|
||||||
stderr.print("error: " ++ fmt_string ++ "\n", args) catch {};
|
|
||||||
std.posix.exit(1);
|
std.posix.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
32
stdx/fmt.zig
32
stdx/fmt.zig
@ -43,8 +43,8 @@ pub const IntFmt = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const FloatFmt = enum(u8) {
|
pub const FloatFmt = enum(u8) {
|
||||||
scientific = @intFromEnum(std.fmt.format_float.Format.scientific),
|
scientific = @intFromEnum(std.fmt.Number.Mode.scientific),
|
||||||
decimal = @intFromEnum(std.fmt.format_float.Format.decimal),
|
decimal = @intFromEnum(std.fmt.Number.Mode.decimal),
|
||||||
hex,
|
hex,
|
||||||
|
|
||||||
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
|
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
|
||||||
@ -71,44 +71,34 @@ pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !vo
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||||
const formatFloat = std.fmt.format_float.formatFloat;
|
|
||||||
var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined;
|
|
||||||
|
|
||||||
const x = switch (@typeInfo(@TypeOf(value))) {
|
const x = switch (@typeInfo(@TypeOf(value))) {
|
||||||
.@"struct" => value.toF32(),
|
.@"struct" => value.toF32(),
|
||||||
.float => value,
|
.float => value,
|
||||||
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
|
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
|
||||||
};
|
};
|
||||||
const s_or_err = switch (full.fmt.float) {
|
try switch (full.fmt.float) {
|
||||||
.scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }),
|
.scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }),
|
||||||
.decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }),
|
.decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }),
|
||||||
.hex => hex: {
|
.hex => writer.printFloatHexOptions(x, .{ .mode = .hex }),
|
||||||
var buf_stream = std.io.fixedBufferStream(&buf);
|
|
||||||
std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable;
|
|
||||||
break :hex buf_stream.getWritten();
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const s = s_or_err catch "(float)";
|
|
||||||
return std.fmt.formatBuf(s, full.options, writer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||||
switch (@typeInfo(@TypeOf(value))) {
|
switch (@typeInfo(@TypeOf(value))) {
|
||||||
.int => {},
|
.int => {},
|
||||||
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
|
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
|
||||||
}
|
}
|
||||||
return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer);
|
return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||||
var buf: [48]u8 = undefined;
|
var buf: [48]u8 = undefined;
|
||||||
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
|
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
|
||||||
buf[45..].* = "...".*;
|
buf[45..].* = "...".*;
|
||||||
break :blk buf[0..];
|
break :blk buf[0..];
|
||||||
};
|
};
|
||||||
return std.fmt.formatBuf(s, full.options, writer);
|
return try writer.alignBufferOptions(s, full.options);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
|||||||
var num_buf: [8]u8 = undefined;
|
var num_buf: [8]u8 = undefined;
|
||||||
tuple_fields[i] = .{
|
tuple_fields[i] = .{
|
||||||
.name = blk: {
|
.name = blk: {
|
||||||
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
|
const s = std.fmt.printInt(&num_buf, i, 10, .lower, .{});
|
||||||
num_buf[s] = 0;
|
num_buf[s] = 0;
|
||||||
break :blk num_buf[0..s :0];
|
break :blk num_buf[0..s :0];
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
pub const BoundedArray = @import("bounded_array.zig").BoundedArray;
|
||||||
|
pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned;
|
||||||
pub const debug = @import("debug.zig");
|
pub const debug = @import("debug.zig");
|
||||||
pub const flags = @import("flags.zig");
|
pub const flags = @import("flags.zig");
|
||||||
pub const fmt = @import("fmt.zig");
|
pub const fmt = @import("fmt.zig");
|
||||||
|
|||||||
@ -11,14 +11,11 @@ pub const Duration = struct {
|
|||||||
return (1 * std.time.ns_per_s) / self.ns;
|
return (1 * std.time.ns_per_s) / self.ns;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn formatDuration(duration: Duration, writer: *std.io.Writer) std.io.Writer.Error!void {
|
||||||
self: Duration,
|
try writer.printDuration(duration.ns, .{});
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) @TypeOf(writer).Error!void {
|
|
||||||
return try std.fmt.fmtDuration(self.ns).format(fmt, options, writer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const format = formatDuration;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Timer = struct {
|
pub const Timer = struct {
|
||||||
|
|||||||
2
third_party/com_github_hejsil_clap/repo.bzl
vendored
2
third_party/com_github_hejsil_clap/repo.bzl
vendored
@ -4,6 +4,6 @@ def repo():
|
|||||||
new_git_repository(
|
new_git_repository(
|
||||||
name = "com_github_hejsil_clap",
|
name = "com_github_hejsil_clap",
|
||||||
remote = "https://github.com/Hejsil/zig-clap.git",
|
remote = "https://github.com/Hejsil/zig-clap.git",
|
||||||
commit = "068c38f89814079635692c7d0be9f58508c86173",
|
commit = "5289e0753cd274d65344bef1c114284c633536ea",
|
||||||
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
|
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
|
||||||
)
|
)
|
||||||
|
|||||||
75
third_party/modules/rules_zig/20250827.0-35b6d57/MODULE.bazel
vendored
Normal file
75
third_party/modules/rules_zig/20250827.0-35b6d57/MODULE.bazel
vendored
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
module(
|
||||||
|
name = "rules_zig",
|
||||||
|
version = "20250827.0-35b6d57",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
|
||||||
|
zig = use_extension("//zig:extensions.bzl", "zig")
|
||||||
|
zig.index(file = "//zig/private:versions.json")
|
||||||
|
use_repo(zig, "zig_toolchains")
|
||||||
|
|
||||||
|
register_toolchains("@rules_zig//zig/target:all")
|
||||||
|
|
||||||
|
register_toolchains("@zig_toolchains//:all")
|
||||||
|
|
||||||
|
zig_dev = use_extension(
|
||||||
|
"//zig:extensions.bzl",
|
||||||
|
"zig",
|
||||||
|
dev_dependency = True,
|
||||||
|
)
|
||||||
|
zig_dev.toolchain(zig_version = "0.13.0")
|
||||||
|
zig_dev.toolchain(zig_version = "0.12.1")
|
||||||
|
zig_dev.toolchain(zig_version = "0.12.0")
|
||||||
|
zig_dev.toolchain(zig_version = "0.11.0")
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
|
|
||||||
|
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
|
||||||
|
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
|
||||||
|
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
|
||||||
|
bazel_dep(
|
||||||
|
name = "buildifier_prebuilt",
|
||||||
|
version = "7.3.1",
|
||||||
|
dev_dependency = True,
|
||||||
|
)
|
||||||
|
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
|
||||||
|
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
|
||||||
|
bazel_dep(
|
||||||
|
name = "rules_bazel_integration_test",
|
||||||
|
version = "0.25.0",
|
||||||
|
dev_dependency = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_binaries = use_extension(
|
||||||
|
"@rules_bazel_integration_test//:extensions.bzl",
|
||||||
|
"bazel_binaries",
|
||||||
|
dev_dependency = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Keep in sync with WORKSPACE.
|
||||||
|
bazel_binaries.download(version_file = "//:.bazelversion")
|
||||||
|
bazel_binaries.download(version = "7.0.0")
|
||||||
|
use_repo(
|
||||||
|
bazel_binaries,
|
||||||
|
"bazel_binaries",
|
||||||
|
"bazel_binaries_bazelisk",
|
||||||
|
"build_bazel_bazel_.bazelversion",
|
||||||
|
"build_bazel_bazel_7_0_0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
|
||||||
|
# However, if we do not include it explicitly, then the runfiles resolution for
|
||||||
|
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
|
||||||
|
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
|
||||||
|
# through the rules_multirun target //util:update.
|
||||||
|
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
|
||||||
|
|
||||||
|
# Hack to get around a cc_common.link limitation.
|
||||||
|
# See https://github.com/bazelbuild/bazel/pull/23838
|
||||||
|
cc_common_link = use_repo_rule("//zig:extensions.bzl", "cc_common_link")
|
||||||
|
|
||||||
|
cc_common_link(name = "build_bazel_rules_android")
|
||||||
5
third_party/modules/rules_zig/20250827.0-35b6d57/source.json
vendored
Normal file
5
third_party/modules/rules_zig/20250827.0-35b6d57/source.json
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "rules_zig-35b6d57e94f3eb08e978c1133314606f4f38e216",
|
||||||
|
"url": "https://github.com/zml/rules_zig/archive/35b6d57e94f3eb08e978c1133314606f4f38e216.tar.gz",
|
||||||
|
"integrity": "sha256-FDnAqynTD2LB3W/IaBgocmnLz8CA9nyHZYYntC4plUU="
|
||||||
|
}
|
||||||
3
third_party/modules/rules_zig/metadata.json
vendored
3
third_party/modules/rules_zig/metadata.json
vendored
@ -17,7 +17,8 @@
|
|||||||
"20250519.0-233b207",
|
"20250519.0-233b207",
|
||||||
"20250613.0-567662a",
|
"20250613.0-567662a",
|
||||||
"20250714.0-b14a4f1",
|
"20250714.0-b14a4f1",
|
||||||
"20250821.0-be53625"
|
"20250821.0-be53625",
|
||||||
|
"20250827.0-35b6d57"
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,15 +1,7 @@
|
|||||||
load("@rules_python//python:py_library.bzl", "py_library")
|
load("@rules_python//python:py_library.bzl", "py_library")
|
||||||
load("@rules_python//python/entry_points:py_console_script_binary.bzl", "py_console_script_binary")
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "zml_utils",
|
name = "zml_utils",
|
||||||
srcs = ["zml_utils.py"],
|
srcs = ["zml_utils.py"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_console_script_binary(
|
|
||||||
name = "hf",
|
|
||||||
pkg = "@huggingface_hub//huggingface_hub:pkg",
|
|
||||||
script = "hf",
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|||||||
@ -153,6 +153,6 @@ pub const Allocator = struct {
|
|||||||
@panic("Unsupported case");
|
@panic("Unsupported case");
|
||||||
}
|
}
|
||||||
|
|
||||||
return @ptrCast(self.allocator.alignedAlloc(u8, @alignOf(*anyopaque), size) catch return null);
|
return @ptrCast(self.allocator.alignedAlloc(u8, std.mem.Alignment.of(*anyopaque), size) catch return null);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -608,7 +608,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
|
|||||||
if (std.mem.startsWith(u8, key, base_key)) {
|
if (std.mem.startsWith(u8, key, base_key)) {
|
||||||
if (matches == 0) log.warn("Similar buffers found:", .{});
|
if (matches == 0) log.warn("Similar buffers found:", .{});
|
||||||
if (!shown_keys.contains(key)) {
|
if (!shown_keys.contains(key)) {
|
||||||
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
||||||
shown_keys.put(key, {}) catch continue;
|
shown_keys.put(key, {}) catch continue;
|
||||||
matches += 1;
|
matches += 1;
|
||||||
}
|
}
|
||||||
@ -625,7 +625,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
|
|||||||
const key = entry.key_ptr.*;
|
const key = entry.key_ptr.*;
|
||||||
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
||||||
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
||||||
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() });
|
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
||||||
shown_keys.put(key, {}) catch continue;
|
shown_keys.put(key, {}) catch continue;
|
||||||
matches += 1;
|
matches += 1;
|
||||||
if (matches >= 5) break;
|
if (matches >= 5) break;
|
||||||
@ -660,8 +660,8 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
return if (buffer_store.get(prefix)) |host_buffer| {
|
return if (buffer_store.get(prefix)) |host_buffer| {
|
||||||
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
||||||
var buf_with_metadata = host_buffer;
|
var buf_with_metadata = host_buffer;
|
||||||
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape });
|
log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape });
|
||||||
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {f} and {f} for tensor {s}", .{ obj._shape, host_buffer, prefix });
|
||||||
buf_with_metadata._shape = obj._shape;
|
buf_with_metadata._shape = obj._shape;
|
||||||
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
|
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -69,7 +69,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix:
|
|||||||
var new_prefix = prefix;
|
var new_prefix = prefix;
|
||||||
if (prefix.items.len > 0)
|
if (prefix.items.len > 0)
|
||||||
new_prefix.appendAssumeCapacity('.');
|
new_prefix.appendAssumeCapacity('.');
|
||||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||||
try parseMetadata(allocator, store, new_prefix, item);
|
try parseMetadata(allocator, store, new_prefix, item);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const zml = @import("../zml.zig");
|
const Allocator = std.mem.Allocator;
|
||||||
const json = @import("json.zig");
|
|
||||||
const HostBuffer = zml.HostBuffer;
|
const asynk = @import("async");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
||||||
|
const zml = @import("../zml.zig");
|
||||||
|
const HostBuffer = zml.HostBuffer;
|
||||||
|
const json = @import("json.zig");
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||||
const Allocator = std.mem.Allocator;
|
|
||||||
const log = std.log.scoped(.@"zml/io");
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
@ -16,7 +19,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
errdefer res.arena.deinit();
|
errdefer res.arena.deinit();
|
||||||
const arena = res.arena.allocator();
|
const arena = res.arena.allocator();
|
||||||
|
|
||||||
var files = std.ArrayList(MemoryMappedFile).init(arena);
|
var files = std.array_list.Managed(MemoryMappedFile).init(arena);
|
||||||
errdefer files.deinit();
|
errdefer files.deinit();
|
||||||
|
|
||||||
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
|
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
|
||||||
@ -28,17 +31,19 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
|
||||||
const file = asynk.File.open(path, .{}) catch |err| {
|
const file = asynk.File.open(path, .{}) catch |err| {
|
||||||
log.err("Failed to open {s}: {}", .{ path, err });
|
log.err("Failed to open {s}: {}", .{ path, err });
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
errdefer file.close() catch unreachable;
|
errdefer file.close() catch unreachable;
|
||||||
var r = file.reader();
|
var buffer: [4096]u8 = undefined;
|
||||||
|
var r = file.reader(&buffer);
|
||||||
|
|
||||||
const json_data = try allocator.alloc(u8, (try file.stat()).size);
|
// const json_data = try allocator.alloc(u8, (try file.stat()).size);
|
||||||
_ = try r.readAtLeast(json_data, json_data.len);
|
var json_reader = std.json.Reader.init(allocator, &r.interface);
|
||||||
const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
|
// _ = try r.readAtLeast(json_data, json_data.len);
|
||||||
|
const index = try std.json.parseFromTokenSourceLeaky(std.json.Value, allocator, &json_reader, .{ .allocate = .alloc_if_needed });
|
||||||
var loaded_files = std.StringHashMap(void).init(allocator);
|
var loaded_files = std.StringHashMap(void).init(allocator);
|
||||||
|
|
||||||
const weight_map = index.object.get("weight_map").?.object;
|
const weight_map = index.object.get("weight_map").?.object;
|
||||||
@ -62,23 +67,24 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
|
||||||
const file = asynk.File.open(path, .{}) catch |err| {
|
const file = asynk.File.open(path, .{}) catch |err| {
|
||||||
log.err("Failed to open {s}: {}", .{ path, err });
|
log.err("Failed to open {s}: {}", .{ path, err });
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
errdefer file.close() catch unreachable;
|
errdefer file.close() catch unreachable;
|
||||||
var r = file.reader();
|
var buffer: [16 * 1024]u8 = undefined;
|
||||||
|
var r = file.reader(&buffer);
|
||||||
|
|
||||||
const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little));
|
const json_header_length: usize = @intCast(try r.interface.takeInt(u64, .little));
|
||||||
const json_data = try allocator.alloc(u8, json_header_length);
|
const json_data = try allocator.alloc(u8, json_header_length);
|
||||||
const n = try r.readAll(json_data);
|
try r.interface.readSliceAll(json_data);
|
||||||
if (n != json_header_length) {
|
// if (n != json_header_length) {
|
||||||
log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
|
// log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
|
||||||
return error.CorruptedFile;
|
// return error.CorruptedFile;
|
||||||
}
|
// }
|
||||||
|
|
||||||
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{});
|
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{});
|
||||||
var buffer_file = try MemoryMappedFile.init(file);
|
var buffer_file = try MemoryMappedFile.init(file);
|
||||||
errdefer buffer_file.deinit();
|
errdefer buffer_file.deinit();
|
||||||
buffer_file.data_offset = 8 + json_header_length;
|
buffer_file.data_offset = 8 + json_header_length;
|
||||||
@ -105,7 +111,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
|
|||||||
const start: usize = @intCast(offset_field.array.items[0].integer);
|
const start: usize = @intCast(offset_field.array.items[0].integer);
|
||||||
const end: usize = @intCast(offset_field.array.items[1].integer);
|
const end: usize = @intCast(offset_field.array.items[1].integer);
|
||||||
const dtype = try stringToDtype(val.object.get("dtype").?.string);
|
const dtype = try stringToDtype(val.object.get("dtype").?.string);
|
||||||
var dims: std.BoundedArray(i64, zml.Shape.MAX_RANK) = .{};
|
var dims: stdx.BoundedArray(i64, zml.Shape.MAX_RANK) = .{};
|
||||||
for (shape_field.items) |d| {
|
for (shape_field.items) |d| {
|
||||||
dims.appendAssumeCapacity(d.integer);
|
dims.appendAssumeCapacity(d.integer);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,7 +45,7 @@ pub const Buffer = struct {
|
|||||||
_shards: Shards,
|
_shards: Shards,
|
||||||
|
|
||||||
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||||
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||||
|
|
||||||
pub const FromOptions = struct {
|
pub const FromOptions = struct {
|
||||||
wait: bool = true,
|
wait: bool = true,
|
||||||
@ -67,7 +67,7 @@ pub const Buffer = struct {
|
|||||||
const n_partitions = platform.sharding().num_partitions;
|
const n_partitions = platform.sharding().num_partitions;
|
||||||
const chunk_size = if (sharding_ax) |ax| cs: {
|
const chunk_size = if (sharding_ax) |ax| cs: {
|
||||||
// This kind of sharding error should be detected earlier on.
|
// This kind of sharding error should be detected earlier on.
|
||||||
stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
|
stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({f}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
|
||||||
break :cs @divExact(host_buffer.dim(ax), n_partitions);
|
break :cs @divExact(host_buffer.dim(ax), n_partitions);
|
||||||
} else 0;
|
} else 0;
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ pub const Buffer = struct {
|
|||||||
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
||||||
if (duration_ms > 100) {
|
if (duration_ms > 100) {
|
||||||
const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024);
|
const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024);
|
||||||
log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 });
|
log.debug("Wrote constant({f}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,7 +301,7 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||||
pub fn getValue(self: Buffer, T: type) !T {
|
pub fn getValue(self: Buffer, T: type) !T {
|
||||||
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||||
var res: T = undefined;
|
var res: T = undefined;
|
||||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
|
||||||
@ -375,13 +375,9 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
pub fn format(
|
pub fn format(
|
||||||
self: Buffer,
|
self: Buffer,
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = fmt;
|
try writer.print("Buffer({f})", .{self._shape});
|
||||||
_ = options;
|
|
||||||
try writer.print("Buffer({_})", .{self._shape});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
||||||
|
|||||||
@ -224,7 +224,7 @@ const CustomCall = struct {
|
|||||||
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
|
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error {
|
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
|
||||||
if (call_frame.registeringHook()) return null;
|
if (call_frame.registeringHook()) return null;
|
||||||
|
|
||||||
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
|
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
|
||||||
@ -275,3 +275,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
|||||||
buffer_desc.data[0..buffer_shape.byteSize()],
|
buffer_desc.data[0..buffer_shape.byteSize()],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const cuda = struct {
|
||||||
|
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
|
||||||
|
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
|
||||||
|
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
|
||||||
|
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
|
||||||
|
|
||||||
|
pub const MemcpyKind = enum(c_int) {
|
||||||
|
host_to_host = 0,
|
||||||
|
host_to_device = 1,
|
||||||
|
device_to_host = 2,
|
||||||
|
device_to_device = 3,
|
||||||
|
inferred = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int;
|
||||||
|
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int;
|
||||||
|
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int;
|
||||||
|
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
|
||||||
|
|
||||||
|
pub fn init() void {
|
||||||
|
var cudart = std.DynLib.open("libcudart.so.12") catch {
|
||||||
|
log.err("cudart not found, callback will segfault", .{});
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
defer cudart.close();
|
||||||
|
|
||||||
|
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
|
||||||
|
@panic("cudaMemcpyAsync not found");
|
||||||
|
};
|
||||||
|
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
|
||||||
|
@panic("cudaMemcpy not found");
|
||||||
|
};
|
||||||
|
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
|
||||||
|
@panic("cudaStreamSynchronize not found");
|
||||||
|
};
|
||||||
|
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
|
||||||
|
@panic("cudaLaunchHostFunc not found");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
|
||||||
|
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
|
||||||
|
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
|
||||||
|
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
|
||||||
|
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
|
||||||
|
check(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(err: c_int) void {
|
||||||
|
if (err == 0) return;
|
||||||
|
stdx.debug.panic("CUDA error: {d}", .{err});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -394,8 +394,8 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
|
|||||||
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
||||||
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
||||||
const model_sharding = ctx.buffers.len;
|
const model_sharding = ctx.buffers.len;
|
||||||
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
||||||
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
|
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
|
||||||
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
||||||
ctx.buffers[d][ctx.index] = shard;
|
ctx.buffers[d][ctx.index] = shard;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,17 +87,10 @@ fn FloatHelpers(Float: type) type {
|
|||||||
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
|
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
float: Float,
|
switch (n.mode) {
|
||||||
comptime fmt: []const u8,
|
.binary, .octal, .hex => try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ x.sign, x.exponent, x.mantissa }),
|
||||||
options: std.fmt.FormatOptions,
|
else => try writer.printFloat(x.toF32(), n),
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
_ = options;
|
|
||||||
if (fmt.len == 1 and fmt[0] == '_') {
|
|
||||||
try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ float.sign, float.exponent, float.mantissa });
|
|
||||||
} else {
|
|
||||||
try writer.print("{" ++ fmt ++ "}", .{float.toF32()});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -113,7 +106,7 @@ pub const Float32 = packed struct(u32) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
const f32_exp_bias = FloatHelpers(Float32).expBias();
|
const f32_exp_bias = FloatHelpers(Float32).expBias();
|
||||||
@ -128,7 +121,7 @@ pub const Float64 = packed struct(u64) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Float8E4M3B11FNUZ = packed struct(u8) {
|
pub const Float8E4M3B11FNUZ = packed struct(u8) {
|
||||||
@ -151,7 +144,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Float8E4M3FN = packed struct(u8) {
|
pub const Float8E4M3FN = packed struct(u8) {
|
||||||
@ -169,7 +162,7 @@ pub const Float8E4M3FN = packed struct(u8) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Float8E4M3FNUZ = packed struct(u8) {
|
pub const Float8E4M3FNUZ = packed struct(u8) {
|
||||||
@ -192,7 +185,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
test "Float8E4" {
|
test "Float8E4" {
|
||||||
@ -247,7 +240,7 @@ pub const Float8E5M2 = packed struct(u8) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Float8E5M2FNUZ = packed struct(u8) {
|
pub const Float8E5M2FNUZ = packed struct(u8) {
|
||||||
@ -266,7 +259,7 @@ pub const Float8E5M2FNUZ = packed struct(u8) {
|
|||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
pub const fromF32 = Helpers.fromF32;
|
||||||
pub const toF32 = Helpers.toF32;
|
pub const toF32 = Helpers.toF32;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
test "Float8E5" {
|
test "Float8E5" {
|
||||||
@ -322,7 +315,7 @@ pub const BFloat16 = packed struct(u16) {
|
|||||||
const Helpers = FloatHelpers(@This());
|
const Helpers = FloatHelpers(@This());
|
||||||
pub const zero = Helpers.zero;
|
pub const zero = Helpers.zero;
|
||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const format = Helpers.format;
|
pub const formatNumber = Helpers.formatNumber;
|
||||||
};
|
};
|
||||||
|
|
||||||
test BFloat16 {
|
test BFloat16 {
|
||||||
|
|||||||
@ -31,7 +31,7 @@ pub const HostBuffer = struct {
|
|||||||
return .{
|
return .{
|
||||||
._shape = sh,
|
._shape = sh,
|
||||||
._strides = sh.computeStrides().buffer,
|
._strides = sh.computeStrides().buffer,
|
||||||
._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr,
|
._data = (try allocator.alignedAlloc(u8, .@"64", sh.byteSize())).ptr,
|
||||||
._memory = .{ .managed = .@"64" },
|
._memory = .{ .managed = .@"64" },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -170,8 +170,8 @@ pub const HostBuffer = struct {
|
|||||||
/// Strided buffers can't use this method.
|
/// Strided buffers can't use this method.
|
||||||
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
||||||
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
||||||
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
|
||||||
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
|
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
|
||||||
return ptr[0..self._shape.count()];
|
return ptr[0..self._shape.count()];
|
||||||
}
|
}
|
||||||
@ -181,7 +181,7 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn bytes(self: HostBuffer) []const u8 {
|
pub fn bytes(self: HostBuffer) []const u8 {
|
||||||
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
||||||
return self._data[0..self._shape.byteSize()];
|
return self._data[0..self._shape.byteSize()];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
||||||
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {f}", .{self});
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.reshape(shape_);
|
res._shape = self._shape.reshape(shape_);
|
||||||
res._strides = res._shape.computeStrides().buffer;
|
res._strides = res._shape.computeStrides().buffer;
|
||||||
@ -252,9 +252,9 @@ pub const HostBuffer = struct {
|
|||||||
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
||||||
var end = s.end orelse d;
|
var end = s.end orelse d;
|
||||||
if (end < 0) end += d;
|
if (end < 0) end += d;
|
||||||
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
|
stdx.debug.assert(start >= 0 and start < d, "slice1d({f}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
|
||||||
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
stdx.debug.assert(end >= 1 and end <= d, "slice1d({f}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
||||||
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
stdx.debug.assert(start < end, "slice1d({f}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
||||||
|
|
||||||
const offset: usize = @intCast(start * self._strides[ax]);
|
const offset: usize = @intCast(start * self._strides[ax]);
|
||||||
const new_shape = self.shape().set(ax, end - start);
|
const new_shape = self.shape().set(ax, end - start);
|
||||||
@ -308,9 +308,9 @@ pub const HostBuffer = struct {
|
|||||||
|
|
||||||
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
||||||
const ax = self._shape.axis(axis_);
|
const ax = self._shape.axis(axis_);
|
||||||
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {f}", .{ ax, self });
|
||||||
|
|
||||||
var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
|
var strd: stdx.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
|
||||||
_ = strd.orderedRemove(ax);
|
_ = strd.orderedRemove(ax);
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
@ -323,16 +323,11 @@ pub const HostBuffer = struct {
|
|||||||
|
|
||||||
pub fn format(
|
pub fn format(
|
||||||
self: HostBuffer,
|
self: HostBuffer,
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = options;
|
// TODO debug option
|
||||||
if (std.mem.eql(u8, fmt, "v")) {
|
// try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||||
try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
try writer.print("HostBuffer(.{f})", .{self._shape});
|
||||||
} else {
|
|
||||||
try writer.print("HostBuffer(.{_})", .{self._shape});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
/// Formatter for a HostBuffer that also print the values not just the shape.
|
||||||
@ -344,21 +339,23 @@ pub const HostBuffer = struct {
|
|||||||
pub const PrettyPrinter = struct {
|
pub const PrettyPrinter = struct {
|
||||||
x: HostBuffer,
|
x: HostBuffer,
|
||||||
|
|
||||||
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
// TODO(0.15.0) revisit pretty printer
|
||||||
|
pub fn format(self: PrettyPrinter, writer: anytype) !void {
|
||||||
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
|
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
|
||||||
.integer => .parse(i32, fmt),
|
.integer => .parse(i32, "d"),
|
||||||
.float => .parse(f32, fmt),
|
.float => .parse(f32, "d"),
|
||||||
else => .parse(void, fmt),
|
else => .parse(void, ""),
|
||||||
};
|
};
|
||||||
|
const options: std.fmt.FormatOptions = .{};
|
||||||
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
|
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void {
|
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void {
|
||||||
return self.prettyPrintIndented(writer, 4, 0, options);
|
return self.prettyPrintIndented(writer, 4, 0, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
|
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
|
||||||
if (self.rank() == 0) {
|
if (self.rank() == 0) {
|
||||||
// Special case input tensor is a scalar
|
// Special case input tensor is a scalar
|
||||||
return switch (self.dtype()) {
|
return switch (self.dtype()) {
|
||||||
@ -376,7 +373,7 @@ pub const HostBuffer = struct {
|
|||||||
if (self.rank() == 1) {
|
if (self.rank() == 1) {
|
||||||
// Print a contiguous slice of items from the buffer in one line.
|
// Print a contiguous slice of items from the buffer in one line.
|
||||||
// The number of items printed is controlled by the user through format syntax.
|
// The number of items printed is controlled by the user through format syntax.
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
try writer.splatByteAll(' ', indent_level);
|
||||||
switch (self.dtype()) {
|
switch (self.dtype()) {
|
||||||
inline else => |dt| {
|
inline else => |dt| {
|
||||||
const values = self.items(dt.toZigType());
|
const values = self.items(dt.toZigType());
|
||||||
@ -391,10 +388,10 @@ pub const HostBuffer = struct {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO: consider removing the \n if dim is 1 for this axis.
|
// TODO: consider removing the \n if dim is 1 for this axis.
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
try writer.splatByteAll(' ', indent_level);
|
||||||
_ = try writer.write("{\n");
|
_ = try writer.write("{\n");
|
||||||
defer {
|
defer {
|
||||||
writer.writeByteNTimes(' ', indent_level) catch {};
|
writer.splatByteAll(' ', indent_level) catch {};
|
||||||
_ = writer.write("},\n") catch {};
|
_ = writer.write("},\n") catch {};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,7 +406,7 @@ pub const HostBuffer = struct {
|
|||||||
if (n < num_rows) return;
|
if (n < num_rows) return;
|
||||||
// Skip middle rows
|
// Skip middle rows
|
||||||
if (n > 2 * num_rows) {
|
if (n > 2 * num_rows) {
|
||||||
try writer.writeByteNTimes(' ', indent_level + 2);
|
try writer.splatByteAll(' ', indent_level + 2);
|
||||||
_ = try writer.write("...\n");
|
_ = try writer.write("...\n");
|
||||||
}
|
}
|
||||||
// Write last rows
|
// Write last rows
|
||||||
|
|||||||
15
zml/meta.zig
15
zml/meta.zig
@ -358,10 +358,11 @@ pub fn MapRestrict(From: type, To: type) type {
|
|||||||
const fields = union_info.fields;
|
const fields = union_info.fields;
|
||||||
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
|
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
|
||||||
for (0.., fields) |i, field| {
|
for (0.., fields) |i, field| {
|
||||||
|
const FT = map(field.type);
|
||||||
union_fields[i] = .{
|
union_fields[i] = .{
|
||||||
.name = field.name,
|
.name = field.name,
|
||||||
.type = map(field.type),
|
.type = FT,
|
||||||
.alignment = 0,
|
.alignment = @alignOf(FT),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return @Type(.{ .@"union" = .{
|
return @Type(.{ .@"union" = .{
|
||||||
@ -453,7 +454,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
else => {},
|
else => {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle std.BoundedArray that contains uninitalized data.
|
// Handle stdx.BoundedArray that contains uninitalized data.
|
||||||
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
|
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
|
||||||
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
|
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
|
||||||
}
|
}
|
||||||
@ -511,7 +512,7 @@ test visit {
|
|||||||
const NestedAttrOptional = struct { nested: ?Attr };
|
const NestedAttrOptional = struct { nested: ?Attr };
|
||||||
const SimpleStruct = struct { prop: Attr };
|
const SimpleStruct = struct { prop: Attr };
|
||||||
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
|
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
|
||||||
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) };
|
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: stdx.BoundedArray(Attr, 8) };
|
||||||
|
|
||||||
const LocalContext = struct { result: usize };
|
const LocalContext = struct { result: usize };
|
||||||
|
|
||||||
@ -565,7 +566,7 @@ test visit {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
var context: LocalContext = .{ .result = 0 };
|
var context: LocalContext = .{ .result = 0 };
|
||||||
const prop5: std.BoundedArray(Attr, 8) = .{
|
const prop5: stdx.BoundedArray(Attr, 8) = .{
|
||||||
.buffer = @splat(.{ .data = 4 }),
|
.buffer = @splat(.{ .data = 4 }),
|
||||||
.len = 2,
|
.len = 2,
|
||||||
};
|
};
|
||||||
@ -677,11 +678,11 @@ test zip {
|
|||||||
|
|
||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
||||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
func_ctx: _CollectCtx(func),
|
func_ctx: _CollectCtx(func),
|
||||||
out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT),
|
out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT),
|
||||||
oom: bool = false,
|
oom: bool = false,
|
||||||
};
|
};
|
||||||
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
||||||
|
|||||||
@ -51,7 +51,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
_module: mlir.Module,
|
_module: mlir.Module,
|
||||||
|
|
||||||
_blocks: std.BoundedArray(TaggedBlock, 64) = .{},
|
_blocks: stdx.BoundedArray(TaggedBlock, 64) = .{},
|
||||||
_fn_cache: FnCache = .{},
|
_fn_cache: FnCache = .{},
|
||||||
|
|
||||||
_block_args: TensorToBlockArg = .{},
|
_block_args: TensorToBlockArg = .{},
|
||||||
@ -63,7 +63,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
|
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
|
||||||
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
||||||
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
const AttributeList = stdx.BoundedArray(mlir.NamedAttribute, 3);
|
||||||
|
|
||||||
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
|
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
|
||||||
const mlir_registry = mlir.Registry.init() catch unreachable;
|
const mlir_registry = mlir.Registry.init() catch unreachable;
|
||||||
@ -185,7 +185,9 @@ pub const CompilationContext = struct {
|
|||||||
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
|
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
|
||||||
const mlir_name = "module.mlir";
|
const mlir_name = "module.mlir";
|
||||||
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
|
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
|
||||||
module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false });
|
var write_buf: [4096]u8 = undefined;
|
||||||
|
var writer = file.writer(&write_buf);
|
||||||
|
module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
|
||||||
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
|
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
|
||||||
} else |_| {
|
} else |_| {
|
||||||
log.warn("Failed to open {s}", .{mlir_name});
|
log.warn("Failed to open {s}", .{mlir_name});
|
||||||
@ -219,7 +221,7 @@ pub const CompilationContext = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
log.debug("******** ZML generated MLIR ********", .{});
|
log.debug("******** ZML generated MLIR ********", .{});
|
||||||
log.debug("{}", .{module.op().mlirFormatter(.{})});
|
log.debug("{f}", .{module.op().mlirFormatter(.{})});
|
||||||
|
|
||||||
if (timer) |*t| {
|
if (timer) |*t| {
|
||||||
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
|
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
|
||||||
@ -339,7 +341,7 @@ pub const CompilationContext = struct {
|
|||||||
const locations = try arena.alloc(mlir.Location, tensor_count);
|
const locations = try arena.alloc(mlir.Location, tensor_count);
|
||||||
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
||||||
|
|
||||||
var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count);
|
var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count);
|
||||||
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
||||||
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||||
|
|
||||||
@ -416,7 +418,7 @@ pub const CompilationContext = struct {
|
|||||||
defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
|
defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
|
||||||
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
|
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
|
||||||
error.InvalidMlir => {
|
error.InvalidMlir => {
|
||||||
log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})});
|
log.err("Failed to canonicalize invalid mlir: {f}", .{mlir_fn.mlirFormatter(.{})});
|
||||||
// user errors should have triggered a panic before we reach this.
|
// user errors should have triggered a panic before we reach this.
|
||||||
@panic("ZML generated invalid mlir. Please open a bug report");
|
@panic("ZML generated invalid mlir. Please open a bug report");
|
||||||
},
|
},
|
||||||
@ -464,7 +466,7 @@ pub const CompilationContext = struct {
|
|||||||
// This will break the day we writer another attribute before donation.
|
// This will break the day we writer another attribute before donation.
|
||||||
// When the time come, do a more fancy lookup here to check if an argument
|
// When the time come, do a more fancy lookup here to check if an argument
|
||||||
// is donated twice.
|
// is donated twice.
|
||||||
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] });
|
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {d} has been donated twice ! To {d} and to {any}", .{ a, index, attributes[a].buffer[0] });
|
||||||
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
|
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
|
||||||
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
// log.debug("attribute: {}", .{attributes[a].constSlice()});
|
||||||
},
|
},
|
||||||
@ -504,9 +506,9 @@ pub const CompilationContext = struct {
|
|||||||
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
||||||
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
||||||
|
|
||||||
var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator);
|
var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator);
|
||||||
defer mlir_bytecode.deinit();
|
defer mlir_bytecode.deinit();
|
||||||
try mlir_bytecode.writer().print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})});
|
||||||
|
|
||||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||||
@ -545,7 +547,7 @@ pub const CompilationContext = struct {
|
|||||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
||||||
const ctx = self.mlirCtx();
|
const ctx = self.mlirCtx();
|
||||||
const num_partitions = self.numPartitions();
|
const num_partitions = self.numPartitions();
|
||||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
var sharding_str: stdx.BoundedArray(u8, 128) = .{};
|
||||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||||
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
||||||
}
|
}
|
||||||
@ -622,7 +624,7 @@ pub const CompilationContext = struct {
|
|||||||
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
||||||
try self.allocator().dupeZ(u8, func_name)
|
try self.allocator().dupeZ(u8, func_name)
|
||||||
else
|
else
|
||||||
try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash });
|
try std.fmt.allocPrintSentinel(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }, 0);
|
||||||
|
|
||||||
var arg_id: u16 = 0;
|
var arg_id: u16 = 0;
|
||||||
var tensor_args: @TypeOf(args) = args;
|
var tensor_args: @TypeOf(args) = args;
|
||||||
@ -702,7 +704,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id);
|
const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id);
|
||||||
if (res.found_existing) {
|
if (res.found_existing) {
|
||||||
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
|
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {f} and {f} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
|
||||||
} else {
|
} else {
|
||||||
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
|
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
|
||||||
}
|
}
|
||||||
@ -777,7 +779,7 @@ pub const CompilationContext = struct {
|
|||||||
.buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res|
|
.buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res|
|
||||||
.{ res[0], res[1] }
|
.{ res[0], res[1] }
|
||||||
else {
|
else {
|
||||||
log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id });
|
log.err("Found unknown tensor id {f}({})", .{ tensor, tensor._id });
|
||||||
@panic("Found unknown tensor id");
|
@panic("Found unknown tensor id");
|
||||||
},
|
},
|
||||||
.mlir => |v| .{ v, tensor._donation },
|
.mlir => |v| .{ v, tensor._donation },
|
||||||
|
|||||||
26
zml/nn.zig
26
zml/nn.zig
@ -40,8 +40,8 @@ pub const TokenEmbedding = struct {
|
|||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
|
|
||||||
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
|
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
|
||||||
stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx});
|
stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {f}", .{idx});
|
||||||
stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight});
|
stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f}", .{self.weight});
|
||||||
return self.weight.gatherValues(0, idx, .{});
|
return self.weight.gatherValues(0, idx, .{});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -204,13 +204,13 @@ pub const RopeOpts = struct {
|
|||||||
/// - pos_idx: optional tensor which indicates which positions are needed.
|
/// - pos_idx: optional tensor which indicates which positions are needed.
|
||||||
/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len.
|
/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len.
|
||||||
pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
||||||
stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {}", .{x});
|
stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {f}", .{x});
|
||||||
|
|
||||||
const idx = if (pos_idx) |idx| blk: {
|
const idx = if (pos_idx) |idx| blk: {
|
||||||
stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={}, idx={})", .{ x, idx });
|
stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={f}, idx={f})", .{ x, idx });
|
||||||
break :blk idx;
|
break :blk idx;
|
||||||
} else blk: {
|
} else blk: {
|
||||||
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x});
|
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={f})", .{x});
|
||||||
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
||||||
};
|
};
|
||||||
const x_real, const x_imag = splitRealImg(x, opts.layout);
|
const x_real, const x_imag = splitRealImg(x, opts.layout);
|
||||||
@ -273,7 +273,7 @@ fn _invFreq(opts: RopeOpts, inv_freq: []f32) void {
|
|||||||
switch (opts.scaling) {
|
switch (opts.scaling) {
|
||||||
.default => {},
|
.default => {},
|
||||||
.custom => {
|
.custom => {
|
||||||
stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len });
|
stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {d}, got {d}", .{ N, opts.scaling.custom.len });
|
||||||
@memcpy(inv_freq, opts.scaling.custom);
|
@memcpy(inv_freq, opts.scaling.custom);
|
||||||
},
|
},
|
||||||
.llama3 => |s| {
|
.llama3 => |s| {
|
||||||
@ -318,7 +318,7 @@ test invFreq {
|
|||||||
var inv_freq: @TypeOf(llama_freq) = undefined;
|
var inv_freq: @TypeOf(llama_freq) = undefined;
|
||||||
_invFreq(llama_conf, &inv_freq);
|
_invFreq(llama_conf, &inv_freq);
|
||||||
for (llama_freq, inv_freq, 0..) |expected, actual, i| {
|
for (llama_freq, inv_freq, 0..) |expected, actual, i| {
|
||||||
errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq });
|
errdefer log.err("Mismatch at position {d}.\nExpected: {any}\nActual: {any}", .{ i, llama_freq, inv_freq });
|
||||||
try std.testing.expectApproxEqRel(expected, actual, 1e-5);
|
try std.testing.expectApproxEqRel(expected, actual, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -462,7 +462,7 @@ pub fn upsample(
|
|||||||
) Tensor {
|
) Tensor {
|
||||||
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
|
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
|
||||||
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
|
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
|
||||||
stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input});
|
stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {f}", .{input});
|
||||||
stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
|
stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
|
||||||
return switch (opts.mode) {
|
return switch (opts.mode) {
|
||||||
.nearest => {
|
.nearest => {
|
||||||
@ -791,7 +791,7 @@ pub fn causalAttnMask(
|
|||||||
attn_window_len: ?u32,
|
attn_window_len: ?u32,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
const attn_shape = Shape.init(attn_shape_, dtype);
|
const attn_shape = Shape.init(attn_shape_, dtype);
|
||||||
stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape});
|
stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({f}) shape need to be exactly 2 axes", .{attn_shape});
|
||||||
const qlen = attn_shape.dim(-2);
|
const qlen = attn_shape.dim(-2);
|
||||||
const q_idx = Tensor.iota(attn_shape, -2);
|
const q_idx = Tensor.iota(attn_shape, -2);
|
||||||
const klen = attn_shape.dim(-1);
|
const klen = attn_shape.dim(-1);
|
||||||
@ -843,7 +843,7 @@ pub const SdpaOpts = struct {
|
|||||||
pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||||
var q, var k, var v = .{ q_, k_, v_ };
|
var q, var k, var v = .{ q_, k_, v_ };
|
||||||
|
|
||||||
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
|
const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! ";
|
||||||
const err_args = .{ q, k, v, opts.attn_mask };
|
const err_args = .{ q, k, v, opts.attn_mask };
|
||||||
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
||||||
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
@ -909,8 +909,8 @@ const SdpaMemEfficient = struct {
|
|||||||
chunking: SdpaChunks,
|
chunking: SdpaChunks,
|
||||||
|
|
||||||
fn forward(self: SdpaMemEfficient) Tensor {
|
fn forward(self: SdpaMemEfficient) Tensor {
|
||||||
stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.q, self.chunking });
|
stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.q, self.chunking });
|
||||||
stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.k, self.chunking });
|
stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.k, self.chunking });
|
||||||
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
|
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
|
||||||
|
|
||||||
const ctx = zml.module.CompilationContext.current();
|
const ctx = zml.module.CompilationContext.current();
|
||||||
@ -1054,7 +1054,7 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft
|
|||||||
// Consider implementing sdpa from sdpaChunk.
|
// Consider implementing sdpa from sdpaChunk.
|
||||||
var q, var k, var v = .{ q_, k_, v_ };
|
var q, var k, var v = .{ q_, k_, v_ };
|
||||||
|
|
||||||
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
|
const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! ";
|
||||||
const err_args = .{ q, k, v, opts.attn_mask };
|
const err_args = .{ q, k, v, opts.attn_mask };
|
||||||
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
|
||||||
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
|
|||||||
@ -51,7 +51,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
||||||
const allocator = fba.allocator();
|
const allocator = fba.allocator();
|
||||||
|
|
||||||
const backend_config = std.fmt.allocPrintZ(
|
const backend_config = std.fmt.allocPrintSentinel(
|
||||||
allocator,
|
allocator,
|
||||||
\\{{
|
\\{{
|
||||||
\\ "operation_queue_id":"0",
|
\\ "operation_queue_id":"0",
|
||||||
@ -110,6 +110,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
q.dim(.q),
|
q.dim(.q),
|
||||||
k.dim(.k),
|
k.dim(.k),
|
||||||
},
|
},
|
||||||
|
0,
|
||||||
) catch unreachable;
|
) catch unreachable;
|
||||||
|
|
||||||
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));
|
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));
|
||||||
|
|||||||
50
zml/ops.zig
50
zml/ops.zig
@ -155,7 +155,7 @@ pub fn reduce(
|
|||||||
// To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
|
// To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
|
||||||
// to find the correct mlir.Value, but we first broadcast before creating the final
|
// to find the correct mlir.Value, but we first broadcast before creating the final
|
||||||
// Tensor struct.
|
// Tensor struct.
|
||||||
var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
var broadcasting_axes: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
||||||
for (0..Tensor.MAX_RANK) |i| {
|
for (0..Tensor.MAX_RANK) |i| {
|
||||||
if (std.mem.indexOfScalar(i64, axes, @intCast(i)) == null) {
|
if (std.mem.indexOfScalar(i64, axes, @intCast(i)) == null) {
|
||||||
broadcasting_axes.append(@intCast(i)) catch unreachable;
|
broadcasting_axes.append(@intCast(i)) catch unreachable;
|
||||||
@ -437,15 +437,15 @@ pub fn if_(
|
|||||||
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
||||||
}
|
}
|
||||||
|
|
||||||
stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred});
|
stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {f}", .{pred});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||||
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||||
|
|
||||||
var true_shapes = std.ArrayList(Shape).init(ctx.allocator());
|
var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
|
||||||
defer true_shapes.deinit();
|
defer true_shapes.deinit();
|
||||||
var false_shapes = std.ArrayList(Shape).init(ctx.allocator());
|
var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
|
||||||
defer false_shapes.deinit();
|
defer false_shapes.deinit();
|
||||||
|
|
||||||
var failed_to_collect = false;
|
var failed_to_collect = false;
|
||||||
@ -456,9 +456,9 @@ pub fn if_(
|
|||||||
failed_to_collect = true;
|
failed_to_collect = true;
|
||||||
};
|
};
|
||||||
if (!failed_to_collect) {
|
if (!failed_to_collect) {
|
||||||
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
|
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
|
||||||
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
|
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
|
||||||
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
|
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -751,7 +751,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
|||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
|
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
|
||||||
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
|
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
|
||||||
stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new });
|
stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {f}", .{ tensor.rank(), new });
|
||||||
// copy tags and sharding info over
|
// copy tags and sharding info over
|
||||||
// some ops can change dims eg reduceWindow, so we trust mlir here.
|
// some ops can change dims eg reduceWindow, so we trust mlir here.
|
||||||
new._shape._tags = tensor._shape._tags;
|
new._shape._tags = tensor._shape._tags;
|
||||||
@ -932,7 +932,7 @@ pub fn scatter(
|
|||||||
|
|
||||||
const n_inputs = meta.count(Tensor, &inputs);
|
const n_inputs = meta.count(Tensor, &inputs);
|
||||||
const n_updates = meta.count(Tensor, &updates);
|
const n_updates = meta.count(Tensor, &updates);
|
||||||
stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {} and {}", .{ n_inputs, n_updates });
|
stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {d} and {d}", .{ n_inputs, n_updates });
|
||||||
|
|
||||||
// Note: I was a bit lazy here, and I only look at tags on the first tensor.
|
// Note: I was a bit lazy here, and I only look at tags on the first tensor.
|
||||||
// we probably should check all of them.
|
// we probably should check all of them.
|
||||||
@ -944,7 +944,7 @@ pub fn scatter(
|
|||||||
|
|
||||||
// validate coord axes: all coord_axes should exist inside self
|
// validate coord axes: all coord_axes should exist inside self
|
||||||
for (indices_axes.constSlice()) |t| {
|
for (indices_axes.constSlice()) |t| {
|
||||||
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() });
|
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={f} and indices={any}", .{ self, indices_axes.constSlice() });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle scalar indices by broadcasting them to the indices with the highest rank.
|
// Handle scalar indices by broadcasting them to the indices with the highest rank.
|
||||||
@ -958,8 +958,8 @@ pub fn scatter(
|
|||||||
break :blk higher_rank;
|
break :blk higher_rank;
|
||||||
};
|
};
|
||||||
for (indices_per_axis.slice()) |*idx| {
|
for (indices_per_axis.slice()) |*idx| {
|
||||||
stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {_}", .{indices_per_axis.slice()});
|
stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {any}", .{indices_per_axis.slice()});
|
||||||
stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {_}", .{indices_per_axis.slice()});
|
stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {any}", .{indices_per_axis.slice()});
|
||||||
idx.* = idx.broad(indices_shape);
|
idx.* = idx.broad(indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -972,7 +972,7 @@ pub fn scatter(
|
|||||||
var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes);
|
var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes);
|
||||||
const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes);
|
const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes);
|
||||||
// const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
// const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
|
||||||
// stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({_}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={_}", .{ self, index_tensors, indices_axes.constSlice(), update });
|
// stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({f}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={f}", .{ self, index_tensors, indices_axes.constSlice(), update });
|
||||||
|
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
var _scalar: T = inputs;
|
var _scalar: T = inputs;
|
||||||
@ -985,10 +985,10 @@ pub fn scatter(
|
|||||||
const UpdateS = BlockSign(update_fn);
|
const UpdateS = BlockSign(update_fn);
|
||||||
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
|
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
|
||||||
|
|
||||||
var input_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM");
|
var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM");
|
||||||
defer input_values.deinit();
|
defer input_values.deinit();
|
||||||
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable;
|
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable;
|
||||||
var updates_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM");
|
var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM");
|
||||||
defer updates_values.deinit();
|
defer updates_values.deinit();
|
||||||
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable;
|
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable;
|
||||||
|
|
||||||
@ -1029,8 +1029,8 @@ pub fn scatter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ScatterConfig = struct {
|
const ScatterConfig = struct {
|
||||||
op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
|
op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
|
||||||
up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
|
up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
|
||||||
indices_batch_axes: Shape.DimsArray = .{},
|
indices_batch_axes: Shape.DimsArray = .{},
|
||||||
scatter_to_operand_axes: Shape.DimsArray = .{},
|
scatter_to_operand_axes: Shape.DimsArray = .{},
|
||||||
updates_transpose: Shape.AxesArray = .{},
|
updates_transpose: Shape.AxesArray = .{},
|
||||||
@ -1041,11 +1041,11 @@ const AxisKind = enum { batching, update_window, inserted_window, window_id };
|
|||||||
fn scatterConfig(
|
fn scatterConfig(
|
||||||
op: Shape,
|
op: Shape,
|
||||||
update: Shape,
|
update: Shape,
|
||||||
indices_per_axis: std.BoundedArray(Tensor, Tensor.MAX_RANK),
|
indices_per_axis: stdx.BoundedArray(Tensor, Tensor.MAX_RANK),
|
||||||
indices_axes: Shape.TagsArray,
|
indices_axes: Shape.TagsArray,
|
||||||
) ScatterConfig {
|
) ScatterConfig {
|
||||||
var op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
|
var op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
|
||||||
var up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
|
var up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
|
||||||
var indices_batch_axes: Shape.DimsArray = .{};
|
var indices_batch_axes: Shape.DimsArray = .{};
|
||||||
var scatter_to_operand_axes: Shape.DimsArray = .{};
|
var scatter_to_operand_axes: Shape.DimsArray = .{};
|
||||||
var updates_transpose: Shape.AxesArray = .{};
|
var updates_transpose: Shape.AxesArray = .{};
|
||||||
@ -1058,7 +1058,7 @@ fn scatterConfig(
|
|||||||
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
|
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
|
||||||
}
|
}
|
||||||
for (indices.tags()) |t| {
|
for (indices.tags()) |t| {
|
||||||
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices });
|
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={f}, updates={f} and indices={f}", .{ op, update, indices });
|
||||||
updates_transpose.appendAssumeCapacity(update.axis(t));
|
updates_transpose.appendAssumeCapacity(update.axis(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1094,11 +1094,11 @@ fn scatterConfig(
|
|||||||
if (indices.hasTag(t) != null) {
|
if (indices.hasTag(t) != null) {
|
||||||
up_kind.appendAssumeCapacity(.window_id);
|
up_kind.appendAssumeCapacity(.window_id);
|
||||||
} else if (op.hasTag(t)) |self_ax| {
|
} else if (op.hasTag(t)) |self_ax| {
|
||||||
stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={_}, updates={_}.", .{ t, op, update });
|
stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={f}, updates={f}.", .{ t, op, update });
|
||||||
up_kind.appendAssumeCapacity(.update_window);
|
up_kind.appendAssumeCapacity(.update_window);
|
||||||
} else {
|
} else {
|
||||||
// TODO: consider accepting untagged update here.
|
// TODO: consider accepting untagged update here.
|
||||||
std.debug.panic("scatter expects 'updates' to be made of axes from op={_} and from indices={s}, got unknown tag {s} in {_}", .{ op, indices_axes.constSlice(), t, update });
|
std.debug.panic("scatter expects 'updates' to be made of axes from op={f} and from indices={any}, got unknown tag {s} in {f}", .{ op, indices_axes.constSlice(), std.mem.sliceTo(t, 0), update });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -1174,7 +1174,7 @@ fn scatterPrepareIndices(
|
|||||||
cfg: *ScatterConfig,
|
cfg: *ScatterConfig,
|
||||||
op: Shape,
|
op: Shape,
|
||||||
update: Shape,
|
update: Shape,
|
||||||
indices_per_axis: *std.BoundedArray(Tensor, Tensor.MAX_RANK),
|
indices_per_axis: *stdx.BoundedArray(Tensor, Tensor.MAX_RANK),
|
||||||
indices_axes: *Shape.TagsArray,
|
indices_axes: *Shape.TagsArray,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
var old_scatter_to_op_axes = cfg.scatter_to_operand_axes;
|
var old_scatter_to_op_axes = cfg.scatter_to_operand_axes;
|
||||||
@ -1194,7 +1194,7 @@ fn scatterPrepareIndices(
|
|||||||
|
|
||||||
// Reorder the axes so that in indices_per_axis is ordered like in op if possible.
|
// Reorder the axes so that in indices_per_axis is ordered like in op if possible.
|
||||||
// TODO: transpose updates if needed
|
// TODO: transpose updates if needed
|
||||||
var indices: std.BoundedArray(Tensor, Tensor.MAX_RANK) = .{};
|
var indices: stdx.BoundedArray(Tensor, Tensor.MAX_RANK) = .{};
|
||||||
var scatter_to_op_axes: Shape.DimsArray = .{};
|
var scatter_to_op_axes: Shape.DimsArray = .{};
|
||||||
|
|
||||||
while (old_scatter_to_op_axes.len > 0) {
|
while (old_scatter_to_op_axes.len > 0) {
|
||||||
@ -1209,7 +1209,7 @@ fn scatterPrepareIndices(
|
|||||||
|
|
||||||
for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| {
|
for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| {
|
||||||
if (i != sc_ax) {
|
if (i != sc_ax) {
|
||||||
log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({_}, {any}, {_}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update });
|
log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({f}, {any}, {f}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update });
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,14 +75,14 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
||||||
var bytecode = std.ArrayList(u8).init(allocator);
|
var bytecode: std.array_list.Managed(u8) = .init(allocator);
|
||||||
defer bytecode.deinit();
|
defer bytecode.deinit();
|
||||||
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| {
|
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| {
|
||||||
log.err("failed to write module bytecode: {}", .{err});
|
log.err("failed to write module bytecode: {}", .{err});
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
|
|
||||||
var serialized_buffer = std.ArrayList(u8).init(allocator);
|
var serialized_buffer: std.array_list.Managed(u8) = .init(allocator);
|
||||||
defer serialized_buffer.deinit();
|
defer serialized_buffer.deinit();
|
||||||
|
|
||||||
const stablehlo_version = blk: {
|
const stablehlo_version = blk: {
|
||||||
@ -220,7 +220,7 @@ pub const Event = opaque {
|
|||||||
}{};
|
}{};
|
||||||
|
|
||||||
try self.inner().onReady(api, &(struct {
|
try self.inner().onReady(api, &(struct {
|
||||||
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void {
|
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.c) void {
|
||||||
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
|
||||||
ctx_.err = err;
|
ctx_.err = err;
|
||||||
ctx_.event.set();
|
ctx_.event.set();
|
||||||
|
|||||||
@ -15,7 +15,7 @@ pub const CompilationOptions = struct {
|
|||||||
xla_dump_fusion_visualization: bool = false,
|
xla_dump_fusion_visualization: bool = false,
|
||||||
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
||||||
sharding_enabled: bool = false,
|
sharding_enabled: bool = false,
|
||||||
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Platform = struct {
|
pub const Platform = struct {
|
||||||
|
|||||||
@ -22,9 +22,9 @@ pub const Shape = struct {
|
|||||||
pub const TagUnknown = "_".ptr;
|
pub const TagUnknown = "_".ptr;
|
||||||
const TagLast = "last".ptr;
|
const TagLast = "last".ptr;
|
||||||
|
|
||||||
pub const DimsArray = std.BoundedArray(i64, MAX_RANK);
|
pub const DimsArray = stdx.BoundedArray(i64, MAX_RANK);
|
||||||
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK);
|
pub const TagsArray = stdx.BoundedArray(Tag, MAX_RANK);
|
||||||
pub const AxesArray = std.BoundedArray(u3, MAX_RANK);
|
pub const AxesArray = stdx.BoundedArray(u3, MAX_RANK);
|
||||||
pub const ShardingInfo = @Vector(MAX_RANK, bool);
|
pub const ShardingInfo = @Vector(MAX_RANK, bool);
|
||||||
|
|
||||||
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
|
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
|
||||||
@ -300,7 +300,7 @@ pub const Shape = struct {
|
|||||||
fn axisFromInt(self: Shape, a: isize) u3 {
|
fn axisFromInt(self: Shape, a: isize) u3 {
|
||||||
const rk: i8 = self.rank();
|
const rk: i8 = self.rank();
|
||||||
if (a < -rk or a > rk) {
|
if (a < -rk or a > rk) {
|
||||||
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a });
|
stdx.debug.panic("Tensor {f} doesn't have dimension: {d}", .{ self, a });
|
||||||
}
|
}
|
||||||
return if (a < 0)
|
return if (a < 0)
|
||||||
@intCast(a + rk)
|
@intCast(a + rk)
|
||||||
@ -341,9 +341,9 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn axisFromTag(self: Shape, d: Tag) u3 {
|
fn axisFromTag(self: Shape, d: Tag) u3 {
|
||||||
stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self });
|
stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {f}", .{ d, self });
|
||||||
return self.axisFromTagMaybe(d) orelse {
|
return self.axisFromTagMaybe(d) orelse {
|
||||||
stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d });
|
stdx.debug.panic("Tensor {f} doesn't have dimension with tag: {s}", .{ self, d });
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,7 +357,7 @@ pub const Shape = struct {
|
|||||||
pub fn count(self: Shape) usize {
|
pub fn count(self: Shape) usize {
|
||||||
var res: i64 = 1;
|
var res: i64 = 1;
|
||||||
for (self.dims()) |d| {
|
for (self.dims()) |d| {
|
||||||
stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self});
|
stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {f}", .{self});
|
||||||
res *= d;
|
res *= d;
|
||||||
}
|
}
|
||||||
return @intCast(res);
|
return @intCast(res);
|
||||||
@ -388,12 +388,11 @@ pub const Shape = struct {
|
|||||||
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
|
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
|
||||||
pub fn format(
|
pub fn format(
|
||||||
self: Shape,
|
self: Shape,
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = options;
|
// TODO: impl alternative format
|
||||||
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||||
|
const bare_fmt = true;
|
||||||
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
||||||
|
|
||||||
var need_comma = false;
|
var need_comma = false;
|
||||||
@ -441,12 +440,12 @@ pub const Shape = struct {
|
|||||||
var new_shape: Shape = .{ ._dtype = self.dtype() };
|
var new_shape: Shape = .{ ._dtype = self.dtype() };
|
||||||
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
|
new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
|
||||||
new_shape.inferMissingAxis(self.count());
|
new_shape.inferMissingAxis(self.count());
|
||||||
stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() });
|
stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {any} to {any}", .{ self.dims(), new_shape.dims() });
|
||||||
return new_shape;
|
return new_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inferMissingAxis(self: *Shape, n_: usize) void {
|
fn inferMissingAxis(self: *Shape, n_: usize) void {
|
||||||
stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*});
|
stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {f}", .{self.*});
|
||||||
|
|
||||||
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
|
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
|
||||||
// We can't use `self.count()` yet cause we have negative dims.
|
// We can't use `self.count()` yet cause we have negative dims.
|
||||||
@ -524,7 +523,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
|
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
|
||||||
stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self});
|
stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {f}, it's already at max rank.", .{self});
|
||||||
|
|
||||||
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
|
||||||
self.rank()
|
self.rank()
|
||||||
@ -652,7 +651,7 @@ pub const Shape = struct {
|
|||||||
var res = self;
|
var res = self;
|
||||||
|
|
||||||
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
||||||
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz });
|
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {any}", .{ self, tagz });
|
||||||
for (tagz, 0..) |tag_, i| {
|
for (tagz, 0..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
@ -660,7 +659,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
||||||
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz });
|
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {}", .{ self, tagz });
|
||||||
inline for (tagz, 0..) |tag_, i| {
|
inline for (tagz, 0..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
@ -699,7 +698,7 @@ pub const Shape = struct {
|
|||||||
var res = self;
|
var res = self;
|
||||||
|
|
||||||
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
|
||||||
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz });
|
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {any}", .{ self, tagz });
|
||||||
for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
@ -707,7 +706,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
|
||||||
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz });
|
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {}", .{ self, tagz });
|
||||||
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
|
||||||
res._tags.set(i, toTag(tag_));
|
res._tags.set(i, toTag(tag_));
|
||||||
}
|
}
|
||||||
@ -765,7 +764,7 @@ pub const Shape = struct {
|
|||||||
var res = self;
|
var res = self;
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
const new_field = @field(renames, field.name);
|
const new_field = @field(renames, field.name);
|
||||||
stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
|
stdx.debug.assert(self.hasTag(new_field) == null, "{f}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
|
||||||
res._tags.set(self.axis(field), toTag(new_field));
|
res._tags.set(self.axis(field), toTag(new_field));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
@ -785,9 +784,9 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
|
pub fn computeStrides(self: Shape) stdx.BoundedArray(i64, MAX_RANK) {
|
||||||
const rk = self.rank();
|
const rk = self.rank();
|
||||||
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
|
var strides: stdx.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
|
||||||
if (rk == 0) return strides;
|
if (rk == 0) return strides;
|
||||||
|
|
||||||
const V = @Vector(MAX_RANK, i64);
|
const V = @Vector(MAX_RANK, i64);
|
||||||
@ -907,7 +906,7 @@ pub const Shape = struct {
|
|||||||
var new_dim: i64 = 1;
|
var new_dim: i64 = 1;
|
||||||
for (axes__.constSlice(), first_axis..) |ax, counter| {
|
for (axes__.constSlice(), first_axis..) |ax, counter| {
|
||||||
new_dim *= self.dim(ax);
|
new_dim *= self.dim(ax);
|
||||||
stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ });
|
stdx.debug.assert(ax == counter, "Can't merge shape {f} along non-contiguous axes {any}", .{ self, axes_ });
|
||||||
}
|
}
|
||||||
|
|
||||||
var new_shape = self;
|
var new_shape = self;
|
||||||
@ -991,10 +990,10 @@ pub const Shape = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parseStruct(T: type, v: anytype) struct { std.BoundedArray(T, MAX_RANK), TagsArray } {
|
pub fn parseStruct(T: type, v: anytype) struct { stdx.BoundedArray(T, MAX_RANK), TagsArray } {
|
||||||
const V = @TypeOf(v);
|
const V = @TypeOf(v);
|
||||||
|
|
||||||
var vals_: std.BoundedArray(T, MAX_RANK) = .{};
|
var vals_: stdx.BoundedArray(T, MAX_RANK) = .{};
|
||||||
var tags_: TagsArray = .{};
|
var tags_: TagsArray = .{};
|
||||||
|
|
||||||
if (comptime stdx.meta.isSliceOf(V, T)) {
|
if (comptime stdx.meta.isSliceOf(V, T)) {
|
||||||
@ -1029,10 +1028,10 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Parses a struct literal into a list of options for each axes.
|
/// Parses a struct literal into a list of options for each axes.
|
||||||
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) {
|
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) stdx.BoundedArray(T, MAX_RANK) {
|
||||||
const V = @TypeOf(options);
|
const V = @TypeOf(options);
|
||||||
|
|
||||||
var res: std.BoundedArray(T, MAX_RANK) = .{};
|
var res: stdx.BoundedArray(T, MAX_RANK) = .{};
|
||||||
if (comptime stdx.meta.isSliceOf(V, T)) {
|
if (comptime stdx.meta.isSliceOf(V, T)) {
|
||||||
stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
|
stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
|
||||||
for (options) |d| {
|
for (options) |d| {
|
||||||
@ -1084,7 +1083,7 @@ pub const Shape = struct {
|
|||||||
for (0..other.rank()) |ax| {
|
for (0..other.rank()) |ax| {
|
||||||
if (other.tag(ax) != Shape.TagUnknown) {
|
if (other.tag(ax) != Shape.TagUnknown) {
|
||||||
if (self.hasTag(other.tag(ax))) |batching_ax| {
|
if (self.hasTag(other.tag(ax))) |batching_ax| {
|
||||||
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({f}, {f})", .{ self, other });
|
||||||
batching_axes += 1;
|
batching_axes += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
151
zml/tensor.zig
151
zml/tensor.zig
@ -53,13 +53,12 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
pub fn format(
|
pub fn format(
|
||||||
self: Tensor,
|
self: Tensor,
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = options;
|
// TODO(0.15.0) handle format
|
||||||
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||||
try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape});
|
const bare_fmt = false;
|
||||||
|
try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the shape of a Tensor.
|
/// Returns the shape of a Tensor.
|
||||||
@ -99,7 +98,7 @@ pub const Tensor = struct {
|
|||||||
if (builtin.mode == .Debug) {
|
if (builtin.mode == .Debug) {
|
||||||
// Check that the MLIR value actually have the same shape.
|
// Check that the MLIR value actually have the same shape.
|
||||||
const other = fromMlirValue(val);
|
const other = fromMlirValue(val);
|
||||||
stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape });
|
stdx.debug.internalAssert(sh.eql(other._shape), "Created a {f} from Mlir value but expected {f}", .{ other._shape, res._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
@ -145,7 +144,7 @@ pub const Tensor = struct {
|
|||||||
/// Returns the indices of each of the given axes.
|
/// Returns the indices of each of the given axes.
|
||||||
///
|
///
|
||||||
/// 'axis_' can be an integer or a tag.
|
/// 'axis_' can be an integer or a tag.
|
||||||
pub fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) {
|
pub fn axes(self: Tensor, axes_: anytype) stdx.BoundedArray(u3, Tensor.MAX_RANK) {
|
||||||
return self._shape.axes(axes_);
|
return self._shape.axes(axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,7 +259,7 @@ pub const Tensor = struct {
|
|||||||
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
|
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
|
||||||
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
|
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
|
||||||
// Note: check donation docs, this may be too permissive.
|
// Note: check donation docs, this may be too permissive.
|
||||||
stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin });
|
stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {f} and {f}", .{ self, origin });
|
||||||
|
|
||||||
// TODO: should we store all donations inside the context ?
|
// TODO: should we store all donations inside the context ?
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -410,7 +409,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||||||
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts });
|
const loc = self.getContext().location(@src(), "triangularSolve({f}, {})", .{ self, opts });
|
||||||
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -435,7 +434,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
|
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
|
||||||
pub fn complex(re: Tensor, im: Tensor) Tensor {
|
pub fn complex(re: Tensor, im: Tensor) Tensor {
|
||||||
stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape });
|
stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {f} and {f}", .{ re._shape, im._shape });
|
||||||
stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
|
stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
|
||||||
|
|
||||||
const loc = re.getContext().mlirCtx().location(@src());
|
const loc = re.getContext().mlirCtx().location(@src());
|
||||||
@ -523,7 +522,7 @@ pub const Tensor = struct {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts });
|
const loc = self.getContext().location(@src(), "fft({f},{})", .{ self, opts });
|
||||||
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
||||||
return _result(sh, op.result(0));
|
return _result(sh, op.result(0));
|
||||||
}
|
}
|
||||||
@ -551,7 +550,7 @@ pub const Tensor = struct {
|
|||||||
/// but it is not guaranteed to be deterministic between implementations.
|
/// but it is not guaranteed to be deterministic between implementations.
|
||||||
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh});
|
const loc = ctx.location(@src(), "rand.bitGen({f})", .{sh});
|
||||||
const op = dialect.stablehlo.rng_bit_generator(
|
const op = dialect.stablehlo.rng_bit_generator(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
self.algorithm,
|
self.algorithm,
|
||||||
@ -589,7 +588,7 @@ pub const Tensor = struct {
|
|||||||
16 => .u16,
|
16 => .u16,
|
||||||
32 => .u32,
|
32 => .u32,
|
||||||
64 => .u64,
|
64 => .u64,
|
||||||
else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}),
|
else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {f}", .{shape_}),
|
||||||
};
|
};
|
||||||
|
|
||||||
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
|
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
|
||||||
@ -674,7 +673,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev });
|
const loc = ctx.location(@src(), "rand.normal({f}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev });
|
||||||
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
||||||
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
||||||
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
||||||
@ -1046,7 +1045,7 @@ pub const Tensor = struct {
|
|||||||
if (to == self.dtype()) {
|
if (to == self.dtype()) {
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
const loc = self.getContext().location(@src(), "convert({f},to={s})", .{ self, @tagName(to) });
|
||||||
|
|
||||||
const mlir_ctx = self.getContext().mlirCtx();
|
const mlir_ctx = self.getContext().mlirCtx();
|
||||||
const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
|
const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
|
||||||
@ -1160,7 +1159,7 @@ pub const Tensor = struct {
|
|||||||
) Tensor {
|
) Tensor {
|
||||||
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
||||||
|
|
||||||
const Axes = std.BoundedArray(i64, MAX_RANK);
|
const Axes = stdx.BoundedArray(i64, MAX_RANK);
|
||||||
|
|
||||||
var res_shape: Shape = .{ ._dtype = lhs.dtype() };
|
var res_shape: Shape = .{ ._dtype = lhs.dtype() };
|
||||||
// Validate batching axes
|
// Validate batching axes
|
||||||
@ -1168,7 +1167,7 @@ pub const Tensor = struct {
|
|||||||
var rhs_batching_axes: Axes = .{};
|
var rhs_batching_axes: Axes = .{};
|
||||||
for (batching_axes) |b_axes| {
|
for (batching_axes) |b_axes| {
|
||||||
const l, const r = b_axes;
|
const l, const r = b_axes;
|
||||||
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs });
|
||||||
var t = lhs._shape.tag(l);
|
var t = lhs._shape.tag(l);
|
||||||
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
||||||
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
||||||
@ -1181,7 +1180,7 @@ pub const Tensor = struct {
|
|||||||
var rhs_contracting_axes: Axes = .{};
|
var rhs_contracting_axes: Axes = .{};
|
||||||
for (contracting_axes) |c_axes| {
|
for (contracting_axes) |c_axes| {
|
||||||
const l, const r = c_axes;
|
const l, const r = c_axes;
|
||||||
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs });
|
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs });
|
||||||
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
|
||||||
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
|
||||||
}
|
}
|
||||||
@ -1209,7 +1208,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const mlir_ctx = lhs.getContext().mlirCtx();
|
const mlir_ctx = lhs.getContext().mlirCtx();
|
||||||
const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes });
|
const loc = lhs.getContext().location(@src(), "dot({f},{f},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes });
|
||||||
const op = dialect.stablehlo.dot_general(
|
const op = dialect.stablehlo.dot_general(
|
||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
lhs.value(),
|
lhs.value(),
|
||||||
@ -1406,7 +1405,7 @@ pub const Tensor = struct {
|
|||||||
else
|
else
|
||||||
toI64(axes__);
|
toI64(axes__);
|
||||||
|
|
||||||
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] });
|
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {f} and {any}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] });
|
||||||
|
|
||||||
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
||||||
return self;
|
return self;
|
||||||
@ -1417,7 +1416,7 @@ pub const Tensor = struct {
|
|||||||
return self.reshape(res_shape);
|
return self.reshape(res_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation });
|
const loc = self.getContext().location(@src(), "transpose({f}, {any})", .{ self, permutation });
|
||||||
const op = dialect.stablehlo.transpose(
|
const op = dialect.stablehlo.transpose(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1505,7 +1504,7 @@ pub const Tensor = struct {
|
|||||||
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||||||
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ });
|
const loc = self.getContext().location(@src(), "flatten({f},{})", .{ self, axis_ });
|
||||||
const reshaped_val = dialect.stablehlo.reshape(
|
const reshaped_val = dialect.stablehlo.reshape(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1684,7 +1683,7 @@ pub const Tensor = struct {
|
|||||||
const res_shape = shape0.insertTag(axis_, 1, tag);
|
const res_shape = shape0.insertTag(axis_, 1, tag);
|
||||||
|
|
||||||
for (tensors[1..]) |tensor| {
|
for (tensors[1..]) |tensor| {
|
||||||
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape });
|
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {f} and {f}", .{ shape0, tensor._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
var reshaped: [32]Tensor = undefined;
|
var reshaped: [32]Tensor = undefined;
|
||||||
@ -1859,7 +1858,7 @@ pub const Tensor = struct {
|
|||||||
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||||
const res_shape = sh.withDtype(dt);
|
const res_shape = sh.withDtype(dt);
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
|
const loc = ctx.location(@src(), "iota({f}, {})", .{ res_shape, a });
|
||||||
|
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
var op = dialect.stablehlo.iota(
|
var op = dialect.stablehlo.iota(
|
||||||
@ -1931,7 +1930,7 @@ pub const Tensor = struct {
|
|||||||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||||||
const sh = Shape.init(dimz, val.dtype());
|
const sh = Shape.init(dimz, val.dtype());
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
|
const loc = CompilationContext.current().location(@src(), "dims={f}, value={}", .{ sh, val });
|
||||||
|
|
||||||
var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
|
var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
|
||||||
dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
|
dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
|
||||||
@ -1951,7 +1950,7 @@ pub const Tensor = struct {
|
|||||||
pub fn constantTensor(val: HostBuffer) Tensor {
|
pub fn constantTensor(val: HostBuffer) Tensor {
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = ctx.location(@src());
|
const loc = ctx.location(@src());
|
||||||
const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {f}", .{val.shape()});
|
||||||
const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
|
const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
|
||||||
return _result(val.shape(), constant_op.result(0));
|
return _result(val.shape(), constant_op.result(0));
|
||||||
}
|
}
|
||||||
@ -1975,10 +1974,10 @@ pub const Tensor = struct {
|
|||||||
/// you will lose the tags.
|
/// you will lose the tags.
|
||||||
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
|
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
|
||||||
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
|
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
|
||||||
stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ });
|
stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({f}, {f}, {any})", .{ self, output_shape, axes_ });
|
||||||
for (0.., axes_) |self_ax, other_ax| {
|
for (0.., axes_) |self_ax, other_ax| {
|
||||||
const d = self.dim(self_ax);
|
const d = self.dim(self_ax);
|
||||||
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
|
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({f}, {f}, {any}), error on self axis {d} mapping to other axis {d}", .{ self, output_shape, axes_, self_ax, other_ax });
|
||||||
}
|
}
|
||||||
|
|
||||||
const res_shape = output_shape.withDtype(self.dtype());
|
const res_shape = output_shape.withDtype(self.dtype());
|
||||||
@ -1989,7 +1988,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
|
const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
|
||||||
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
const loc = ctx.location(@src(), "broadcast({f}, {f}, axes={any})", .{ self, res_shape, axes_ });
|
||||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
||||||
|
|
||||||
return _result(res_shape, broadcast_op.result(0));
|
return _result(res_shape, broadcast_op.result(0));
|
||||||
@ -1997,7 +1996,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
|
/// Broadcasts a Tensor to the given shape, adding axes at the beginning.
|
||||||
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
|
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
|
||||||
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() });
|
||||||
|
|
||||||
const a = output_shape.rank() - self.rank();
|
const a = output_shape.rank() - self.rank();
|
||||||
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
|
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
|
||||||
@ -2009,7 +2008,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Broadcasts a Tensor to the given shape, adding axes at the end.
|
/// Broadcasts a Tensor to the given shape, adding axes at the end.
|
||||||
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
|
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
|
||||||
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() });
|
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() });
|
||||||
|
|
||||||
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
|
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
|
||||||
return self;
|
return self;
|
||||||
@ -2022,7 +2021,7 @@ pub const Tensor = struct {
|
|||||||
pub fn broad(self: Tensor, other: Shape) Tensor {
|
pub fn broad(self: Tensor, other: Shape) Tensor {
|
||||||
// TODO: broad is too restrictive because sometime you only want to specify one specific axis
|
// TODO: broad is too restrictive because sometime you only want to specify one specific axis
|
||||||
// Note: if you code below, make sure to update Shape.canBroadcastTo.
|
// Note: if you code below, make sure to update Shape.canBroadcastTo.
|
||||||
stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {} to {}", .{ self, other });
|
stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {f} to {f}", .{ self, other });
|
||||||
|
|
||||||
// Already the right shape
|
// Already the right shape
|
||||||
if (std.mem.eql(i64, self.dims(), other.dims())) return self;
|
if (std.mem.eql(i64, self.dims(), other.dims())) return self;
|
||||||
@ -2036,7 +2035,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check that each axis of self maps to an axis of other
|
// check that each axis of self maps to an axis of other
|
||||||
var axes_: std.BoundedArray(i64, MAX_RANK) = .{};
|
var axes_: stdx.BoundedArray(i64, MAX_RANK) = .{};
|
||||||
for (self._shape.tags()) |t| {
|
for (self._shape.tags()) |t| {
|
||||||
axes_.appendAssumeCapacity(@intCast(other.axis(t)));
|
axes_.appendAssumeCapacity(@intCast(other.axis(t)));
|
||||||
}
|
}
|
||||||
@ -2047,14 +2046,14 @@ pub const Tensor = struct {
|
|||||||
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
||||||
const output_shape = self._shape.reshape(output_shape_);
|
const output_shape = self._shape.reshape(output_shape_);
|
||||||
const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
|
const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
|
||||||
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
|
const loc = self.getContext().location(@src(), "reshape({f})", .{output_shape});
|
||||||
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
||||||
return _result(output_shape, reshape_value.result(0));
|
return _result(output_shape, reshape_value.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
||||||
pub fn asScalar(self: Tensor) Tensor {
|
pub fn asScalar(self: Tensor) Tensor {
|
||||||
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self});
|
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self});
|
||||||
return self.reshape(.{});
|
return self.reshape(.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2177,12 +2176,12 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
|
||||||
for (coord_axes_.constSlice(), 0..) |a, i| {
|
for (coord_axes_.constSlice(), 0..) |a, i| {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self });
|
stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {f}", .{ coord_axes, self });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const AxisKind = enum { batching, offset, collapsed, indices };
|
const AxisKind = enum { batching, offset, collapsed, indices };
|
||||||
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||||
var indices_batch_axes: Shape.DimsArray = .{};
|
var indices_batch_axes: Shape.DimsArray = .{};
|
||||||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||||||
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
|
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
|
||||||
@ -2191,7 +2190,7 @@ pub const Tensor = struct {
|
|||||||
// Note: tags are required for batching.
|
// Note: tags are required for batching.
|
||||||
self_kind.appendAssumeCapacity(.batching);
|
self_kind.appendAssumeCapacity(.batching);
|
||||||
indices_batch_axes.appendAssumeCapacity(id_ax);
|
indices_batch_axes.appendAssumeCapacity(id_ax);
|
||||||
stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'coord_axes_={any}' and in 'indices={f}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
|
||||||
} else if (maybe_coord_ax) |_| {
|
} else if (maybe_coord_ax) |_| {
|
||||||
// for gatherValues we collapsed all gathered axes
|
// for gatherValues we collapsed all gathered axes
|
||||||
// (contrary to gatherSlices where we collapse none)
|
// (contrary to gatherSlices where we collapse none)
|
||||||
@ -2208,13 +2207,13 @@ pub const Tensor = struct {
|
|||||||
indices.rank()
|
indices.rank()
|
||||||
else blk: {
|
else blk: {
|
||||||
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||||
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices });
|
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {f}", .{ coord_axes, coord_axes_.len, indices });
|
||||||
break :blk ax;
|
break :blk ax;
|
||||||
};
|
};
|
||||||
|
|
||||||
// compute res shape
|
// compute res shape
|
||||||
var res_shape = Shape.init(.{}, self.dtype());
|
var res_shape = Shape.init(.{}, self.dtype());
|
||||||
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{};
|
var res_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
|
||||||
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
|
for (self_kind.constSlice(), 0..) |kind, ax_usize| {
|
||||||
const ax: u3 = @intCast(ax_usize);
|
const ax: u3 = @intCast(ax_usize);
|
||||||
if (ax == coord_axes_.get(0)) {
|
if (ax == coord_axes_.get(0)) {
|
||||||
@ -2275,7 +2274,7 @@ pub const Tensor = struct {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
|
||||||
stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices });
|
stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices });
|
||||||
return _result(res_shape, gather_op.result(0));
|
return _result(res_shape, gather_op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2347,29 +2346,29 @@ pub const Tensor = struct {
|
|||||||
/// and gatherSlices can copy data by group of C'*D elements.
|
/// and gatherSlices can copy data by group of C'*D elements.
|
||||||
pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||||
const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32);
|
const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32);
|
||||||
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
|
// scoped_log.debug("gatherSlice({}, {f}, {})", .{ self, slice_shape, indices });
|
||||||
|
|
||||||
const tagged_api = slice_shape.isFullyTagged();
|
const tagged_api = slice_shape.isFullyTagged();
|
||||||
if (tagged_api) {
|
if (tagged_api) {
|
||||||
for (slice_shape.tags()) |t| {
|
for (slice_shape.tags()) |t| {
|
||||||
stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self });
|
stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {f}", .{ t, self });
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For untagged api, we require all slices to be specified.
|
// For untagged api, we require all slices to be specified.
|
||||||
// Note: we could relax this and right align the slice.
|
// Note: we could relax this and right align the slice.
|
||||||
stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
|
stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({f}, slice={f}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
|
||||||
stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices });
|
stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({f}, slice={f}, indices) expects 'indices' to be a tensor [..., {}], got {f}", .{ self, slice_shape, slice_shape.rank(), indices });
|
||||||
|
|
||||||
// Compute result shape
|
// Compute result shape
|
||||||
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
|
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
|
||||||
var slice_dims = self._shape._dims;
|
var slice_dims = self._shape._dims;
|
||||||
var self_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
var self_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
|
||||||
var indices_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
var indices_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
|
||||||
var start_index_map: std.BoundedArray(i64, MAX_RANK) = .{};
|
var start_index_map: stdx.BoundedArray(i64, MAX_RANK) = .{};
|
||||||
var self_offset_axes: std.BoundedArray(i64, MAX_RANK) = .{};
|
var self_offset_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
|
||||||
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
|
||||||
const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax);
|
const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax);
|
||||||
|
|
||||||
@ -2379,12 +2378,12 @@ pub const Tensor = struct {
|
|||||||
self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
|
self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
|
||||||
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
|
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
|
||||||
slice_dims.set(self_ax, 1);
|
slice_dims.set(self_ax, 1);
|
||||||
stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices });
|
stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={f}` and `indices={f}`", .{ t, slice_shape, indices });
|
||||||
} else if (maybe_slice_ax) |slice_ax| {
|
} else if (maybe_slice_ax) |slice_ax| {
|
||||||
// Specified axes contains the start offset of the slices,
|
// Specified axes contains the start offset of the slices,
|
||||||
// and are collected in `start_index_map`.
|
// and are collected in `start_index_map`.
|
||||||
const slice_dim = slice_shape.dim(slice_ax);
|
const slice_dim = slice_shape.dim(slice_ax);
|
||||||
stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape });
|
stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {f} > {f}.", .{ t, slice_shape, self._shape });
|
||||||
slice_dims.set(self_ax, slice_dim);
|
slice_dims.set(self_ax, slice_dim);
|
||||||
res_shape = res_shape.appendDim(slice_dim, t);
|
res_shape = res_shape.appendDim(slice_dim, t);
|
||||||
start_index_map.appendAssumeCapacity(@intCast(self_ax));
|
start_index_map.appendAssumeCapacity(@intCast(self_ax));
|
||||||
@ -2396,7 +2395,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices });
|
const loc = self.getContext().location(@src(), "gatherSlices({f}, slice_shape={f}, idx={f})", .{ self, slice_shape, indices });
|
||||||
const gather_op = dialect.stablehlo.gather(
|
const gather_op = dialect.stablehlo.gather(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -3172,7 +3171,7 @@ pub const Tensor = struct {
|
|||||||
/// Note: this doesn't support tagging, if you have tags,
|
/// Note: this doesn't support tagging, if you have tags,
|
||||||
/// you should use `dynamicSlice` directly.
|
/// you should use `dynamicSlice` directly.
|
||||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
|
pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
|
||||||
stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start});
|
stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {f}", .{slice_.start});
|
||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const new_shape = self._shape.set(a, slice_.len);
|
const new_shape = self._shape.set(a, slice_.len);
|
||||||
@ -3226,17 +3225,17 @@ pub const Tensor = struct {
|
|||||||
const offset = slice_.start;
|
const offset = slice_.start;
|
||||||
const len = slice_.len;
|
const len = slice_.len;
|
||||||
if (slices_tags.len == 0) {
|
if (slices_tags.len == 0) {
|
||||||
stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len });
|
stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {d} and {d}", .{ self.rank(), slices.len });
|
||||||
|
|
||||||
offset_values[i] = offset.value();
|
offset_values[i] = offset.value();
|
||||||
res_shape._dims.set(i, len);
|
res_shape._dims.set(i, len);
|
||||||
|
|
||||||
stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i });
|
stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {d} and {d} for index {d}", .{ len, self.dim(i), i });
|
||||||
} else {
|
} else {
|
||||||
const t = slices_tags.get(i);
|
const t = slices_tags.get(i);
|
||||||
const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {f})", .{ t, self._shape });
|
||||||
|
|
||||||
stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t });
|
stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {d} and {d} for axis {s}", .{ len, self.dim(a), t });
|
||||||
|
|
||||||
offset_values[a] = offset.value();
|
offset_values[a] = offset.value();
|
||||||
res_shape._dims.set(a, len);
|
res_shape._dims.set(a, len);
|
||||||
@ -3304,14 +3303,14 @@ pub const Tensor = struct {
|
|||||||
if (tagged_api) {
|
if (tagged_api) {
|
||||||
// Check that all update tags are known.
|
// Check that all update tags are known.
|
||||||
for (update._shape._tags.constSlice()) |t| {
|
for (update._shape._tags.constSlice()) |t| {
|
||||||
stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape });
|
stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {f})", .{ t, self._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
var update_shape = self._shape;
|
var update_shape = self._shape;
|
||||||
var prev_ax: i8 = -1;
|
var prev_ax: i8 = -1;
|
||||||
for (self._shape.tags(), 0..) |t, self_ax| {
|
for (self._shape.tags(), 0..) |t, self_ax| {
|
||||||
if (update._shape.hasTag(t)) |up_ax| {
|
if (update._shape.hasTag(t)) |up_ax| {
|
||||||
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_, self });
|
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {f} and {f}. (hint: you need to explicitly transpose 'update_')", .{ update_, self });
|
||||||
|
|
||||||
update_shape._dims.set(self_ax, update.dim(up_ax));
|
update_shape._dims.set(self_ax, update.dim(up_ax));
|
||||||
prev_ax = up_ax;
|
prev_ax = up_ax;
|
||||||
@ -3322,7 +3321,7 @@ pub const Tensor = struct {
|
|||||||
update = update.reshape(update_shape);
|
update = update.reshape(update_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {}", .{ self, update });
|
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {f} and {f}", .{ self, update });
|
||||||
|
|
||||||
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
|
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
|
||||||
const t = self._shape.debugTag(ax);
|
const t = self._shape.debugTag(ax);
|
||||||
@ -3350,7 +3349,7 @@ pub const Tensor = struct {
|
|||||||
// This is only allowed when using tagged sliced.
|
// This is only allowed when using tagged sliced.
|
||||||
offset_values = .{zero} ** MAX_RANK;
|
offset_values = .{zero} ** MAX_RANK;
|
||||||
for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
|
for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
|
||||||
const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape });
|
const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {f})", .{ t, self._shape });
|
||||||
offset_values[a] = start.value();
|
offset_values[a] = start.value();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3469,12 +3468,12 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
|
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
|
||||||
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
|
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
|
||||||
stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {t} and {t}", .{ self.dtype(), other.dtype() });
|
||||||
|
|
||||||
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
|
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
|
||||||
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
|
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
|
||||||
|
|
||||||
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {f} and {f}", .{ self._shape, other._shape });
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)});
|
const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)});
|
||||||
const op = dialect.stablehlo.compare(
|
const op = dialect.stablehlo.compare(
|
||||||
@ -3492,7 +3491,7 @@ pub const Tensor = struct {
|
|||||||
/// For each vector in the input tensor,
|
/// For each vector in the input tensor,
|
||||||
/// creates a diagonal-matrix where diagonal values are set to the vector values.
|
/// creates a diagonal-matrix where diagonal values are set to the vector values.
|
||||||
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
|
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
|
||||||
stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self });
|
stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {d} rank, got {f}", .{ MAX_RANK - 1, self });
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const d = self.dim(a);
|
const d = self.dim(a);
|
||||||
var res_shape = self._shape;
|
var res_shape = self._shape;
|
||||||
@ -3622,8 +3621,8 @@ pub const Tensor = struct {
|
|||||||
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
|
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape });
|
stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_true._shape });
|
||||||
stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape });
|
stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_false._shape });
|
||||||
|
|
||||||
const loc = bool_tensor.getContext().mlirCtx().location(@src());
|
const loc = bool_tensor.getContext().mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.select(
|
const op = dialect.stablehlo.select(
|
||||||
@ -3762,7 +3761,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
|
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
|
||||||
pub fn cartesianProductStacked(vectors: []const Tensor) Tensor {
|
pub fn cartesianProductStacked(vectors: []const Tensor) Tensor {
|
||||||
var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable;
|
var out = stdx.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable;
|
||||||
_cartesianProduct(vectors, out.slice());
|
_cartesianProduct(vectors, out.slice());
|
||||||
|
|
||||||
return Tensor.stack(out.constSlice(), .last, .coord);
|
return Tensor.stack(out.constSlice(), .last, .coord);
|
||||||
@ -3801,7 +3800,7 @@ pub const Tensor = struct {
|
|||||||
) fn (Tensor, Tensor) Tensor {
|
) fn (Tensor, Tensor) Tensor {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
|
||||||
stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other });
|
stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {f} and {f}", .{ op_name, self, other });
|
||||||
|
|
||||||
if (self.rank() == 0 and other.rank() != 0) {
|
if (self.rank() == 0 and other.rank() != 0) {
|
||||||
return binaryOpHelper(self.broad(other._shape), other);
|
return binaryOpHelper(self.broad(other._shape), other);
|
||||||
@ -3811,10 +3810,10 @@ pub const Tensor = struct {
|
|||||||
return binaryOpHelper(self, other.broad(self._shape));
|
return binaryOpHelper(self, other.broad(self._shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape });
|
||||||
|
|
||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other });
|
const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other });
|
||||||
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
|
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
|
||||||
return _result(self._shape, ret.result(0));
|
return _result(self._shape, ret.result(0));
|
||||||
}
|
}
|
||||||
@ -3837,7 +3836,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
||||||
const host_buffer = inputs[0];
|
const host_buffer = inputs[0];
|
||||||
std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||||
// This is true because of the operand aliases.
|
// This is true because of the operand aliases.
|
||||||
// Since the result is already pointing to the input we don't need to modify the buffer.
|
// Since the result is already pointing to the input we don't need to modify the buffer.
|
||||||
std.debug.assert(host_buffer._data == outputs[0]._data);
|
std.debug.assert(host_buffer._data == outputs[0]._data);
|
||||||
@ -4060,8 +4059,8 @@ test shapesOf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) {
|
pub fn _collectAxes(T: type, bounded_array: stdx.BoundedArray(T, Tensor.MAX_RANK), value: T) stdx.BoundedArray(i64, Tensor.MAX_RANK) {
|
||||||
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
var res: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{};
|
||||||
for (bounded_array.constSlice(), 0..) |v, ax| {
|
for (bounded_array.constSlice(), 0..) |v, ax| {
|
||||||
if (v == value) {
|
if (v == value) {
|
||||||
res.appendAssumeCapacity(@intCast(ax));
|
res.appendAssumeCapacity(@intCast(ax));
|
||||||
@ -4070,12 +4069,12 @@ pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK)
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, stdx.BoundedArray(u3, Tensor.MAX_RANK) } {
|
||||||
const AxesT = @TypeOf(axes_);
|
const AxesT = @TypeOf(axes_);
|
||||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int;
|
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int;
|
||||||
|
|
||||||
const coord_axes = if (axes_is_scalar)
|
const coord_axes = if (axes_is_scalar)
|
||||||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
stdx.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||||||
else
|
else
|
||||||
self.axes(axes_);
|
self.axes(axes_);
|
||||||
|
|
||||||
@ -4099,7 +4098,7 @@ inline fn toI64(values: anytype) []i64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
|
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
|
||||||
var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
|
var perm: stdx.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
|
||||||
// Don't rewrite on invalid inputs.
|
// Don't rewrite on invalid inputs.
|
||||||
if (permutation.len > x.rank()) return false;
|
if (permutation.len > x.rank()) return false;
|
||||||
for (permutation) |ax| {
|
for (permutation) |ax| {
|
||||||
|
|||||||
@ -31,7 +31,7 @@ pub fn asyncMain() !void {
|
|||||||
.root_name = "Test",
|
.root_name = "Test",
|
||||||
.estimated_total_items = test_fn_list.len,
|
.estimated_total_items = test_fn_list.len,
|
||||||
});
|
});
|
||||||
const have_tty = std.io.getStdErr().isTty();
|
const have_tty = std.fs.File.stderr().isTty();
|
||||||
|
|
||||||
var args = std.process.args();
|
var args = std.process.args();
|
||||||
// Skip executable path
|
// Skip executable path
|
||||||
|
|||||||
@ -51,10 +51,10 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
|||||||
if (should_free_left) left.deinit(allocator);
|
if (should_free_left) left.deinit(allocator);
|
||||||
if (should_free_right) right.deinit(allocator);
|
if (should_free_right) right.deinit(allocator);
|
||||||
}
|
}
|
||||||
errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() });
|
errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() });
|
||||||
|
|
||||||
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
||||||
log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() });
|
log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
|
||||||
return error.TestUnexpectedResult;
|
return error.TestUnexpectedResult;
|
||||||
}
|
}
|
||||||
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
|
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
|
||||||
@ -89,7 +89,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
|||||||
const right_data = right.items(R);
|
const right_data = right.items(R);
|
||||||
for (left_data, right_data, 0..) |l, r, i| {
|
for (left_data, right_data, 0..) |l, r, i| {
|
||||||
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
|
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
|
||||||
log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
|
log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
|
||||||
return error.TestUnexpectedResult;
|
return error.TestUnexpectedResult;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,7 +108,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
|||||||
pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpectedEqual}!void {
|
pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpectedEqual}!void {
|
||||||
if (expected.eqlWithTags(actual)) return;
|
if (expected.eqlWithTags(actual)) return;
|
||||||
|
|
||||||
std.debug.print("Expected {}, got {}", .{ expected, actual });
|
std.debug.print("Expected {f}, got {f}", .{ expected, actual });
|
||||||
return error.TestExpectedEqual;
|
return error.TestExpectedEqual;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ zig_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//async",
|
"//async",
|
||||||
"//ffi:zig",
|
"//ffi:zig",
|
||||||
|
"//stdx",
|
||||||
"//zml/tokenizer/hftokenizers",
|
"//zml/tokenizer/hftokenizers",
|
||||||
"//zml/tokenizer/sentencepiece",
|
"//zml/tokenizer/sentencepiece",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -27,5 +27,6 @@ zig_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":hftokenizers_cc",
|
":hftokenizers_cc",
|
||||||
"//ffi:zig",
|
"//ffi:zig",
|
||||||
|
"//stdx",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const ffi = @import("ffi");
|
const ffi = @import("ffi");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub const Encoder = struct {
|
pub const Encoder = struct {
|
||||||
inner: *HFTokenizer,
|
inner: *HFTokenizer,
|
||||||
@ -33,8 +35,8 @@ pub const Encoder = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const Decoder = struct {
|
pub const Decoder = struct {
|
||||||
const StringBuffer = std.BoundedArray(u8, 128);
|
const StringBuffer = stdx.BoundedArray(u8, 128);
|
||||||
const TokensIdsBuffer = std.BoundedArray(u32, 4);
|
const TokensIdsBuffer = stdx.BoundedArray(u32, 4);
|
||||||
|
|
||||||
inner: *HFTokenizer,
|
inner: *HFTokenizer,
|
||||||
current_string: ?[]const u8 = null,
|
current_string: ?[]const u8 = null,
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
//! Disclaimer this is not a very robust implementation:
|
//! Disclaimer this is not a very robust implementation:
|
||||||
//! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization.
|
//! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization.
|
||||||
//! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer.
|
//! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer.
|
||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const log = std.log.scoped(.@"zml/tokenizer");
|
const log = std.log.scoped(.@"zml/tokenizer");
|
||||||
|
|
||||||
@ -87,12 +88,11 @@ pub const Tokenizer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Reads a new word directly into the tokenizer arena.
|
/// Reads a new word directly into the tokenizer arena.
|
||||||
pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: anytype) !void {
|
pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: *std.Io.Reader) !void {
|
||||||
const arena = self.arena_state.allocator();
|
const arena = self.arena_state.allocator();
|
||||||
|
|
||||||
const token = try arena.alloc(u8, len);
|
const token = try arena.alloc(u8, len);
|
||||||
const n = try tok_reader.readAll(token);
|
try tok_reader.readSliceAll(token);
|
||||||
std.debug.assert(n == len);
|
|
||||||
return self.addOwnedToken(score, token);
|
return self.addOwnedToken(score, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,9 +190,9 @@ pub const Tokenizer = struct {
|
|||||||
if (options.debug) {
|
if (options.debug) {
|
||||||
var _debug_buf: [256]u8 = undefined;
|
var _debug_buf: [256]u8 = undefined;
|
||||||
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
||||||
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator());
|
var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator());
|
||||||
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
||||||
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
||||||
}
|
}
|
||||||
var best_score: f32 = -1e10;
|
var best_score: f32 = -1e10;
|
||||||
var best_token: u32 = 0;
|
var best_token: u32 = 0;
|
||||||
@ -312,7 +312,7 @@ pub const Tokenizer = struct {
|
|||||||
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
||||||
/// the output is not valid utf8.
|
/// the output is not valid utf8.
|
||||||
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
|
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
|
||||||
var output = std.ArrayList(u8).init(allocator);
|
var output = std.array_list.Managed(u8).init(allocator);
|
||||||
errdefer output.deinit();
|
errdefer output.deinit();
|
||||||
|
|
||||||
try self.decodeWithOpts(&output, input, .{});
|
try self.decodeWithOpts(&output, input, .{});
|
||||||
@ -321,7 +321,7 @@ pub const Tokenizer = struct {
|
|||||||
|
|
||||||
pub fn decodeWithOpts(
|
pub fn decodeWithOpts(
|
||||||
self: *const Tokenizer,
|
self: *const Tokenizer,
|
||||||
output: *std.ArrayList(u8),
|
output: *std.array_list.Managed(u8),
|
||||||
input: []const u32,
|
input: []const u32,
|
||||||
opts: struct { sep: []const u8 = "" },
|
opts: struct { sep: []const u8 = "" },
|
||||||
) error{OutOfMemory}!void {
|
) error{OutOfMemory}!void {
|
||||||
@ -363,7 +363,8 @@ pub const Tokenizer = struct {
|
|||||||
|
|
||||||
// First lookup the byte fallback entry.
|
// First lookup the byte fallback entry.
|
||||||
// Note: we assume upper case, but we could try both upper and lower case if needed.
|
// Note: we assume upper case, but we could try both upper and lower case if needed.
|
||||||
_ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 });
|
var writer: std.Io.Writer = .fixed(byte_fallback_buf[3..5]);
|
||||||
|
try writer.printInt(c, 16, .upper, .{ .fill = '0', .width = 2 });
|
||||||
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
|
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
|
||||||
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
|
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
|
||||||
return error.InvalidInput;
|
return error.InvalidInput;
|
||||||
@ -443,8 +444,8 @@ pub const Encoder = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const Decoder = struct {
|
pub const Decoder = struct {
|
||||||
const StringBuffer = std.BoundedArray(u8, 128);
|
const StringBuffer = stdx.BoundedArray(u8, 128);
|
||||||
const TokensIdsBuffer = std.BoundedArray(u32, 4);
|
const TokensIdsBuffer = stdx.BoundedArray(u32, 4);
|
||||||
|
|
||||||
inner: *Tokenizer,
|
inner: *Tokenizer,
|
||||||
arena: std.heap.ArenaAllocator,
|
arena: std.heap.ArenaAllocator,
|
||||||
@ -571,7 +572,7 @@ test CharTokenIterator {
|
|||||||
{
|
{
|
||||||
tokenizer.byte_fallback = false;
|
tokenizer.byte_fallback = false;
|
||||||
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
||||||
var res: std.BoundedArray(u32, 8) = .{};
|
var res: stdx.BoundedArray(u32, 8) = .{};
|
||||||
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
||||||
res.appendAssumeCapacity(token);
|
res.appendAssumeCapacity(token);
|
||||||
}
|
}
|
||||||
@ -582,7 +583,7 @@ test CharTokenIterator {
|
|||||||
{
|
{
|
||||||
tokenizer.byte_fallback = true;
|
tokenizer.byte_fallback = true;
|
||||||
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
var it: CharTokenIterator = .{ .input = "ζℳL" };
|
||||||
var res: std.BoundedArray(u32, 8) = .{};
|
var res: stdx.BoundedArray(u32, 8) = .{};
|
||||||
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
while (try it.nextCodepointToken(&tokenizer)) |token| {
|
||||||
res.appendAssumeCapacity(token);
|
res.appendAssumeCapacity(token);
|
||||||
}
|
}
|
||||||
@ -596,7 +597,7 @@ pub const Normalizer = struct {
|
|||||||
/// Space token used by sentencepiece derived tokenizer.
|
/// Space token used by sentencepiece derived tokenizer.
|
||||||
pub const sentencepiece_space = "▁"; // \xe2\x96\x81
|
pub const sentencepiece_space = "▁"; // \xe2\x96\x81
|
||||||
|
|
||||||
_whitespace: std.BoundedArray(u8, 8) = .{},
|
_whitespace: stdx.BoundedArray(u8, 8) = .{},
|
||||||
|
|
||||||
flags: packed struct {
|
flags: packed struct {
|
||||||
remove_extra_whitespaces: bool,
|
remove_extra_whitespaces: bool,
|
||||||
@ -610,7 +611,7 @@ pub const Normalizer = struct {
|
|||||||
split_on_punct_ascii: bool,
|
split_on_punct_ascii: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer {
|
pub fn init(flags: @FieldType(Normalizer, "flags"), escaped_whitespace: ?[]const u8) Normalizer {
|
||||||
var res: Normalizer = .{ .flags = flags };
|
var res: Normalizer = .{ .flags = flags };
|
||||||
if (escaped_whitespace) |escaped| {
|
if (escaped_whitespace) |escaped| {
|
||||||
res._whitespace.appendSliceAssumeCapacity(escaped);
|
res._whitespace.appendSliceAssumeCapacity(escaped);
|
||||||
@ -622,7 +623,7 @@ pub const Normalizer = struct {
|
|||||||
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
|
fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void {
|
||||||
try normalized.appendSlice(data);
|
try normalized.appendSlice(data);
|
||||||
for (data) |_| try normalized_to_origin.append(consumed);
|
for (data) |_| try normalized_to_origin.append(consumed);
|
||||||
}
|
}
|
||||||
@ -672,9 +673,9 @@ pub const Normalizer = struct {
|
|||||||
// Pre-allocate outputs
|
// Pre-allocate outputs
|
||||||
const space = self.escapedSpace() orelse " ";
|
const space = self.escapedSpace() orelse " ";
|
||||||
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
||||||
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
||||||
errdefer normalized.deinit();
|
errdefer normalized.deinit();
|
||||||
var normalized_to_origin = try std.ArrayList(usize).initCapacity(allocator, normalized.capacity);
|
var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity);
|
||||||
errdefer normalized_to_origin.deinit();
|
errdefer normalized_to_origin.deinit();
|
||||||
|
|
||||||
// If the spec asks for it, add a whitespace at the beginning.
|
// If the spec asks for it, add a whitespace at the beginning.
|
||||||
@ -965,7 +966,7 @@ test Normalizer {
|
|||||||
/// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm,
|
/// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm,
|
||||||
/// into utf8 bytes, and do lookups at runtime.
|
/// into utf8 bytes, and do lookups at runtime.
|
||||||
pub const Gpt2TextDecoder = struct {
|
pub const Gpt2TextDecoder = struct {
|
||||||
const Code = std.BoundedArray(u8, 2);
|
const Code = stdx.BoundedArray(u8, 2);
|
||||||
|
|
||||||
// TODO: benchmark this is more efficient than doing the conversion at runtime.
|
// TODO: benchmark this is more efficient than doing the conversion at runtime.
|
||||||
code_to_byte: std.AutoArrayHashMap(Code, u8),
|
code_to_byte: std.AutoArrayHashMap(Code, u8),
|
||||||
@ -982,7 +983,7 @@ pub const Gpt2TextDecoder = struct {
|
|||||||
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
|
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
|
||||||
const i: u8 = @intCast(index);
|
const i: u8 = @intCast(index);
|
||||||
if (isPrintableByte(i)) {
|
if (isPrintableByte(i)) {
|
||||||
if (std.ascii.isASCII(i)) {
|
if (std.ascii.isAscii(i)) {
|
||||||
code.appendAssumeCapacity(i);
|
code.appendAssumeCapacity(i);
|
||||||
} else {
|
} else {
|
||||||
const codepoint: u21 = @as(u21, @intCast(i));
|
const codepoint: u21 = @as(u21, @intCast(i));
|
||||||
@ -1005,7 +1006,7 @@ pub const Gpt2TextDecoder = struct {
|
|||||||
|
|
||||||
/// Transform bytes representing text under the gpt2 encoding,
|
/// Transform bytes representing text under the gpt2 encoding,
|
||||||
/// and write to the `unicode` buffer utf-8 bytes.
|
/// and write to the `unicode` buffer utf-8 bytes.
|
||||||
pub fn decode(self: Gpt2TextDecoder, unicode: *std.ArrayList(u8), bytes: []const u8) ![]const u8 {
|
pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 {
|
||||||
const start = unicode.items.len;
|
const start = unicode.items.len;
|
||||||
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
|
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
|
||||||
while (it.nextCodepointSlice()) |codepoint| {
|
while (it.nextCodepointSlice()) |codepoint| {
|
||||||
@ -1029,7 +1030,7 @@ test Gpt2TextDecoder {
|
|||||||
var decoder = try Gpt2TextDecoder.init(testing.allocator);
|
var decoder = try Gpt2TextDecoder.init(testing.allocator);
|
||||||
defer decoder.deinit();
|
defer decoder.deinit();
|
||||||
|
|
||||||
var out = std.ArrayList(u8).init(testing.allocator);
|
var out = std.array_list.Managed(u8).init(testing.allocator);
|
||||||
defer out.deinit();
|
defer out.deinit();
|
||||||
|
|
||||||
// Ascii is not changed.
|
// Ascii is not changed.
|
||||||
@ -1076,7 +1077,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
|
|
||||||
// Buffer containing all concatenated tokens.
|
// Buffer containing all concatenated tokens.
|
||||||
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
||||||
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
||||||
const original_alloc = all_tokens.items.ptr;
|
const original_alloc = all_tokens.items.ptr;
|
||||||
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
||||||
// If this is too annoying we could switch to a custom type instead of slices.
|
// If this is too annoying we could switch to a custom type instead of slices.
|
||||||
@ -1164,7 +1165,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a copy of the given string, stored inside the given ArrayList.
|
/// Returns a copy of the given string, stored inside the given ArrayList.
|
||||||
fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 {
|
fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 {
|
||||||
const n = buffer.items.len;
|
const n = buffer.items.len;
|
||||||
try buffer.appendSlice(str);
|
try buffer.appendSlice(str);
|
||||||
return buffer.items[n..];
|
return buffer.items[n..];
|
||||||
@ -1175,7 +1176,7 @@ fn objectGet(
|
|||||||
object: std.json.ObjectMap,
|
object: std.json.ObjectMap,
|
||||||
comptime kind: std.meta.FieldEnum(std.json.Value),
|
comptime kind: std.meta.FieldEnum(std.json.Value),
|
||||||
key: []const u8,
|
key: []const u8,
|
||||||
) ?std.meta.FieldType(std.json.Value, kind) {
|
) ?@FieldType(std.json.Value, @tagName(kind)) {
|
||||||
const val = object.get(key) orelse return null;
|
const val = object.get(key) orelse return null;
|
||||||
if (val != kind) return null;
|
if (val != kind) return null;
|
||||||
return @field(val, @tagName(kind));
|
return @field(val, @tagName(kind));
|
||||||
@ -1184,10 +1185,11 @@ fn objectGet(
|
|||||||
pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer {
|
pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer {
|
||||||
const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{});
|
||||||
defer tokenizer_file.close();
|
defer tokenizer_file.close();
|
||||||
var tok_reader = std.io.bufferedReader(tokenizer_file.reader());
|
var read_buff: [4096]u8 = undefined;
|
||||||
const r = tok_reader.reader();
|
var tok_reader = tokenizer_file.reader(&read_buff);
|
||||||
|
const r: *std.Io.Reader = &tok_reader.interface;
|
||||||
|
|
||||||
const max_token_len = try r.readInt(u32, .little);
|
const max_token_len = try readValueLE(u32, r);
|
||||||
const special_tokens: Tokenizer.SpecialTokens = .{
|
const special_tokens: Tokenizer.SpecialTokens = .{
|
||||||
.unk = 0,
|
.unk = 0,
|
||||||
.bos = 1,
|
.bos = 1,
|
||||||
@ -1195,11 +1197,11 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u
|
|||||||
};
|
};
|
||||||
var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true);
|
var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true);
|
||||||
var i: u32 = 0;
|
var i: u32 = 0;
|
||||||
while (readToken(&tokenizer, &r)) : (i += 1) {
|
while (readToken(&tokenizer, r)) : (i += 1) {
|
||||||
// Pass
|
// Pass
|
||||||
} else |_| {
|
} else |_| {
|
||||||
if (i < vocab_size) {
|
if (i < vocab_size) {
|
||||||
log.info("Read {d} words out of {?d}", .{ i, vocab_size });
|
log.info("Read {d} words out of {d}", .{ i, vocab_size });
|
||||||
}
|
}
|
||||||
tokenizer.vocab_size = i;
|
tokenizer.vocab_size = i;
|
||||||
}
|
}
|
||||||
@ -1207,8 +1209,14 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u
|
|||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readToken(tokenizer: *Tokenizer, tok_reader: anytype) !void {
|
fn readToken(tokenizer: *Tokenizer, tok_reader: *std.Io.Reader) !void {
|
||||||
const score: f32 = @bitCast(try tok_reader.readInt(u32, .little));
|
const score: f32 = try readValueLE(f32, tok_reader);
|
||||||
const len: usize = @intCast(try tok_reader.readInt(u32, .little));
|
const len: usize = try readValueLE(u32, tok_reader);
|
||||||
try tokenizer.readTokenInto(score, len, tok_reader);
|
try tokenizer.readTokenInto(score, len, tok_reader);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn readValueLE(T: type, reader: *std.Io.Reader) !T {
|
||||||
|
var res: [1]T = undefined;
|
||||||
|
try reader.readSliceEndian(T, &res, .little);
|
||||||
|
return res[0];
|
||||||
|
}
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const log = std.log.scoped(.@"//zml/tokenizer");
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const zml_tokenizer = @import("zml/tokenizer");
|
const zml_tokenizer = @import("zml/tokenizer");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"//zml/tokenizer");
|
||||||
|
|
||||||
const Flags = struct {
|
const Flags = struct {
|
||||||
tokenizer: []const u8,
|
tokenizer: []const u8,
|
||||||
prompt: []const u8,
|
prompt: []const u8,
|
||||||
@ -35,7 +36,7 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
const prompt_tok = try encoder.encode(args.prompt);
|
const prompt_tok = try encoder.encode(args.prompt);
|
||||||
|
|
||||||
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok });
|
log.info("Input: {s}\nOutput: {any}", .{ args.prompt, prompt_tok });
|
||||||
|
|
||||||
var errors: u8 = 0;
|
var errors: u8 = 0;
|
||||||
{
|
{
|
||||||
@ -47,14 +48,14 @@ pub fn asyncMain() !void {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (args.expected.len > 0) {
|
if (args.expected.len > 0) {
|
||||||
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len);
|
var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len);
|
||||||
var it = std.mem.splitSequence(u8, args.expected, ",");
|
var it = std.mem.splitSequence(u8, args.expected, ",");
|
||||||
while (it.next()) |int_token| {
|
while (it.next()) |int_token| {
|
||||||
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
||||||
try expected.append(tok);
|
try expected.append(tok);
|
||||||
}
|
}
|
||||||
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
||||||
log.err("Doesn't match expected: {d}", .{expected.items});
|
log.err("Doesn't match expected: {any}", .{expected.items});
|
||||||
errors += 1;
|
errors += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,5 +19,6 @@ zig_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":sentencepiece_swig",
|
":sentencepiece_swig",
|
||||||
"//ffi:zig",
|
"//ffi:zig",
|
||||||
|
"//stdx",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const ffi = @import("ffi");
|
const ffi = @import("ffi");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const StringToTokenRatio = 3;
|
const StringToTokenRatio = 3;
|
||||||
|
|
||||||
@ -81,7 +83,7 @@ pub const Encoder = struct {
|
|||||||
|
|
||||||
pub const Decoder = struct {
|
pub const Decoder = struct {
|
||||||
const StringBufferSize = 64;
|
const StringBufferSize = 64;
|
||||||
const StringBuffer = std.BoundedArray(u8, StringBufferSize);
|
const StringBuffer = stdx.BoundedArray(u8, StringBufferSize);
|
||||||
const TokenIdsBufferSize = 4;
|
const TokenIdsBufferSize = 4;
|
||||||
|
|
||||||
inner: *SentencePieceProcessor,
|
inner: *SentencePieceProcessor,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ const Tensor = zml.Tensor;
|
|||||||
/// * `matmul(.{10}, .{10}) -> .{}`
|
/// * `matmul(.{10}, .{10}) -> .{}`
|
||||||
/// * `matmul(.{10}, .{10}) -> .{}`
|
/// * `matmul(.{10}, .{10}) -> .{}`
|
||||||
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
||||||
stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
|
stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({f}, {f}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
|
||||||
|
|
||||||
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
|
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
|
||||||
if (lhs.rank() == 1 or rhs.rank() <= 2) {
|
if (lhs.rank() == 1 or rhs.rank() <= 2) {
|
||||||
@ -22,7 +22,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
|
|||||||
return lhs.dotGeneral(rhs, &contracting, &.{});
|
return lhs.dotGeneral(rhs, &contracting, &.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
|
stdx.debug.assert(lhs.rank() == 2, "Can't matmul({f}, {f}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
|
||||||
|
|
||||||
// Pytorch treats the extra dimensions of rhs has batching dimensions,
|
// Pytorch treats the extra dimensions of rhs has batching dimensions,
|
||||||
// and implicitly broadcast lhs along those.
|
// and implicitly broadcast lhs along those.
|
||||||
@ -91,7 +91,7 @@ pub fn unsqueeze(
|
|||||||
self: Tensor,
|
self: Tensor,
|
||||||
axis_: anytype,
|
axis_: anytype,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {f}, it's already at max rank.", .{self});
|
||||||
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
||||||
.int, .comptime_int => if (axis_ < 0)
|
.int, .comptime_int => if (axis_ < 0)
|
||||||
@as(i8, self.rank()) + 1 + axis_
|
@as(i8, self.rank()) + 1 + axis_
|
||||||
@ -125,9 +125,9 @@ test unsqueeze {
|
|||||||
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
|
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
|
||||||
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
|
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
|
||||||
const shape = tensor.shape();
|
const shape = tensor.shape();
|
||||||
stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
|
stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({f}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
|
||||||
|
|
||||||
stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
|
stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({f}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
|
||||||
|
|
||||||
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
|
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
|
||||||
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });
|
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user