diff --git a/BUILD.bazel b/BUILD.bazel index 63540f5..aa1656d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -11,6 +11,7 @@ zls_completion( visibility = ["//visibility:public"], deps = [ "//async", + "//examples/llama", "//stdx", "//zml", ], diff --git a/MODULE.bazel b/MODULE.bazel index 7f34bc9..72ad53a 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "rules_oci", version = "2.2.6") bazel_dep(name = "rules_proto", version = "7.1.0") bazel_dep(name = "rules_python", version = "1.5.3") bazel_dep(name = "rules_rust", version = "0.63.0") -bazel_dep(name = "rules_zig", version = "20250821.0-be53625") +bazel_dep(name = "rules_zig", version = "20250827.0-35b6d57") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4") bazel_dep(name = "with_cfg.bzl", version = "0.11.0") @@ -26,7 +26,7 @@ bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = Tr zig = use_extension("@rules_zig//zig:extensions.bzl", "zig") zig.index(file = "//bazel:zig_index.json") -zig.toolchain(zig_version = "0.14.1") +zig.toolchain(zig_version = "0.15.1") zig.mirrors(urls = [ "https://mirror.zml.ai/zig", "https://ziglang.org/builds/", diff --git a/async/async.zig b/async/async.zig index f14608b..e105252 100644 --- a/async/async.zig +++ b/async/async.zig @@ -228,36 +228,88 @@ pub const AsyncThread = struct { }; pub fn getStdIn() !File { - return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); + return File.initStreaming(std.fs.File.stdin()) catch @panic("Unable to open stdin"); } pub fn getStdOut() File { - return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); + return File.initStreaming(std.fs.File.stdout()) catch @panic("Unable to open stdout"); } pub fn getStdErr() File { - return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); + return File.initStreaming(std.fs.File.stderr()) catch @panic("Unable to open stderr"); } pub const File = struct { + pub const Reader = struct { + interface: std.Io.Reader, + file: File, + mode: std.fs.File.Reader.Mode = .positional, + pos: u64 = 0, + + fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const self: *Reader = @alignCast(@fieldParentPtr("interface", r)); + const dest = limit.slice(try w.writableSliceGreedy(1)); + const n = switch (self.mode) { + .streaming => self.file.read(dest), + .positional => self.file.pread(dest, self.pos), + else => @panic("UNSUPPORTED"), + } catch { + return std.Io.Reader.StreamError.ReadFailed; + }; + if (n == 0) { + return std.Io.Reader.StreamError.EndOfStream; + } + self.pos += n; + w.advance(n); + return n; + } + }; + + pub const Writer = struct { + interface: std.Io.Writer, + file: File, + mode: std.fs.File.Writer.Mode = .positional, + pos: u64 = 0, + + fn write(self: *Writer, buf: []const u8) !usize { + const n = switch (self.mode) { + .streaming => self.file.write(buf), + .positional => self.file.pwrite(buf, self.pos), + else => unreachable, + } catch { + return std.Io.Writer.Error.WriteFailed; + }; + self.pos += n; + return n; + } + + fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + // TODO: implement splat + _ = splat; + const self: *Writer = @alignCast(@fieldParentPtr("interface", w)); + var total: usize = 0; + if (w.buffered().len > 0) { + total += w.consume(try self.write(w.buffered())); + } + for (data) |d| { + const n = try self.write(d); + total += n; + if (n < d.len) { + return total; + } + } + return total; + } + }; + pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?; pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?; - pub const Reader = std.io.GenericReader(File, stdx.meta.FnSignature(File.read, null).ReturnErrorSet.?, File.read); - pub const Writer = std.io.GenericWriter(File, stdx.meta.FnSignature(File.write, null).ReturnErrorSet.?, File.write); - pub const SeekableStream = std.io.SeekableStream( - File, - SeekError, - GetSeekPosError, - seekTo, - seekBy, - getPos, - getEndPos, - ); _handle: std.fs.File.Handle, inner: aio.File, + is_streaming: bool = false, - fn asFile(self: File) std.fs.File { + pub fn asFile(self: File) std.fs.File { return .{ .handle = self._handle }; } @@ -272,6 +324,14 @@ pub const File = struct { }; } + pub fn initStreaming(file_: std.fs.File) !File { + return .{ + ._handle = file_.handle, + .inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)), + .is_streaming = true, + }; + } + pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags })); } @@ -280,6 +340,21 @@ pub const File = struct { return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags }); } + pub fn reader(self: File, buffer: []u8) Reader { + return .{ + .file = self, + .interface = .{ + .vtable = &.{ + .stream = Reader.stream, + }, + .buffer = buffer, + .seek = 0, + .end = 0, + }, + .mode = if (self.is_streaming) .streaming else .positional, + }; + } + pub fn read(self: File, buf: []u8) !usize { // NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend. if (buf.len == 0) { @@ -310,6 +385,19 @@ pub const File = struct { return self.inner.write(.{ .slice = buf }); } + pub fn writer(self: File, buffer: []u8) Writer { + return .{ + .file = self, + .interface = .{ + .vtable = &.{ + .drain = Writer.drain, + }, + .buffer = buffer, + }, + .mode = if (self.is_streaming) .streaming else .positional, + }; + } + pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize { return self.inner.pwrite(.{ .slice = buf }, offset); } @@ -318,18 +406,6 @@ pub const File = struct { return self.inner.close(); } - pub fn reader(self: File) Reader { - return .{ .context = self }; - } - - pub fn seekableStream(file: File) SeekableStream { - return .{ .context = file }; - } - - pub fn writer(self: File) Writer { - return .{ .context = self }; - } - pub fn stat(self: File) !std.fs.File.Stat { return try callBlocking(std.fs.File.stat, .{self.asFile()}); } @@ -375,8 +451,47 @@ pub const Socket = struct { pub const TCP = struct { const Inner = aio.TCP; - pub const Reader = std.io.GenericReader(TCP, stdx.meta.FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read); - pub const Writer = std.io.GenericWriter(TCP, stdx.meta.FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write); + pub const Reader = struct { + interface: std.Io.Reader, + socket: TCP, + + fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const self: *Reader = @alignCast(@fieldParentPtr("interface", r)); + const dest = limit.slice(try w.writableSliceGreedy(1)); + const n = self.socket.read(dest) catch { + return std.Io.Reader.StreamError.ReadFailed; + }; + w.advance(n); + return n; + } + }; + + pub const Writer = struct { + interface: std.Io.Writer, + socket: TCP, + + fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + // TODO: implement splat + _ = splat; + const self: *Writer = @alignCast(@fieldParentPtr("interface", w)); + var total: usize = 0; + if (w.buffered().len >= 0) { + total += w.consume(self.socket.write(w.buffered()) catch { + return std.Io.Writer.Error.WriteFailed; + }); + } + for (data) |d| { + const n = self.socket.write(d) catch { + return std.Io.Writer.Error.WriteFailed; + }; + total += n; + if (n < d.len) { + return total; + } + } + return total; + } + }; inner: aio.TCP, @@ -418,12 +533,30 @@ pub const Socket = struct { return self.inner.close(); } - pub fn reader(self: TCP) Reader { - return .{ .context = self }; + pub fn reader(self: TCP, buffer: []u8) Reader { + return .{ + .socket = self, + .interface = .{ + .vtable = &.{ + .stream = Reader.stream, + }, + .buffer = buffer, + .seek = 0, + .end = 0, + }, + }; } - pub fn writer(self: TCP) Writer { - return .{ .context = self }; + pub fn writer(self: TCP, buffer: []u8) Writer { + return .{ + .socket = self, + .interface = .{ + .vtable = &.{ + .drain = Writer.drain, + }, + .buffer = buffer, + }, + }; } }; @@ -564,9 +697,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn { const level_txt = comptime message_level.asText(); const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; - const stderr = getStdErr().writer(); - var bw = std.io.bufferedWriter(stderr); - const writer = bw.writer(); + var buffer: [1024]u8 = undefined; + var stderr = getStdErr().writer(&buffer); var mutex = Self.mu orelse blk: { Self.mu = Mutex.init(); @@ -575,8 +707,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn { mutex.lock(); defer mutex.unlock(); nosuspend { - writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return; - bw.flush() catch return; + stderr.interface.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch unreachable; + stderr.interface.flush() catch unreachable; } } }.call; diff --git a/async/coro.zig b/async/coro.zig index be26b95..a4d0345 100644 --- a/async/coro.zig +++ b/async/coro.zig @@ -156,7 +156,7 @@ const Coro = struct { return self; } - fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.C) noreturn { + fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn { const from_coro: *Coro = @fieldParentPtr("impl", from); const this_coro: *Coro = @fieldParentPtr("impl", this); log(.debug, "coro start {any}", .{this_coro.id}); diff --git a/async/coro_base.zig b/async/coro_base.zig index 9cd6f1c..e1a4d6c 100644 --- a/async/coro_base.zig +++ b/async/coro_base.zig @@ -49,7 +49,7 @@ pub const Coro = packed struct { const Func = *const fn ( from: *Coro, self: *Coro, - ) callconv(.C) noreturn; + ) callconv(.c) noreturn; pub fn init(func: Func, stack: []align(stack_alignment) u8) !Self { stdx.debug.assertComptime(@sizeOf(usize) == 8, "usize expected to take 8 bytes", .{}); diff --git a/bazel/zig_index.json b/bazel/zig_index.json index afe965e..2dec408 100644 --- a/bazel/zig_index.json +++ b/bazel/zig_index.json @@ -1,83 +1,264 @@ { "master": { - "version": "0.15.0-dev.650+4f3b59f70", - "date": "2025-05-29", + "version": "0.16.0-dev.27+83f773fc6", + "date": "2025-08-24", "docs": "https://ziglang.org/documentation/master/", "stdDocs": "https://ziglang.org/documentation/master/std/", "src": { - "tarball": "https://ziglang.org/builds/zig-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "c14764ee9fd16f4437f2e2e7092cc7ec7ff76469f719be04155c29fa3bcc52cd", - "size": "21279148" + "tarball": "https://ziglang.org/builds/zig-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "afddafbede9becaa6d94a544d3115893cc6aa6492a58545b4f638c58990586dc", + "size": "21370308" }, "bootstrap": { - "tarball": "https://ziglang.org/builds/zig-bootstrap-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "142343992733282138b88b2205f29173ce3febcf5d73de4a19111fb36630f5a6", - "size": "52649088" + "tarball": "https://ziglang.org/builds/zig-bootstrap-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "109cdda833baf951ce8116d3fdfd50b3770b02cff5c057f777601ca694f44b7c", + "size": "52732704" }, "x86_64-macos": { - "tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "70117a96313ddfe57bd0b3fbbbdb2beffc27fb81f5993ba9b2825628bb9f5aaa", - "size": "55795264" + "tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "1ac88b947a6001eb30d41c15c404847209dfe945a27dbce6d1b7d401e4aef325", + "size": "55817716" }, "aarch64-macos": { - "tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "991ef1b1871852f87e1d53f65c3dfb3b3f7187a0dc1f8fe6e7be279227b868d3", - "size": "50625312" + "tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "75243ad6d2e9fcd634862dda9003883484024fd2e067fc7e4a0f6c524cb6c86e", + "size": "50659200" }, "x86_64-linux": { - "tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "2c2f65db1ad72d415b5a5bfb0ccd437bfc16e91b18a58035d28ea3177d07b1e2", - "size": "53712376" + "tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "c260aefaf5bf10bbd48101217874ccc9c0a37513729dabbbfe09bcf18fff1b5b", + "size": "53759360" }, "aarch64-linux": { - "tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "98bfd9c33b737aa66d5cc19886b772cdc99801e66d8c891d400b0370fc626455", - "size": "49500380" + "tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "658241c15ad827a89878b2e35e13261cd1c2aaefe9a8bb94ccb907725cf80564", + "size": "49485912" }, - "armv7a-linux": { - "tarball": "https://ziglang.org/builds/zig-armv7a-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "c204b60e27930702a64d55768fbbf2c819109dd0400168e0daf2994f4e57ed8b", - "size": "50413108" + "arm-linux": { + "tarball": "https://ziglang.org/builds/zig-arm-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "6039c11f41bd032ec3cbda2d92e38b409aa9d9e2935b21266033b820ecb1a668", + "size": "50475500" }, "riscv64-linux": { - "tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "2212613f09d114296f542dc3ca4583d3e1cfa72642e0d192a450ef1704722295", - "size": "53645244" + "tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "f56712849af75bb1bfda502c8cdfc52dcc41c18d3d7eddd329d94be85538d258", + "size": "53610616" }, "powerpc64le-linux": { - "tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "6a68c45d8bb5a3c1c3a259ad5ef965011c02feb0d17e8c96abd76aa48af03fe9", - "size": "53559816" + "tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "0562fdd578b5ae65a17e614454e67466314790a1d373279a4b054398f6a7b364", + "size": "53585484" }, "x86-linux": { - "tarball": "https://ziglang.org/builds/zig-x86-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "802e860891aa12979c3279e6d0571d08c0ed0398f362ac3f2a2fd945ef59bd1c", - "size": "56328204" + "tarball": "https://ziglang.org/builds/zig-x86-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "6eafb1b1d81118066dc03b17bed6612d7a9da057794f1d0f2451c249cf8a231d", + "size": "56334072" }, "loongarch64-linux": { - "tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "7fb8e715cdf44a2158afced5eecb3d2ae0f44d0a161f43ea8dfb04e95b22c89f", - "size": "50835068" + "tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "4fddf62a6698dbece51f504b4de8360a55d24efb0c7a56117eef64c8b717d13f", + "size": "50823012" }, "s390x-linux": { - "tarball": "https://ziglang.org/builds/zig-s390x-linux-0.15.0-dev.650+4f3b59f70.tar.xz", - "shasum": "cbbcbf324db5158232d1fd5a1efcc2d694e058ad66772495e291828fa5227450", - "size": "53386612" + "tarball": "https://ziglang.org/builds/zig-s390x-linux-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "12ed086066c373e4098dc9f132e42bbc478a3be6bd5fc6967f39396c43543f89", + "size": "53525496" }, "x86_64-windows": { - "tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.15.0-dev.650+4f3b59f70.zip", - "shasum": "0d4e148a60e859f0fa5ce6d293dcec7ba79a7e1da25df151b34e3ce71d0fa84f", - "size": "94297867" + "tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.16.0-dev.27+83f773fc6.zip", + "shasum": "ae987b8d93eec8a923bba26a74d20b5d15eac7ee9694c44db82c232247810a51", + "size": "93312425" }, "aarch64-windows": { - "tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.15.0-dev.650+4f3b59f70.zip", - "shasum": "6b1a094cb4c666d0078531ce3d4f970dcf01f63e5b885d9af0f498726bdf5e43", - "size": "90200802" + "tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.16.0-dev.27+83f773fc6.zip", + "shasum": "2e3c95d044dd36d302a301fb029ea080b2bda4ef8e42bbb1b2a0246973952c53", + "size": "89157759" }, "x86-windows": { - "tarball": "https://ziglang.org/builds/zig-x86-windows-0.15.0-dev.650+4f3b59f70.zip", - "shasum": "f18b6f9472a27abea8edbfb43687a75b61298e6d4da42247218dc3fd23560aa7", - "size": "96232989" + "tarball": "https://ziglang.org/builds/zig-x86-windows-0.16.0-dev.27+83f773fc6.zip", + "shasum": "b696f07f7104da430c7b37750bc6e36571c3ffcd5bc49240d8a19aa4b431049f", + "size": "95215201" + }, + "aarch64-freebsd": { + "tarball": "https://ziglang.org/builds/zig-aarch64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "0295a3f6bebb6d25fbc916cd4e3cff1285acab75b9591695069ad9a878301203", + "size": "49384624" + }, + "arm-freebsd": { + "tarball": "https://ziglang.org/builds/zig-arm-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "a293670c7ee89ce6cf11f03dee297de2a31ffdc9d26f2c0e6e78fd66b11b2327", + "size": "50925296" + }, + "powerpc64-freebsd": { + "tarball": "https://ziglang.org/builds/zig-powerpc64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "503922dd11374ddc44013ae3d7b07a9a17a1404da296bc30bf8fdc6d7273fd82", + "size": "52148468" + }, + "powerpc64le-freebsd": { + "tarball": "https://ziglang.org/builds/zig-powerpc64le-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "1cf44603dec137a35551916ab5e542c36e6c1ec54df4efd7783f8ae66c80acc4", + "size": "53516496" + }, + "riscv64-freebsd": { + "tarball": "https://ziglang.org/builds/zig-riscv64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "c4b1c691328abacf7f7b07f1900dd44c2f2852a0b2611378a2531975abb6f260", + "size": "53716116" + }, + "x86_64-freebsd": { + "tarball": "https://ziglang.org/builds/zig-x86_64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "989caf4343e1dfb4c9a68f3d3146cac96729e6f6e58276001654307f56c43981", + "size": "53808720" + }, + "aarch64-netbsd": { + "tarball": "https://ziglang.org/builds/zig-aarch64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "bbbe2aa2163d3c25d7414795537088245f9ec3cb53f10518b5824e1a31c35ff9", + "size": "49387264" + }, + "arm-netbsd": { + "tarball": "https://ziglang.org/builds/zig-arm-netbsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "264346f44664055d61c91087d0c55fdd04e3bdf1f2f0f969cd5efcf47d54b247", + "size": "52034816" + }, + "x86-netbsd": { + "tarball": "https://ziglang.org/builds/zig-x86-netbsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "f5e3cdab5d7e24552eee2200b1e4ac45cd84f13d5d64ae0f558ad0210d6883d5", + "size": "56915916" + }, + "x86_64-netbsd": { + "tarball": "https://ziglang.org/builds/zig-x86_64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz", + "shasum": "b6418c9d5036d2328271078687dbf2d772279e65787546974943855423fc6545", + "size": "53809524" + } + }, + "0.15.1": { + "date": "2025-08-19", + "docs": "https://ziglang.org/documentation/master/", + "stdDocs": "https://ziglang.org/documentation/master/std/", + "notes": "https://ziglang.org/download/0.15.1/release-notes.html", + "src": { + "tarball": "https://ziglang.org/download/0.15.1/zig-0.15.1.tar.xz", + "shasum": "816c0303ab313f59766ce2097658c9fff7fafd1504f61f80f9507cd11652865f", + "size": "21359884" + }, + "bootstrap": { + "tarball": "https://ziglang.org/download/0.15.1/zig-bootstrap-0.15.1.tar.xz", + "shasum": "4c0cfbcf12da144955761ca43f89e3c74956bce978694fc1d0a63555f5c0a199", + "size": "52711548" + }, + "x86_64-macos": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-macos-0.15.1.tar.xz", + "shasum": "9919392e0287cccc106dfbcbb46c7c1c3fa05d919567bb58d7eb16bca4116184", + "size": "55791880" + }, + "aarch64-macos": { + "tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-macos-0.15.1.tar.xz", + "shasum": "c4bd624d901c1268f2deb9d8eb2d86a2f8b97bafa3f118025344242da2c54d7b", + "size": "50644996" + }, + "x86_64-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-linux-0.15.1.tar.xz", + "shasum": "c61c5da6edeea14ca51ecd5e4520c6f4189ef5250383db33d01848293bfafe05", + "size": "53734456" + }, + "aarch64-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-linux-0.15.1.tar.xz", + "shasum": "bb4a8d2ad735e7fba764c497ddf4243cb129fece4148da3222a7046d3f1f19fe", + "size": "49493872" + }, + "arm-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-arm-linux-0.15.1.tar.xz", + "shasum": "3f4bf3b06b67d14e3f38be30798488c1abe3cf5b33de570cd0e87bbf09b978ad", + "size": "50477464" + }, + "riscv64-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-linux-0.15.1.tar.xz", + "shasum": "7ca7a3e621436fb31d66a253132fc39574a13d2a1b4d8458af4f2e7c6e4374fe", + "size": "53597792" + }, + "powerpc64le-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-linux-0.15.1.tar.xz", + "shasum": "339e2106496be70b614e32d444298216e676c36d08f5cd7bee3dd1dbd4567fd7", + "size": "53566944" + }, + "x86-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86-linux-0.15.1.tar.xz", + "shasum": "dff166f25fdd06e8341d831a71211b5ba7411463a6b264bdefa8868438690b6a", + "size": "56311228" + }, + "loongarch64-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-loongarch64-linux-0.15.1.tar.xz", + "shasum": "0af18a012f3c4cffbef29ab5f42021484d92517f921a1380faacc89c50c1f89d", + "size": "50794376" + }, + "s390x-linux": { + "tarball": "https://ziglang.org/download/0.15.1/zig-s390x-linux-0.15.1.tar.xz", + "shasum": "bcd13e5c88cf2d0da7100a48572195560069a94402c9bec3b316398204aa27e2", + "size": "53508068" + }, + "x86_64-windows": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-windows-0.15.1.zip", + "shasum": "91e69e887ca8c943ce9a515df3af013d95a66a190a3df3f89221277ebad29e34", + "size": "92612958" + }, + "aarch64-windows": { + "tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-windows-0.15.1.zip", + "shasum": "1f1bf16228b0ffcc882b713dc5e11a6db4219cb30997e13c72e8e723c2104ec6", + "size": "88458549" + }, + "x86-windows": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86-windows-0.15.1.zip", + "shasum": "fb1c07cffbb43615d3158ab8b8f5db5da1d48875eca99e1d7a8a0064ff63fc5b", + "size": "94516052" + }, + "aarch64-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-freebsd-0.15.1.tar.xz", + "shasum": "4d9d25c775828d49ea037b2284310c295d951793da8ebe94827a54fed4cca3ce", + "size": "49358464" + }, + "arm-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-arm-freebsd-0.15.1.tar.xz", + "shasum": "9707f3a5f7e1a3d99c40db9a74de1acc61016a197ad289c2ad964f93cb213a18", + "size": "50904124" + }, + "powerpc64-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64-freebsd-0.15.1.tar.xz", + "shasum": "79448884372db04e62f77a46b92245fce805063e69534729537f75cb4681e7e3", + "size": "52099020" + }, + "powerpc64le-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-freebsd-0.15.1.tar.xz", + "shasum": "f18ee12ba9c98a20b8d2ad0410c679e7aa5591bc7917f169fd6b377833d2c7ad", + "size": "53480976" + }, + "riscv64-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-freebsd-0.15.1.tar.xz", + "shasum": "ee9f864a6fd8b57c1f4fdbb11daa06578746a6f8253afe3f5ddb5a76f2eddd2d", + "size": "53677800" + }, + "x86_64-freebsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-freebsd-0.15.1.tar.xz", + "shasum": "9714f8ac3d3dc908b1599837c6167f857c1efaa930f0cfa840699458de7c3cd0", + "size": "53782112" + }, + "aarch64-netbsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-netbsd-0.15.1.tar.xz", + "shasum": "b2a528399777583b85b89c54ccd45488af7709d6dd29a27323ec2a229db40910", + "size": "49368000" + }, + "arm-netbsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-arm-netbsd-0.15.1.tar.xz", + "shasum": "93dc70109cbf5d2e022d20dfb56211978c4ea3c0b1e67aaabff947d8d1583aab", + "size": "52028184" + }, + "x86-netbsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86-netbsd-0.15.1.tar.xz", + "shasum": "a91b26051822ff17f3143f859b87dce5b4a13e90928bd6daa6f07a895d3410f0", + "size": "56881864" + }, + "x86_64-netbsd": { + "tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-netbsd-0.15.1.tar.xz", + "shasum": "6d7ba6eca5b4434351ebdb971b7303c9934514f9bb8481852251dbd5b52b03d6", + "size": "53794836" } }, "0.14.1": { diff --git a/ffi/zig_allocator.zig b/ffi/zig_allocator.zig index 412dcd1..35f5fc3 100644 --- a/ffi/zig_allocator.zig +++ b/ffi/zig_allocator.zig @@ -10,13 +10,13 @@ pub const ZigAllocator = struct { }; } - pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) ?*anyopaque { + pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) ?*anyopaque { const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx)); const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null; return @ptrCast(ret); } - pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) void { + pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) void { const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx)); const memory: [*c]u8 = @ptrCast(ptr); const size = elem * nelems; diff --git a/mlir/dialects/BUILD.bazel b/mlir/dialects/BUILD.bazel index 845775e..101f67f 100644 --- a/mlir/dialects/BUILD.bazel +++ b/mlir/dialects/BUILD.bazel @@ -15,6 +15,7 @@ zig_library( deps = [ "//mlir", "//mlir/dialects/stablehlo", + "//stdx", ], ) diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index 9db3612..868ff3b 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -1,6 +1,7 @@ const std = @import("std"); const mlir = @import("mlir"); +const stdx = @import("stdx"); pub fn func( ctx: mlir.Context, @@ -14,7 +15,7 @@ pub fn func( location: mlir.Location, }, ) mlir.Operation { - var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; + var attrs_tuple_buffer = stdx.BoundedArray(mlir.AttrTuple, 4){}; attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) }); if (args.arg_attrs.len > 0) { diff --git a/mlir/dialects/stablehlo/stablehlo.zig b/mlir/dialects/stablehlo/stablehlo.zig index 8010985..0c60fd0 100644 --- a/mlir/dialects/stablehlo/stablehlo.zig +++ b/mlir/dialects/stablehlo/stablehlo.zig @@ -761,7 +761,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa ); } - var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; + var attrs: stdx.BoundedArray(mlir.AttrTuple, 32) = .{}; attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, .{ "call_target_name", .string(ctx, opts.call_target_name) }, @@ -770,7 +770,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa }); { - var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + var output_operand_aliases: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (opts.output_operand_aliases) |alias| { output_operand_aliases.appendAssumeCapacity( OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(), @@ -789,14 +789,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa }; if (opts.operand_layouts) |layouts| { - var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + var operand_layouts: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (layouts) |ol| { operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); } attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); } else { const operand_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + var ret: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (inputs) |input| { const ranked_type = input.getType().as(mlir.RankedTensorType).?; const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; @@ -808,14 +808,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa } if (opts.result_layouts) |layouts| { - var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + var result_layouts: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (layouts) |rl| { result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); } attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); } else { const result_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + var ret: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; for (res_types) |t| { const ranked_t = t.as(mlir.RankedTensorType).?; const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; @@ -1271,7 +1271,7 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo const WriterContext = @TypeOf(context); c.stablehloVersionFromCompatibilityRequirement(req, (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; } @@ -1292,7 +1292,7 @@ pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) [] const WriterContext = @TypeOf(context); _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; } @@ -1313,7 +1313,7 @@ pub fn getCurrentVersion() []const u8 { const ContextWriter = @TypeOf(writer_); c.stablehloGetCurrentVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { const writer: *ContextWriter = @ptrCast(@alignCast(userdata)); _ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; } @@ -1339,7 +1339,7 @@ pub fn getMinimumVersion() []const u8 { const WriterContext = @TypeOf(context); c.stablehloGetMinimumVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; } @@ -1358,7 +1358,7 @@ pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u const WriterContext = @TypeOf(context); try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; } diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 10d26c9..a222c40 100755 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -40,7 +40,7 @@ pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void { } /// Alternative to MlirWrapperType -pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void; +pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void; pub const Registry = struct { _inner: c.MlirDialectRegistry, @@ -171,7 +171,7 @@ pub const PassManager = struct { } }; -fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.C) void { +fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.c) void { _ = ctx; std.debug.print(">>ERROR: {s}\n", .{err.data}); } @@ -754,7 +754,7 @@ pub const Operation = struct { state.addOperands(operands); } else if (args.variadic_operands) |operands_segments| { const MAX_SEGMENTS = 32; - var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; + var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{}; for (operands_segments) |operands| { state.addOperands(operands); @@ -764,7 +764,7 @@ pub const Operation = struct { } else if (args.tt_variadic_operands) |operands_segments| { // stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that. const MAX_SEGMENTS = 32; - var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; + var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{}; for (operands_segments) |operands| { state.addOperands(operands); @@ -811,7 +811,7 @@ pub const Operation = struct { @panic("Failed to create MLIR operation"); }; if (args.verify and new_op.verify() == false) { - log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })}); + log.err("Failed to verify MLIR operation:\n{f}", .{new_op.mlirFormatter(.{ .debug_info = true })}); @panic("Failed to verify MLIR operation"); } return new_op; @@ -888,7 +888,7 @@ pub const Operation = struct { c.mlirOperationWriteBytecode( self._inner, (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { + pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); _ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable; } @@ -916,7 +916,7 @@ pub const Operation = struct { self._inner, cfg, (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { + pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); _ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| { inner_writer_context.write_error = err; @@ -939,29 +939,25 @@ pub const Operation = struct { op: Operation, flags: OpPrintingFlags, - pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - _ = fmt; - _ = options; + pub fn format(self: @This(), writer: anytype) !void { self.op.print(writer, self.flags); } }; - pub fn print(self: Self, writer: anytype, flags: OpPrintingFlags) void { + pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void { const pflags = flags.create(); defer c.mlirOpPrintingFlagsDestroy(pflags); - var writer_context = .{ .writer = writer }; - const WriterContext = @TypeOf(writer_context); c.mlirOperationPrintWithFlags( self._inner, pflags, (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { - const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); - _ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable; + pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { + const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_)); + _writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed"); } }).callback, - &writer_context, + writer, ); } @@ -991,7 +987,7 @@ pub const Operation = struct { c.mlirOperationWalk( self._inner, (struct { - pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.C) c.MlirWalkResult { + pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.c) c.MlirWalkResult { const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_)); return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op })); } @@ -1017,24 +1013,26 @@ pub const Operation = struct { return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_)); } + /// Hash the canonicalized IR, without debug information that can change across builds. pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void { - const NoError = error{}; - const write = struct { - fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize { - hasher_.update(bytes); - return bytes.len; - } - }.write; - const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write); - const writer: HashWriter = .{ .context = hasher }; - - // Hash the canonicalized IR, without debug information that can change across builds. // Note: before we where using op.writeBytecode(writer), // but it crashes on some inputs, notably for unused variables. // So we use the text representation of the mlir. // See https://github.com/zml/zml/issues/97. - // Writes can't fail because we are writing to a hasher. - op.print(writer, .{ .debug_info = false }); + const flags = OpPrintingFlags.create(.{ .debug_info = false }); + defer c.mlirOpPrintingFlagsDestroy(flags); + + c.mlirOperationPrintWithFlags( + op._inner, + flags, + (struct { + pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { + const _hasher: *std.hash.XxHash64 = @ptrCast(@alignCast(ctx_)); + _hasher.update(str.data[0..str.length]); + } + }).callback, + hasher, + ); } }; @@ -1185,9 +1183,9 @@ pub const BlockArgument = struct { return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner)); } - pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + pub fn format(self: BlockArgument, writer: anytype) !void { const value = Value{ ._inner = self._inner }; - return value.format(fmt, options, writer); + return value.format(writer); } }; @@ -1234,8 +1232,8 @@ pub const Type = struct { pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type { return struct { - pub fn format(self: SpecificType, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - return try Type.format(self.asType(), fmt, options, writer); + pub fn format(self: SpecificType, writer: anytype) !void { + return try Type.format(self.asType(), writer); } }; } @@ -1344,7 +1342,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type { pub const eql = Type.eqlAny(Int); pub const format = helpers.format(Int, c.mlirTypePrint); - fn typeIsAIntegerExact(typ: c.MlirType) callconv(.C) bool { + fn typeIsAIntegerExact(typ: c.MlirType) callconv(.c) bool { const bit_width = Config[0]; const is_sign = Config[2]; return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ); @@ -1422,20 +1420,20 @@ pub fn ComplexType(comptime ct: ComplexTypes) type { _inner: c.MlirType, const Complex = @This(); - fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType { + fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType { return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx)); } - fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType { + fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType { return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx)); } - fn mlirTypeIsAC64(typ: c.MlirType) callconv(.C) bool { + fn mlirTypeIsAC64(typ: c.MlirType) callconv(.c) bool { const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ); return c.mlirTypeIsAF32(element_type); } - fn mlirTypeIsAC128(typ: c.MlirType) callconv(.C) bool { + fn mlirTypeIsAC128(typ: c.MlirType) callconv(.c) bool { const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ); return c.mlirTypeIsAF64(element_type); } @@ -1446,7 +1444,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type { .unknown => .{ c.mlirTypeIsAComplex, null }, }; - fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.C) bool { + fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.c) bool { return c.mlirTypeIsAComplex(typ); } @@ -1685,7 +1683,7 @@ pub const Block = struct { .op_result => |parent_op| self.appendOperationRecursive(parent_op, opt), .block_argument => |arg| { // Hermetic blocks are not allowed to use arguments from other blocks. - stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr }); + stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {f} from {*} block to {*} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr }); }, .null => @panic("InvalidMlir"), } @@ -1694,7 +1692,7 @@ pub const Block = struct { pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void { if (op.block()) |prev_block| { // Hermetic blocks are not allowed to reference values from other blocks. - stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr }); + stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {*} block to {*} block", .{ op, prev_block._inner.ptr, self._inner.ptr }); return; } for (0..op.numOperands()) |i| { @@ -1705,7 +1703,7 @@ pub const Block = struct { }; pub const helpers = struct { - pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.C) bool) fn (T, T) bool { + pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.c) bool) fn (T, T) bool { return struct { fn eql(a: T, b: T) bool { return equal_fn(a._inner, b._inner); @@ -1713,7 +1711,7 @@ pub const helpers = struct { }.eql; } - pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (*T) void { + pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (*T) void { return struct { fn deinit(a: *T) void { deinit_fn(a._inner); @@ -1722,7 +1720,7 @@ pub const helpers = struct { }.deinit; } - pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (T) void { + pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (T) void { return struct { fn dump(a: T) void { return dump_fn(a._inner); @@ -1730,7 +1728,7 @@ pub const helpers = struct { }.dump; } - pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (T) bool { + pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (T) bool { return struct { fn isNull(a: T) bool { return is_null_fn(a._inner); @@ -1738,21 +1736,13 @@ pub const helpers = struct { }.isNull; } - pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void) type { + pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type { return struct { - pub fn format( - self: Any, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; - - const Writer = struct { - writer: @TypeOf(writer), - err: ?@TypeOf(writer).Error = null, - fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void { + pub fn format(self: Any, writer: *std.Io.Writer) !void { + const WriterWithErr = struct { + writer: *std.Io.Writer, + err: ?std.Io.Writer.Error = null, + fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void { var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); if (ctx.err) |_| return; _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { @@ -1762,14 +1752,14 @@ pub const helpers = struct { } }; - var context: Writer = .{ .writer = writer }; - print_fn(self._inner, &Writer.printCallback, &context); + var context: WriterWithErr = .{ .writer = writer }; + print_fn(self._inner, &WriterWithErr.printCallback, &context); if (context.err) |err| return err; } }; } - pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (@FieldType(T, "_inner")) ?T { + pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T { return struct { fn wrapOr(inner: @FieldType(T, "_inner")) ?T { if (is_null_fn(inner)) return null; @@ -1778,7 +1768,7 @@ pub const helpers = struct { }.wrapOr; } - pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) ?T { + pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T { if (is_null_fn(inner)) return null; return .{ ._inner = inner }; } diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig index d4fb6ee..b79a789 100644 --- a/pjrt/ffi.zig +++ b/pjrt/ffi.zig @@ -424,7 +424,7 @@ pub const CallFrame = extern struct { } }; -pub const Handler = fn (*CallFrame) callconv(.C) ?*Error; +pub const Handler = fn (*CallFrame) callconv(.c) ?*Error; pub const ErrorCode = enum(c.XLA_FFI_Error_Code) { cancelled = c.XLA_FFI_Error_Code_CANCELLED, diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 00749d5..87f3d17 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -88,7 +88,7 @@ pub const Api = struct { return err; }, }; - const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse { + const DynGetPjrtApi = lib.lookup(*const fn () callconv(.c) *const Api, "GetPjrtApi") orelse { std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library}); }; @@ -100,7 +100,7 @@ pub const Api = struct { } fn CallFnArgType(comptime func: Funcs) type { - const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func)); + const fti = @typeInfo(@FieldType(c.PJRT_Api, @tagName(func))); const fn_ptr = @typeInfo(fti.optional.child); const fn_type_info = @typeInfo(fn_ptr.pointer.child); const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?); @@ -403,8 +403,8 @@ pub const Client = opaque { element_type: BufferType, layout: MemoryLayout, device: *const Device, - on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = &struct { - fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} + on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.c) void = &struct { + fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.c) void {} }.call, on_delete_callback_arg: ?*anyopaque = null, stream: ?*const Stream = null, @@ -637,7 +637,7 @@ pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError; pub const SerializeResult = struct { bytes: []const u8, handle: *anyopaque, - deleter: *const fn (?*anyopaque) callconv(.C) void, + deleter: *const fn (?*anyopaque) callconv(.c) void, pub fn deinit(self: *SerializeResult) void { self.deleter(self.handle); @@ -1036,7 +1036,7 @@ pub const Event = opaque { }); } - pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.C) void, user_arg: ?*anyopaque) ApiError!void { + pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.c) void, user_arg: ?*anyopaque) ApiError!void { _ = try api.call(.PJRT_Event_OnReady, .{ .event = self.inner(), .callback = @ptrCast(func), diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 76389c5..90456d1 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -3,6 +3,7 @@ load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test") zig_library( name = "stdx", srcs = [ + "bounded_array.zig", "debug.zig", "flags.zig", "fmt.zig", diff --git a/stdx/bounded_array.zig b/stdx/bounded_array.zig new file mode 100644 index 0000000..670201a --- /dev/null +++ b/stdx/bounded_array.zig @@ -0,0 +1,412 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; +const testing = std.testing; +const Alignment = std.mem.Alignment; + +/// A structure with an array and a length, that can be used as a slice. +/// +/// Useful to pass around small arrays whose exact size is only known at +/// runtime, but whose maximum size is known at comptime, without requiring +/// an `Allocator`. +/// +/// ```zig +/// var actual_size = 32; +/// var a = try BoundedArray(u8, 64).init(actual_size); +/// var slice = a.slice(); // a slice of the 64-byte array +/// var a_clone = a; // creates a copy - the structure doesn't use any internal pointers +/// ``` +pub fn BoundedArray(comptime T: type, comptime buffer_capacity: usize) type { + return BoundedArrayAligned(T, .of(T), buffer_capacity); +} + +/// A structure with an array, length and alignment, that can be used as a +/// slice. +/// +/// Useful to pass around small explicitly-aligned arrays whose exact size is +/// only known at runtime, but whose maximum size is known at comptime, without +/// requiring an `Allocator`. +/// ```zig +// var a = try BoundedArrayAligned(u8, 16, 2).init(0); +// try a.append(255); +// try a.append(255); +// const b = @ptrCast(*const [1]u16, a.constSlice().ptr); +// try testing.expectEqual(@as(u16, 65535), b[0]); +/// ``` +pub fn BoundedArrayAligned( + comptime T: type, + comptime alignment: Alignment, + comptime buffer_capacity: usize, +) type { + return struct { + const Self = @This(); + buffer: [buffer_capacity]T align(alignment.toByteUnits()) = undefined, + len: usize = 0, + + /// Set the actual length of the slice. + /// Returns error.Overflow if it exceeds the length of the backing array. + pub fn init(len: usize) error{Overflow}!Self { + if (len > buffer_capacity) return error.Overflow; + return Self{ .len = len }; + } + + /// View the internal array as a slice whose size was previously set. + pub fn slice(self: anytype) switch (@TypeOf(&self.buffer)) { + *align(alignment.toByteUnits()) [buffer_capacity]T => []align(alignment.toByteUnits()) T, + *align(alignment.toByteUnits()) const [buffer_capacity]T => []align(alignment.toByteUnits()) const T, + else => unreachable, + } { + return self.buffer[0..self.len]; + } + + /// View the internal array as a constant slice whose size was previously set. + pub fn constSlice(self: *const Self) []align(alignment.toByteUnits()) const T { + return self.slice(); + } + + /// Adjust the slice's length to `len`. + /// Does not initialize added items if any. + pub fn resize(self: *Self, len: usize) error{Overflow}!void { + if (len > buffer_capacity) return error.Overflow; + self.len = len; + } + + /// Remove all elements from the slice. + pub fn clear(self: *Self) void { + self.len = 0; + } + + /// Copy the content of an existing slice. + pub fn fromSlice(m: []const T) error{Overflow}!Self { + var list = try init(m.len); + @memcpy(list.slice(), m); + return list; + } + + /// Return the element at index `i` of the slice. + pub fn get(self: Self, i: usize) T { + return self.constSlice()[i]; + } + + /// Set the value of the element at index `i` of the slice. + pub fn set(self: *Self, i: usize, item: T) void { + self.slice()[i] = item; + } + + /// Return the maximum length of a slice. + pub fn capacity(self: Self) usize { + return self.buffer.len; + } + + /// Check that the slice can hold at least `additional_count` items. + pub fn ensureUnusedCapacity(self: Self, additional_count: usize) error{Overflow}!void { + if (self.len + additional_count > buffer_capacity) { + return error.Overflow; + } + } + + /// Increase length by 1, returning a pointer to the new item. + pub fn addOne(self: *Self) error{Overflow}!*T { + try self.ensureUnusedCapacity(1); + return self.addOneAssumeCapacity(); + } + + /// Increase length by 1, returning pointer to the new item. + /// Asserts that there is space for the new item. + pub fn addOneAssumeCapacity(self: *Self) *T { + assert(self.len < buffer_capacity); + self.len += 1; + return &self.slice()[self.len - 1]; + } + + /// Resize the slice, adding `n` new elements, which have `undefined` values. + /// The return value is a pointer to the array of uninitialized elements. + pub fn addManyAsArray(self: *Self, comptime n: usize) error{Overflow}!*align(alignment.toByteUnits()) [n]T { + const prev_len = self.len; + try self.resize(self.len + n); + return self.slice()[prev_len..][0..n]; + } + + /// Resize the slice, adding `n` new elements, which have `undefined` values. + /// The return value is a slice pointing to the uninitialized elements. + pub fn addManyAsSlice(self: *Self, n: usize) error{Overflow}![]align(alignment.toByteUnits()) T { + const prev_len = self.len; + try self.resize(self.len + n); + return self.slice()[prev_len..][0..n]; + } + + /// Remove and return the last element from the slice, or return `null` if the slice is empty. + pub fn pop(self: *Self) ?T { + if (self.len == 0) return null; + const item = self.get(self.len - 1); + self.len -= 1; + return item; + } + + /// Return a slice of only the extra capacity after items. + /// This can be useful for writing directly into it. + /// Note that such an operation must be followed up with a + /// call to `resize()` + pub fn unusedCapacitySlice(self: *Self) []align(alignment.toByteUnits()) T { + return self.buffer[self.len..]; + } + + /// Insert `item` at index `i` by moving `slice[n .. slice.len]` to make room. + /// This operation is O(N). + pub fn insert( + self: *Self, + i: usize, + item: T, + ) error{Overflow}!void { + if (i > self.len) { + return error.Overflow; + } + _ = try self.addOne(); + var s = self.slice(); + mem.copyBackwards(T, s[i + 1 .. s.len], s[i .. s.len - 1]); + self.buffer[i] = item; + } + + /// Insert slice `items` at index `i` by moving `slice[i .. slice.len]` to make room. + /// This operation is O(N). + pub fn insertSlice(self: *Self, i: usize, items: []const T) error{Overflow}!void { + try self.ensureUnusedCapacity(items.len); + self.len += items.len; + mem.copyBackwards(T, self.slice()[i + items.len .. self.len], self.constSlice()[i .. self.len - items.len]); + @memcpy(self.slice()[i..][0..items.len], items); + } + + /// Replace range of elements `slice[start..][0..len]` with `new_items`. + /// Grows slice if `len < new_items.len`. + /// Shrinks slice if `len > new_items.len`. + pub fn replaceRange( + self: *Self, + start: usize, + len: usize, + new_items: []const T, + ) error{Overflow}!void { + const after_range = start + len; + var range = self.slice()[start..after_range]; + + if (range.len == new_items.len) { + @memcpy(range[0..new_items.len], new_items); + } else if (range.len < new_items.len) { + const first = new_items[0..range.len]; + const rest = new_items[range.len..]; + @memcpy(range[0..first.len], first); + try self.insertSlice(after_range, rest); + } else { + @memcpy(range[0..new_items.len], new_items); + const after_subrange = start + new_items.len; + for (self.constSlice()[after_range..], 0..) |item, i| { + self.slice()[after_subrange..][i] = item; + } + self.len -= len - new_items.len; + } + } + + /// Extend the slice by 1 element. + pub fn append(self: *Self, item: T) error{Overflow}!void { + const new_item_ptr = try self.addOne(); + new_item_ptr.* = item; + } + + /// Extend the slice by 1 element, asserting the capacity is already + /// enough to store the new item. + pub fn appendAssumeCapacity(self: *Self, item: T) void { + const new_item_ptr = self.addOneAssumeCapacity(); + new_item_ptr.* = item; + } + + /// Remove the element at index `i`, shift elements after index + /// `i` forward, and return the removed element. + /// Asserts the slice has at least one item. + /// This operation is O(N). + pub fn orderedRemove(self: *Self, i: usize) T { + const newlen = self.len - 1; + if (newlen == i) return self.pop().?; + const old_item = self.get(i); + for (self.slice()[i..newlen], 0..) |*b, j| b.* = self.get(i + 1 + j); + self.set(newlen, undefined); + self.len = newlen; + return old_item; + } + + /// Remove the element at the specified index and return it. + /// The empty slot is filled from the end of the slice. + /// This operation is O(1). + pub fn swapRemove(self: *Self, i: usize) T { + if (self.len - 1 == i) return self.pop().?; + const old_item = self.get(i); + self.set(i, self.pop().?); + return old_item; + } + + /// Append the slice of items to the slice. + pub fn appendSlice(self: *Self, items: []const T) error{Overflow}!void { + try self.ensureUnusedCapacity(items.len); + self.appendSliceAssumeCapacity(items); + } + + /// Append the slice of items to the slice, asserting the capacity is already + /// enough to store the new items. + pub fn appendSliceAssumeCapacity(self: *Self, items: []const T) void { + const old_len = self.len; + self.len += items.len; + @memcpy(self.slice()[old_len..][0..items.len], items); + } + + /// Append a value to the slice `n` times. + /// Allocates more memory as necessary. + pub fn appendNTimes(self: *Self, value: T, n: usize) error{Overflow}!void { + const old_len = self.len; + try self.resize(old_len + n); + @memset(self.slice()[old_len..self.len], value); + } + + /// Append a value to the slice `n` times. + /// Asserts the capacity is enough. + pub fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void { + const old_len = self.len; + self.len += n; + assert(self.len <= buffer_capacity); + @memset(self.slice()[old_len..self.len], value); + } + + pub const Writer = if (T != u8) + @compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++ + "but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)") + else + std.io.GenericWriter(*Self, error{Overflow}, appendWrite); + + /// Initializes a writer which will write into the array. + pub fn writer(self: *Self) Writer { + return .{ .context = self }; + } + + /// Same as `appendSlice` except it returns the number of bytes written, which is always the same + /// as `m.len`. The purpose of this function existing is to match `std.io.GenericWriter` API. + fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize { + try self.appendSlice(m); + return m.len; + } + }; +} + +test BoundedArray { + var a = try BoundedArray(u8, 64).init(32); + + try testing.expectEqual(a.capacity(), 64); + try testing.expectEqual(a.slice().len, 32); + try testing.expectEqual(a.constSlice().len, 32); + + try a.resize(48); + try testing.expectEqual(a.len, 48); + + const x = [_]u8{1} ** 10; + a = try BoundedArray(u8, 64).fromSlice(&x); + try testing.expectEqualSlices(u8, &x, a.constSlice()); + + var a2 = a; + try testing.expectEqualSlices(u8, a.constSlice(), a2.constSlice()); + a2.set(0, 0); + try testing.expect(a.get(0) != a2.get(0)); + + try testing.expectError(error.Overflow, a.resize(100)); + try testing.expectError(error.Overflow, BoundedArray(u8, x.len - 1).fromSlice(&x)); + + try a.resize(0); + try a.ensureUnusedCapacity(a.capacity()); + (try a.addOne()).* = 0; + try a.ensureUnusedCapacity(a.capacity() - 1); + try testing.expectEqual(a.len, 1); + + const uninitialized = try a.addManyAsArray(4); + try testing.expectEqual(uninitialized.len, 4); + try testing.expectEqual(a.len, 5); + + try a.append(0xff); + try testing.expectEqual(a.len, 6); + try testing.expectEqual(a.pop(), 0xff); + + a.appendAssumeCapacity(0xff); + try testing.expectEqual(a.len, 6); + try testing.expectEqual(a.pop(), 0xff); + + try a.resize(1); + try testing.expectEqual(a.pop(), 0); + try testing.expectEqual(a.pop(), null); + var unused = a.unusedCapacitySlice(); + @memset(unused[0..8], 2); + unused[8] = 3; + unused[9] = 4; + try testing.expectEqual(unused.len, a.capacity()); + try a.resize(10); + + try a.insert(5, 0xaa); + try testing.expectEqual(a.len, 11); + try testing.expectEqual(a.get(5), 0xaa); + try testing.expectEqual(a.get(9), 3); + try testing.expectEqual(a.get(10), 4); + + try a.insert(11, 0xbb); + try testing.expectEqual(a.len, 12); + try testing.expectEqual(a.pop(), 0xbb); + + try a.appendSlice(&x); + try testing.expectEqual(a.len, 11 + x.len); + + try a.appendNTimes(0xbb, 5); + try testing.expectEqual(a.len, 11 + x.len + 5); + try testing.expectEqual(a.pop(), 0xbb); + + a.appendNTimesAssumeCapacity(0xcc, 5); + try testing.expectEqual(a.len, 11 + x.len + 5 - 1 + 5); + try testing.expectEqual(a.pop(), 0xcc); + + try testing.expectEqual(a.len, 29); + try a.replaceRange(1, 20, &x); + try testing.expectEqual(a.len, 29 + x.len - 20); + + try a.insertSlice(0, &x); + try testing.expectEqual(a.len, 29 + x.len - 20 + x.len); + + try a.replaceRange(1, 5, &x); + try testing.expectEqual(a.len, 29 + x.len - 20 + x.len + x.len - 5); + + try a.append(10); + try testing.expectEqual(a.pop(), 10); + + try a.append(20); + const removed = a.orderedRemove(5); + try testing.expectEqual(removed, 1); + try testing.expectEqual(a.len, 34); + + a.set(0, 0xdd); + a.set(a.len - 1, 0xee); + const swapped = a.swapRemove(0); + try testing.expectEqual(swapped, 0xdd); + try testing.expectEqual(a.get(0), 0xee); + + const added_slice = try a.addManyAsSlice(3); + try testing.expectEqual(added_slice.len, 3); + try testing.expectEqual(a.len, 36); + + while (a.pop()) |_| {} + const w = a.writer(); + const s = "hello, this is a test string"; + try w.writeAll(s); + try testing.expectEqualStrings(s, a.constSlice()); +} + +test "BoundedArrayAligned" { + var a = try BoundedArrayAligned(u8, .@"16", 4).init(0); + try a.append(0); + try a.append(0); + try a.append(255); + try a.append(255); + + const b = @as(*const [2]u16, @ptrCast(a.constSlice().ptr)); + try testing.expectEqual(@as(u16, 0), b[0]); + try testing.expectEqual(@as(u16, 65535), b[1]); +} diff --git a/stdx/flags.zig b/stdx/flags.zig index 7aab69e..467cea6 100644 --- a/stdx/flags.zig +++ b/stdx/flags.zig @@ -47,8 +47,7 @@ const debug = @import("debug.zig"); /// Format and print an error message to stderr, then exit with an exit code of 1. pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn { - const stderr = std.io.getStdErr().writer(); - stderr.print("error: " ++ fmt_string ++ "\n", args) catch {}; + std.debug.print("error: " ++ fmt_string ++ "\n", args); std.posix.exit(1); } diff --git a/stdx/fmt.zig b/stdx/fmt.zig index 3934de1..259b3cd 100644 --- a/stdx/fmt.zig +++ b/stdx/fmt.zig @@ -43,8 +43,8 @@ pub const IntFmt = struct { }; pub const FloatFmt = enum(u8) { - scientific = @intFromEnum(std.fmt.format_float.Format.scientific), - decimal = @intFromEnum(std.fmt.format_float.Format.decimal), + scientific = @intFromEnum(std.fmt.Number.Mode.scientific), + decimal = @intFromEnum(std.fmt.Number.Mode.decimal), hex, pub fn parseComptime(comptime fmt_: []const u8) FloatFmt { @@ -71,44 +71,34 @@ pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !vo }; } -pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { - const formatFloat = std.fmt.format_float.formatFloat; - var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined; - +pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { const x = switch (@typeInfo(@TypeOf(value))) { .@"struct" => value.toF32(), .float => value, else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))), }; - const s_or_err = switch (full.fmt.float) { - .scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }), - .decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }), - .hex => hex: { - var buf_stream = std.io.fixedBufferStream(&buf); - std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable; - break :hex buf_stream.getWritten(); - }, + try switch (full.fmt.float) { + .scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }), + .decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }), + .hex => writer.printFloatHexOptions(x, .{ .mode = .hex }), }; - - const s = s_or_err catch "(float)"; - return std.fmt.formatBuf(s, full.options, writer); } -pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { +pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { switch (@typeInfo(@TypeOf(value))) { .int => {}, else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))), } - return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer); + return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options); } -pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { +pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void { var buf: [48]u8 = undefined; const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: { buf[45..].* = "...".*; break :blk buf[0..]; }; - return std.fmt.formatBuf(s, full.options, writer); + return try writer.alignBufferOptions(s, full.options); } pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void { diff --git a/stdx/signature.zig b/stdx/signature.zig index 9938fb2..2f042ce 100644 --- a/stdx/signature.zig +++ b/stdx/signature.zig @@ -26,7 +26,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type { var num_buf: [8]u8 = undefined; tuple_fields[i] = .{ .name = blk: { - const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); + const s = std.fmt.printInt(&num_buf, i, 10, .lower, .{}); num_buf[s] = 0; break :blk num_buf[0..s :0]; }, diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 87cc231..661dea6 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,3 +1,5 @@ +pub const BoundedArray = @import("bounded_array.zig").BoundedArray; +pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned; pub const debug = @import("debug.zig"); pub const flags = @import("flags.zig"); pub const fmt = @import("fmt.zig"); diff --git a/stdx/time.zig b/stdx/time.zig index b77a489..48d38fd 100644 --- a/stdx/time.zig +++ b/stdx/time.zig @@ -11,14 +11,11 @@ pub const Duration = struct { return (1 * std.time.ns_per_s) / self.ns; } - pub fn format( - self: Duration, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) @TypeOf(writer).Error!void { - return try std.fmt.fmtDuration(self.ns).format(fmt, options, writer); + pub fn formatDuration(duration: Duration, writer: *std.io.Writer) std.io.Writer.Error!void { + try writer.printDuration(duration.ns, .{}); } + + pub const format = formatDuration; }; pub const Timer = struct { diff --git a/third_party/com_github_hejsil_clap/repo.bzl b/third_party/com_github_hejsil_clap/repo.bzl index f665c74..4359f51 100644 --- a/third_party/com_github_hejsil_clap/repo.bzl +++ b/third_party/com_github_hejsil_clap/repo.bzl @@ -4,6 +4,6 @@ def repo(): new_git_repository( name = "com_github_hejsil_clap", remote = "https://github.com/Hejsil/zig-clap.git", - commit = "068c38f89814079635692c7d0be9f58508c86173", + commit = "5289e0753cd274d65344bef1c114284c633536ea", build_file = "//:third_party/com_github_hejsil_clap/clap.bazel", ) diff --git a/third_party/modules/rules_zig/20250827.0-35b6d57/MODULE.bazel b/third_party/modules/rules_zig/20250827.0-35b6d57/MODULE.bazel new file mode 100644 index 0000000..7620c53 --- /dev/null +++ b/third_party/modules/rules_zig/20250827.0-35b6d57/MODULE.bazel @@ -0,0 +1,75 @@ +module( + name = "rules_zig", + version = "20250827.0-35b6d57", + compatibility_level = 1, +) + +bazel_dep(name = "aspect_bazel_lib", version = "2.8.1") +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") + +zig = use_extension("//zig:extensions.bzl", "zig") +zig.index(file = "//zig/private:versions.json") +use_repo(zig, "zig_toolchains") + +register_toolchains("@rules_zig//zig/target:all") + +register_toolchains("@zig_toolchains//:all") + +zig_dev = use_extension( + "//zig:extensions.bzl", + "zig", + dev_dependency = True, +) +zig_dev.toolchain(zig_version = "0.13.0") +zig_dev.toolchain(zig_version = "0.12.1") +zig_dev.toolchain(zig_version = "0.12.0") +zig_dev.toolchain(zig_version = "0.11.0") + +bazel_dep(name = "rules_cc", version = "0.0.9") + +bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc") +bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle") +bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True) +bazel_dep( + name = "buildifier_prebuilt", + version = "7.3.1", + dev_dependency = True, +) +bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True) +bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True) +bazel_dep( + name = "rules_bazel_integration_test", + version = "0.25.0", + dev_dependency = True, +) + +bazel_binaries = use_extension( + "@rules_bazel_integration_test//:extensions.bzl", + "bazel_binaries", + dev_dependency = True, +) + +# NOTE: Keep in sync with WORKSPACE. +bazel_binaries.download(version_file = "//:.bazelversion") +bazel_binaries.download(version = "7.0.0") +use_repo( + bazel_binaries, + "bazel_binaries", + "bazel_binaries_bazelisk", + "build_bazel_bazel_.bazelversion", + "build_bazel_bazel_7_0_0", +) + +# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test. +# However, if we do not include it explicitly, then the runfiles resolution for +# cgrindel_bazel_starlib/shlib/lib/message.sh fails in +# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked +# through the rules_multirun target //util:update. +bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True) + +# Hack to get around a cc_common.link limitation. +# See https://github.com/bazelbuild/bazel/pull/23838 +cc_common_link = use_repo_rule("//zig:extensions.bzl", "cc_common_link") + +cc_common_link(name = "build_bazel_rules_android") diff --git a/third_party/modules/rules_zig/20250827.0-35b6d57/source.json b/third_party/modules/rules_zig/20250827.0-35b6d57/source.json new file mode 100644 index 0000000..e42d62b --- /dev/null +++ b/third_party/modules/rules_zig/20250827.0-35b6d57/source.json @@ -0,0 +1,5 @@ +{ + "strip_prefix": "rules_zig-35b6d57e94f3eb08e978c1133314606f4f38e216", + "url": "https://github.com/zml/rules_zig/archive/35b6d57e94f3eb08e978c1133314606f4f38e216.tar.gz", + "integrity": "sha256-FDnAqynTD2LB3W/IaBgocmnLz8CA9nyHZYYntC4plUU=" +} diff --git a/third_party/modules/rules_zig/metadata.json b/third_party/modules/rules_zig/metadata.json index 00757b8..e7937f1 100644 --- a/third_party/modules/rules_zig/metadata.json +++ b/third_party/modules/rules_zig/metadata.json @@ -17,7 +17,8 @@ "20250519.0-233b207", "20250613.0-567662a", "20250714.0-b14a4f1", - "20250821.0-be53625" + "20250821.0-be53625", + "20250827.0-35b6d57" ], "yanked_versions": {} } diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index af65de3..5dc3d21 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -1,15 +1,7 @@ load("@rules_python//python:py_library.bzl", "py_library") -load("@rules_python//python/entry_points:py_console_script_binary.bzl", "py_console_script_binary") py_library( name = "zml_utils", srcs = ["zml_utils.py"], visibility = ["//visibility:public"], ) - -py_console_script_binary( - name = "hf", - pkg = "@huggingface_hub//huggingface_hub:pkg", - script = "hf", - visibility = ["//visibility:public"], -) diff --git a/upb/upb.zig b/upb/upb.zig index 27d1502..1a233ec 100644 --- a/upb/upb.zig +++ b/upb/upb.zig @@ -153,6 +153,6 @@ pub const Allocator = struct { @panic("Unsupported case"); } - return @ptrCast(self.allocator.alignedAlloc(u8, @alignOf(*anyopaque), size) catch return null); + return @ptrCast(self.allocator.alignedAlloc(u8, std.mem.Alignment.of(*anyopaque), size) catch return null); } }; diff --git a/zml/aio.zig b/zml/aio.zig index 08f352e..67e1876 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -608,7 +608,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo if (std.mem.startsWith(u8, key, base_key)) { if (matches == 0) log.warn("Similar buffers found:", .{}); if (!shown_keys.contains(key)) { - log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); + log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); shown_keys.put(key, {}) catch continue; matches += 1; } @@ -625,7 +625,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo const key = entry.key_ptr.*; if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) { if (matches == 0) log.warn("Partial matches for '{s}':", .{component}); - log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); + log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); shown_keys.put(key, {}) catch continue; matches += 1; if (matches >= 5) break; @@ -660,8 +660,8 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi return if (buffer_store.get(prefix)) |host_buffer| { // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. var buf_with_metadata = host_buffer; - log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); - stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); + log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape }); + stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {f} and {f} for tensor {s}", .{ obj._shape, host_buffer, prefix }); buf_with_metadata._shape = obj._shape; obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{}); } else { diff --git a/zml/aio/json.zig b/zml/aio/json.zig index 9c02a5d..a42e561 100644 --- a/zml/aio/json.zig +++ b/zml/aio/json.zig @@ -69,7 +69,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: var new_prefix = prefix; if (prefix.items.len > 0) new_prefix.appendAssumeCapacity('.'); - new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); + new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); try parseMetadata(allocator, store, new_prefix, item); } }; diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index 581ab77..08952b8 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -1,12 +1,15 @@ -const asynk = @import("async"); const std = @import("std"); -const zml = @import("../zml.zig"); -const json = @import("json.zig"); -const HostBuffer = zml.HostBuffer; +const Allocator = std.mem.Allocator; + +const asynk = @import("async"); +const stdx = @import("stdx"); + const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; +const zml = @import("../zml.zig"); +const HostBuffer = zml.HostBuffer; +const json = @import("json.zig"); const StringBuilder = std.ArrayListUnmanaged(u8); -const Allocator = std.mem.Allocator; const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { @@ -16,7 +19,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore errdefer res.arena.deinit(); const arena = res.arena.allocator(); - var files = std.ArrayList(MemoryMappedFile).init(arena); + var files = std.array_list.Managed(MemoryMappedFile).init(arena); errdefer files.deinit(); if (std.mem.endsWith(u8, path, ".safetensors.index.json")) { @@ -28,17 +31,19 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore return res; } -fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { +fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { const file = asynk.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; }; errdefer file.close() catch unreachable; - var r = file.reader(); + var buffer: [4096]u8 = undefined; + var r = file.reader(&buffer); - const json_data = try allocator.alloc(u8, (try file.stat()).size); - _ = try r.readAtLeast(json_data, json_data.len); - const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); + // const json_data = try allocator.alloc(u8, (try file.stat()).size); + var json_reader = std.json.Reader.init(allocator, &r.interface); + // _ = try r.readAtLeast(json_data, json_data.len); + const index = try std.json.parseFromTokenSourceLeaky(std.json.Value, allocator, &json_reader, .{ .allocate = .alloc_if_needed }); var loaded_files = std.StringHashMap(void).init(allocator); const weight_map = index.object.get("weight_map").?.object; @@ -62,23 +67,24 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std. } } -fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { +fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { const file = asynk.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; }; errdefer file.close() catch unreachable; - var r = file.reader(); + var buffer: [16 * 1024]u8 = undefined; + var r = file.reader(&buffer); - const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little)); + const json_header_length: usize = @intCast(try r.interface.takeInt(u64, .little)); const json_data = try allocator.alloc(u8, json_header_length); - const n = try r.readAll(json_data); - if (n != json_header_length) { - log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path }); - return error.CorruptedFile; - } + try r.interface.readSliceAll(json_data); + // if (n != json_header_length) { + // log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path }); + // return error.CorruptedFile; + // } - const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{}); + const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{}); var buffer_file = try MemoryMappedFile.init(file); errdefer buffer_file.deinit(); buffer_file.data_offset = 8 + json_header_length; @@ -105,7 +111,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array const start: usize = @intCast(offset_field.array.items[0].integer); const end: usize = @intCast(offset_field.array.items[1].integer); const dtype = try stringToDtype(val.object.get("dtype").?.string); - var dims: std.BoundedArray(i64, zml.Shape.MAX_RANK) = .{}; + var dims: stdx.BoundedArray(i64, zml.Shape.MAX_RANK) = .{}; for (shape_field.items) |d| { dims.appendAssumeCapacity(d.integer); } diff --git a/zml/buffer.zig b/zml/buffer.zig index 36f6125..3553bb9 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -45,7 +45,7 @@ pub const Buffer = struct { _shards: Shards, pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES; - pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); + pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); pub const FromOptions = struct { wait: bool = true, @@ -67,7 +67,7 @@ pub const Buffer = struct { const n_partitions = platform.sharding().num_partitions; const chunk_size = if (sharding_ax) |ax| cs: { // This kind of sharding error should be detected earlier on. - stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); + stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({f}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); break :cs @divExact(host_buffer.dim(ax), n_partitions); } else 0; @@ -201,7 +201,7 @@ pub const Buffer = struct { const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms); if (duration_ms > 100) { const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024); - log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 }); + log.debug("Wrote constant({f}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 }); } } @@ -301,7 +301,7 @@ pub const Buffer = struct { /// Fetches the content of the given buffer into a stack variable of the given type. pub fn getValue(self: Buffer, T: type) !T { - stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); + stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); var res: T = undefined; stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); @@ -375,13 +375,9 @@ pub const Buffer = struct { pub fn format( self: Buffer, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = fmt; - _ = options; - try writer.print("Buffer({_})", .{self._shape}); + try writer.print("Buffer({f})", .{self._shape}); } pub fn getMemory(self: Buffer) *const pjrt.Memory { diff --git a/zml/context.zig b/zml/context.zig index bde867f..15acb93 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -224,7 +224,7 @@ const CustomCall = struct { try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{}); } - fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error { + fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error { if (call_frame.registeringHook()) return null; const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable; @@ -275,3 +275,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { buffer_desc.data[0..buffer_shape.byteSize()], ); } + +pub const cuda = struct { + pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00); + pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00); + var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00); + var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00); + + pub const MemcpyKind = enum(c_int) { + host_to_host = 0, + host_to_device = 1, + device_to_host = 2, + device_to_device = 3, + inferred = 4, + }; + + const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int; + const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int; + const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int; + const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int; + + pub fn init() void { + var cudart = std.DynLib.open("libcudart.so.12") catch { + log.err("cudart not found, callback will segfault", .{}); + return; + }; + defer cudart.close(); + + _memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse { + @panic("cudaMemcpyAsync not found"); + }; + _memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse { + @panic("cudaMemcpy not found"); + }; + streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse { + @panic("cudaStreamSynchronize not found"); + }; + cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse { + @panic("cudaLaunchHostFunc not found"); + }; + } + + pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void { + const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host); + check(err); + } + + pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void { + const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device); + check(err); + } + + pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void { + const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream); + check(err); + } + + pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void { + const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream); + check(err); + } + + pub fn check(err: c_int) void { + if (err == 0) return; + stdx.debug.panic("CUDA error: {d}", .{err}); + } +}; diff --git a/zml/exe.zig b/zml/exe.zig index 0923f17..625fecf 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -394,8 +394,8 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff fn cb(ctx: *LocalContext, buffer: *const Buffer) void { // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); const model_sharding = ctx.buffers.len; - stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); - stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() }); + stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); + stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() }); for (buffer._shards.constSlice(), 0..) |shard, d| { ctx.buffers[d][ctx.index] = shard; } diff --git a/zml/floats.zig b/zml/floats.zig index 74a8b99..57bf515 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -87,17 +87,10 @@ fn FloatHelpers(Float: type) type { return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1)); } - pub fn format( - float: Float, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = options; - if (fmt.len == 1 and fmt[0] == '_') { - try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ float.sign, float.exponent, float.mantissa }); - } else { - try writer.print("{" ++ fmt ++ "}", .{float.toF32()}); + pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { + switch (n.mode) { + .binary, .octal, .hex => try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ x.sign, x.exponent, x.mantissa }), + else => try writer.printFloat(x.toF32(), n), } } }; @@ -113,7 +106,7 @@ pub const Float32 = packed struct(u32) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; const f32_exp_bias = FloatHelpers(Float32).expBias(); @@ -128,7 +121,7 @@ pub const Float64 = packed struct(u64) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; pub const Float8E4M3B11FNUZ = packed struct(u8) { @@ -151,7 +144,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; pub const Float8E4M3FN = packed struct(u8) { @@ -169,7 +162,7 @@ pub const Float8E4M3FN = packed struct(u8) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; pub const Float8E4M3FNUZ = packed struct(u8) { @@ -192,7 +185,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; test "Float8E4" { @@ -247,7 +240,7 @@ pub const Float8E5M2 = packed struct(u8) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; pub const Float8E5M2FNUZ = packed struct(u8) { @@ -266,7 +259,7 @@ pub const Float8E5M2FNUZ = packed struct(u8) { pub const neg = Helpers.neg; pub const fromF32 = Helpers.fromF32; pub const toF32 = Helpers.toF32; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; test "Float8E5" { @@ -322,7 +315,7 @@ pub const BFloat16 = packed struct(u16) { const Helpers = FloatHelpers(@This()); pub const zero = Helpers.zero; pub const neg = Helpers.neg; - pub const format = Helpers.format; + pub const formatNumber = Helpers.formatNumber; }; test BFloat16 { diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 8f6fcb1..117cb1e 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -31,7 +31,7 @@ pub const HostBuffer = struct { return .{ ._shape = sh, ._strides = sh.computeStrides().buffer, - ._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr, + ._data = (try allocator.alignedAlloc(u8, .@"64", sh.byteSize())).ptr, ._memory = .{ .managed = .@"64" }, }; } @@ -170,8 +170,8 @@ pub const HostBuffer = struct { /// Strided buffers can't use this method. pub fn items(self: HostBuffer, comptime T: type) []const T { // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32. - stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) }); - stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); + stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) }); + stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self}); const ptr: [*]const T = @alignCast(@ptrCast(self._data)); return ptr[0..self._shape.count()]; } @@ -181,7 +181,7 @@ pub const HostBuffer = struct { } pub fn bytes(self: HostBuffer) []const u8 { - stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); + stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self}); return self._data[0..self._shape.byteSize()]; } @@ -233,7 +233,7 @@ pub const HostBuffer = struct { } pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { - stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); + stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {f}", .{self}); var res = self; res._shape = self._shape.reshape(shape_); res._strides = res._shape.computeStrides().buffer; @@ -252,9 +252,9 @@ pub const HostBuffer = struct { const start: i64 = if (s.start < 0) s.start + d else s.start; var end = s.end orelse d; if (end < 0) end += d; - stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s }); - stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); - stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); + stdx.debug.assert(start >= 0 and start < d, "slice1d({f}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s }); + stdx.debug.assert(end >= 1 and end <= d, "slice1d({f}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); + stdx.debug.assert(start < end, "slice1d({f}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); const offset: usize = @intCast(start * self._strides[ax]); const new_shape = self.shape().set(ax, end - start); @@ -308,9 +308,9 @@ pub const HostBuffer = struct { pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { const ax = self._shape.axis(axis_); - stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); + stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {f}", .{ ax, self }); - var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() }; + var strd: stdx.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() }; _ = strd.orderedRemove(ax); return .{ @@ -323,16 +323,11 @@ pub const HostBuffer = struct { pub fn format( self: HostBuffer, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = options; - if (std.mem.eql(u8, fmt, "v")) { - try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) }); - } else { - try writer.print("HostBuffer(.{_})", .{self._shape}); - } + // TODO debug option + // try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) }); + try writer.print("HostBuffer(.{f})", .{self._shape}); } /// Formatter for a HostBuffer that also print the values not just the shape. @@ -344,21 +339,23 @@ pub const HostBuffer = struct { pub const PrettyPrinter = struct { x: HostBuffer, - pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + // TODO(0.15.0) revisit pretty printer + pub fn format(self: PrettyPrinter, writer: anytype) !void { const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) { - .integer => .parse(i32, fmt), - .float => .parse(f32, fmt), - else => .parse(void, fmt), + .integer => .parse(i32, "d"), + .float => .parse(f32, "d"), + else => .parse(void, ""), }; + const options: std.fmt.FormatOptions = .{}; try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options }); } }; - pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void { + pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void { return self.prettyPrintIndented(writer, 4, 0, options); } - fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void { + fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void { if (self.rank() == 0) { // Special case input tensor is a scalar return switch (self.dtype()) { @@ -376,7 +373,7 @@ pub const HostBuffer = struct { if (self.rank() == 1) { // Print a contiguous slice of items from the buffer in one line. // The number of items printed is controlled by the user through format syntax. - try writer.writeByteNTimes(' ', indent_level); + try writer.splatByteAll(' ', indent_level); switch (self.dtype()) { inline else => |dt| { const values = self.items(dt.toZigType()); @@ -391,10 +388,10 @@ pub const HostBuffer = struct { return; } // TODO: consider removing the \n if dim is 1 for this axis. - try writer.writeByteNTimes(' ', indent_level); + try writer.splatByteAll(' ', indent_level); _ = try writer.write("{\n"); defer { - writer.writeByteNTimes(' ', indent_level) catch {}; + writer.splatByteAll(' ', indent_level) catch {}; _ = writer.write("},\n") catch {}; } @@ -409,7 +406,7 @@ pub const HostBuffer = struct { if (n < num_rows) return; // Skip middle rows if (n > 2 * num_rows) { - try writer.writeByteNTimes(' ', indent_level + 2); + try writer.splatByteAll(' ', indent_level + 2); _ = try writer.write("...\n"); } // Write last rows diff --git a/zml/meta.zig b/zml/meta.zig index ad50fe0..ee2fbc6 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -358,10 +358,11 @@ pub fn MapRestrict(From: type, To: type) type { const fields = union_info.fields; var union_fields: [fields.len]std.builtin.Type.UnionField = undefined; for (0.., fields) |i, field| { + const FT = map(field.type); union_fields[i] = .{ .name = field.name, - .type = map(field.type), - .alignment = 0, + .type = FT, + .alignment = @alignOf(FT), }; } return @Type(.{ .@"union" = .{ @@ -453,7 +454,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { else => {}, } - // Handle std.BoundedArray that contains uninitalized data. + // Handle stdx.BoundedArray that contains uninitalized data. if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) { return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice()); } @@ -511,7 +512,7 @@ test visit { const NestedAttrOptional = struct { nested: ?Attr }; const SimpleStruct = struct { prop: Attr }; const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr }; - const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) }; + const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: stdx.BoundedArray(Attr, 8) }; const LocalContext = struct { result: usize }; @@ -565,7 +566,7 @@ test visit { } { var context: LocalContext = .{ .result = 0 }; - const prop5: std.BoundedArray(Attr, 8) = .{ + const prop5: stdx.BoundedArray(Attr, 8) = .{ .buffer = @splat(.{ .data = 4 }), .len = 2, }; @@ -677,11 +678,11 @@ test zip { /// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// finds all X in the given object, and write the result of func(X) into an arraylist. -pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { +pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); const LocalContext = struct { func_ctx: _CollectCtx(func), - out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), + out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), oom: bool = false, }; var context = LocalContext{ .func_ctx = func_ctx, .out = out }; diff --git a/zml/module.zig b/zml/module.zig index 1f9beaf..4e40882 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -51,7 +51,7 @@ pub const CompilationContext = struct { _module: mlir.Module, - _blocks: std.BoundedArray(TaggedBlock, 64) = .{}, + _blocks: stdx.BoundedArray(TaggedBlock, 64) = .{}, _fn_cache: FnCache = .{}, _block_args: TensorToBlockArg = .{}, @@ -63,7 +63,7 @@ pub const CompilationContext = struct { const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts }; const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); - const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); + const AttributeList = stdx.BoundedArray(mlir.NamedAttribute, 3); pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext { const mlir_registry = mlir.Registry.init() catch unreachable; @@ -185,7 +185,9 @@ pub const CompilationContext = struct { // Write the mlir to a file. All errors are discarded, since this is for debugging only. const mlir_name = "module.mlir"; if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| { - module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); + var write_buf: [4096]u8 = undefined; + var writer = file.writer(&write_buf); + module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false }); log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name }); } else |_| { log.warn("Failed to open {s}", .{mlir_name}); @@ -219,7 +221,7 @@ pub const CompilationContext = struct { }; log.debug("******** ZML generated MLIR ********", .{}); - log.debug("{}", .{module.op().mlirFormatter(.{})}); + log.debug("{f}", .{module.op().mlirFormatter(.{})}); if (timer) |*t| { const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); @@ -339,7 +341,7 @@ pub const CompilationContext = struct { const locations = try arena.alloc(mlir.Location, tensor_count); @memset(locations, mlir.Location.unknown(mlir_ctx)); - var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count); + var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count); meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); @@ -416,7 +418,7 @@ pub const CompilationContext = struct { defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize"); self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) { error.InvalidMlir => { - log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})}); + log.err("Failed to canonicalize invalid mlir: {f}", .{mlir_fn.mlirFormatter(.{})}); // user errors should have triggered a panic before we reach this. @panic("ZML generated invalid mlir. Please open a bug report"); }, @@ -464,7 +466,7 @@ pub const CompilationContext = struct { // This will break the day we writer another attribute before donation. // When the time come, do a more fancy lookup here to check if an argument // is donated twice. - stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); + stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {d} has been donated twice ! To {d} and to {any}", .{ a, index, attributes[a].buffer[0] }); attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index)))); // log.debug("attribute: {}", .{attributes[a].constSlice()}); }, @@ -504,9 +506,9 @@ pub const CompilationContext = struct { var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); - var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator); + var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator); defer mlir_bytecode.deinit(); - try mlir_bytecode.writer().print("{}", .{f.mlir_fn.mlirFormatter(.{})}); + try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})}); // Check that the `x` input argument gives its buffer to the result tensor. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. @@ -545,7 +547,7 @@ pub const CompilationContext = struct { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { const ctx = self.mlirCtx(); const num_partitions = self.numPartitions(); - var sharding_str: std.BoundedArray(u8, 128) = .{}; + var sharding_str: stdx.BoundedArray(u8, 128) = .{}; writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; return mlir.Attribute.string(ctx, sharding_str.constSlice()); } @@ -622,7 +624,7 @@ pub const CompilationContext = struct { const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) try self.allocator().dupeZ(u8, func_name) else - try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }); + try std.fmt.allocPrintSentinel(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }, 0); var arg_id: u16 = 0; var tensor_args: @TypeOf(args) = args; @@ -702,7 +704,7 @@ pub const CompilationContext = struct { const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id); if (res.found_existing) { - stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id }); + stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {f} and {f} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id }); } else { res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; } @@ -777,7 +779,7 @@ pub const CompilationContext = struct { .buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res| .{ res[0], res[1] } else { - log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id }); + log.err("Found unknown tensor id {f}({})", .{ tensor, tensor._id }); @panic("Found unknown tensor id"); }, .mlir => |v| .{ v, tensor._donation }, diff --git a/zml/nn.zig b/zml/nn.zig index ffdb7a9..631bce1 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -40,8 +40,8 @@ pub const TokenEmbedding = struct { weight: Tensor, pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { - stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); - stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); + stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {f}", .{idx}); + stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f}", .{self.weight}); return self.weight.gatherValues(0, idx, .{}); } }; @@ -204,13 +204,13 @@ pub const RopeOpts = struct { /// - pos_idx: optional tensor which indicates which positions are needed. /// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len. pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { - stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {}", .{x}); + stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {f}", .{x}); const idx = if (pos_idx) |idx| blk: { - stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={}, idx={})", .{ x, idx }); + stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={f}, idx={f})", .{ x, idx }); break :blk idx; } else blk: { - stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x}); + stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={f})", .{x}); break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s}); }; const x_real, const x_imag = splitRealImg(x, opts.layout); @@ -273,7 +273,7 @@ fn _invFreq(opts: RopeOpts, inv_freq: []f32) void { switch (opts.scaling) { .default => {}, .custom => { - stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len }); + stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {d}, got {d}", .{ N, opts.scaling.custom.len }); @memcpy(inv_freq, opts.scaling.custom); }, .llama3 => |s| { @@ -318,7 +318,7 @@ test invFreq { var inv_freq: @TypeOf(llama_freq) = undefined; _invFreq(llama_conf, &inv_freq); for (llama_freq, inv_freq, 0..) |expected, actual, i| { - errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq }); + errdefer log.err("Mismatch at position {d}.\nExpected: {any}\nActual: {any}", .{ i, llama_freq, inv_freq }); try std.testing.expectApproxEqRel(expected, actual, 1e-5); } } @@ -462,7 +462,7 @@ pub fn upsample( ) Tensor { // TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here. // resize* have API which are more explicit, this assume you want to scale the N-2 last axes. - stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input}); + stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {f}", .{input}); stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{}); return switch (opts.mode) { .nearest => { @@ -791,7 +791,7 @@ pub fn causalAttnMask( attn_window_len: ?u32, ) Tensor { const attn_shape = Shape.init(attn_shape_, dtype); - stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape}); + stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({f}) shape need to be exactly 2 axes", .{attn_shape}); const qlen = attn_shape.dim(-2); const q_idx = Tensor.iota(attn_shape, -2); const klen = attn_shape.dim(-1); @@ -843,7 +843,7 @@ pub const SdpaOpts = struct { pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { var q, var k, var v = .{ q_, k_, v_ }; - const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; + const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! "; const err_args = .{ q, k, v, opts.attn_mask }; stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); @@ -909,8 +909,8 @@ const SdpaMemEfficient = struct { chunking: SdpaChunks, fn forward(self: SdpaMemEfficient) Tensor { - stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.q, self.chunking }); - stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.k, self.chunking }); + stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.q, self.chunking }); + stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.k, self.chunking }); const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size)); const ctx = zml.module.CompilationContext.current(); @@ -1054,7 +1054,7 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft // Consider implementing sdpa from sdpaChunk. var q, var k, var v = .{ q_, k_, v_ }; - const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; + const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! "; const err_args = .{ q, k, v, opts.attn_mask }; stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index d0de654..4a682f3 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -51,7 +51,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { var fba = std.heap.FixedBufferAllocator.init(&buffer); const allocator = fba.allocator(); - const backend_config = std.fmt.allocPrintZ( + const backend_config = std.fmt.allocPrintSentinel( allocator, \\{{ \\ "operation_queue_id":"0", @@ -110,6 +110,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { q.dim(.q), k.dim(.k), }, + 0, ) catch unreachable; var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0)); diff --git a/zml/ops.zig b/zml/ops.zig index e0af64e..c26eac3 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -155,7 +155,7 @@ pub fn reduce( // To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit, // to find the correct mlir.Value, but we first broadcast before creating the final // Tensor struct. - var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; + var broadcasting_axes: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{}; for (0..Tensor.MAX_RANK) |i| { if (std.mem.indexOfScalar(i64, axes, @intCast(i)) == null) { broadcasting_axes.append(@intCast(i)) catch unreachable; @@ -437,15 +437,15 @@ pub fn if_( @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); } - stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred}); + stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {f}", .{pred}); const ctx = CompilationContext.current(); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); - var true_shapes = std.ArrayList(Shape).init(ctx.allocator()); + var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); defer true_shapes.deinit(); - var false_shapes = std.ArrayList(Shape).init(ctx.allocator()); + var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); defer false_shapes.deinit(); var failed_to_collect = false; @@ -456,9 +456,9 @@ pub fn if_( failed_to_collect = true; }; if (!failed_to_collect) { - stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); + stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items }); for (true_shapes.items, false_shapes.items) |true_shape, false_shape| { - stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); + stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items }); } } @@ -751,7 +751,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base meta.visit((struct { fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); - stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new }); + stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {f}", .{ tensor.rank(), new }); // copy tags and sharding info over // some ops can change dims eg reduceWindow, so we trust mlir here. new._shape._tags = tensor._shape._tags; @@ -932,7 +932,7 @@ pub fn scatter( const n_inputs = meta.count(Tensor, &inputs); const n_updates = meta.count(Tensor, &updates); - stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {} and {}", .{ n_inputs, n_updates }); + stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {d} and {d}", .{ n_inputs, n_updates }); // Note: I was a bit lazy here, and I only look at tags on the first tensor. // we probably should check all of them. @@ -944,7 +944,7 @@ pub fn scatter( // validate coord axes: all coord_axes should exist inside self for (indices_axes.constSlice()) |t| { - stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() }); + stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={f} and indices={any}", .{ self, indices_axes.constSlice() }); } // Handle scalar indices by broadcasting them to the indices with the highest rank. @@ -958,8 +958,8 @@ pub fn scatter( break :blk higher_rank; }; for (indices_per_axis.slice()) |*idx| { - stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {_}", .{indices_per_axis.slice()}); - stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {_}", .{indices_per_axis.slice()}); + stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {any}", .{indices_per_axis.slice()}); + stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {any}", .{indices_per_axis.slice()}); idx.* = idx.broad(indices_shape); } @@ -972,7 +972,7 @@ pub fn scatter( var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes); const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes); // const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; - // stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({_}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={_}", .{ self, index_tensors, indices_axes.constSlice(), update }); + // stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({f}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={f}", .{ self, index_tensors, indices_axes.constSlice(), update }); const mlir_ctx = ctx.mlirCtx(); var _scalar: T = inputs; @@ -985,10 +985,10 @@ pub fn scatter( const UpdateS = BlockSign(update_fn); const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar }); - var input_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); + var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); defer input_values.deinit(); meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable; - var updates_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); + var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); defer updates_values.deinit(); meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable; @@ -1029,8 +1029,8 @@ pub fn scatter( } const ScatterConfig = struct { - op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, - up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, + op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, + up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, indices_batch_axes: Shape.DimsArray = .{}, scatter_to_operand_axes: Shape.DimsArray = .{}, updates_transpose: Shape.AxesArray = .{}, @@ -1041,11 +1041,11 @@ const AxisKind = enum { batching, update_window, inserted_window, window_id }; fn scatterConfig( op: Shape, update: Shape, - indices_per_axis: std.BoundedArray(Tensor, Tensor.MAX_RANK), + indices_per_axis: stdx.BoundedArray(Tensor, Tensor.MAX_RANK), indices_axes: Shape.TagsArray, ) ScatterConfig { - var op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; - var up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; + var op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; + var up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; var indices_batch_axes: Shape.DimsArray = .{}; var scatter_to_operand_axes: Shape.DimsArray = .{}; var updates_transpose: Shape.AxesArray = .{}; @@ -1058,7 +1058,7 @@ fn scatterConfig( scatter_to_operand_axes.appendAssumeCapacity(op.axis(t)); } for (indices.tags()) |t| { - stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices }); + stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={f}, updates={f} and indices={f}", .{ op, update, indices }); updates_transpose.appendAssumeCapacity(update.axis(t)); } @@ -1094,11 +1094,11 @@ fn scatterConfig( if (indices.hasTag(t) != null) { up_kind.appendAssumeCapacity(.window_id); } else if (op.hasTag(t)) |self_ax| { - stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={_}, updates={_}.", .{ t, op, update }); + stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={f}, updates={f}.", .{ t, op, update }); up_kind.appendAssumeCapacity(.update_window); } else { // TODO: consider accepting untagged update here. - std.debug.panic("scatter expects 'updates' to be made of axes from op={_} and from indices={s}, got unknown tag {s} in {_}", .{ op, indices_axes.constSlice(), t, update }); + std.debug.panic("scatter expects 'updates' to be made of axes from op={f} and from indices={any}, got unknown tag {s} in {f}", .{ op, indices_axes.constSlice(), std.mem.sliceTo(t, 0), update }); } } } else { @@ -1174,7 +1174,7 @@ fn scatterPrepareIndices( cfg: *ScatterConfig, op: Shape, update: Shape, - indices_per_axis: *std.BoundedArray(Tensor, Tensor.MAX_RANK), + indices_per_axis: *stdx.BoundedArray(Tensor, Tensor.MAX_RANK), indices_axes: *Shape.TagsArray, ) Tensor { var old_scatter_to_op_axes = cfg.scatter_to_operand_axes; @@ -1194,7 +1194,7 @@ fn scatterPrepareIndices( // Reorder the axes so that in indices_per_axis is ordered like in op if possible. // TODO: transpose updates if needed - var indices: std.BoundedArray(Tensor, Tensor.MAX_RANK) = .{}; + var indices: stdx.BoundedArray(Tensor, Tensor.MAX_RANK) = .{}; var scatter_to_op_axes: Shape.DimsArray = .{}; while (old_scatter_to_op_axes.len > 0) { @@ -1209,7 +1209,7 @@ fn scatterPrepareIndices( for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| { if (i != sc_ax) { - log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({_}, {any}, {_}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update }); + log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({f}, {any}, {f}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update }); break; } } diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index b12bdb2..84343a6 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -75,14 +75,14 @@ pub const Client = opaque { } fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - var bytecode = std.ArrayList(u8).init(allocator); + var bytecode: std.array_list.Managed(u8) = .init(allocator); defer bytecode.deinit(); module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| { log.err("failed to write module bytecode: {}", .{err}); return err; }; - var serialized_buffer = std.ArrayList(u8).init(allocator); + var serialized_buffer: std.array_list.Managed(u8) = .init(allocator); defer serialized_buffer.deinit(); const stablehlo_version = blk: { @@ -220,7 +220,7 @@ pub const Event = opaque { }{}; try self.inner().onReady(api, &(struct { - fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { + fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.c) void { const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); ctx_.err = err; ctx_.event.set(); diff --git a/zml/platform.zig b/zml/platform.zig index 6f48137..0bd8a2a 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -15,7 +15,7 @@ pub const CompilationOptions = struct { xla_dump_fusion_visualization: bool = false, xla_dump_hlo_pass_re: ?[]const u8 = null, sharding_enabled: bool = false, - sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, + sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{}, }; pub const Platform = struct { diff --git a/zml/shape.zig b/zml/shape.zig index ce16ec9..bd5f2a2 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -22,9 +22,9 @@ pub const Shape = struct { pub const TagUnknown = "_".ptr; const TagLast = "last".ptr; - pub const DimsArray = std.BoundedArray(i64, MAX_RANK); - pub const TagsArray = std.BoundedArray(Tag, MAX_RANK); - pub const AxesArray = std.BoundedArray(u3, MAX_RANK); + pub const DimsArray = stdx.BoundedArray(i64, MAX_RANK); + pub const TagsArray = stdx.BoundedArray(Tag, MAX_RANK); + pub const AxesArray = stdx.BoundedArray(u3, MAX_RANK); pub const ShardingInfo = @Vector(MAX_RANK, bool); const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK }; @@ -300,7 +300,7 @@ pub const Shape = struct { fn axisFromInt(self: Shape, a: isize) u3 { const rk: i8 = self.rank(); if (a < -rk or a > rk) { - stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a }); + stdx.debug.panic("Tensor {f} doesn't have dimension: {d}", .{ self, a }); } return if (a < 0) @intCast(a + rk) @@ -341,9 +341,9 @@ pub const Shape = struct { } fn axisFromTag(self: Shape, d: Tag) u3 { - stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self }); + stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {f}", .{ d, self }); return self.axisFromTagMaybe(d) orelse { - stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d }); + stdx.debug.panic("Tensor {f} doesn't have dimension with tag: {s}", .{ self, d }); }; } @@ -357,7 +357,7 @@ pub const Shape = struct { pub fn count(self: Shape) usize { var res: i64 = 1; for (self.dims()) |d| { - stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self}); + stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {f}", .{self}); res *= d; } return @intCast(res); @@ -388,12 +388,11 @@ pub const Shape = struct { /// Bare format {_}: "{.a=10, .b=20}, dtype=.f32" pub fn format( self: Shape, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = options; - const bare_fmt = fmt.len == 1 and fmt[0] == '_'; + // TODO: impl alternative format + // const bare_fmt = fmt.len == 1 and fmt[0] == '_'; + const bare_fmt = true; _ = try writer.write(if (bare_fmt) "{" else "Shape({"); var need_comma = false; @@ -441,12 +440,12 @@ pub const Shape = struct { var new_shape: Shape = .{ ._dtype = self.dtype() }; new_shape._dims, new_shape._tags = parseDimensions(new_shape_); new_shape.inferMissingAxis(self.count()); - stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() }); + stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {any} to {any}", .{ self.dims(), new_shape.dims() }); return new_shape; } fn inferMissingAxis(self: *Shape, n_: usize) void { - stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*}); + stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {f}", .{self.*}); const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return; // We can't use `self.count()` yet cause we have negative dims. @@ -524,7 +523,7 @@ pub const Shape = struct { } pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape { - stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self}); + stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {f}, it's already at max rank.", .{self}); const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last) self.rank() @@ -652,7 +651,7 @@ pub const Shape = struct { var res = self; if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { - stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz }); + stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {any}", .{ self, tagz }); for (tagz, 0..) |tag_, i| { res._tags.set(i, toTag(tag_)); } @@ -660,7 +659,7 @@ pub const Shape = struct { } if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { - stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz }); + stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {}", .{ self, tagz }); inline for (tagz, 0..) |tag_, i| { res._tags.set(i, toTag(tag_)); } @@ -699,7 +698,7 @@ pub const Shape = struct { var res = self; if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { - stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz }); + stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {any}", .{ self, tagz }); for (tagz, self.rank() - tagz.len..) |tag_, i| { res._tags.set(i, toTag(tag_)); } @@ -707,7 +706,7 @@ pub const Shape = struct { } if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { - stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz }); + stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {}", .{ self, tagz }); inline for (tagz, self.rank() - tagz.len..) |tag_, i| { res._tags.set(i, toTag(tag_)); } @@ -765,7 +764,7 @@ pub const Shape = struct { var res = self; inline for (std.meta.fields(T)) |field| { const new_field = @field(renames, field.name); - stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field }); + stdx.debug.assert(self.hasTag(new_field) == null, "{f}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field }); res._tags.set(self.axis(field), toTag(new_field)); } return res; @@ -785,9 +784,9 @@ pub const Shape = struct { } } - pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) { + pub fn computeStrides(self: Shape) stdx.BoundedArray(i64, MAX_RANK) { const rk = self.rank(); - var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk }; + var strides: stdx.BoundedArray(i64, MAX_RANK) = .{ .len = rk }; if (rk == 0) return strides; const V = @Vector(MAX_RANK, i64); @@ -907,7 +906,7 @@ pub const Shape = struct { var new_dim: i64 = 1; for (axes__.constSlice(), first_axis..) |ax, counter| { new_dim *= self.dim(ax); - stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ }); + stdx.debug.assert(ax == counter, "Can't merge shape {f} along non-contiguous axes {any}", .{ self, axes_ }); } var new_shape = self; @@ -991,10 +990,10 @@ pub const Shape = struct { return res; } - pub fn parseStruct(T: type, v: anytype) struct { std.BoundedArray(T, MAX_RANK), TagsArray } { + pub fn parseStruct(T: type, v: anytype) struct { stdx.BoundedArray(T, MAX_RANK), TagsArray } { const V = @TypeOf(v); - var vals_: std.BoundedArray(T, MAX_RANK) = .{}; + var vals_: stdx.BoundedArray(T, MAX_RANK) = .{}; var tags_: TagsArray = .{}; if (comptime stdx.meta.isSliceOf(V, T)) { @@ -1029,10 +1028,10 @@ pub const Shape = struct { } /// Parses a struct literal into a list of options for each axes. - pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) { + pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) stdx.BoundedArray(T, MAX_RANK) { const V = @TypeOf(options); - var res: std.BoundedArray(T, MAX_RANK) = .{}; + var res: stdx.BoundedArray(T, MAX_RANK) = .{}; if (comptime stdx.meta.isSliceOf(V, T)) { stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); for (options) |d| { @@ -1084,7 +1083,7 @@ pub const Shape = struct { for (0..other.rank()) |ax| { if (other.tag(ax) != Shape.TagUnknown) { if (self.hasTag(other.tag(ax))) |batching_ax| { - stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other }); + stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({f}, {f})", .{ self, other }); batching_axes += 1; } } diff --git a/zml/tensor.zig b/zml/tensor.zig index 74bf177..26f670d 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -53,13 +53,12 @@ pub const Tensor = struct { pub fn format( self: Tensor, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = options; - const bare_fmt = fmt.len == 1 and fmt[0] == '_'; - try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape}); + // TODO(0.15.0) handle format + // const bare_fmt = fmt.len == 1 and fmt[0] == '_'; + const bare_fmt = false; + try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape}); } /// Returns the shape of a Tensor. @@ -99,7 +98,7 @@ pub const Tensor = struct { if (builtin.mode == .Debug) { // Check that the MLIR value actually have the same shape. const other = fromMlirValue(val); - stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape }); + stdx.debug.internalAssert(sh.eql(other._shape), "Created a {f} from Mlir value but expected {f}", .{ other._shape, res._shape }); } return res; @@ -145,7 +144,7 @@ pub const Tensor = struct { /// Returns the indices of each of the given axes. /// /// 'axis_' can be an integer or a tag. - pub fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) { + pub fn axes(self: Tensor, axes_: anytype) stdx.BoundedArray(u3, Tensor.MAX_RANK) { return self._shape.axes(axes_); } @@ -260,7 +259,7 @@ pub const Tensor = struct { /// For `reuseBuffer` to be effective, it needs to propagate all the way through the output. pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor { // Note: check donation docs, this may be too permissive. - stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin }); + stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {f} and {f}", .{ self, origin }); // TODO: should we store all donations inside the context ? var res = self; @@ -410,7 +409,7 @@ pub const Tensor = struct { stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); - const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts }); + const loc = self.getContext().location(@src(), "triangularSolve({f}, {})", .{ self, opts }); const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts); return _result(self._shape, op.result(0)); } @@ -435,7 +434,7 @@ pub const Tensor = struct { /// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors. pub fn complex(re: Tensor, im: Tensor) Tensor { - stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape }); + stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {f} and {f}", .{ re._shape, im._shape }); stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()}); const loc = re.getContext().mlirCtx().location(@src()); @@ -523,7 +522,7 @@ pub const Tensor = struct { }, }; - const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts }); + const loc = self.getContext().location(@src(), "fft({f},{})", .{ self, opts }); const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts); return _result(sh, op.result(0)); } @@ -551,7 +550,7 @@ pub const Tensor = struct { /// but it is not guaranteed to be deterministic between implementations. pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } { const ctx = CompilationContext.current(); - const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh}); + const loc = ctx.location(@src(), "rand.bitGen({f})", .{sh}); const op = dialect.stablehlo.rng_bit_generator( ctx.mlirCtx(), self.algorithm, @@ -589,7 +588,7 @@ pub const Tensor = struct { 16 => .u16, 32 => .u32, 64 => .u64, - else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}), + else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {f}", .{shape_}), }; const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype)); @@ -674,7 +673,7 @@ pub const Tensor = struct { stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); const ctx = CompilationContext.current(); - const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev }); + const loc = ctx.location(@src(), "rand.normal({f}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev }); const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean)); const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev)); const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims())); @@ -1046,7 +1045,7 @@ pub const Tensor = struct { if (to == self.dtype()) { return self; } - const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); + const loc = self.getContext().location(@src(), "convert({f},to={s})", .{ self, @tagName(to) }); const mlir_ctx = self.getContext().mlirCtx(); const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to)); @@ -1160,7 +1159,7 @@ pub const Tensor = struct { ) Tensor { stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); - const Axes = std.BoundedArray(i64, MAX_RANK); + const Axes = stdx.BoundedArray(i64, MAX_RANK); var res_shape: Shape = .{ ._dtype = lhs.dtype() }; // Validate batching axes @@ -1168,7 +1167,7 @@ pub const Tensor = struct { var rhs_batching_axes: Axes = .{}; for (batching_axes) |b_axes| { const l, const r = b_axes; - stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); + stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs }); var t = lhs._shape.tag(l); if (t == Shape.TagUnknown) t = rhs._shape.tag(r); res_shape = res_shape.appendDim(lhs._shape.dim(l), t); @@ -1181,7 +1180,7 @@ pub const Tensor = struct { var rhs_contracting_axes: Axes = .{}; for (contracting_axes) |c_axes| { const l, const r = c_axes; - stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); + stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs }); lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l)); rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r)); } @@ -1209,7 +1208,7 @@ pub const Tensor = struct { } const mlir_ctx = lhs.getContext().mlirCtx(); - const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes }); + const loc = lhs.getContext().location(@src(), "dot({f},{f},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes }); const op = dialect.stablehlo.dot_general( mlir_ctx, lhs.value(), @@ -1406,7 +1405,7 @@ pub const Tensor = struct { else toI64(axes__); - stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] }); + stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {f} and {any}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] }); if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) { return self; @@ -1417,7 +1416,7 @@ pub const Tensor = struct { return self.reshape(res_shape); } - const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation }); + const loc = self.getContext().location(@src(), "transpose({f}, {any})", .{ self, permutation }); const op = dialect.stablehlo.transpose( self.getContext().mlirCtx(), self.value(), @@ -1505,7 +1504,7 @@ pub const Tensor = struct { // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); - const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ }); + const loc = self.getContext().location(@src(), "flatten({f},{})", .{ self, axis_ }); const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), @@ -1684,7 +1683,7 @@ pub const Tensor = struct { const res_shape = shape0.insertTag(axis_, 1, tag); for (tensors[1..]) |tensor| { - stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape }); + stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {f} and {f}", .{ shape0, tensor._shape }); } var reshaped: [32]Tensor = undefined; @@ -1859,7 +1858,7 @@ pub const Tensor = struct { const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const res_shape = sh.withDtype(dt); const ctx = CompilationContext.current(); - const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a }); + const loc = ctx.location(@src(), "iota({f}, {})", .{ res_shape, a }); const mlir_ctx = ctx.mlirCtx(); var op = dialect.stablehlo.iota( @@ -1931,7 +1930,7 @@ pub const Tensor = struct { pub fn constant(dimz: anytype, val: Data) Tensor { const sh = Shape.init(dimz, val.dtype()); const ctx = CompilationContext.current().mlirCtx(); - const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val }); + const loc = CompilationContext.current().location(@src(), "dims={f}, value={}", .{ sh, val }); var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type| dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc) @@ -1951,7 +1950,7 @@ pub const Tensor = struct { pub fn constantTensor(val: HostBuffer) Tensor { const ctx = CompilationContext.current().mlirCtx(); const loc = ctx.location(@src()); - const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); + const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {f}", .{val.shape()}); const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc); return _result(val.shape(), constant_op.result(0)); } @@ -1975,10 +1974,10 @@ pub const Tensor = struct { /// you will lose the tags. /// To avoid use favorise `.broad(shape)` when working with tagged tensors. pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor { - stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ }); + stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({f}, {f}, {any})", .{ self, output_shape, axes_ }); for (0.., axes_) |self_ax, other_ax| { const d = self.dim(self_ax); - stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); + stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({f}, {f}, {any}), error on self axis {d} mapping to other axis {d}", .{ self, output_shape, axes_, self_ax, other_ax }); } const res_shape = output_shape.withDtype(self.dtype()); @@ -1989,7 +1988,7 @@ pub const Tensor = struct { } const ctx = self.getContext(); const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape); - const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); + const loc = ctx.location(@src(), "broadcast({f}, {f}, axes={any})", .{ self, res_shape, axes_ }); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); return _result(res_shape, broadcast_op.result(0)); @@ -1997,7 +1996,7 @@ pub const Tensor = struct { /// Broadcasts a Tensor to the given shape, adding axes at the beginning. pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor { - stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); + stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() }); const a = output_shape.rank() - self.rank(); if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) { @@ -2009,7 +2008,7 @@ pub const Tensor = struct { /// Broadcasts a Tensor to the given shape, adding axes at the end. pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor { - stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); + stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() }); if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) { return self; @@ -2022,7 +2021,7 @@ pub const Tensor = struct { pub fn broad(self: Tensor, other: Shape) Tensor { // TODO: broad is too restrictive because sometime you only want to specify one specific axis // Note: if you code below, make sure to update Shape.canBroadcastTo. - stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {} to {}", .{ self, other }); + stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {f} to {f}", .{ self, other }); // Already the right shape if (std.mem.eql(i64, self.dims(), other.dims())) return self; @@ -2036,7 +2035,7 @@ pub const Tensor = struct { } // check that each axis of self maps to an axis of other - var axes_: std.BoundedArray(i64, MAX_RANK) = .{}; + var axes_: stdx.BoundedArray(i64, MAX_RANK) = .{}; for (self._shape.tags()) |t| { axes_.appendAssumeCapacity(@intCast(other.axis(t))); } @@ -2047,14 +2046,14 @@ pub const Tensor = struct { pub fn reshape(self: Tensor, output_shape_: anytype) Tensor { const output_shape = self._shape.reshape(output_shape_); const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape); - const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape}); + const loc = self.getContext().location(@src(), "reshape({f})", .{output_shape}); const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc); return _result(output_shape, reshape_value.result(0)); } /// Converts the given 1 element Tensor into a 0-rank Tensor. pub fn asScalar(self: Tensor) Tensor { - stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self}); + stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self}); return self.reshape(.{}); } @@ -2177,12 +2176,12 @@ pub const Tensor = struct { stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); for (coord_axes_.constSlice(), 0..) |a, i| { if (i > 0) { - stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); + stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {f}", .{ coord_axes, self }); } } const AxisKind = enum { batching, offset, collapsed, indices }; - var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{}; var indices_batch_axes: Shape.DimsArray = .{}; for (self._shape.tags(), 0..self.rank()) |t, self_ax| { const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax)); @@ -2191,7 +2190,7 @@ pub const Tensor = struct { // Note: tags are required for batching. self_kind.appendAssumeCapacity(.batching); indices_batch_axes.appendAssumeCapacity(id_ax); - stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); + stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'coord_axes_={any}' and in 'indices={f}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); } else if (maybe_coord_ax) |_| { // for gatherValues we collapsed all gathered axes // (contrary to gatherSlices where we collapse none) @@ -2208,13 +2207,13 @@ pub const Tensor = struct { indices.rank() else blk: { const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); + stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {f}", .{ coord_axes, coord_axes_.len, indices }); break :blk ax; }; // compute res shape var res_shape = Shape.init(.{}, self.dtype()); - var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; + var res_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{}; for (self_kind.constSlice(), 0..) |kind, ax_usize| { const ax: u3 = @intCast(ax_usize); if (ax == coord_axes_.get(0)) { @@ -2275,7 +2274,7 @@ pub const Tensor = struct { ); const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); - stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); + stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices }); return _result(res_shape, gather_op.result(0)); } @@ -2347,29 +2346,29 @@ pub const Tensor = struct { /// and gatherSlices can copy data by group of C'*D elements. pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor { const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32); - // scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices }); + // scoped_log.debug("gatherSlice({}, {f}, {})", .{ self, slice_shape, indices }); const tagged_api = slice_shape.isFullyTagged(); if (tagged_api) { for (slice_shape.tags()) |t| { - stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self }); + stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {f}", .{ t, self }); } } else { // For untagged api, we require all slices to be specified. // Note: we could relax this and right align the slice. - stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); + stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({f}, slice={f}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); } const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); - stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices }); + stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({f}, slice={f}, indices) expects 'indices' to be a tensor [..., {}], got {f}", .{ self, slice_shape, slice_shape.rank(), indices }); // Compute result shape var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype()); var slice_dims = self._shape._dims; - var self_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{}; - var indices_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{}; - var start_index_map: std.BoundedArray(i64, MAX_RANK) = .{}; - var self_offset_axes: std.BoundedArray(i64, MAX_RANK) = .{}; + var self_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{}; + var indices_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{}; + var start_index_map: stdx.BoundedArray(i64, MAX_RANK) = .{}; + var self_offset_axes: stdx.BoundedArray(i64, MAX_RANK) = .{}; for (self._shape.tags(), 0..self.rank()) |t, self_ax| { const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax); @@ -2379,12 +2378,12 @@ pub const Tensor = struct { self_batch_axes.appendAssumeCapacity(@intCast(self_ax)); indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t)); slice_dims.set(self_ax, 1); - stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices }); + stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={f}` and `indices={f}`", .{ t, slice_shape, indices }); } else if (maybe_slice_ax) |slice_ax| { // Specified axes contains the start offset of the slices, // and are collected in `start_index_map`. const slice_dim = slice_shape.dim(slice_ax); - stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape }); + stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {f} > {f}.", .{ t, slice_shape, self._shape }); slice_dims.set(self_ax, slice_dim); res_shape = res_shape.appendDim(slice_dim, t); start_index_map.appendAssumeCapacity(@intCast(self_ax)); @@ -2396,7 +2395,7 @@ pub const Tensor = struct { } } - const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices }); + const loc = self.getContext().location(@src(), "gatherSlices({f}, slice_shape={f}, idx={f})", .{ self, slice_shape, indices }); const gather_op = dialect.stablehlo.gather( self.getContext().mlirCtx(), self.value(), @@ -3172,7 +3171,7 @@ pub const Tensor = struct { /// Note: this doesn't support tagging, if you have tags, /// you should use `dynamicSlice` directly. pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor { - stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start}); + stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {f}", .{slice_.start}); const a = self.axis(axis_); const new_shape = self._shape.set(a, slice_.len); @@ -3226,17 +3225,17 @@ pub const Tensor = struct { const offset = slice_.start; const len = slice_.len; if (slices_tags.len == 0) { - stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len }); + stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {d} and {d}", .{ self.rank(), slices.len }); offset_values[i] = offset.value(); res_shape._dims.set(i, len); - stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i }); + stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {d} and {d} for index {d}", .{ len, self.dim(i), i }); } else { const t = slices_tags.get(i); - const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape }); + const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {f})", .{ t, self._shape }); - stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t }); + stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {d} and {d} for axis {s}", .{ len, self.dim(a), t }); offset_values[a] = offset.value(); res_shape._dims.set(a, len); @@ -3304,14 +3303,14 @@ pub const Tensor = struct { if (tagged_api) { // Check that all update tags are known. for (update._shape._tags.constSlice()) |t| { - stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape }); + stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {f})", .{ t, self._shape }); } var update_shape = self._shape; var prev_ax: i8 = -1; for (self._shape.tags(), 0..) |t, self_ax| { if (update._shape.hasTag(t)) |up_ax| { - stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_, self }); + stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {f} and {f}. (hint: you need to explicitly transpose 'update_')", .{ update_, self }); update_shape._dims.set(self_ax, update.dim(up_ax)); prev_ax = up_ax; @@ -3322,7 +3321,7 @@ pub const Tensor = struct { update = update.reshape(update_shape); } - stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {}", .{ self, update }); + stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {f} and {f}", .{ self, update }); for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| { const t = self._shape.debugTag(ax); @@ -3350,7 +3349,7 @@ pub const Tensor = struct { // This is only allowed when using tagged sliced. offset_values = .{zero} ** MAX_RANK; for (offset.constSlice(), offset_tags.constSlice()) |start, t| { - const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape }); + const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {f})", .{ t, self._shape }); offset_values[a] = start.value(); } } @@ -3469,12 +3468,12 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors. pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor { - stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); + stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {t} and {t}", .{ self.dtype(), other.dtype() }); if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other); if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{})); - stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); + stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {f} and {f}", .{ self._shape, other._shape }); const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)}); const op = dialect.stablehlo.compare( @@ -3492,7 +3491,7 @@ pub const Tensor = struct { /// For each vector in the input tensor, /// creates a diagonal-matrix where diagonal values are set to the vector values. pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor { - stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); + stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {d} rank, got {f}", .{ MAX_RANK - 1, self }); const a = self.axis(axis_); const d = self.dim(a); var res_shape = self._shape; @@ -3622,8 +3621,8 @@ pub const Tensor = struct { return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape())); } - stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); - stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); + stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_true._shape }); + stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_false._shape }); const loc = bool_tensor.getContext().mlirCtx().location(@src()); const op = dialect.stablehlo.select( @@ -3762,7 +3761,7 @@ pub const Tensor = struct { /// /// - res[a, b, c, d] == (A[a], B[b], C[c], D[d]) pub fn cartesianProductStacked(vectors: []const Tensor) Tensor { - var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable; + var out = stdx.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable; _cartesianProduct(vectors, out.slice()); return Tensor.stack(out.constSlice(), .last, .coord); @@ -3801,7 +3800,7 @@ pub const Tensor = struct { ) fn (Tensor, Tensor) Tensor { return struct { pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor { - stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); + stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {f} and {f}", .{ op_name, self, other }); if (self.rank() == 0 and other.rank() != 0) { return binaryOpHelper(self.broad(other._shape), other); @@ -3811,10 +3810,10 @@ pub const Tensor = struct { return binaryOpHelper(self, other.broad(self._shape)); } - stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); + stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape }); const ctx = self.getContext(); - const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other }); + const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other }); const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location }); return _result(self._shape, ret.result(0)); } @@ -3837,7 +3836,7 @@ pub const Tensor = struct { fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { const host_buffer = inputs[0]; - std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); + std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() }); // This is true because of the operand aliases. // Since the result is already pointing to the input we don't need to modify the buffer. std.debug.assert(host_buffer._data == outputs[0]._data); @@ -4060,8 +4059,8 @@ test shapesOf { } } -pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { - var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; +pub fn _collectAxes(T: type, bounded_array: stdx.BoundedArray(T, Tensor.MAX_RANK), value: T) stdx.BoundedArray(i64, Tensor.MAX_RANK) { + var res: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{}; for (bounded_array.constSlice(), 0..) |v, ax| { if (v == value) { res.appendAssumeCapacity(@intCast(ax)); @@ -4070,12 +4069,12 @@ pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK) return res; } -fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } { +fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, stdx.BoundedArray(u3, Tensor.MAX_RANK) } { const AxesT = @TypeOf(axes_); const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int; const coord_axes = if (axes_is_scalar) - std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable + stdx.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable else self.axes(axes_); @@ -4099,7 +4098,7 @@ inline fn toI64(values: anytype) []i64 { } fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool { - var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{}; + var perm: stdx.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{}; // Don't rewrite on invalid inputs. if (permutation.len > x.rank()) return false; for (permutation) |ax| { diff --git a/zml/test_runner.zig b/zml/test_runner.zig index ed8ee11..8b8d655 100644 --- a/zml/test_runner.zig +++ b/zml/test_runner.zig @@ -31,7 +31,7 @@ pub fn asyncMain() !void { .root_name = "Test", .estimated_total_items = test_fn_list.len, }); - const have_tty = std.io.getStdErr().isTty(); + const have_tty = std.fs.File.stderr().isTty(); var args = std.process.args(); // Skip executable path diff --git a/zml/testing.zig b/zml/testing.zig index 0dcf394..b236400 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -51,10 +51,10 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { if (should_free_left) left.deinit(allocator); if (should_free_right) right.deinit(allocator); } - errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() }); + errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() }); if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) { - log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() }); + log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() }); return error.TestUnexpectedResult; } if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) { @@ -89,7 +89,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { const right_data = right.items(R); for (left_data, right_data, 0..) |l, r, i| { if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) { - log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); + log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); return error.TestUnexpectedResult; } } @@ -108,7 +108,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpectedEqual}!void { if (expected.eqlWithTags(actual)) return; - std.debug.print("Expected {}, got {}", .{ expected, actual }); + std.debug.print("Expected {f}, got {f}", .{ expected, actual }); return error.TestExpectedEqual; } diff --git a/zml/tokenizer/BUILD.bazel b/zml/tokenizer/BUILD.bazel index 4489d77..9b2f9c7 100644 --- a/zml/tokenizer/BUILD.bazel +++ b/zml/tokenizer/BUILD.bazel @@ -9,6 +9,7 @@ zig_library( deps = [ "//async", "//ffi:zig", + "//stdx", "//zml/tokenizer/hftokenizers", "//zml/tokenizer/sentencepiece", ], diff --git a/zml/tokenizer/hftokenizers/BUILD.bazel b/zml/tokenizer/hftokenizers/BUILD.bazel index 54a01c8..a07d855 100644 --- a/zml/tokenizer/hftokenizers/BUILD.bazel +++ b/zml/tokenizer/hftokenizers/BUILD.bazel @@ -27,5 +27,6 @@ zig_library( deps = [ ":hftokenizers_cc", "//ffi:zig", + "//stdx", ], ) diff --git a/zml/tokenizer/hftokenizers/hftokenizers.zig b/zml/tokenizer/hftokenizers/hftokenizers.zig index 538a25a..7b36dbe 100644 --- a/zml/tokenizer/hftokenizers/hftokenizers.zig +++ b/zml/tokenizer/hftokenizers/hftokenizers.zig @@ -1,6 +1,8 @@ const std = @import("std"); + const c = @import("c"); const ffi = @import("ffi"); +const stdx = @import("stdx"); pub const Encoder = struct { inner: *HFTokenizer, @@ -33,8 +35,8 @@ pub const Encoder = struct { }; pub const Decoder = struct { - const StringBuffer = std.BoundedArray(u8, 128); - const TokensIdsBuffer = std.BoundedArray(u32, 4); + const StringBuffer = stdx.BoundedArray(u8, 128); + const TokensIdsBuffer = stdx.BoundedArray(u32, 4); inner: *HFTokenizer, current_string: ?[]const u8 = null, diff --git a/zml/tokenizer/homemade.zig b/zml/tokenizer/homemade.zig index 98f574c..9b5898d 100644 --- a/zml/tokenizer/homemade.zig +++ b/zml/tokenizer/homemade.zig @@ -2,10 +2,11 @@ //! Disclaimer this is not a very robust implementation: //! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization. //! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer. -const builtin = @import("builtin"); const std = @import("std"); - const testing = std.testing; +const builtin = @import("builtin"); + +const stdx = @import("stdx"); const log = std.log.scoped(.@"zml/tokenizer"); @@ -87,12 +88,11 @@ pub const Tokenizer = struct { } /// Reads a new word directly into the tokenizer arena. - pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: anytype) !void { + pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: *std.Io.Reader) !void { const arena = self.arena_state.allocator(); const token = try arena.alloc(u8, len); - const n = try tok_reader.readAll(token); - std.debug.assert(n == len); + try tok_reader.readSliceAll(token); return self.addOwnedToken(score, token); } @@ -190,9 +190,9 @@ pub const Tokenizer = struct { if (options.debug) { var _debug_buf: [256]u8 = undefined; var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); - var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator()); + var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator()); self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; - log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); + log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); } var best_score: f32 = -1e10; var best_token: u32 = 0; @@ -312,7 +312,7 @@ pub const Tokenizer = struct { /// Note that if the tokenizer allows sub-unicode bytes, it's possible /// the output is not valid utf8. pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { - var output = std.ArrayList(u8).init(allocator); + var output = std.array_list.Managed(u8).init(allocator); errdefer output.deinit(); try self.decodeWithOpts(&output, input, .{}); @@ -321,7 +321,7 @@ pub const Tokenizer = struct { pub fn decodeWithOpts( self: *const Tokenizer, - output: *std.ArrayList(u8), + output: *std.array_list.Managed(u8), input: []const u32, opts: struct { sep: []const u8 = "" }, ) error{OutOfMemory}!void { @@ -363,7 +363,8 @@ pub const Tokenizer = struct { // First lookup the byte fallback entry. // Note: we assume upper case, but we could try both upper and lower case if needed. - _ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 }); + var writer: std.Io.Writer = .fixed(byte_fallback_buf[3..5]); + try writer.printInt(c, 16, .upper, .{ .fill = '0', .width = 2 }); const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse { log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf}); return error.InvalidInput; @@ -443,8 +444,8 @@ pub const Encoder = struct { }; pub const Decoder = struct { - const StringBuffer = std.BoundedArray(u8, 128); - const TokensIdsBuffer = std.BoundedArray(u32, 4); + const StringBuffer = stdx.BoundedArray(u8, 128); + const TokensIdsBuffer = stdx.BoundedArray(u32, 4); inner: *Tokenizer, arena: std.heap.ArenaAllocator, @@ -571,7 +572,7 @@ test CharTokenIterator { { tokenizer.byte_fallback = false; var it: CharTokenIterator = .{ .input = "ζℳL" }; - var res: std.BoundedArray(u32, 8) = .{}; + var res: stdx.BoundedArray(u32, 8) = .{}; while (try it.nextCodepointToken(&tokenizer)) |token| { res.appendAssumeCapacity(token); } @@ -582,7 +583,7 @@ test CharTokenIterator { { tokenizer.byte_fallback = true; var it: CharTokenIterator = .{ .input = "ζℳL" }; - var res: std.BoundedArray(u32, 8) = .{}; + var res: stdx.BoundedArray(u32, 8) = .{}; while (try it.nextCodepointToken(&tokenizer)) |token| { res.appendAssumeCapacity(token); } @@ -596,7 +597,7 @@ pub const Normalizer = struct { /// Space token used by sentencepiece derived tokenizer. pub const sentencepiece_space = "▁"; // \xe2\x96\x81 - _whitespace: std.BoundedArray(u8, 8) = .{}, + _whitespace: stdx.BoundedArray(u8, 8) = .{}, flags: packed struct { remove_extra_whitespaces: bool, @@ -610,7 +611,7 @@ pub const Normalizer = struct { split_on_punct_ascii: bool, }, - pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer { + pub fn init(flags: @FieldType(Normalizer, "flags"), escaped_whitespace: ?[]const u8) Normalizer { var res: Normalizer = .{ .flags = flags }; if (escaped_whitespace) |escaped| { res._whitespace.appendSliceAssumeCapacity(escaped); @@ -622,7 +623,7 @@ pub const Normalizer = struct { return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; } - fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { + fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void { try normalized.appendSlice(data); for (data) |_| try normalized_to_origin.append(consumed); } @@ -672,9 +673,9 @@ pub const Normalizer = struct { // Pre-allocate outputs const space = self.escapedSpace() orelse " "; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; - var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); + var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); errdefer normalized.deinit(); - var normalized_to_origin = try std.ArrayList(usize).initCapacity(allocator, normalized.capacity); + var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity); errdefer normalized_to_origin.deinit(); // If the spec asks for it, add a whitespace at the beginning. @@ -965,7 +966,7 @@ test Normalizer { /// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm, /// into utf8 bytes, and do lookups at runtime. pub const Gpt2TextDecoder = struct { - const Code = std.BoundedArray(u8, 2); + const Code = stdx.BoundedArray(u8, 2); // TODO: benchmark this is more efficient than doing the conversion at runtime. code_to_byte: std.AutoArrayHashMap(Code, u8), @@ -982,7 +983,7 @@ pub const Gpt2TextDecoder = struct { var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init const i: u8 = @intCast(index); if (isPrintableByte(i)) { - if (std.ascii.isASCII(i)) { + if (std.ascii.isAscii(i)) { code.appendAssumeCapacity(i); } else { const codepoint: u21 = @as(u21, @intCast(i)); @@ -1005,7 +1006,7 @@ pub const Gpt2TextDecoder = struct { /// Transform bytes representing text under the gpt2 encoding, /// and write to the `unicode` buffer utf-8 bytes. - pub fn decode(self: Gpt2TextDecoder, unicode: *std.ArrayList(u8), bytes: []const u8) ![]const u8 { + pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 { const start = unicode.items.len; var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes }; while (it.nextCodepointSlice()) |codepoint| { @@ -1029,7 +1030,7 @@ test Gpt2TextDecoder { var decoder = try Gpt2TextDecoder.init(testing.allocator); defer decoder.deinit(); - var out = std.ArrayList(u8).init(testing.allocator); + var out = std.array_list.Managed(u8).init(testing.allocator); defer out.deinit(); // Ascii is not changed. @@ -1076,7 +1077,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok // Buffer containing all concatenated tokens. // Reserve a big chunk, to avoid grow event, but release over-allocated memory. - var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); + var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); const original_alloc = all_tokens.items.ptr; // A re-alloc event here means we have invalidated all slices inside the tokenizer. // If this is too annoying we could switch to a custom type instead of slices. @@ -1164,7 +1165,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok } /// Returns a copy of the given string, stored inside the given ArrayList. -fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { +fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 { const n = buffer.items.len; try buffer.appendSlice(str); return buffer.items[n..]; @@ -1175,7 +1176,7 @@ fn objectGet( object: std.json.ObjectMap, comptime kind: std.meta.FieldEnum(std.json.Value), key: []const u8, -) ?std.meta.FieldType(std.json.Value, kind) { +) ?@FieldType(std.json.Value, @tagName(kind)) { const val = object.get(key) orelse return null; if (val != kind) return null; return @field(val, @tagName(kind)); @@ -1184,10 +1185,11 @@ fn objectGet( pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer { const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{}); defer tokenizer_file.close(); - var tok_reader = std.io.bufferedReader(tokenizer_file.reader()); - const r = tok_reader.reader(); + var read_buff: [4096]u8 = undefined; + var tok_reader = tokenizer_file.reader(&read_buff); + const r: *std.Io.Reader = &tok_reader.interface; - const max_token_len = try r.readInt(u32, .little); + const max_token_len = try readValueLE(u32, r); const special_tokens: Tokenizer.SpecialTokens = .{ .unk = 0, .bos = 1, @@ -1195,11 +1197,11 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u }; var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true); var i: u32 = 0; - while (readToken(&tokenizer, &r)) : (i += 1) { + while (readToken(&tokenizer, r)) : (i += 1) { // Pass } else |_| { if (i < vocab_size) { - log.info("Read {d} words out of {?d}", .{ i, vocab_size }); + log.info("Read {d} words out of {d}", .{ i, vocab_size }); } tokenizer.vocab_size = i; } @@ -1207,8 +1209,14 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u return tokenizer; } -fn readToken(tokenizer: *Tokenizer, tok_reader: anytype) !void { - const score: f32 = @bitCast(try tok_reader.readInt(u32, .little)); - const len: usize = @intCast(try tok_reader.readInt(u32, .little)); +fn readToken(tokenizer: *Tokenizer, tok_reader: *std.Io.Reader) !void { + const score: f32 = try readValueLE(f32, tok_reader); + const len: usize = try readValueLE(u32, tok_reader); try tokenizer.readTokenInto(score, len, tok_reader); } + +fn readValueLE(T: type, reader: *std.Io.Reader) !T { + var res: [1]T = undefined; + try reader.readSliceEndian(T, &res, .little); + return res[0]; +} diff --git a/zml/tokenizer/main.zig b/zml/tokenizer/main.zig index d3554f1..b6f91dc 100644 --- a/zml/tokenizer/main.zig +++ b/zml/tokenizer/main.zig @@ -1,10 +1,11 @@ const std = @import("std"); -const log = std.log.scoped(.@"//zml/tokenizer"); const asynk = @import("async"); const stdx = @import("stdx"); const zml_tokenizer = @import("zml/tokenizer"); +const log = std.log.scoped(.@"//zml/tokenizer"); + const Flags = struct { tokenizer: []const u8, prompt: []const u8, @@ -35,7 +36,7 @@ pub fn asyncMain() !void { const prompt_tok = try encoder.encode(args.prompt); - log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok }); + log.info("Input: {s}\nOutput: {any}", .{ args.prompt, prompt_tok }); var errors: u8 = 0; { @@ -47,14 +48,14 @@ pub fn asyncMain() !void { } if (args.expected.len > 0) { - var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len); + var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len); var it = std.mem.splitSequence(u8, args.expected, ","); while (it.next()) |int_token| { const tok = try std.fmt.parseInt(u32, int_token, 10); try expected.append(tok); } if (!std.mem.eql(u32, expected.items, prompt_tok)) { - log.err("Doesn't match expected: {d}", .{expected.items}); + log.err("Doesn't match expected: {any}", .{expected.items}); errors += 1; } } diff --git a/zml/tokenizer/sentencepiece/BUILD.bazel b/zml/tokenizer/sentencepiece/BUILD.bazel index 354b289..dba86dd 100644 --- a/zml/tokenizer/sentencepiece/BUILD.bazel +++ b/zml/tokenizer/sentencepiece/BUILD.bazel @@ -19,5 +19,6 @@ zig_library( deps = [ ":sentencepiece_swig", "//ffi:zig", + "//stdx", ], ) diff --git a/zml/tokenizer/sentencepiece/sentencepiece.zig b/zml/tokenizer/sentencepiece/sentencepiece.zig index 7f4edc6..c3623cb 100644 --- a/zml/tokenizer/sentencepiece/sentencepiece.zig +++ b/zml/tokenizer/sentencepiece/sentencepiece.zig @@ -1,6 +1,8 @@ const std = @import("std"); + const c = @import("c"); const ffi = @import("ffi"); +const stdx = @import("stdx"); const StringToTokenRatio = 3; @@ -81,7 +83,7 @@ pub const Encoder = struct { pub const Decoder = struct { const StringBufferSize = 64; - const StringBuffer = std.BoundedArray(u8, StringBufferSize); + const StringBuffer = stdx.BoundedArray(u8, StringBufferSize); const TokenIdsBufferSize = 4; inner: *SentencePieceProcessor, diff --git a/zml/torch.zig b/zml/torch.zig index 5889353..e982aae 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -14,7 +14,7 @@ const Tensor = zml.Tensor; /// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}` pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { - stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); + stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({f}, {f}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }}; if (lhs.rank() == 1 or rhs.rank() <= 2) { @@ -22,7 +22,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { return lhs.dotGeneral(rhs, &contracting, &.{}); } - stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); + stdx.debug.assert(lhs.rank() == 2, "Can't matmul({f}, {f}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); // Pytorch treats the extra dimensions of rhs has batching dimensions, // and implicitly broadcast lhs along those. @@ -91,7 +91,7 @@ pub fn unsqueeze( self: Tensor, axis_: anytype, ) Tensor { - stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self}); + stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {f}, it's already at max rank.", .{self}); const a = switch (@typeInfo(@TypeOf(axis_))) { .int, .comptime_int => if (axis_ < 0) @as(i8, self.rank()) + 1 + axis_ @@ -125,9 +125,9 @@ test unsqueeze { /// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor { const shape = tensor.shape(); - stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); + stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({f}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); - stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); + stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({f}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor }); const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });