Add Zig 0.15 compatibility: update BUILD files, async primitives, stdx utilities, MLIR dialects, and PJRT FFI.

This commit is contained in:
Tarry Singh 2025-07-28 13:54:28 +00:00
parent e3b7705e3d
commit 488a844a0f
53 changed files with 1376 additions and 520 deletions

View File

@ -11,6 +11,7 @@ zls_completion(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//async", "//async",
"//examples/llama",
"//stdx", "//stdx",
"//zml", "//zml",
], ],

View File

@ -18,7 +18,7 @@ bazel_dep(name = "rules_oci", version = "2.2.6")
bazel_dep(name = "rules_proto", version = "7.1.0") bazel_dep(name = "rules_proto", version = "7.1.0")
bazel_dep(name = "rules_python", version = "1.5.3") bazel_dep(name = "rules_python", version = "1.5.3")
bazel_dep(name = "rules_rust", version = "0.63.0") bazel_dep(name = "rules_rust", version = "0.63.0")
bazel_dep(name = "rules_zig", version = "20250821.0-be53625") bazel_dep(name = "rules_zig", version = "20250827.0-35b6d57")
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4") bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
bazel_dep(name = "with_cfg.bzl", version = "0.11.0") bazel_dep(name = "with_cfg.bzl", version = "0.11.0")
@ -26,7 +26,7 @@ bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = Tr
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig") zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
zig.index(file = "//bazel:zig_index.json") zig.index(file = "//bazel:zig_index.json")
zig.toolchain(zig_version = "0.14.1") zig.toolchain(zig_version = "0.15.1")
zig.mirrors(urls = [ zig.mirrors(urls = [
"https://mirror.zml.ai/zig", "https://mirror.zml.ai/zig",
"https://ziglang.org/builds/", "https://ziglang.org/builds/",

View File

@ -228,36 +228,88 @@ pub const AsyncThread = struct {
}; };
pub fn getStdIn() !File { pub fn getStdIn() !File {
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); return File.initStreaming(std.fs.File.stdin()) catch @panic("Unable to open stdin");
} }
pub fn getStdOut() File { pub fn getStdOut() File {
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); return File.initStreaming(std.fs.File.stdout()) catch @panic("Unable to open stdout");
} }
pub fn getStdErr() File { pub fn getStdErr() File {
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); return File.initStreaming(std.fs.File.stderr()) catch @panic("Unable to open stderr");
} }
pub const File = struct { pub const File = struct {
pub const Reader = struct {
interface: std.Io.Reader,
file: File,
mode: std.fs.File.Reader.Mode = .positional,
pos: u64 = 0,
fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
const self: *Reader = @alignCast(@fieldParentPtr("interface", r));
const dest = limit.slice(try w.writableSliceGreedy(1));
const n = switch (self.mode) {
.streaming => self.file.read(dest),
.positional => self.file.pread(dest, self.pos),
else => @panic("UNSUPPORTED"),
} catch {
return std.Io.Reader.StreamError.ReadFailed;
};
if (n == 0) {
return std.Io.Reader.StreamError.EndOfStream;
}
self.pos += n;
w.advance(n);
return n;
}
};
pub const Writer = struct {
interface: std.Io.Writer,
file: File,
mode: std.fs.File.Writer.Mode = .positional,
pos: u64 = 0,
fn write(self: *Writer, buf: []const u8) !usize {
const n = switch (self.mode) {
.streaming => self.file.write(buf),
.positional => self.file.pwrite(buf, self.pos),
else => unreachable,
} catch {
return std.Io.Writer.Error.WriteFailed;
};
self.pos += n;
return n;
}
fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
// TODO: implement splat
_ = splat;
const self: *Writer = @alignCast(@fieldParentPtr("interface", w));
var total: usize = 0;
if (w.buffered().len > 0) {
total += w.consume(try self.write(w.buffered()));
}
for (data) |d| {
const n = try self.write(d);
total += n;
if (n < d.len) {
return total;
}
}
return total;
}
};
pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?; pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?;
pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?; pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?;
pub const Reader = std.io.GenericReader(File, stdx.meta.FnSignature(File.read, null).ReturnErrorSet.?, File.read);
pub const Writer = std.io.GenericWriter(File, stdx.meta.FnSignature(File.write, null).ReturnErrorSet.?, File.write);
pub const SeekableStream = std.io.SeekableStream(
File,
SeekError,
GetSeekPosError,
seekTo,
seekBy,
getPos,
getEndPos,
);
_handle: std.fs.File.Handle, _handle: std.fs.File.Handle,
inner: aio.File, inner: aio.File,
is_streaming: bool = false,
fn asFile(self: File) std.fs.File { pub fn asFile(self: File) std.fs.File {
return .{ .handle = self._handle }; return .{ .handle = self._handle };
} }
@ -272,6 +324,14 @@ pub const File = struct {
}; };
} }
pub fn initStreaming(file_: std.fs.File) !File {
return .{
._handle = file_.handle,
.inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)),
.is_streaming = true,
};
}
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags })); return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags }));
} }
@ -280,6 +340,21 @@ pub const File = struct {
return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags }); return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags });
} }
pub fn reader(self: File, buffer: []u8) Reader {
return .{
.file = self,
.interface = .{
.vtable = &.{
.stream = Reader.stream,
},
.buffer = buffer,
.seek = 0,
.end = 0,
},
.mode = if (self.is_streaming) .streaming else .positional,
};
}
pub fn read(self: File, buf: []u8) !usize { pub fn read(self: File, buf: []u8) !usize {
// NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend. // NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend.
if (buf.len == 0) { if (buf.len == 0) {
@ -310,6 +385,19 @@ pub const File = struct {
return self.inner.write(.{ .slice = buf }); return self.inner.write(.{ .slice = buf });
} }
pub fn writer(self: File, buffer: []u8) Writer {
return .{
.file = self,
.interface = .{
.vtable = &.{
.drain = Writer.drain,
},
.buffer = buffer,
},
.mode = if (self.is_streaming) .streaming else .positional,
};
}
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize { pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
return self.inner.pwrite(.{ .slice = buf }, offset); return self.inner.pwrite(.{ .slice = buf }, offset);
} }
@ -318,18 +406,6 @@ pub const File = struct {
return self.inner.close(); return self.inner.close();
} }
pub fn reader(self: File) Reader {
return .{ .context = self };
}
pub fn seekableStream(file: File) SeekableStream {
return .{ .context = file };
}
pub fn writer(self: File) Writer {
return .{ .context = self };
}
pub fn stat(self: File) !std.fs.File.Stat { pub fn stat(self: File) !std.fs.File.Stat {
return try callBlocking(std.fs.File.stat, .{self.asFile()}); return try callBlocking(std.fs.File.stat, .{self.asFile()});
} }
@ -375,8 +451,47 @@ pub const Socket = struct {
pub const TCP = struct { pub const TCP = struct {
const Inner = aio.TCP; const Inner = aio.TCP;
pub const Reader = std.io.GenericReader(TCP, stdx.meta.FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read); pub const Reader = struct {
pub const Writer = std.io.GenericWriter(TCP, stdx.meta.FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write); interface: std.Io.Reader,
socket: TCP,
fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
const self: *Reader = @alignCast(@fieldParentPtr("interface", r));
const dest = limit.slice(try w.writableSliceGreedy(1));
const n = self.socket.read(dest) catch {
return std.Io.Reader.StreamError.ReadFailed;
};
w.advance(n);
return n;
}
};
pub const Writer = struct {
interface: std.Io.Writer,
socket: TCP,
fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
// TODO: implement splat
_ = splat;
const self: *Writer = @alignCast(@fieldParentPtr("interface", w));
var total: usize = 0;
if (w.buffered().len >= 0) {
total += w.consume(self.socket.write(w.buffered()) catch {
return std.Io.Writer.Error.WriteFailed;
});
}
for (data) |d| {
const n = self.socket.write(d) catch {
return std.Io.Writer.Error.WriteFailed;
};
total += n;
if (n < d.len) {
return total;
}
}
return total;
}
};
inner: aio.TCP, inner: aio.TCP,
@ -418,12 +533,30 @@ pub const Socket = struct {
return self.inner.close(); return self.inner.close();
} }
pub fn reader(self: TCP) Reader { pub fn reader(self: TCP, buffer: []u8) Reader {
return .{ .context = self }; return .{
.socket = self,
.interface = .{
.vtable = &.{
.stream = Reader.stream,
},
.buffer = buffer,
.seek = 0,
.end = 0,
},
};
} }
pub fn writer(self: TCP) Writer { pub fn writer(self: TCP, buffer: []u8) Writer {
return .{ .context = self }; return .{
.socket = self,
.interface = .{
.vtable = &.{
.drain = Writer.drain,
},
.buffer = buffer,
},
};
} }
}; };
@ -564,9 +697,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn {
const level_txt = comptime message_level.asText(); const level_txt = comptime message_level.asText();
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
const stderr = getStdErr().writer(); var buffer: [1024]u8 = undefined;
var bw = std.io.bufferedWriter(stderr); var stderr = getStdErr().writer(&buffer);
const writer = bw.writer();
var mutex = Self.mu orelse blk: { var mutex = Self.mu orelse blk: {
Self.mu = Mutex.init(); Self.mu = Mutex.init();
@ -575,8 +707,8 @@ pub fn logFn(comptime fallbackLogFn: LogFn) LogFn {
mutex.lock(); mutex.lock();
defer mutex.unlock(); defer mutex.unlock();
nosuspend { nosuspend {
writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return; stderr.interface.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch unreachable;
bw.flush() catch return; stderr.interface.flush() catch unreachable;
} }
} }
}.call; }.call;

View File

@ -156,7 +156,7 @@ const Coro = struct {
return self; return self;
} }
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.C) noreturn { fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn {
const from_coro: *Coro = @fieldParentPtr("impl", from); const from_coro: *Coro = @fieldParentPtr("impl", from);
const this_coro: *Coro = @fieldParentPtr("impl", this); const this_coro: *Coro = @fieldParentPtr("impl", this);
log(.debug, "coro start {any}", .{this_coro.id}); log(.debug, "coro start {any}", .{this_coro.id});

View File

@ -49,7 +49,7 @@ pub const Coro = packed struct {
const Func = *const fn ( const Func = *const fn (
from: *Coro, from: *Coro,
self: *Coro, self: *Coro,
) callconv(.C) noreturn; ) callconv(.c) noreturn;
pub fn init(func: Func, stack: []align(stack_alignment) u8) !Self { pub fn init(func: Func, stack: []align(stack_alignment) u8) !Self {
stdx.debug.assertComptime(@sizeOf(usize) == 8, "usize expected to take 8 bytes", .{}); stdx.debug.assertComptime(@sizeOf(usize) == 8, "usize expected to take 8 bytes", .{});

View File

@ -1,83 +1,264 @@
{ {
"master": { "master": {
"version": "0.15.0-dev.650+4f3b59f70", "version": "0.16.0-dev.27+83f773fc6",
"date": "2025-05-29", "date": "2025-08-24",
"docs": "https://ziglang.org/documentation/master/", "docs": "https://ziglang.org/documentation/master/",
"stdDocs": "https://ziglang.org/documentation/master/std/", "stdDocs": "https://ziglang.org/documentation/master/std/",
"src": { "src": {
"tarball": "https://ziglang.org/builds/zig-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "c14764ee9fd16f4437f2e2e7092cc7ec7ff76469f719be04155c29fa3bcc52cd", "shasum": "afddafbede9becaa6d94a544d3115893cc6aa6492a58545b4f638c58990586dc",
"size": "21279148" "size": "21370308"
}, },
"bootstrap": { "bootstrap": {
"tarball": "https://ziglang.org/builds/zig-bootstrap-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-bootstrap-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "142343992733282138b88b2205f29173ce3febcf5d73de4a19111fb36630f5a6", "shasum": "109cdda833baf951ce8116d3fdfd50b3770b02cff5c057f777601ca694f44b7c",
"size": "52649088" "size": "52732704"
}, },
"x86_64-macos": { "x86_64-macos": {
"tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-x86_64-macos-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "70117a96313ddfe57bd0b3fbbbdb2beffc27fb81f5993ba9b2825628bb9f5aaa", "shasum": "1ac88b947a6001eb30d41c15c404847209dfe945a27dbce6d1b7d401e4aef325",
"size": "55795264" "size": "55817716"
}, },
"aarch64-macos": { "aarch64-macos": {
"tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-aarch64-macos-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "991ef1b1871852f87e1d53f65c3dfb3b3f7187a0dc1f8fe6e7be279227b868d3", "shasum": "75243ad6d2e9fcd634862dda9003883484024fd2e067fc7e4a0f6c524cb6c86e",
"size": "50625312" "size": "50659200"
}, },
"x86_64-linux": { "x86_64-linux": {
"tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-x86_64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "2c2f65db1ad72d415b5a5bfb0ccd437bfc16e91b18a58035d28ea3177d07b1e2", "shasum": "c260aefaf5bf10bbd48101217874ccc9c0a37513729dabbbfe09bcf18fff1b5b",
"size": "53712376" "size": "53759360"
}, },
"aarch64-linux": { "aarch64-linux": {
"tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-aarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "98bfd9c33b737aa66d5cc19886b772cdc99801e66d8c891d400b0370fc626455", "shasum": "658241c15ad827a89878b2e35e13261cd1c2aaefe9a8bb94ccb907725cf80564",
"size": "49500380" "size": "49485912"
}, },
"armv7a-linux": { "arm-linux": {
"tarball": "https://ziglang.org/builds/zig-armv7a-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-arm-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "c204b60e27930702a64d55768fbbf2c819109dd0400168e0daf2994f4e57ed8b", "shasum": "6039c11f41bd032ec3cbda2d92e38b409aa9d9e2935b21266033b820ecb1a668",
"size": "50413108" "size": "50475500"
}, },
"riscv64-linux": { "riscv64-linux": {
"tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-riscv64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "2212613f09d114296f542dc3ca4583d3e1cfa72642e0d192a450ef1704722295", "shasum": "f56712849af75bb1bfda502c8cdfc52dcc41c18d3d7eddd329d94be85538d258",
"size": "53645244" "size": "53610616"
}, },
"powerpc64le-linux": { "powerpc64le-linux": {
"tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-powerpc64le-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "6a68c45d8bb5a3c1c3a259ad5ef965011c02feb0d17e8c96abd76aa48af03fe9", "shasum": "0562fdd578b5ae65a17e614454e67466314790a1d373279a4b054398f6a7b364",
"size": "53559816" "size": "53585484"
}, },
"x86-linux": { "x86-linux": {
"tarball": "https://ziglang.org/builds/zig-x86-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-x86-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "802e860891aa12979c3279e6d0571d08c0ed0398f362ac3f2a2fd945ef59bd1c", "shasum": "6eafb1b1d81118066dc03b17bed6612d7a9da057794f1d0f2451c249cf8a231d",
"size": "56328204" "size": "56334072"
}, },
"loongarch64-linux": { "loongarch64-linux": {
"tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-loongarch64-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "7fb8e715cdf44a2158afced5eecb3d2ae0f44d0a161f43ea8dfb04e95b22c89f", "shasum": "4fddf62a6698dbece51f504b4de8360a55d24efb0c7a56117eef64c8b717d13f",
"size": "50835068" "size": "50823012"
}, },
"s390x-linux": { "s390x-linux": {
"tarball": "https://ziglang.org/builds/zig-s390x-linux-0.15.0-dev.650+4f3b59f70.tar.xz", "tarball": "https://ziglang.org/builds/zig-s390x-linux-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "cbbcbf324db5158232d1fd5a1efcc2d694e058ad66772495e291828fa5227450", "shasum": "12ed086066c373e4098dc9f132e42bbc478a3be6bd5fc6967f39396c43543f89",
"size": "53386612" "size": "53525496"
}, },
"x86_64-windows": { "x86_64-windows": {
"tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.15.0-dev.650+4f3b59f70.zip", "tarball": "https://ziglang.org/builds/zig-x86_64-windows-0.16.0-dev.27+83f773fc6.zip",
"shasum": "0d4e148a60e859f0fa5ce6d293dcec7ba79a7e1da25df151b34e3ce71d0fa84f", "shasum": "ae987b8d93eec8a923bba26a74d20b5d15eac7ee9694c44db82c232247810a51",
"size": "94297867" "size": "93312425"
}, },
"aarch64-windows": { "aarch64-windows": {
"tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.15.0-dev.650+4f3b59f70.zip", "tarball": "https://ziglang.org/builds/zig-aarch64-windows-0.16.0-dev.27+83f773fc6.zip",
"shasum": "6b1a094cb4c666d0078531ce3d4f970dcf01f63e5b885d9af0f498726bdf5e43", "shasum": "2e3c95d044dd36d302a301fb029ea080b2bda4ef8e42bbb1b2a0246973952c53",
"size": "90200802" "size": "89157759"
}, },
"x86-windows": { "x86-windows": {
"tarball": "https://ziglang.org/builds/zig-x86-windows-0.15.0-dev.650+4f3b59f70.zip", "tarball": "https://ziglang.org/builds/zig-x86-windows-0.16.0-dev.27+83f773fc6.zip",
"shasum": "f18b6f9472a27abea8edbfb43687a75b61298e6d4da42247218dc3fd23560aa7", "shasum": "b696f07f7104da430c7b37750bc6e36571c3ffcd5bc49240d8a19aa4b431049f",
"size": "96232989" "size": "95215201"
},
"aarch64-freebsd": {
"tarball": "https://ziglang.org/builds/zig-aarch64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "0295a3f6bebb6d25fbc916cd4e3cff1285acab75b9591695069ad9a878301203",
"size": "49384624"
},
"arm-freebsd": {
"tarball": "https://ziglang.org/builds/zig-arm-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "a293670c7ee89ce6cf11f03dee297de2a31ffdc9d26f2c0e6e78fd66b11b2327",
"size": "50925296"
},
"powerpc64-freebsd": {
"tarball": "https://ziglang.org/builds/zig-powerpc64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "503922dd11374ddc44013ae3d7b07a9a17a1404da296bc30bf8fdc6d7273fd82",
"size": "52148468"
},
"powerpc64le-freebsd": {
"tarball": "https://ziglang.org/builds/zig-powerpc64le-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "1cf44603dec137a35551916ab5e542c36e6c1ec54df4efd7783f8ae66c80acc4",
"size": "53516496"
},
"riscv64-freebsd": {
"tarball": "https://ziglang.org/builds/zig-riscv64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "c4b1c691328abacf7f7b07f1900dd44c2f2852a0b2611378a2531975abb6f260",
"size": "53716116"
},
"x86_64-freebsd": {
"tarball": "https://ziglang.org/builds/zig-x86_64-freebsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "989caf4343e1dfb4c9a68f3d3146cac96729e6f6e58276001654307f56c43981",
"size": "53808720"
},
"aarch64-netbsd": {
"tarball": "https://ziglang.org/builds/zig-aarch64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "bbbe2aa2163d3c25d7414795537088245f9ec3cb53f10518b5824e1a31c35ff9",
"size": "49387264"
},
"arm-netbsd": {
"tarball": "https://ziglang.org/builds/zig-arm-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "264346f44664055d61c91087d0c55fdd04e3bdf1f2f0f969cd5efcf47d54b247",
"size": "52034816"
},
"x86-netbsd": {
"tarball": "https://ziglang.org/builds/zig-x86-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "f5e3cdab5d7e24552eee2200b1e4ac45cd84f13d5d64ae0f558ad0210d6883d5",
"size": "56915916"
},
"x86_64-netbsd": {
"tarball": "https://ziglang.org/builds/zig-x86_64-netbsd-0.16.0-dev.27+83f773fc6.tar.xz",
"shasum": "b6418c9d5036d2328271078687dbf2d772279e65787546974943855423fc6545",
"size": "53809524"
}
},
"0.15.1": {
"date": "2025-08-19",
"docs": "https://ziglang.org/documentation/master/",
"stdDocs": "https://ziglang.org/documentation/master/std/",
"notes": "https://ziglang.org/download/0.15.1/release-notes.html",
"src": {
"tarball": "https://ziglang.org/download/0.15.1/zig-0.15.1.tar.xz",
"shasum": "816c0303ab313f59766ce2097658c9fff7fafd1504f61f80f9507cd11652865f",
"size": "21359884"
},
"bootstrap": {
"tarball": "https://ziglang.org/download/0.15.1/zig-bootstrap-0.15.1.tar.xz",
"shasum": "4c0cfbcf12da144955761ca43f89e3c74956bce978694fc1d0a63555f5c0a199",
"size": "52711548"
},
"x86_64-macos": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-macos-0.15.1.tar.xz",
"shasum": "9919392e0287cccc106dfbcbb46c7c1c3fa05d919567bb58d7eb16bca4116184",
"size": "55791880"
},
"aarch64-macos": {
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-macos-0.15.1.tar.xz",
"shasum": "c4bd624d901c1268f2deb9d8eb2d86a2f8b97bafa3f118025344242da2c54d7b",
"size": "50644996"
},
"x86_64-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-linux-0.15.1.tar.xz",
"shasum": "c61c5da6edeea14ca51ecd5e4520c6f4189ef5250383db33d01848293bfafe05",
"size": "53734456"
},
"aarch64-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-linux-0.15.1.tar.xz",
"shasum": "bb4a8d2ad735e7fba764c497ddf4243cb129fece4148da3222a7046d3f1f19fe",
"size": "49493872"
},
"arm-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-linux-0.15.1.tar.xz",
"shasum": "3f4bf3b06b67d14e3f38be30798488c1abe3cf5b33de570cd0e87bbf09b978ad",
"size": "50477464"
},
"riscv64-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-linux-0.15.1.tar.xz",
"shasum": "7ca7a3e621436fb31d66a253132fc39574a13d2a1b4d8458af4f2e7c6e4374fe",
"size": "53597792"
},
"powerpc64le-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-linux-0.15.1.tar.xz",
"shasum": "339e2106496be70b614e32d444298216e676c36d08f5cd7bee3dd1dbd4567fd7",
"size": "53566944"
},
"x86-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-linux-0.15.1.tar.xz",
"shasum": "dff166f25fdd06e8341d831a71211b5ba7411463a6b264bdefa8868438690b6a",
"size": "56311228"
},
"loongarch64-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-loongarch64-linux-0.15.1.tar.xz",
"shasum": "0af18a012f3c4cffbef29ab5f42021484d92517f921a1380faacc89c50c1f89d",
"size": "50794376"
},
"s390x-linux": {
"tarball": "https://ziglang.org/download/0.15.1/zig-s390x-linux-0.15.1.tar.xz",
"shasum": "bcd13e5c88cf2d0da7100a48572195560069a94402c9bec3b316398204aa27e2",
"size": "53508068"
},
"x86_64-windows": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-windows-0.15.1.zip",
"shasum": "91e69e887ca8c943ce9a515df3af013d95a66a190a3df3f89221277ebad29e34",
"size": "92612958"
},
"aarch64-windows": {
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-windows-0.15.1.zip",
"shasum": "1f1bf16228b0ffcc882b713dc5e11a6db4219cb30997e13c72e8e723c2104ec6",
"size": "88458549"
},
"x86-windows": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-windows-0.15.1.zip",
"shasum": "fb1c07cffbb43615d3158ab8b8f5db5da1d48875eca99e1d7a8a0064ff63fc5b",
"size": "94516052"
},
"aarch64-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-freebsd-0.15.1.tar.xz",
"shasum": "4d9d25c775828d49ea037b2284310c295d951793da8ebe94827a54fed4cca3ce",
"size": "49358464"
},
"arm-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-freebsd-0.15.1.tar.xz",
"shasum": "9707f3a5f7e1a3d99c40db9a74de1acc61016a197ad289c2ad964f93cb213a18",
"size": "50904124"
},
"powerpc64-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64-freebsd-0.15.1.tar.xz",
"shasum": "79448884372db04e62f77a46b92245fce805063e69534729537f75cb4681e7e3",
"size": "52099020"
},
"powerpc64le-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-powerpc64le-freebsd-0.15.1.tar.xz",
"shasum": "f18ee12ba9c98a20b8d2ad0410c679e7aa5591bc7917f169fd6b377833d2c7ad",
"size": "53480976"
},
"riscv64-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-riscv64-freebsd-0.15.1.tar.xz",
"shasum": "ee9f864a6fd8b57c1f4fdbb11daa06578746a6f8253afe3f5ddb5a76f2eddd2d",
"size": "53677800"
},
"x86_64-freebsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-freebsd-0.15.1.tar.xz",
"shasum": "9714f8ac3d3dc908b1599837c6167f857c1efaa930f0cfa840699458de7c3cd0",
"size": "53782112"
},
"aarch64-netbsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-aarch64-netbsd-0.15.1.tar.xz",
"shasum": "b2a528399777583b85b89c54ccd45488af7709d6dd29a27323ec2a229db40910",
"size": "49368000"
},
"arm-netbsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-arm-netbsd-0.15.1.tar.xz",
"shasum": "93dc70109cbf5d2e022d20dfb56211978c4ea3c0b1e67aaabff947d8d1583aab",
"size": "52028184"
},
"x86-netbsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86-netbsd-0.15.1.tar.xz",
"shasum": "a91b26051822ff17f3143f859b87dce5b4a13e90928bd6daa6f07a895d3410f0",
"size": "56881864"
},
"x86_64-netbsd": {
"tarball": "https://ziglang.org/download/0.15.1/zig-x86_64-netbsd-0.15.1.tar.xz",
"shasum": "6d7ba6eca5b4434351ebdb971b7303c9934514f9bb8481852251dbd5b52b03d6",
"size": "53794836"
} }
}, },
"0.14.1": { "0.14.1": {

View File

@ -10,13 +10,13 @@ pub const ZigAllocator = struct {
}; };
} }
pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) ?*anyopaque { pub fn alloc(ctx: ?*const anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) ?*anyopaque {
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx)); const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null; const ret = self.rawAlloc(elem * nelems, std.math.log2_int(usize, alignment), @returnAddress()) orelse return null;
return @ptrCast(ret); return @ptrCast(ret);
} }
pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.C) void { pub fn free(ctx: ?*const anyopaque, ptr: ?*anyopaque, elem: usize, nelems: usize, alignment: usize) callconv(.c) void {
const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx)); const self: *const std.mem.Allocator = @ptrCast(@alignCast(ctx));
const memory: [*c]u8 = @ptrCast(ptr); const memory: [*c]u8 = @ptrCast(ptr);
const size = elem * nelems; const size = elem * nelems;

View File

@ -15,6 +15,7 @@ zig_library(
deps = [ deps = [
"//mlir", "//mlir",
"//mlir/dialects/stablehlo", "//mlir/dialects/stablehlo",
"//stdx",
], ],
) )

View File

@ -1,6 +1,7 @@
const std = @import("std"); const std = @import("std");
const mlir = @import("mlir"); const mlir = @import("mlir");
const stdx = @import("stdx");
pub fn func( pub fn func(
ctx: mlir.Context, ctx: mlir.Context,
@ -14,7 +15,7 @@ pub fn func(
location: mlir.Location, location: mlir.Location,
}, },
) mlir.Operation { ) mlir.Operation {
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){}; var attrs_tuple_buffer = stdx.BoundedArray(mlir.AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) }); attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", .string(ctx, args.sym_name) });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", .type_(.function(ctx, args.args, args.results)) });
if (args.arg_attrs.len > 0) { if (args.arg_attrs.len > 0) {

View File

@ -761,7 +761,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
); );
} }
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; var attrs: stdx.BoundedArray(mlir.AttrTuple, 32) = .{};
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
.{ "call_target_name", .string(ctx, opts.call_target_name) }, .{ "call_target_name", .string(ctx, opts.call_target_name) },
@ -770,7 +770,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
}); });
{ {
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; var output_operand_aliases: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| { for (opts.output_operand_aliases) |alias| {
output_operand_aliases.appendAssumeCapacity( output_operand_aliases.appendAssumeCapacity(
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(), OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
@ -789,14 +789,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
}; };
if (opts.operand_layouts) |layouts| { if (opts.operand_layouts) |layouts| {
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; var operand_layouts: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (layouts) |ol| { for (layouts) |ol| {
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
} }
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
} else { } else {
const operand_layouts = blk: { const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; var ret: stdx.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (inputs) |input| { for (inputs) |input| {
const ranked_type = input.getType().as(mlir.RankedTensorType).?; const ranked_type = input.getType().as(mlir.RankedTensorType).?;
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
@ -808,14 +808,14 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
} }
if (opts.result_layouts) |layouts| { if (opts.result_layouts) |layouts| {
var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; var result_layouts: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (layouts) |rl| { for (layouts) |rl| {
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
} }
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
} else { } else {
const result_layouts = blk: { const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; var ret: stdx.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (res_types) |t| { for (res_types) |t| {
const ranked_t = t.as(mlir.RankedTensorType).?; const ranked_t = t.as(mlir.RankedTensorType).?;
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
@ -1271,7 +1271,7 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
const WriterContext = @TypeOf(context); const WriterContext = @TypeOf(context);
c.stablehloVersionFromCompatibilityRequirement(req, (struct { c.stablehloVersionFromCompatibilityRequirement(req, (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }
@ -1292,7 +1292,7 @@ pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []
const WriterContext = @TypeOf(context); const WriterContext = @TypeOf(context);
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }
@ -1313,7 +1313,7 @@ pub fn getCurrentVersion() []const u8 {
const ContextWriter = @TypeOf(writer_); const ContextWriter = @TypeOf(writer_);
c.stablehloGetCurrentVersion((struct { c.stablehloGetCurrentVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const writer: *ContextWriter = @ptrCast(@alignCast(userdata)); const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; _ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }
@ -1339,7 +1339,7 @@ pub fn getMinimumVersion() []const u8 {
const WriterContext = @TypeOf(context); const WriterContext = @TypeOf(context);
c.stablehloGetMinimumVersion((struct { c.stablehloGetMinimumVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }
@ -1358,7 +1358,7 @@ pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u
const WriterContext = @TypeOf(context); const WriterContext = @TypeOf(context);
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }

View File

@ -40,7 +40,7 @@ pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
} }
/// Alternative to MlirWrapperType /// Alternative to MlirWrapperType
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void; pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
pub const Registry = struct { pub const Registry = struct {
_inner: c.MlirDialectRegistry, _inner: c.MlirDialectRegistry,
@ -171,7 +171,7 @@ pub const PassManager = struct {
} }
}; };
fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.C) void { fn _mlir_passpipeline_error(err: c.MlirStringRef, ctx: ?*anyopaque) callconv(.c) void {
_ = ctx; _ = ctx;
std.debug.print(">>ERROR: {s}\n", .{err.data}); std.debug.print(">>ERROR: {s}\n", .{err.data});
} }
@ -754,7 +754,7 @@ pub const Operation = struct {
state.addOperands(operands); state.addOperands(operands);
} else if (args.variadic_operands) |operands_segments| { } else if (args.variadic_operands) |operands_segments| {
const MAX_SEGMENTS = 32; const MAX_SEGMENTS = 32;
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
for (operands_segments) |operands| { for (operands_segments) |operands| {
state.addOperands(operands); state.addOperands(operands);
@ -764,7 +764,7 @@ pub const Operation = struct {
} else if (args.tt_variadic_operands) |operands_segments| { } else if (args.tt_variadic_operands) |operands_segments| {
// stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that. // stablehlo and triton seems to disagree on the expected type of operandSegmentSizes, let's fix that.
const MAX_SEGMENTS = 32; const MAX_SEGMENTS = 32;
var segments: std.BoundedArray(i32, MAX_SEGMENTS) = .{}; var segments: stdx.BoundedArray(i32, MAX_SEGMENTS) = .{};
for (operands_segments) |operands| { for (operands_segments) |operands| {
state.addOperands(operands); state.addOperands(operands);
@ -811,7 +811,7 @@ pub const Operation = struct {
@panic("Failed to create MLIR operation"); @panic("Failed to create MLIR operation");
}; };
if (args.verify and new_op.verify() == false) { if (args.verify and new_op.verify() == false) {
log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })}); log.err("Failed to verify MLIR operation:\n{f}", .{new_op.mlirFormatter(.{ .debug_info = true })});
@panic("Failed to verify MLIR operation"); @panic("Failed to verify MLIR operation");
} }
return new_op; return new_op;
@ -888,7 +888,7 @@ pub const Operation = struct {
c.mlirOperationWriteBytecode( c.mlirOperationWriteBytecode(
self._inner, self._inner,
(struct { (struct {
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable; _ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
} }
@ -916,7 +916,7 @@ pub const Operation = struct {
self._inner, self._inner,
cfg, cfg,
(struct { (struct {
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| { _ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
inner_writer_context.write_error = err; inner_writer_context.write_error = err;
@ -939,29 +939,25 @@ pub const Operation = struct {
op: Operation, op: Operation,
flags: OpPrintingFlags, flags: OpPrintingFlags,
pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: @This(), writer: anytype) !void {
_ = fmt;
_ = options;
self.op.print(writer, self.flags); self.op.print(writer, self.flags);
} }
}; };
pub fn print(self: Self, writer: anytype, flags: OpPrintingFlags) void { pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void {
const pflags = flags.create(); const pflags = flags.create();
defer c.mlirOpPrintingFlagsDestroy(pflags); defer c.mlirOpPrintingFlagsDestroy(pflags);
var writer_context = .{ .writer = writer };
const WriterContext = @TypeOf(writer_context);
c.mlirOperationPrintWithFlags( c.mlirOperationPrintWithFlags(
self._inner, self._inner,
pflags, pflags,
(struct { (struct {
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.C) void { pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_));
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable; _writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed");
} }
}).callback, }).callback,
&writer_context, writer,
); );
} }
@ -991,7 +987,7 @@ pub const Operation = struct {
c.mlirOperationWalk( c.mlirOperationWalk(
self._inner, self._inner,
(struct { (struct {
pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.C) c.MlirWalkResult { pub fn callback(op: c.MlirOperation, ctx_: ?*anyopaque) callconv(.c) c.MlirWalkResult {
const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_)); const inner_ctx_: *ContextType = @ptrCast(@alignCast(ctx_));
return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op })); return @intFromEnum(walkfn(inner_ctx_.ctx, .{ ._inner = op }));
} }
@ -1017,24 +1013,26 @@ pub const Operation = struct {
return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_)); return c.mlirOperationRemoveAttributeByName(self._inner, stringRef(name_));
} }
/// Hash the canonicalized IR, without debug information that can change across builds.
pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void { pub fn hash(op: Operation, hasher: *std.hash.XxHash64) void {
const NoError = error{};
const write = struct {
fn write(hasher_: *std.hash.XxHash64, bytes: []const u8) NoError!usize {
hasher_.update(bytes);
return bytes.len;
}
}.write;
const HashWriter = std.io.Writer(*std.hash.XxHash64, NoError, write);
const writer: HashWriter = .{ .context = hasher };
// Hash the canonicalized IR, without debug information that can change across builds.
// Note: before we where using op.writeBytecode(writer), // Note: before we where using op.writeBytecode(writer),
// but it crashes on some inputs, notably for unused variables. // but it crashes on some inputs, notably for unused variables.
// So we use the text representation of the mlir. // So we use the text representation of the mlir.
// See https://github.com/zml/zml/issues/97. // See https://github.com/zml/zml/issues/97.
// Writes can't fail because we are writing to a hasher. const flags = OpPrintingFlags.create(.{ .debug_info = false });
op.print(writer, .{ .debug_info = false }); defer c.mlirOpPrintingFlagsDestroy(flags);
c.mlirOperationPrintWithFlags(
op._inner,
flags,
(struct {
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
const _hasher: *std.hash.XxHash64 = @ptrCast(@alignCast(ctx_));
_hasher.update(str.data[0..str.length]);
}
}).callback,
hasher,
);
} }
}; };
@ -1185,9 +1183,9 @@ pub const BlockArgument = struct {
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner)); return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
} }
pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: BlockArgument, writer: anytype) !void {
const value = Value{ ._inner = self._inner }; const value = Value{ ._inner = self._inner };
return value.format(fmt, options, writer); return value.format(writer);
} }
}; };
@ -1234,8 +1232,8 @@ pub const Type = struct {
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type { pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
return struct { return struct {
pub fn format(self: SpecificType, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: SpecificType, writer: anytype) !void {
return try Type.format(self.asType(), fmt, options, writer); return try Type.format(self.asType(), writer);
} }
}; };
} }
@ -1344,7 +1342,7 @@ pub fn IntegerType(comptime it: IntegerTypes) type {
pub const eql = Type.eqlAny(Int); pub const eql = Type.eqlAny(Int);
pub const format = helpers.format(Int, c.mlirTypePrint); pub const format = helpers.format(Int, c.mlirTypePrint);
fn typeIsAIntegerExact(typ: c.MlirType) callconv(.C) bool { fn typeIsAIntegerExact(typ: c.MlirType) callconv(.c) bool {
const bit_width = Config[0]; const bit_width = Config[0];
const is_sign = Config[2]; const is_sign = Config[2];
return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ); return c.mlirTypeIsAInteger(typ) and (c.mlirIntegerTypeGetWidth(typ) == bit_width) and is_sign(typ);
@ -1422,20 +1420,20 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
_inner: c.MlirType, _inner: c.MlirType,
const Complex = @This(); const Complex = @This();
fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType { fn mlirC64TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx)); return c.mlirComplexTypeGet(c.mlirF32TypeGet(ctx));
} }
fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.C) c.MlirType { fn mlirC128TypeGet(ctx: c.MlirContext) callconv(.c) c.MlirType {
return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx)); return c.mlirComplexTypeGet(c.mlirF64TypeGet(ctx));
} }
fn mlirTypeIsAC64(typ: c.MlirType) callconv(.C) bool { fn mlirTypeIsAC64(typ: c.MlirType) callconv(.c) bool {
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ); const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
return c.mlirTypeIsAF32(element_type); return c.mlirTypeIsAF32(element_type);
} }
fn mlirTypeIsAC128(typ: c.MlirType) callconv(.C) bool { fn mlirTypeIsAC128(typ: c.MlirType) callconv(.c) bool {
const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ); const element_type: c.MlirType = c.mlirComplexTypeGetElementType(typ);
return c.mlirTypeIsAF64(element_type); return c.mlirTypeIsAF64(element_type);
} }
@ -1446,7 +1444,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
.unknown => .{ c.mlirTypeIsAComplex, null }, .unknown => .{ c.mlirTypeIsAComplex, null },
}; };
fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.C) bool { fn typeIsAUnknownComplex(typ: c.MlirType) callconv(.c) bool {
return c.mlirTypeIsAComplex(typ); return c.mlirTypeIsAComplex(typ);
} }
@ -1685,7 +1683,7 @@ pub const Block = struct {
.op_result => |parent_op| self.appendOperationRecursive(parent_op, opt), .op_result => |parent_op| self.appendOperationRecursive(parent_op, opt),
.block_argument => |arg| { .block_argument => |arg| {
// Hermetic blocks are not allowed to use arguments from other blocks. // Hermetic blocks are not allowed to use arguments from other blocks.
stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr }); stdx.debug.assert(opt == .open or self.eql(arg.block()), "Can't add {f} from {*} block to {*} block", .{ arg, arg.block()._inner.ptr, self._inner.ptr });
}, },
.null => @panic("InvalidMlir"), .null => @panic("InvalidMlir"),
} }
@ -1694,7 +1692,7 @@ pub const Block = struct {
pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void { pub fn appendOperationRecursive(self: Block, op: Operation, opt: RecursiveOpts) void {
if (op.block()) |prev_block| { if (op.block()) |prev_block| {
// Hermetic blocks are not allowed to reference values from other blocks. // Hermetic blocks are not allowed to reference values from other blocks.
stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {?x} block to {?x} block", .{ op, prev_block._inner.ptr, self._inner.ptr }); stdx.debug.assert(opt == .open or self.equals(prev_block), "Can't add {} from {*} block to {*} block", .{ op, prev_block._inner.ptr, self._inner.ptr });
return; return;
} }
for (0..op.numOperands()) |i| { for (0..op.numOperands()) |i| {
@ -1705,7 +1703,7 @@ pub const Block = struct {
}; };
pub const helpers = struct { pub const helpers = struct {
pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.C) bool) fn (T, T) bool { pub fn eql(T: type, equal_fn: fn (@FieldType(T, "_inner"), @FieldType(T, "_inner")) callconv(.c) bool) fn (T, T) bool {
return struct { return struct {
fn eql(a: T, b: T) bool { fn eql(a: T, b: T) bool {
return equal_fn(a._inner, b._inner); return equal_fn(a._inner, b._inner);
@ -1713,7 +1711,7 @@ pub const helpers = struct {
}.eql; }.eql;
} }
pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (*T) void { pub fn deinit(T: type, deinit_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (*T) void {
return struct { return struct {
fn deinit(a: *T) void { fn deinit(a: *T) void {
deinit_fn(a._inner); deinit_fn(a._inner);
@ -1722,7 +1720,7 @@ pub const helpers = struct {
}.deinit; }.deinit;
} }
pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.C) void) fn (T) void { pub fn dump(T: type, dump_fn: fn (@FieldType(T, "_inner")) callconv(.c) void) fn (T) void {
return struct { return struct {
fn dump(a: T) void { fn dump(a: T) void {
return dump_fn(a._inner); return dump_fn(a._inner);
@ -1730,7 +1728,7 @@ pub const helpers = struct {
}.dump; }.dump;
} }
pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (T) bool { pub fn isNull(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (T) bool {
return struct { return struct {
fn isNull(a: T) bool { fn isNull(a: T) bool {
return is_null_fn(a._inner); return is_null_fn(a._inner);
@ -1738,21 +1736,13 @@ pub const helpers = struct {
}.isNull; }.isNull;
} }
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.C) void) type { pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type {
return struct { return struct {
pub fn format( pub fn format(self: Any, writer: *std.Io.Writer) !void {
self: Any, const WriterWithErr = struct {
comptime fmt: []const u8, writer: *std.Io.Writer,
options: std.fmt.FormatOptions, err: ?std.Io.Writer.Error = null,
writer: anytype, fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
) !void {
_ = fmt;
_ = options;
const Writer = struct {
writer: @TypeOf(writer),
err: ?@TypeOf(writer).Error = null,
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.C) void {
var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx)); var ctx: *@This() = @alignCast(@ptrCast(opaque_ctx));
if (ctx.err) |_| return; if (ctx.err) |_| return;
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
@ -1762,14 +1752,14 @@ pub const helpers = struct {
} }
}; };
var context: Writer = .{ .writer = writer }; var context: WriterWithErr = .{ .writer = writer };
print_fn(self._inner, &Writer.printCallback, &context); print_fn(self._inner, &WriterWithErr.printCallback, &context);
if (context.err) |err| return err; if (context.err) |err| return err;
} }
}; };
} }
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) fn (@FieldType(T, "_inner")) ?T { pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
return struct { return struct {
fn wrapOr(inner: @FieldType(T, "_inner")) ?T { fn wrapOr(inner: @FieldType(T, "_inner")) ?T {
if (is_null_fn(inner)) return null; if (is_null_fn(inner)) return null;
@ -1778,7 +1768,7 @@ pub const helpers = struct {
}.wrapOr; }.wrapOr;
} }
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.C) bool) ?T { pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T {
if (is_null_fn(inner)) return null; if (is_null_fn(inner)) return null;
return .{ ._inner = inner }; return .{ ._inner = inner };
} }

View File

@ -424,7 +424,7 @@ pub const CallFrame = extern struct {
} }
}; };
pub const Handler = fn (*CallFrame) callconv(.C) ?*Error; pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) { pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
cancelled = c.XLA_FFI_Error_Code_CANCELLED, cancelled = c.XLA_FFI_Error_Code_CANCELLED,

View File

@ -88,7 +88,7 @@ pub const Api = struct {
return err; return err;
}, },
}; };
const DynGetPjrtApi = lib.lookup(*const fn () callconv(.C) *const Api, "GetPjrtApi") orelse { const DynGetPjrtApi = lib.lookup(*const fn () callconv(.c) *const Api, "GetPjrtApi") orelse {
std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library}); std.debug.panic("Unable to find GetPjrtApi symbol in library: {s}", .{library});
}; };
@ -100,7 +100,7 @@ pub const Api = struct {
} }
fn CallFnArgType(comptime func: Funcs) type { fn CallFnArgType(comptime func: Funcs) type {
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func)); const fti = @typeInfo(@FieldType(c.PJRT_Api, @tagName(func)));
const fn_ptr = @typeInfo(fti.optional.child); const fn_ptr = @typeInfo(fti.optional.child);
const fn_type_info = @typeInfo(fn_ptr.pointer.child); const fn_type_info = @typeInfo(fn_ptr.pointer.child);
const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?); const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?);
@ -403,8 +403,8 @@ pub const Client = opaque {
element_type: BufferType, element_type: BufferType,
layout: MemoryLayout, layout: MemoryLayout,
device: *const Device, device: *const Device,
on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.C) void = &struct { on_delete_callback: *const fn (device_buffer_ptr: ?*anyopaque, ctx: ?*anyopaque) callconv(.c) void = &struct {
fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.C) void {} fn call(_: ?*anyopaque, _: ?*anyopaque) callconv(.c) void {}
}.call, }.call,
on_delete_callback_arg: ?*anyopaque = null, on_delete_callback_arg: ?*anyopaque = null,
stream: ?*const Stream = null, stream: ?*const Stream = null,
@ -637,7 +637,7 @@ pub const GetCostAnalysisError = std.mem.Allocator.Error || ApiError;
pub const SerializeResult = struct { pub const SerializeResult = struct {
bytes: []const u8, bytes: []const u8,
handle: *anyopaque, handle: *anyopaque,
deleter: *const fn (?*anyopaque) callconv(.C) void, deleter: *const fn (?*anyopaque) callconv(.c) void,
pub fn deinit(self: *SerializeResult) void { pub fn deinit(self: *SerializeResult) void {
self.deleter(self.handle); self.deleter(self.handle);
@ -1036,7 +1036,7 @@ pub const Event = opaque {
}); });
} }
pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.C) void, user_arg: ?*anyopaque) ApiError!void { pub fn onReady(self: *Event, api: *const Api, func: *const fn (err: ?*Error, user_arg: ?*anyopaque) callconv(.c) void, user_arg: ?*anyopaque) ApiError!void {
_ = try api.call(.PJRT_Event_OnReady, .{ _ = try api.call(.PJRT_Event_OnReady, .{
.event = self.inner(), .event = self.inner(),
.callback = @ptrCast(func), .callback = @ptrCast(func),

View File

@ -3,6 +3,7 @@ load("@rules_zig//zig:defs.bzl", "zig_library", "zig_test")
zig_library( zig_library(
name = "stdx", name = "stdx",
srcs = [ srcs = [
"bounded_array.zig",
"debug.zig", "debug.zig",
"flags.zig", "flags.zig",
"fmt.zig", "fmt.zig",

412
stdx/bounded_array.zig Normal file
View File

@ -0,0 +1,412 @@
const std = @import("std");
const assert = std.debug.assert;
const mem = std.mem;
const testing = std.testing;
const Alignment = std.mem.Alignment;
/// A structure with an array and a length, that can be used as a slice.
///
/// Useful to pass around small arrays whose exact size is only known at
/// runtime, but whose maximum size is known at comptime, without requiring
/// an `Allocator`.
///
/// ```zig
/// var actual_size = 32;
/// var a = try BoundedArray(u8, 64).init(actual_size);
/// var slice = a.slice(); // a slice of the 64-byte array
/// var a_clone = a; // creates a copy - the structure doesn't use any internal pointers
/// ```
pub fn BoundedArray(comptime T: type, comptime buffer_capacity: usize) type {
return BoundedArrayAligned(T, .of(T), buffer_capacity);
}
/// A structure with an array, length and alignment, that can be used as a
/// slice.
///
/// Useful to pass around small explicitly-aligned arrays whose exact size is
/// only known at runtime, but whose maximum size is known at comptime, without
/// requiring an `Allocator`.
/// ```zig
// var a = try BoundedArrayAligned(u8, 16, 2).init(0);
// try a.append(255);
// try a.append(255);
// const b = @ptrCast(*const [1]u16, a.constSlice().ptr);
// try testing.expectEqual(@as(u16, 65535), b[0]);
/// ```
pub fn BoundedArrayAligned(
comptime T: type,
comptime alignment: Alignment,
comptime buffer_capacity: usize,
) type {
return struct {
const Self = @This();
buffer: [buffer_capacity]T align(alignment.toByteUnits()) = undefined,
len: usize = 0,
/// Set the actual length of the slice.
/// Returns error.Overflow if it exceeds the length of the backing array.
pub fn init(len: usize) error{Overflow}!Self {
if (len > buffer_capacity) return error.Overflow;
return Self{ .len = len };
}
/// View the internal array as a slice whose size was previously set.
pub fn slice(self: anytype) switch (@TypeOf(&self.buffer)) {
*align(alignment.toByteUnits()) [buffer_capacity]T => []align(alignment.toByteUnits()) T,
*align(alignment.toByteUnits()) const [buffer_capacity]T => []align(alignment.toByteUnits()) const T,
else => unreachable,
} {
return self.buffer[0..self.len];
}
/// View the internal array as a constant slice whose size was previously set.
pub fn constSlice(self: *const Self) []align(alignment.toByteUnits()) const T {
return self.slice();
}
/// Adjust the slice's length to `len`.
/// Does not initialize added items if any.
pub fn resize(self: *Self, len: usize) error{Overflow}!void {
if (len > buffer_capacity) return error.Overflow;
self.len = len;
}
/// Remove all elements from the slice.
pub fn clear(self: *Self) void {
self.len = 0;
}
/// Copy the content of an existing slice.
pub fn fromSlice(m: []const T) error{Overflow}!Self {
var list = try init(m.len);
@memcpy(list.slice(), m);
return list;
}
/// Return the element at index `i` of the slice.
pub fn get(self: Self, i: usize) T {
return self.constSlice()[i];
}
/// Set the value of the element at index `i` of the slice.
pub fn set(self: *Self, i: usize, item: T) void {
self.slice()[i] = item;
}
/// Return the maximum length of a slice.
pub fn capacity(self: Self) usize {
return self.buffer.len;
}
/// Check that the slice can hold at least `additional_count` items.
pub fn ensureUnusedCapacity(self: Self, additional_count: usize) error{Overflow}!void {
if (self.len + additional_count > buffer_capacity) {
return error.Overflow;
}
}
/// Increase length by 1, returning a pointer to the new item.
pub fn addOne(self: *Self) error{Overflow}!*T {
try self.ensureUnusedCapacity(1);
return self.addOneAssumeCapacity();
}
/// Increase length by 1, returning pointer to the new item.
/// Asserts that there is space for the new item.
pub fn addOneAssumeCapacity(self: *Self) *T {
assert(self.len < buffer_capacity);
self.len += 1;
return &self.slice()[self.len - 1];
}
/// Resize the slice, adding `n` new elements, which have `undefined` values.
/// The return value is a pointer to the array of uninitialized elements.
pub fn addManyAsArray(self: *Self, comptime n: usize) error{Overflow}!*align(alignment.toByteUnits()) [n]T {
const prev_len = self.len;
try self.resize(self.len + n);
return self.slice()[prev_len..][0..n];
}
/// Resize the slice, adding `n` new elements, which have `undefined` values.
/// The return value is a slice pointing to the uninitialized elements.
pub fn addManyAsSlice(self: *Self, n: usize) error{Overflow}![]align(alignment.toByteUnits()) T {
const prev_len = self.len;
try self.resize(self.len + n);
return self.slice()[prev_len..][0..n];
}
/// Remove and return the last element from the slice, or return `null` if the slice is empty.
pub fn pop(self: *Self) ?T {
if (self.len == 0) return null;
const item = self.get(self.len - 1);
self.len -= 1;
return item;
}
/// Return a slice of only the extra capacity after items.
/// This can be useful for writing directly into it.
/// Note that such an operation must be followed up with a
/// call to `resize()`
pub fn unusedCapacitySlice(self: *Self) []align(alignment.toByteUnits()) T {
return self.buffer[self.len..];
}
/// Insert `item` at index `i` by moving `slice[n .. slice.len]` to make room.
/// This operation is O(N).
pub fn insert(
self: *Self,
i: usize,
item: T,
) error{Overflow}!void {
if (i > self.len) {
return error.Overflow;
}
_ = try self.addOne();
var s = self.slice();
mem.copyBackwards(T, s[i + 1 .. s.len], s[i .. s.len - 1]);
self.buffer[i] = item;
}
/// Insert slice `items` at index `i` by moving `slice[i .. slice.len]` to make room.
/// This operation is O(N).
pub fn insertSlice(self: *Self, i: usize, items: []const T) error{Overflow}!void {
try self.ensureUnusedCapacity(items.len);
self.len += items.len;
mem.copyBackwards(T, self.slice()[i + items.len .. self.len], self.constSlice()[i .. self.len - items.len]);
@memcpy(self.slice()[i..][0..items.len], items);
}
/// Replace range of elements `slice[start..][0..len]` with `new_items`.
/// Grows slice if `len < new_items.len`.
/// Shrinks slice if `len > new_items.len`.
pub fn replaceRange(
self: *Self,
start: usize,
len: usize,
new_items: []const T,
) error{Overflow}!void {
const after_range = start + len;
var range = self.slice()[start..after_range];
if (range.len == new_items.len) {
@memcpy(range[0..new_items.len], new_items);
} else if (range.len < new_items.len) {
const first = new_items[0..range.len];
const rest = new_items[range.len..];
@memcpy(range[0..first.len], first);
try self.insertSlice(after_range, rest);
} else {
@memcpy(range[0..new_items.len], new_items);
const after_subrange = start + new_items.len;
for (self.constSlice()[after_range..], 0..) |item, i| {
self.slice()[after_subrange..][i] = item;
}
self.len -= len - new_items.len;
}
}
/// Extend the slice by 1 element.
pub fn append(self: *Self, item: T) error{Overflow}!void {
const new_item_ptr = try self.addOne();
new_item_ptr.* = item;
}
/// Extend the slice by 1 element, asserting the capacity is already
/// enough to store the new item.
pub fn appendAssumeCapacity(self: *Self, item: T) void {
const new_item_ptr = self.addOneAssumeCapacity();
new_item_ptr.* = item;
}
/// Remove the element at index `i`, shift elements after index
/// `i` forward, and return the removed element.
/// Asserts the slice has at least one item.
/// This operation is O(N).
pub fn orderedRemove(self: *Self, i: usize) T {
const newlen = self.len - 1;
if (newlen == i) return self.pop().?;
const old_item = self.get(i);
for (self.slice()[i..newlen], 0..) |*b, j| b.* = self.get(i + 1 + j);
self.set(newlen, undefined);
self.len = newlen;
return old_item;
}
/// Remove the element at the specified index and return it.
/// The empty slot is filled from the end of the slice.
/// This operation is O(1).
pub fn swapRemove(self: *Self, i: usize) T {
if (self.len - 1 == i) return self.pop().?;
const old_item = self.get(i);
self.set(i, self.pop().?);
return old_item;
}
/// Append the slice of items to the slice.
pub fn appendSlice(self: *Self, items: []const T) error{Overflow}!void {
try self.ensureUnusedCapacity(items.len);
self.appendSliceAssumeCapacity(items);
}
/// Append the slice of items to the slice, asserting the capacity is already
/// enough to store the new items.
pub fn appendSliceAssumeCapacity(self: *Self, items: []const T) void {
const old_len = self.len;
self.len += items.len;
@memcpy(self.slice()[old_len..][0..items.len], items);
}
/// Append a value to the slice `n` times.
/// Allocates more memory as necessary.
pub fn appendNTimes(self: *Self, value: T, n: usize) error{Overflow}!void {
const old_len = self.len;
try self.resize(old_len + n);
@memset(self.slice()[old_len..self.len], value);
}
/// Append a value to the slice `n` times.
/// Asserts the capacity is enough.
pub fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void {
const old_len = self.len;
self.len += n;
assert(self.len <= buffer_capacity);
@memset(self.slice()[old_len..self.len], value);
}
pub const Writer = if (T != u8)
@compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++
"but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)")
else
std.io.GenericWriter(*Self, error{Overflow}, appendWrite);
/// Initializes a writer which will write into the array.
pub fn writer(self: *Self) Writer {
return .{ .context = self };
}
/// Same as `appendSlice` except it returns the number of bytes written, which is always the same
/// as `m.len`. The purpose of this function existing is to match `std.io.GenericWriter` API.
fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize {
try self.appendSlice(m);
return m.len;
}
};
}
test BoundedArray {
var a = try BoundedArray(u8, 64).init(32);
try testing.expectEqual(a.capacity(), 64);
try testing.expectEqual(a.slice().len, 32);
try testing.expectEqual(a.constSlice().len, 32);
try a.resize(48);
try testing.expectEqual(a.len, 48);
const x = [_]u8{1} ** 10;
a = try BoundedArray(u8, 64).fromSlice(&x);
try testing.expectEqualSlices(u8, &x, a.constSlice());
var a2 = a;
try testing.expectEqualSlices(u8, a.constSlice(), a2.constSlice());
a2.set(0, 0);
try testing.expect(a.get(0) != a2.get(0));
try testing.expectError(error.Overflow, a.resize(100));
try testing.expectError(error.Overflow, BoundedArray(u8, x.len - 1).fromSlice(&x));
try a.resize(0);
try a.ensureUnusedCapacity(a.capacity());
(try a.addOne()).* = 0;
try a.ensureUnusedCapacity(a.capacity() - 1);
try testing.expectEqual(a.len, 1);
const uninitialized = try a.addManyAsArray(4);
try testing.expectEqual(uninitialized.len, 4);
try testing.expectEqual(a.len, 5);
try a.append(0xff);
try testing.expectEqual(a.len, 6);
try testing.expectEqual(a.pop(), 0xff);
a.appendAssumeCapacity(0xff);
try testing.expectEqual(a.len, 6);
try testing.expectEqual(a.pop(), 0xff);
try a.resize(1);
try testing.expectEqual(a.pop(), 0);
try testing.expectEqual(a.pop(), null);
var unused = a.unusedCapacitySlice();
@memset(unused[0..8], 2);
unused[8] = 3;
unused[9] = 4;
try testing.expectEqual(unused.len, a.capacity());
try a.resize(10);
try a.insert(5, 0xaa);
try testing.expectEqual(a.len, 11);
try testing.expectEqual(a.get(5), 0xaa);
try testing.expectEqual(a.get(9), 3);
try testing.expectEqual(a.get(10), 4);
try a.insert(11, 0xbb);
try testing.expectEqual(a.len, 12);
try testing.expectEqual(a.pop(), 0xbb);
try a.appendSlice(&x);
try testing.expectEqual(a.len, 11 + x.len);
try a.appendNTimes(0xbb, 5);
try testing.expectEqual(a.len, 11 + x.len + 5);
try testing.expectEqual(a.pop(), 0xbb);
a.appendNTimesAssumeCapacity(0xcc, 5);
try testing.expectEqual(a.len, 11 + x.len + 5 - 1 + 5);
try testing.expectEqual(a.pop(), 0xcc);
try testing.expectEqual(a.len, 29);
try a.replaceRange(1, 20, &x);
try testing.expectEqual(a.len, 29 + x.len - 20);
try a.insertSlice(0, &x);
try testing.expectEqual(a.len, 29 + x.len - 20 + x.len);
try a.replaceRange(1, 5, &x);
try testing.expectEqual(a.len, 29 + x.len - 20 + x.len + x.len - 5);
try a.append(10);
try testing.expectEqual(a.pop(), 10);
try a.append(20);
const removed = a.orderedRemove(5);
try testing.expectEqual(removed, 1);
try testing.expectEqual(a.len, 34);
a.set(0, 0xdd);
a.set(a.len - 1, 0xee);
const swapped = a.swapRemove(0);
try testing.expectEqual(swapped, 0xdd);
try testing.expectEqual(a.get(0), 0xee);
const added_slice = try a.addManyAsSlice(3);
try testing.expectEqual(added_slice.len, 3);
try testing.expectEqual(a.len, 36);
while (a.pop()) |_| {}
const w = a.writer();
const s = "hello, this is a test string";
try w.writeAll(s);
try testing.expectEqualStrings(s, a.constSlice());
}
test "BoundedArrayAligned" {
var a = try BoundedArrayAligned(u8, .@"16", 4).init(0);
try a.append(0);
try a.append(0);
try a.append(255);
try a.append(255);
const b = @as(*const [2]u16, @ptrCast(a.constSlice().ptr));
try testing.expectEqual(@as(u16, 0), b[0]);
try testing.expectEqual(@as(u16, 65535), b[1]);
}

View File

@ -47,8 +47,7 @@ const debug = @import("debug.zig");
/// Format and print an error message to stderr, then exit with an exit code of 1. /// Format and print an error message to stderr, then exit with an exit code of 1.
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn { pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
const stderr = std.io.getStdErr().writer(); std.debug.print("error: " ++ fmt_string ++ "\n", args);
stderr.print("error: " ++ fmt_string ++ "\n", args) catch {};
std.posix.exit(1); std.posix.exit(1);
} }

View File

@ -43,8 +43,8 @@ pub const IntFmt = struct {
}; };
pub const FloatFmt = enum(u8) { pub const FloatFmt = enum(u8) {
scientific = @intFromEnum(std.fmt.format_float.Format.scientific), scientific = @intFromEnum(std.fmt.Number.Mode.scientific),
decimal = @intFromEnum(std.fmt.format_float.Format.decimal), decimal = @intFromEnum(std.fmt.Number.Mode.decimal),
hex, hex,
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt { pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
@ -71,44 +71,34 @@ pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !vo
}; };
} }
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
const formatFloat = std.fmt.format_float.formatFloat;
var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined;
const x = switch (@typeInfo(@TypeOf(value))) { const x = switch (@typeInfo(@TypeOf(value))) {
.@"struct" => value.toF32(), .@"struct" => value.toF32(),
.float => value, .float => value,
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))), else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
}; };
const s_or_err = switch (full.fmt.float) { try switch (full.fmt.float) {
.scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }), .scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }),
.decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }), .decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }),
.hex => hex: { .hex => writer.printFloatHexOptions(x, .{ .mode = .hex }),
var buf_stream = std.io.fixedBufferStream(&buf);
std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable;
break :hex buf_stream.getWritten();
},
}; };
const s = s_or_err catch "(float)";
return std.fmt.formatBuf(s, full.options, writer);
} }
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
switch (@typeInfo(@TypeOf(value))) { switch (@typeInfo(@TypeOf(value))) {
.int => {}, .int => {},
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))), else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
} }
return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer); return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options);
} }
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
var buf: [48]u8 = undefined; var buf: [48]u8 = undefined;
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: { const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
buf[45..].* = "...".*; buf[45..].* = "...".*;
break :blk buf[0..]; break :blk buf[0..];
}; };
return std.fmt.formatBuf(s, full.options, writer); return try writer.alignBufferOptions(s, full.options);
} }
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void { pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {

View File

@ -26,7 +26,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
var num_buf: [8]u8 = undefined; var num_buf: [8]u8 = undefined;
tuple_fields[i] = .{ tuple_fields[i] = .{
.name = blk: { .name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); const s = std.fmt.printInt(&num_buf, i, 10, .lower, .{});
num_buf[s] = 0; num_buf[s] = 0;
break :blk num_buf[0..s :0]; break :blk num_buf[0..s :0];
}, },

View File

@ -1,3 +1,5 @@
pub const BoundedArray = @import("bounded_array.zig").BoundedArray;
pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned;
pub const debug = @import("debug.zig"); pub const debug = @import("debug.zig");
pub const flags = @import("flags.zig"); pub const flags = @import("flags.zig");
pub const fmt = @import("fmt.zig"); pub const fmt = @import("fmt.zig");

View File

@ -11,14 +11,11 @@ pub const Duration = struct {
return (1 * std.time.ns_per_s) / self.ns; return (1 * std.time.ns_per_s) / self.ns;
} }
pub fn format( pub fn formatDuration(duration: Duration, writer: *std.io.Writer) std.io.Writer.Error!void {
self: Duration, try writer.printDuration(duration.ns, .{});
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
return try std.fmt.fmtDuration(self.ns).format(fmt, options, writer);
} }
pub const format = formatDuration;
}; };
pub const Timer = struct { pub const Timer = struct {

View File

@ -4,6 +4,6 @@ def repo():
new_git_repository( new_git_repository(
name = "com_github_hejsil_clap", name = "com_github_hejsil_clap",
remote = "https://github.com/Hejsil/zig-clap.git", remote = "https://github.com/Hejsil/zig-clap.git",
commit = "068c38f89814079635692c7d0be9f58508c86173", commit = "5289e0753cd274d65344bef1c114284c633536ea",
build_file = "//:third_party/com_github_hejsil_clap/clap.bazel", build_file = "//:third_party/com_github_hejsil_clap/clap.bazel",
) )

View File

@ -0,0 +1,75 @@
module(
name = "rules_zig",
version = "20250827.0-35b6d57",
compatibility_level = 1,
)
bazel_dep(name = "aspect_bazel_lib", version = "2.8.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "platforms", version = "0.0.10")
zig = use_extension("//zig:extensions.bzl", "zig")
zig.index(file = "//zig/private:versions.json")
use_repo(zig, "zig_toolchains")
register_toolchains("@rules_zig//zig/target:all")
register_toolchains("@zig_toolchains//:all")
zig_dev = use_extension(
"//zig:extensions.bzl",
"zig",
dev_dependency = True,
)
zig_dev.toolchain(zig_version = "0.13.0")
zig_dev.toolchain(zig_version = "0.12.1")
zig_dev.toolchain(zig_version = "0.12.0")
zig_dev.toolchain(zig_version = "0.11.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "stardoc", version = "0.7.0", dev_dependency = True, repo_name = "io_bazel_stardoc")
bazel_dep(name = "gazelle", version = "0.38.0", dev_dependency = True, repo_name = "bazel_gazelle")
bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.7.1", dev_dependency = True)
bazel_dep(
name = "buildifier_prebuilt",
version = "7.3.1",
dev_dependency = True,
)
bazel_dep(name = "rules_multirun", version = "0.9.0", dev_dependency = True)
bazel_dep(name = "rules_python", version = "0.35.0", dev_dependency = True)
bazel_dep(
name = "rules_bazel_integration_test",
version = "0.25.0",
dev_dependency = True,
)
bazel_binaries = use_extension(
"@rules_bazel_integration_test//:extensions.bzl",
"bazel_binaries",
dev_dependency = True,
)
# NOTE: Keep in sync with WORKSPACE.
bazel_binaries.download(version_file = "//:.bazelversion")
bazel_binaries.download(version = "7.0.0")
use_repo(
bazel_binaries,
"bazel_binaries",
"bazel_binaries_bazelisk",
"build_bazel_bazel_.bazelversion",
"build_bazel_bazel_7_0_0",
)
# TODO[AH] Should be an implicit transitive dependency through rules_bazel_integration_test.
# However, if we do not include it explicitly, then the runfiles resolution for
# cgrindel_bazel_starlib/shlib/lib/message.sh fails in
# rules_bazel_integration_test/tools/update_deleted_packages.sh when invoked
# through the rules_multirun target //util:update.
bazel_dep(name = "cgrindel_bazel_starlib", version = "0.21.0", dev_dependency = True)
# Hack to get around a cc_common.link limitation.
# See https://github.com/bazelbuild/bazel/pull/23838
cc_common_link = use_repo_rule("//zig:extensions.bzl", "cc_common_link")
cc_common_link(name = "build_bazel_rules_android")

View File

@ -0,0 +1,5 @@
{
"strip_prefix": "rules_zig-35b6d57e94f3eb08e978c1133314606f4f38e216",
"url": "https://github.com/zml/rules_zig/archive/35b6d57e94f3eb08e978c1133314606f4f38e216.tar.gz",
"integrity": "sha256-FDnAqynTD2LB3W/IaBgocmnLz8CA9nyHZYYntC4plUU="
}

View File

@ -17,7 +17,8 @@
"20250519.0-233b207", "20250519.0-233b207",
"20250613.0-567662a", "20250613.0-567662a",
"20250714.0-b14a4f1", "20250714.0-b14a4f1",
"20250821.0-be53625" "20250821.0-be53625",
"20250827.0-35b6d57"
], ],
"yanked_versions": {} "yanked_versions": {}
} }

View File

@ -1,15 +1,7 @@
load("@rules_python//python:py_library.bzl", "py_library") load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python/entry_points:py_console_script_binary.bzl", "py_console_script_binary")
py_library( py_library(
name = "zml_utils", name = "zml_utils",
srcs = ["zml_utils.py"], srcs = ["zml_utils.py"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
py_console_script_binary(
name = "hf",
pkg = "@huggingface_hub//huggingface_hub:pkg",
script = "hf",
visibility = ["//visibility:public"],
)

View File

@ -153,6 +153,6 @@ pub const Allocator = struct {
@panic("Unsupported case"); @panic("Unsupported case");
} }
return @ptrCast(self.allocator.alignedAlloc(u8, @alignOf(*anyopaque), size) catch return null); return @ptrCast(self.allocator.alignedAlloc(u8, std.mem.Alignment.of(*anyopaque), size) catch return null);
} }
}; };

View File

@ -608,7 +608,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
if (std.mem.startsWith(u8, key, base_key)) { if (std.mem.startsWith(u8, key, base_key)) {
if (matches == 0) log.warn("Similar buffers found:", .{}); if (matches == 0) log.warn("Similar buffers found:", .{});
if (!shown_keys.contains(key)) { if (!shown_keys.contains(key)) {
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue; shown_keys.put(key, {}) catch continue;
matches += 1; matches += 1;
} }
@ -625,7 +625,7 @@ fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allo
const key = entry.key_ptr.*; const key = entry.key_ptr.*;
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) { if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
if (matches == 0) log.warn("Partial matches for '{s}':", .{component}); if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
log.warn(" - {s}: {}", .{ key, entry.value_ptr.*.shape() }); log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue; shown_keys.put(key, {}) catch continue;
matches += 1; matches += 1;
if (matches >= 5) break; if (matches >= 5) break;
@ -660,8 +660,8 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
return if (buffer_store.get(prefix)) |host_buffer| { return if (buffer_store.get(prefix)) |host_buffer| {
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer; var buf_with_metadata = host_buffer;
log.debug("Loading buffer {s} ({})", .{ prefix, obj._shape }); log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape });
stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s}", .{ obj._shape, host_buffer, prefix }); stdx.debug.assert(host_buffer.shape().eql(obj._shape), "loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {f} and {f} for tensor {s}", .{ obj._shape, host_buffer, prefix });
buf_with_metadata._shape = obj._shape; buf_with_metadata._shape = obj._shape;
obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{}); obj.* = try zml.Buffer.from(platform, buf_with_metadata, .{});
} else { } else {

View File

@ -69,7 +69,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix:
var new_prefix = prefix; var new_prefix = prefix;
if (prefix.items.len > 0) if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); new_prefix.items.len += std.fmt.printInt(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_prefix, item); try parseMetadata(allocator, store, new_prefix, item);
} }
}; };

View File

@ -1,12 +1,15 @@
const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const zml = @import("../zml.zig"); const Allocator = std.mem.Allocator;
const json = @import("json.zig");
const HostBuffer = zml.HostBuffer; const asynk = @import("async");
const stdx = @import("stdx");
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile; const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
const zml = @import("../zml.zig");
const HostBuffer = zml.HostBuffer;
const json = @import("json.zig");
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.@"zml/io"); const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
@ -16,7 +19,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
errdefer res.arena.deinit(); errdefer res.arena.deinit();
const arena = res.arena.allocator(); const arena = res.arena.allocator();
var files = std.ArrayList(MemoryMappedFile).init(arena); var files = std.array_list.Managed(MemoryMappedFile).init(arena);
errdefer files.deinit(); errdefer files.deinit();
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) { if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
@ -28,17 +31,19 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
return res; return res;
} }
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
const file = asynk.File.open(path, .{}) catch |err| { const file = asynk.File.open(path, .{}) catch |err| {
log.err("Failed to open {s}: {}", .{ path, err }); log.err("Failed to open {s}: {}", .{ path, err });
return err; return err;
}; };
errdefer file.close() catch unreachable; errdefer file.close() catch unreachable;
var r = file.reader(); var buffer: [4096]u8 = undefined;
var r = file.reader(&buffer);
const json_data = try allocator.alloc(u8, (try file.stat()).size); // const json_data = try allocator.alloc(u8, (try file.stat()).size);
_ = try r.readAtLeast(json_data, json_data.len); var json_reader = std.json.Reader.init(allocator, &r.interface);
const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed }); // _ = try r.readAtLeast(json_data, json_data.len);
const index = try std.json.parseFromTokenSourceLeaky(std.json.Value, allocator, &json_reader, .{ .allocate = .alloc_if_needed });
var loaded_files = std.StringHashMap(void).init(allocator); var loaded_files = std.StringHashMap(void).init(allocator);
const weight_map = index.object.get("weight_map").?.object; const weight_map = index.object.get("weight_map").?.object;
@ -62,23 +67,24 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
} }
} }
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
const file = asynk.File.open(path, .{}) catch |err| { const file = asynk.File.open(path, .{}) catch |err| {
log.err("Failed to open {s}: {}", .{ path, err }); log.err("Failed to open {s}: {}", .{ path, err });
return err; return err;
}; };
errdefer file.close() catch unreachable; errdefer file.close() catch unreachable;
var r = file.reader(); var buffer: [16 * 1024]u8 = undefined;
var r = file.reader(&buffer);
const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little)); const json_header_length: usize = @intCast(try r.interface.takeInt(u64, .little));
const json_data = try allocator.alloc(u8, json_header_length); const json_data = try allocator.alloc(u8, json_header_length);
const n = try r.readAll(json_data); try r.interface.readSliceAll(json_data);
if (n != json_header_length) { // if (n != json_header_length) {
log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path }); // log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
return error.CorruptedFile; // return error.CorruptedFile;
} // }
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{}); const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{});
var buffer_file = try MemoryMappedFile.init(file); var buffer_file = try MemoryMappedFile.init(file);
errdefer buffer_file.deinit(); errdefer buffer_file.deinit();
buffer_file.data_offset = 8 + json_header_length; buffer_file.data_offset = 8 + json_header_length;
@ -105,7 +111,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
const start: usize = @intCast(offset_field.array.items[0].integer); const start: usize = @intCast(offset_field.array.items[0].integer);
const end: usize = @intCast(offset_field.array.items[1].integer); const end: usize = @intCast(offset_field.array.items[1].integer);
const dtype = try stringToDtype(val.object.get("dtype").?.string); const dtype = try stringToDtype(val.object.get("dtype").?.string);
var dims: std.BoundedArray(i64, zml.Shape.MAX_RANK) = .{}; var dims: stdx.BoundedArray(i64, zml.Shape.MAX_RANK) = .{};
for (shape_field.items) |d| { for (shape_field.items) |d| {
dims.appendAssumeCapacity(d.integer); dims.appendAssumeCapacity(d.integer);
} }

View File

@ -45,7 +45,7 @@ pub const Buffer = struct {
_shards: Shards, _shards: Shards,
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES; pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
pub const Shards = std.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
pub const FromOptions = struct { pub const FromOptions = struct {
wait: bool = true, wait: bool = true,
@ -67,7 +67,7 @@ pub const Buffer = struct {
const n_partitions = platform.sharding().num_partitions; const n_partitions = platform.sharding().num_partitions;
const chunk_size = if (sharding_ax) |ax| cs: { const chunk_size = if (sharding_ax) |ax| cs: {
// This kind of sharding error should be detected earlier on. // This kind of sharding error should be detected earlier on.
stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions }); stdx.debug.assert(@rem(host_buffer.dim(ax), n_partitions) == 0, "Buffer.from({f}) expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ host_buffer, ax, n_partitions });
break :cs @divExact(host_buffer.dim(ax), n_partitions); break :cs @divExact(host_buffer.dim(ax), n_partitions);
} else 0; } else 0;
@ -201,7 +201,7 @@ pub const Buffer = struct {
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms); const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
if (duration_ms > 100) { if (duration_ms > 100) {
const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024); const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024);
log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 }); log.debug("Wrote constant({f}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 });
} }
} }
@ -301,7 +301,7 @@ pub const Buffer = struct {
/// Fetches the content of the given buffer into a stack variable of the given type. /// Fetches the content of the given buffer into a stack variable of the given type.
pub fn getValue(self: Buffer, T: type) !T { pub fn getValue(self: Buffer, T: type) !T {
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) }); stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
var res: T = undefined; var res: T = undefined;
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, std.mem.asBytes(&res));
@ -375,13 +375,9 @@ pub const Buffer = struct {
pub fn format( pub fn format(
self: Buffer, self: Buffer,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
_ = fmt; try writer.print("Buffer({f})", .{self._shape});
_ = options;
try writer.print("Buffer({_})", .{self._shape});
} }
pub fn getMemory(self: Buffer) *const pjrt.Memory { pub fn getMemory(self: Buffer) *const pjrt.Memory {

View File

@ -224,7 +224,7 @@ const CustomCall = struct {
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{}); try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
} }
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error { fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
if (call_frame.registeringHook()) return null; if (call_frame.registeringHook()) return null;
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable; const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
@ -275,3 +275,69 @@ fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
buffer_desc.data[0..buffer_shape.byteSize()], buffer_desc.data[0..buffer_shape.byteSize()],
); );
} }
pub const cuda = struct {
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
inferred = 4,
};
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int;
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int;
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int;
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
pub fn init() void {
var cudart = std.DynLib.open("libcudart.so.12") catch {
log.err("cudart not found, callback will segfault", .{});
return;
};
defer cudart.close();
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
@panic("cudaMemcpyAsync not found");
};
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
@panic("cudaMemcpy not found");
};
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
@panic("cudaStreamSynchronize not found");
};
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
@panic("cudaLaunchHostFunc not found");
};
}
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
check(err);
}
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
check(err);
}
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
check(err);
}
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
check(err);
}
pub fn check(err: c_int) void {
if (err == 0) return;
stdx.debug.panic("CUDA error: {d}", .{err});
}
};

View File

@ -394,8 +394,8 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
fn cb(ctx: *LocalContext, buffer: *const Buffer) void { fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
const model_sharding = ctx.buffers.len; const model_sharding = ctx.buffers.len;
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() }); stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
for (buffer._shards.constSlice(), 0..) |shard, d| { for (buffer._shards.constSlice(), 0..) |shard, d| {
ctx.buffers[d][ctx.index] = shard; ctx.buffers[d][ctx.index] = shard;
} }

View File

@ -87,17 +87,10 @@ fn FloatHelpers(Float: type) type {
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1)); return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
} }
pub fn format( pub fn formatNumber(x: Float, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
float: Float, switch (n.mode) {
comptime fmt: []const u8, .binary, .octal, .hex => try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ x.sign, x.exponent, x.mantissa }),
options: std.fmt.FormatOptions, else => try writer.printFloat(x.toF32(), n),
writer: anytype,
) !void {
_ = options;
if (fmt.len == 1 and fmt[0] == '_') {
try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ float.sign, float.exponent, float.mantissa });
} else {
try writer.print("{" ++ fmt ++ "}", .{float.toF32()});
} }
} }
}; };
@ -113,7 +106,7 @@ pub const Float32 = packed struct(u32) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
const f32_exp_bias = FloatHelpers(Float32).expBias(); const f32_exp_bias = FloatHelpers(Float32).expBias();
@ -128,7 +121,7 @@ pub const Float64 = packed struct(u64) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
pub const Float8E4M3B11FNUZ = packed struct(u8) { pub const Float8E4M3B11FNUZ = packed struct(u8) {
@ -151,7 +144,7 @@ pub const Float8E4M3B11FNUZ = packed struct(u8) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
pub const Float8E4M3FN = packed struct(u8) { pub const Float8E4M3FN = packed struct(u8) {
@ -169,7 +162,7 @@ pub const Float8E4M3FN = packed struct(u8) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
pub const Float8E4M3FNUZ = packed struct(u8) { pub const Float8E4M3FNUZ = packed struct(u8) {
@ -192,7 +185,7 @@ pub const Float8E4M3FNUZ = packed struct(u8) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
test "Float8E4" { test "Float8E4" {
@ -247,7 +240,7 @@ pub const Float8E5M2 = packed struct(u8) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
pub const Float8E5M2FNUZ = packed struct(u8) { pub const Float8E5M2FNUZ = packed struct(u8) {
@ -266,7 +259,7 @@ pub const Float8E5M2FNUZ = packed struct(u8) {
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32; pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32; pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
test "Float8E5" { test "Float8E5" {
@ -322,7 +315,7 @@ pub const BFloat16 = packed struct(u16) {
const Helpers = FloatHelpers(@This()); const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero; pub const zero = Helpers.zero;
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const format = Helpers.format; pub const formatNumber = Helpers.formatNumber;
}; };
test BFloat16 { test BFloat16 {

View File

@ -31,7 +31,7 @@ pub const HostBuffer = struct {
return .{ return .{
._shape = sh, ._shape = sh,
._strides = sh.computeStrides().buffer, ._strides = sh.computeStrides().buffer,
._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr, ._data = (try allocator.alignedAlloc(u8, .@"64", sh.byteSize())).ptr,
._memory = .{ .managed = .@"64" }, ._memory = .{ .managed = .@"64" },
}; };
} }
@ -170,8 +170,8 @@ pub const HostBuffer = struct {
/// Strided buffers can't use this method. /// Strided buffers can't use this method.
pub fn items(self: HostBuffer, comptime T: type) []const T { pub fn items(self: HostBuffer, comptime T: type) []const T {
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32. // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) }); stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
const ptr: [*]const T = @alignCast(@ptrCast(self._data)); const ptr: [*]const T = @alignCast(@ptrCast(self._data));
return ptr[0..self._shape.count()]; return ptr[0..self._shape.count()];
} }
@ -181,7 +181,7 @@ pub const HostBuffer = struct {
} }
pub fn bytes(self: HostBuffer) []const u8 { pub fn bytes(self: HostBuffer) []const u8 {
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
return self._data[0..self._shape.byteSize()]; return self._data[0..self._shape.byteSize()];
} }
@ -233,7 +233,7 @@ pub const HostBuffer = struct {
} }
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {f}", .{self});
var res = self; var res = self;
res._shape = self._shape.reshape(shape_); res._shape = self._shape.reshape(shape_);
res._strides = res._shape.computeStrides().buffer; res._strides = res._shape.computeStrides().buffer;
@ -252,9 +252,9 @@ pub const HostBuffer = struct {
const start: i64 = if (s.start < 0) s.start + d else s.start; const start: i64 = if (s.start < 0) s.start + d else s.start;
var end = s.end orelse d; var end = s.end orelse d;
if (end < 0) end += d; if (end < 0) end += d;
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s }); stdx.debug.assert(start >= 0 and start < d, "slice1d({f}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); stdx.debug.assert(end >= 1 and end <= d, "slice1d({f}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); stdx.debug.assert(start < end, "slice1d({f}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
const offset: usize = @intCast(start * self._strides[ax]); const offset: usize = @intCast(start * self._strides[ax]);
const new_shape = self.shape().set(ax, end - start); const new_shape = self.shape().set(ax, end - start);
@ -308,9 +308,9 @@ pub const HostBuffer = struct {
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
const ax = self._shape.axis(axis_); const ax = self._shape.axis(axis_);
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {f}", .{ ax, self });
var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() }; var strd: stdx.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
_ = strd.orderedRemove(ax); _ = strd.orderedRemove(ax);
return .{ return .{
@ -323,16 +323,11 @@ pub const HostBuffer = struct {
pub fn format( pub fn format(
self: HostBuffer, self: HostBuffer,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
_ = options; // TODO debug option
if (std.mem.eql(u8, fmt, "v")) { // try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) }); try writer.print("HostBuffer(.{f})", .{self._shape});
} else {
try writer.print("HostBuffer(.{_})", .{self._shape});
}
} }
/// Formatter for a HostBuffer that also print the values not just the shape. /// Formatter for a HostBuffer that also print the values not just the shape.
@ -344,21 +339,23 @@ pub const HostBuffer = struct {
pub const PrettyPrinter = struct { pub const PrettyPrinter = struct {
x: HostBuffer, x: HostBuffer,
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { // TODO(0.15.0) revisit pretty printer
pub fn format(self: PrettyPrinter, writer: anytype) !void {
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) { const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
.integer => .parse(i32, fmt), .integer => .parse(i32, "d"),
.float => .parse(f32, fmt), .float => .parse(f32, "d"),
else => .parse(void, fmt), else => .parse(void, ""),
}; };
const options: std.fmt.FormatOptions = .{};
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options }); try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
} }
}; };
pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void { pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void {
return self.prettyPrintIndented(writer, 4, 0, options); return self.prettyPrintIndented(writer, 4, 0, options);
} }
fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void { fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
if (self.rank() == 0) { if (self.rank() == 0) {
// Special case input tensor is a scalar // Special case input tensor is a scalar
return switch (self.dtype()) { return switch (self.dtype()) {
@ -376,7 +373,7 @@ pub const HostBuffer = struct {
if (self.rank() == 1) { if (self.rank() == 1) {
// Print a contiguous slice of items from the buffer in one line. // Print a contiguous slice of items from the buffer in one line.
// The number of items printed is controlled by the user through format syntax. // The number of items printed is controlled by the user through format syntax.
try writer.writeByteNTimes(' ', indent_level); try writer.splatByteAll(' ', indent_level);
switch (self.dtype()) { switch (self.dtype()) {
inline else => |dt| { inline else => |dt| {
const values = self.items(dt.toZigType()); const values = self.items(dt.toZigType());
@ -391,10 +388,10 @@ pub const HostBuffer = struct {
return; return;
} }
// TODO: consider removing the \n if dim is 1 for this axis. // TODO: consider removing the \n if dim is 1 for this axis.
try writer.writeByteNTimes(' ', indent_level); try writer.splatByteAll(' ', indent_level);
_ = try writer.write("{\n"); _ = try writer.write("{\n");
defer { defer {
writer.writeByteNTimes(' ', indent_level) catch {}; writer.splatByteAll(' ', indent_level) catch {};
_ = writer.write("},\n") catch {}; _ = writer.write("},\n") catch {};
} }
@ -409,7 +406,7 @@ pub const HostBuffer = struct {
if (n < num_rows) return; if (n < num_rows) return;
// Skip middle rows // Skip middle rows
if (n > 2 * num_rows) { if (n > 2 * num_rows) {
try writer.writeByteNTimes(' ', indent_level + 2); try writer.splatByteAll(' ', indent_level + 2);
_ = try writer.write("...\n"); _ = try writer.write("...\n");
} }
// Write last rows // Write last rows

View File

@ -358,10 +358,11 @@ pub fn MapRestrict(From: type, To: type) type {
const fields = union_info.fields; const fields = union_info.fields;
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined; var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
for (0.., fields) |i, field| { for (0.., fields) |i, field| {
const FT = map(field.type);
union_fields[i] = .{ union_fields[i] = .{
.name = field.name, .name = field.name,
.type = map(field.type), .type = FT,
.alignment = 0, .alignment = @alignOf(FT),
}; };
} }
return @Type(.{ .@"union" = .{ return @Type(.{ .@"union" = .{
@ -453,7 +454,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
else => {}, else => {},
} }
// Handle std.BoundedArray that contains uninitalized data. // Handle stdx.BoundedArray that contains uninitalized data.
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) { if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice()); return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
} }
@ -511,7 +512,7 @@ test visit {
const NestedAttrOptional = struct { nested: ?Attr }; const NestedAttrOptional = struct { nested: ?Attr };
const SimpleStruct = struct { prop: Attr }; const SimpleStruct = struct { prop: Attr };
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr }; const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) }; const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: stdx.BoundedArray(Attr, 8) };
const LocalContext = struct { result: usize }; const LocalContext = struct { result: usize };
@ -565,7 +566,7 @@ test visit {
} }
{ {
var context: LocalContext = .{ .result = 0 }; var context: LocalContext = .{ .result = 0 };
const prop5: std.BoundedArray(Attr, 8) = .{ const prop5: stdx.BoundedArray(Attr, 8) = .{
.buffer = @splat(.{ .data = 4 }), .buffer = @splat(.{ .data = 4 }),
.len = 2, .len = 2,
}; };
@ -677,11 +678,11 @@ test zip {
/// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist. /// finds all X in the given object, and write the result of func(X) into an arraylist.
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
const LocalContext = struct { const LocalContext = struct {
func_ctx: _CollectCtx(func), func_ctx: _CollectCtx(func),
out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT),
oom: bool = false, oom: bool = false,
}; };
var context = LocalContext{ .func_ctx = func_ctx, .out = out }; var context = LocalContext{ .func_ctx = func_ctx, .out = out };

View File

@ -51,7 +51,7 @@ pub const CompilationContext = struct {
_module: mlir.Module, _module: mlir.Module,
_blocks: std.BoundedArray(TaggedBlock, 64) = .{}, _blocks: stdx.BoundedArray(TaggedBlock, 64) = .{},
_fn_cache: FnCache = .{}, _fn_cache: FnCache = .{},
_block_args: TensorToBlockArg = .{}, _block_args: TensorToBlockArg = .{},
@ -63,7 +63,7 @@ pub const CompilationContext = struct {
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts }; const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); const AttributeList = stdx.BoundedArray(mlir.NamedAttribute, 3);
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext { pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
const mlir_registry = mlir.Registry.init() catch unreachable; const mlir_registry = mlir.Registry.init() catch unreachable;
@ -185,7 +185,9 @@ pub const CompilationContext = struct {
// Write the mlir to a file. All errors are discarded, since this is for debugging only. // Write the mlir to a file. All errors are discarded, since this is for debugging only.
const mlir_name = "module.mlir"; const mlir_name = "module.mlir";
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| { if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); var write_buf: [4096]u8 = undefined;
var writer = file.writer(&write_buf);
module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name }); log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
} else |_| { } else |_| {
log.warn("Failed to open {s}", .{mlir_name}); log.warn("Failed to open {s}", .{mlir_name});
@ -219,7 +221,7 @@ pub const CompilationContext = struct {
}; };
log.debug("******** ZML generated MLIR ********", .{}); log.debug("******** ZML generated MLIR ********", .{});
log.debug("{}", .{module.op().mlirFormatter(.{})}); log.debug("{f}", .{module.op().mlirFormatter(.{})});
if (timer) |*t| { if (timer) |*t| {
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms); const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
@ -339,7 +341,7 @@ pub const CompilationContext = struct {
const locations = try arena.alloc(mlir.Location, tensor_count); const locations = try arena.alloc(mlir.Location, tensor_count);
@memset(locations, mlir.Location.unknown(mlir_ctx)); @memset(locations, mlir.Location.unknown(mlir_ctx));
var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count); var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
@ -416,7 +418,7 @@ pub const CompilationContext = struct {
defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize"); defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) { self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
error.InvalidMlir => { error.InvalidMlir => {
log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})}); log.err("Failed to canonicalize invalid mlir: {f}", .{mlir_fn.mlirFormatter(.{})});
// user errors should have triggered a panic before we reach this. // user errors should have triggered a panic before we reach this.
@panic("ZML generated invalid mlir. Please open a bug report"); @panic("ZML generated invalid mlir. Please open a bug report");
}, },
@ -464,7 +466,7 @@ pub const CompilationContext = struct {
// This will break the day we writer another attribute before donation. // This will break the day we writer another attribute before donation.
// When the time come, do a more fancy lookup here to check if an argument // When the time come, do a more fancy lookup here to check if an argument
// is donated twice. // is donated twice.
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {} has been donated twice ! To {} and to {}", .{ a, index, attributes[a].buffer[0] }); stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {d} has been donated twice ! To {d} and to {any}", .{ a, index, attributes[a].buffer[0] });
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index)))); attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
// log.debug("attribute: {}", .{attributes[a].constSlice()}); // log.debug("attribute: {}", .{attributes[a].constSlice()});
}, },
@ -504,9 +506,9 @@ pub const CompilationContext = struct {
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator); var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator);
defer mlir_bytecode.deinit(); defer mlir_bytecode.deinit();
try mlir_bytecode.writer().print("{}", .{f.mlir_fn.mlirFormatter(.{})}); try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})});
// Check that the `x` input argument gives its buffer to the result tensor. // Check that the `x` input argument gives its buffer to the result tensor.
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
@ -545,7 +547,7 @@ pub const CompilationContext = struct {
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
const ctx = self.mlirCtx(); const ctx = self.mlirCtx();
const num_partitions = self.numPartitions(); const num_partitions = self.numPartitions();
var sharding_str: std.BoundedArray(u8, 128) = .{}; var sharding_str: stdx.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.Attribute.string(ctx, sharding_str.constSlice()); return mlir.Attribute.string(ctx, sharding_str.constSlice());
} }
@ -622,7 +624,7 @@ pub const CompilationContext = struct {
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
try self.allocator().dupeZ(u8, func_name) try self.allocator().dupeZ(u8, func_name)
else else
try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }); try std.fmt.allocPrintSentinel(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }, 0);
var arg_id: u16 = 0; var arg_id: u16 = 0;
var tensor_args: @TypeOf(args) = args; var tensor_args: @TypeOf(args) = args;
@ -702,7 +704,7 @@ pub const CompilationContext = struct {
const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id); const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id);
if (res.found_existing) { if (res.found_existing) {
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id }); stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {f} and {f} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
} else { } else {
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
} }
@ -777,7 +779,7 @@ pub const CompilationContext = struct {
.buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res| .buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res|
.{ res[0], res[1] } .{ res[0], res[1] }
else { else {
log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id }); log.err("Found unknown tensor id {f}({})", .{ tensor, tensor._id });
@panic("Found unknown tensor id"); @panic("Found unknown tensor id");
}, },
.mlir => |v| .{ v, tensor._donation }, .mlir => |v| .{ v, tensor._donation },

View File

@ -40,8 +40,8 @@ pub const TokenEmbedding = struct {
weight: Tensor, weight: Tensor,
pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor { pub fn forward(self: TokenEmbedding, idx: Tensor) Tensor {
stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {}", .{idx}); stdx.debug.assert(idx.dtype().isInteger(), "TokenEmbedding expects an integer input, received: {f}", .{idx});
stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {}", .{self.weight}); stdx.debug.assert(self.weight.rank() == 2, "TokenEmbedding expects it's weight Tensor to be a 2D matrix, got {f}", .{self.weight});
return self.weight.gatherValues(0, idx, .{}); return self.weight.gatherValues(0, idx, .{});
} }
}; };
@ -204,13 +204,13 @@ pub const RopeOpts = struct {
/// - pos_idx: optional tensor which indicates which positions are needed. /// - pos_idx: optional tensor which indicates which positions are needed.
/// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len. /// When not set `rope` return all positions from 0 to x.dim(.s) which is the max seq len.
pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {}", .{x}); stdx.debug.assert(@mod(x.dim(.hd), 2) == 0, "rope expects a even head dim (.hd), got {f}", .{x});
const idx = if (pos_idx) |idx| blk: { const idx = if (pos_idx) |idx| blk: {
stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={}, idx={})", .{ x, idx }); stdx.debug.assert(x.shape().hasTags(.{.hd}), "rope expects x argument to have .hd axes got: rope(x={f}, idx={f})", .{ x, idx });
break :blk idx; break :blk idx;
} else blk: { } else blk: {
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x}); stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={f})", .{x});
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s}); break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
}; };
const x_real, const x_imag = splitRealImg(x, opts.layout); const x_real, const x_imag = splitRealImg(x, opts.layout);
@ -273,7 +273,7 @@ fn _invFreq(opts: RopeOpts, inv_freq: []f32) void {
switch (opts.scaling) { switch (opts.scaling) {
.default => {}, .default => {},
.custom => { .custom => {
stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len }); stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {d}, got {d}", .{ N, opts.scaling.custom.len });
@memcpy(inv_freq, opts.scaling.custom); @memcpy(inv_freq, opts.scaling.custom);
}, },
.llama3 => |s| { .llama3 => |s| {
@ -318,7 +318,7 @@ test invFreq {
var inv_freq: @TypeOf(llama_freq) = undefined; var inv_freq: @TypeOf(llama_freq) = undefined;
_invFreq(llama_conf, &inv_freq); _invFreq(llama_conf, &inv_freq);
for (llama_freq, inv_freq, 0..) |expected, actual, i| { for (llama_freq, inv_freq, 0..) |expected, actual, i| {
errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq }); errdefer log.err("Mismatch at position {d}.\nExpected: {any}\nActual: {any}", .{ i, llama_freq, inv_freq });
try std.testing.expectApproxEqRel(expected, actual, 1e-5); try std.testing.expectApproxEqRel(expected, actual, 1e-5);
} }
} }
@ -462,7 +462,7 @@ pub fn upsample(
) Tensor { ) Tensor {
// TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here. // TODO(james): make `nearest` compatible with resizeBilinear and resizeBicubic, and wrap them here.
// resize* have API which are more explicit, this assume you want to scale the N-2 last axes. // resize* have API which are more explicit, this assume you want to scale the N-2 last axes.
stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {}", .{input}); stdx.debug.assert(3 <= input.rank() and input.rank() <= 5, "upsample is only implemented for (3,4,5)-D tensors, received {f}", .{input});
stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{}); stdx.debug.assert(opts.scale_factor.len == 1 or opts.scale_factor.len == input.rank() - 2, "scale factors", .{});
return switch (opts.mode) { return switch (opts.mode) {
.nearest => { .nearest => {
@ -791,7 +791,7 @@ pub fn causalAttnMask(
attn_window_len: ?u32, attn_window_len: ?u32,
) Tensor { ) Tensor {
const attn_shape = Shape.init(attn_shape_, dtype); const attn_shape = Shape.init(attn_shape_, dtype);
stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({}) shape need to be exactly 2 axes", .{attn_shape}); stdx.debug.assert(attn_shape.rank() == 2, "causalAttnMask({f}) shape need to be exactly 2 axes", .{attn_shape});
const qlen = attn_shape.dim(-2); const qlen = attn_shape.dim(-2);
const q_idx = Tensor.iota(attn_shape, -2); const q_idx = Tensor.iota(attn_shape, -2);
const klen = attn_shape.dim(-1); const klen = attn_shape.dim(-1);
@ -843,7 +843,7 @@ pub const SdpaOpts = struct {
pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
var q, var k, var v = .{ q_, k_, v_ }; var q, var k, var v = .{ q_, k_, v_ };
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! ";
const err_args = .{ q, k, v, opts.attn_mask }; const err_args = .{ q, k, v, opts.attn_mask };
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
@ -909,8 +909,8 @@ const SdpaMemEfficient = struct {
chunking: SdpaChunks, chunking: SdpaChunks,
fn forward(self: SdpaMemEfficient) Tensor { fn forward(self: SdpaMemEfficient) Tensor {
stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.q, self.chunking }); stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.q, self.chunking });
stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.k, self.chunking }); stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({f}, {})", .{ self.k, self.chunking });
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size)); const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
const ctx = zml.module.CompilationContext.current(); const ctx = zml.module.CompilationContext.current();
@ -1054,7 +1054,7 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft
// Consider implementing sdpa from sdpaChunk. // Consider implementing sdpa from sdpaChunk.
var q, var k, var v = .{ q_, k_, v_ }; var q, var k, var v = .{ q_, k_, v_ };
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; const err_template = "sdpa(q: {f}, k: {f}, v: {f}, attn: {?f}) is invalid ! ";
const err_args = .{ q, k, v, opts.attn_mask }; const err_args = .{ q, k, v, opts.attn_mask };
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);

View File

@ -51,7 +51,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
var fba = std.heap.FixedBufferAllocator.init(&buffer); var fba = std.heap.FixedBufferAllocator.init(&buffer);
const allocator = fba.allocator(); const allocator = fba.allocator();
const backend_config = std.fmt.allocPrintZ( const backend_config = std.fmt.allocPrintSentinel(
allocator, allocator,
\\{{ \\{{
\\ "operation_queue_id":"0", \\ "operation_queue_id":"0",
@ -110,6 +110,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
q.dim(.q), q.dim(.q),
k.dim(.k), k.dim(.k),
}, },
0,
) catch unreachable; ) catch unreachable;
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0)); var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));

View File

@ -155,7 +155,7 @@ pub fn reduce(
// To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit, // To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
// to find the correct mlir.Value, but we first broadcast before creating the final // to find the correct mlir.Value, but we first broadcast before creating the final
// Tensor struct. // Tensor struct.
var broadcasting_axes: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; var broadcasting_axes: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{};
for (0..Tensor.MAX_RANK) |i| { for (0..Tensor.MAX_RANK) |i| {
if (std.mem.indexOfScalar(i64, axes, @intCast(i)) == null) { if (std.mem.indexOfScalar(i64, axes, @intCast(i)) == null) {
broadcasting_axes.append(@intCast(i)) catch unreachable; broadcasting_axes.append(@intCast(i)) catch unreachable;
@ -437,15 +437,15 @@ pub fn if_(
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
} }
stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred}); stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {f}", .{pred});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
var true_shapes = std.ArrayList(Shape).init(ctx.allocator()); var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
defer true_shapes.deinit(); defer true_shapes.deinit();
var false_shapes = std.ArrayList(Shape).init(ctx.allocator()); var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
defer false_shapes.deinit(); defer false_shapes.deinit();
var failed_to_collect = false; var failed_to_collect = false;
@ -456,9 +456,9 @@ pub fn if_(
failed_to_collect = true; failed_to_collect = true;
}; };
if (!failed_to_collect) { if (!failed_to_collect) {
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| { for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
} }
} }
@ -751,7 +751,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
meta.visit((struct { meta.visit((struct {
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new }); stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {f}", .{ tensor.rank(), new });
// copy tags and sharding info over // copy tags and sharding info over
// some ops can change dims eg reduceWindow, so we trust mlir here. // some ops can change dims eg reduceWindow, so we trust mlir here.
new._shape._tags = tensor._shape._tags; new._shape._tags = tensor._shape._tags;
@ -932,7 +932,7 @@ pub fn scatter(
const n_inputs = meta.count(Tensor, &inputs); const n_inputs = meta.count(Tensor, &inputs);
const n_updates = meta.count(Tensor, &updates); const n_updates = meta.count(Tensor, &updates);
stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {} and {}", .{ n_inputs, n_updates }); stdx.debug.assert(n_inputs == n_updates, "zml.ops.scatter expects the same number of tensors in inputs and updates, got {d} and {d}", .{ n_inputs, n_updates });
// Note: I was a bit lazy here, and I only look at tags on the first tensor. // Note: I was a bit lazy here, and I only look at tags on the first tensor.
// we probably should check all of them. // we probably should check all of them.
@ -944,7 +944,7 @@ pub fn scatter(
// validate coord axes: all coord_axes should exist inside self // validate coord axes: all coord_axes should exist inside self
for (indices_axes.constSlice()) |t| { for (indices_axes.constSlice()) |t| {
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() }); stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={f} and indices={any}", .{ self, indices_axes.constSlice() });
} }
// Handle scalar indices by broadcasting them to the indices with the highest rank. // Handle scalar indices by broadcasting them to the indices with the highest rank.
@ -958,8 +958,8 @@ pub fn scatter(
break :blk higher_rank; break :blk higher_rank;
}; };
for (indices_per_axis.slice()) |*idx| { for (indices_per_axis.slice()) |*idx| {
stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {_}", .{indices_per_axis.slice()}); stdx.debug.assert(idx.shape().canBroadcastTo(indices_shape), "zml.ops.scatter expects all indices tensor to have the same shape, got {any}", .{indices_per_axis.slice()});
stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {_}", .{indices_per_axis.slice()}); stdx.debug.assert(idx.dtype() == indices_shape.dtype(), "zml.ops.scatter expects all indices tensor to have the same dtype, got {any}", .{indices_per_axis.slice()});
idx.* = idx.broad(indices_shape); idx.* = idx.broad(indices_shape);
} }
@ -972,7 +972,7 @@ pub fn scatter(
var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes); var config = scatterConfig(self.shape(), update.shape(), indices_per_axis, indices_axes);
const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes); const indices = scatterPrepareIndices(&config, self.shape(), update.shape(), &indices_per_axis, &indices_axes);
// const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len; // const n_indices_axes = update.rank() - _collectAxes(AxisKind, up_kind, .update_window).len;
// stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({_}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={_}", .{ self, index_tensors, indices_axes.constSlice(), update }); // stdx.debug.assert(n_indices_axe == indices_axes.len, "scatter({f}, {any}) expects 'updates' to contain all axes from 'indices', got indices={s}, updates={f}", .{ self, index_tensors, indices_axes.constSlice(), update });
const mlir_ctx = ctx.mlirCtx(); const mlir_ctx = ctx.mlirCtx();
var _scalar: T = inputs; var _scalar: T = inputs;
@ -985,10 +985,10 @@ pub fn scatter(
const UpdateS = BlockSign(update_fn); const UpdateS = BlockSign(update_fn);
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar }); const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
var input_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM");
defer input_values.deinit(); defer input_values.deinit();
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable; meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable;
var updates_values = std.ArrayList(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM");
defer updates_values.deinit(); defer updates_values.deinit();
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable; meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable;
@ -1029,8 +1029,8 @@ pub fn scatter(
} }
const ScatterConfig = struct { const ScatterConfig = struct {
op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}, up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{},
indices_batch_axes: Shape.DimsArray = .{}, indices_batch_axes: Shape.DimsArray = .{},
scatter_to_operand_axes: Shape.DimsArray = .{}, scatter_to_operand_axes: Shape.DimsArray = .{},
updates_transpose: Shape.AxesArray = .{}, updates_transpose: Shape.AxesArray = .{},
@ -1041,11 +1041,11 @@ const AxisKind = enum { batching, update_window, inserted_window, window_id };
fn scatterConfig( fn scatterConfig(
op: Shape, op: Shape,
update: Shape, update: Shape,
indices_per_axis: std.BoundedArray(Tensor, Tensor.MAX_RANK), indices_per_axis: stdx.BoundedArray(Tensor, Tensor.MAX_RANK),
indices_axes: Shape.TagsArray, indices_axes: Shape.TagsArray,
) ScatterConfig { ) ScatterConfig {
var op_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; var op_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
var up_kind: std.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{}; var up_kind: stdx.BoundedArray(AxisKind, Tensor.MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{}; var indices_batch_axes: Shape.DimsArray = .{};
var scatter_to_operand_axes: Shape.DimsArray = .{}; var scatter_to_operand_axes: Shape.DimsArray = .{};
var updates_transpose: Shape.AxesArray = .{}; var updates_transpose: Shape.AxesArray = .{};
@ -1058,7 +1058,7 @@ fn scatterConfig(
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t)); scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
} }
for (indices.tags()) |t| { for (indices.tags()) |t| {
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices }); stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={f}, updates={f} and indices={f}", .{ op, update, indices });
updates_transpose.appendAssumeCapacity(update.axis(t)); updates_transpose.appendAssumeCapacity(update.axis(t));
} }
@ -1094,11 +1094,11 @@ fn scatterConfig(
if (indices.hasTag(t) != null) { if (indices.hasTag(t) != null) {
up_kind.appendAssumeCapacity(.window_id); up_kind.appendAssumeCapacity(.window_id);
} else if (op.hasTag(t)) |self_ax| { } else if (op.hasTag(t)) |self_ax| {
stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={_}, updates={_}.", .{ t, op, update }); stdx.debug.assert(update.dim(up_ax) <= op.dim(self_ax), "scatter expects the slices described in 'updates' to fit inside 'op', but along axis .{s} it doesn't. Got op={f}, updates={f}.", .{ t, op, update });
up_kind.appendAssumeCapacity(.update_window); up_kind.appendAssumeCapacity(.update_window);
} else { } else {
// TODO: consider accepting untagged update here. // TODO: consider accepting untagged update here.
std.debug.panic("scatter expects 'updates' to be made of axes from op={_} and from indices={s}, got unknown tag {s} in {_}", .{ op, indices_axes.constSlice(), t, update }); std.debug.panic("scatter expects 'updates' to be made of axes from op={f} and from indices={any}, got unknown tag {s} in {f}", .{ op, indices_axes.constSlice(), std.mem.sliceTo(t, 0), update });
} }
} }
} else { } else {
@ -1174,7 +1174,7 @@ fn scatterPrepareIndices(
cfg: *ScatterConfig, cfg: *ScatterConfig,
op: Shape, op: Shape,
update: Shape, update: Shape,
indices_per_axis: *std.BoundedArray(Tensor, Tensor.MAX_RANK), indices_per_axis: *stdx.BoundedArray(Tensor, Tensor.MAX_RANK),
indices_axes: *Shape.TagsArray, indices_axes: *Shape.TagsArray,
) Tensor { ) Tensor {
var old_scatter_to_op_axes = cfg.scatter_to_operand_axes; var old_scatter_to_op_axes = cfg.scatter_to_operand_axes;
@ -1194,7 +1194,7 @@ fn scatterPrepareIndices(
// Reorder the axes so that in indices_per_axis is ordered like in op if possible. // Reorder the axes so that in indices_per_axis is ordered like in op if possible.
// TODO: transpose updates if needed // TODO: transpose updates if needed
var indices: std.BoundedArray(Tensor, Tensor.MAX_RANK) = .{}; var indices: stdx.BoundedArray(Tensor, Tensor.MAX_RANK) = .{};
var scatter_to_op_axes: Shape.DimsArray = .{}; var scatter_to_op_axes: Shape.DimsArray = .{};
while (old_scatter_to_op_axes.len > 0) { while (old_scatter_to_op_axes.len > 0) {
@ -1209,7 +1209,7 @@ fn scatterPrepareIndices(
for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| { for (scatter_to_op_axes.constSlice(), 0..) |sc_ax, i| {
if (i != sc_ax) { if (i != sc_ax) {
log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({_}, {any}, {_}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update }); log.warn("Found a slow scatter pattern, which is going to generate a while loop: scatter({f}, {any}, {f}). Because the index axes aren't the major ones in the input tensor.", .{ op, scatter_to_op_axes.constSlice(), update });
break; break;
} }
} }

View File

@ -75,14 +75,14 @@ pub const Client = opaque {
} }
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
var bytecode = std.ArrayList(u8).init(allocator); var bytecode: std.array_list.Managed(u8) = .init(allocator);
defer bytecode.deinit(); defer bytecode.deinit();
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| { module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| {
log.err("failed to write module bytecode: {}", .{err}); log.err("failed to write module bytecode: {}", .{err});
return err; return err;
}; };
var serialized_buffer = std.ArrayList(u8).init(allocator); var serialized_buffer: std.array_list.Managed(u8) = .init(allocator);
defer serialized_buffer.deinit(); defer serialized_buffer.deinit();
const stablehlo_version = blk: { const stablehlo_version = blk: {
@ -220,7 +220,7 @@ pub const Event = opaque {
}{}; }{};
try self.inner().onReady(api, &(struct { try self.inner().onReady(api, &(struct {
fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.c) void {
const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?));
ctx_.err = err; ctx_.err = err;
ctx_.event.set(); ctx_.event.set();

View File

@ -15,7 +15,7 @@ pub const CompilationOptions = struct {
xla_dump_fusion_visualization: bool = false, xla_dump_fusion_visualization: bool = false,
xla_dump_hlo_pass_re: ?[]const u8 = null, xla_dump_hlo_pass_re: ?[]const u8 = null,
sharding_enabled: bool = false, sharding_enabled: bool = false,
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
}; };
pub const Platform = struct { pub const Platform = struct {

View File

@ -22,9 +22,9 @@ pub const Shape = struct {
pub const TagUnknown = "_".ptr; pub const TagUnknown = "_".ptr;
const TagLast = "last".ptr; const TagLast = "last".ptr;
pub const DimsArray = std.BoundedArray(i64, MAX_RANK); pub const DimsArray = stdx.BoundedArray(i64, MAX_RANK);
pub const TagsArray = std.BoundedArray(Tag, MAX_RANK); pub const TagsArray = stdx.BoundedArray(Tag, MAX_RANK);
pub const AxesArray = std.BoundedArray(u3, MAX_RANK); pub const AxesArray = stdx.BoundedArray(u3, MAX_RANK);
pub const ShardingInfo = @Vector(MAX_RANK, bool); pub const ShardingInfo = @Vector(MAX_RANK, bool);
const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK }; const UnknownTags: TagsArray = .{ .len = 0, .buffer = [_]Tag{TagUnknown} ** MAX_RANK };
@ -300,7 +300,7 @@ pub const Shape = struct {
fn axisFromInt(self: Shape, a: isize) u3 { fn axisFromInt(self: Shape, a: isize) u3 {
const rk: i8 = self.rank(); const rk: i8 = self.rank();
if (a < -rk or a > rk) { if (a < -rk or a > rk) {
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a }); stdx.debug.panic("Tensor {f} doesn't have dimension: {d}", .{ self, a });
} }
return if (a < 0) return if (a < 0)
@intCast(a + rk) @intCast(a + rk)
@ -341,9 +341,9 @@ pub const Shape = struct {
} }
fn axisFromTag(self: Shape, d: Tag) u3 { fn axisFromTag(self: Shape, d: Tag) u3 {
stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {}", .{ d, self }); stdx.debug.assert(d != TagUnknown, "The unknown tag .{s} can't be used to fetch axis in {f}", .{ d, self });
return self.axisFromTagMaybe(d) orelse { return self.axisFromTagMaybe(d) orelse {
stdx.debug.panic("Tensor {} doesn't have dimension with tag: {s}", .{ self, d }); stdx.debug.panic("Tensor {f} doesn't have dimension with tag: {s}", .{ self, d });
}; };
} }
@ -357,7 +357,7 @@ pub const Shape = struct {
pub fn count(self: Shape) usize { pub fn count(self: Shape) usize {
var res: i64 = 1; var res: i64 = 1;
for (self.dims()) |d| { for (self.dims()) |d| {
stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {}", .{self}); stdx.debug.assert(d >= 0, "Can't count elements in shape with negative dimension: {f}", .{self});
res *= d; res *= d;
} }
return @intCast(res); return @intCast(res);
@ -388,12 +388,11 @@ pub const Shape = struct {
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32" /// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
pub fn format( pub fn format(
self: Shape, self: Shape,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
_ = options; // TODO: impl alternative format
const bare_fmt = fmt.len == 1 and fmt[0] == '_'; // const bare_fmt = fmt.len == 1 and fmt[0] == '_';
const bare_fmt = true;
_ = try writer.write(if (bare_fmt) "{" else "Shape({"); _ = try writer.write(if (bare_fmt) "{" else "Shape({");
var need_comma = false; var need_comma = false;
@ -441,12 +440,12 @@ pub const Shape = struct {
var new_shape: Shape = .{ ._dtype = self.dtype() }; var new_shape: Shape = .{ ._dtype = self.dtype() };
new_shape._dims, new_shape._tags = parseDimensions(new_shape_); new_shape._dims, new_shape._tags = parseDimensions(new_shape_);
new_shape.inferMissingAxis(self.count()); new_shape.inferMissingAxis(self.count());
stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {d} to {d}", .{ self.dims(), new_shape.dims() }); stdx.debug.assert(self.count() == new_shape.count(), "Can't reshape {any} to {any}", .{ self.dims(), new_shape.dims() });
return new_shape; return new_shape;
} }
fn inferMissingAxis(self: *Shape, n_: usize) void { fn inferMissingAxis(self: *Shape, n_: usize) void {
stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {}", .{self.*}); stdx.debug.assert(std.mem.count(i64, self.dims(), &.{-1}) < 2, "Cannot infer multiple dimensions when reshaping to: {f}", .{self.*});
const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return; const inferred_ax = std.mem.indexOfScalar(i64, self.dims(), -1) orelse return;
// We can't use `self.count()` yet cause we have negative dims. // We can't use `self.count()` yet cause we have negative dims.
@ -524,7 +523,7 @@ pub const Shape = struct {
} }
pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape { pub fn insertTag(self: Shape, axis_: anytype, d: i64, tag_: anytype) Shape {
stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {}, it's already at max rank.", .{self}); stdx.debug.assert(self.rank() < MAX_RANK - 1, "Can't insert new axis in {f}, it's already at max rank.", .{self});
const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last) const ax = if (@TypeOf(axis_) == EnumLiteral and axis_ == .last)
self.rank() self.rank()
@ -652,7 +651,7 @@ pub const Shape = struct {
var res = self; var res = self;
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {any}", .{ self, tagz }); stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {any}", .{ self, tagz });
for (tagz, 0..) |tag_, i| { for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
@ -660,7 +659,7 @@ pub const Shape = struct {
} }
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {}, got {}", .{ self, tagz }); stdx.debug.assert(tagz.len == self.rank(), "Not enough tags for shape {f}, got {}", .{ self, tagz });
inline for (tagz, 0..) |tag_, i| { inline for (tagz, 0..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
@ -699,7 +698,7 @@ pub const Shape = struct {
var res = self; var res = self;
if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) { if (comptime stdx.meta.isSliceOf(T, Tag) or stdx.meta.isSliceOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {any}", .{ self, tagz }); stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {any}", .{ self, tagz });
for (tagz, self.rank() - tagz.len..) |tag_, i| { for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
@ -707,7 +706,7 @@ pub const Shape = struct {
} }
if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) { if (comptime stdx.meta.isTupleOf(T, Tag) or stdx.meta.isTupleOf(T, EnumLiteral)) {
stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {}, got {}", .{ self, tagz }); stdx.debug.assert(tagz.len <= self.rank(), "Too many tags for shape {f}, got {}", .{ self, tagz });
inline for (tagz, self.rank() - tagz.len..) |tag_, i| { inline for (tagz, self.rank() - tagz.len..) |tag_, i| {
res._tags.set(i, toTag(tag_)); res._tags.set(i, toTag(tag_));
} }
@ -765,7 +764,7 @@ pub const Shape = struct {
var res = self; var res = self;
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
const new_field = @field(renames, field.name); const new_field = @field(renames, field.name);
stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field }); stdx.debug.assert(self.hasTag(new_field) == null, "{f}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
res._tags.set(self.axis(field), toTag(new_field)); res._tags.set(self.axis(field), toTag(new_field));
} }
return res; return res;
@ -785,9 +784,9 @@ pub const Shape = struct {
} }
} }
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) { pub fn computeStrides(self: Shape) stdx.BoundedArray(i64, MAX_RANK) {
const rk = self.rank(); const rk = self.rank();
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk }; var strides: stdx.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
if (rk == 0) return strides; if (rk == 0) return strides;
const V = @Vector(MAX_RANK, i64); const V = @Vector(MAX_RANK, i64);
@ -907,7 +906,7 @@ pub const Shape = struct {
var new_dim: i64 = 1; var new_dim: i64 = 1;
for (axes__.constSlice(), first_axis..) |ax, counter| { for (axes__.constSlice(), first_axis..) |ax, counter| {
new_dim *= self.dim(ax); new_dim *= self.dim(ax);
stdx.debug.assert(ax == counter, "Can't merge shape {} along non-contiguous axes {any}", .{ self, axes_ }); stdx.debug.assert(ax == counter, "Can't merge shape {f} along non-contiguous axes {any}", .{ self, axes_ });
} }
var new_shape = self; var new_shape = self;
@ -991,10 +990,10 @@ pub const Shape = struct {
return res; return res;
} }
pub fn parseStruct(T: type, v: anytype) struct { std.BoundedArray(T, MAX_RANK), TagsArray } { pub fn parseStruct(T: type, v: anytype) struct { stdx.BoundedArray(T, MAX_RANK), TagsArray } {
const V = @TypeOf(v); const V = @TypeOf(v);
var vals_: std.BoundedArray(T, MAX_RANK) = .{}; var vals_: stdx.BoundedArray(T, MAX_RANK) = .{};
var tags_: TagsArray = .{}; var tags_: TagsArray = .{};
if (comptime stdx.meta.isSliceOf(V, T)) { if (comptime stdx.meta.isSliceOf(V, T)) {
@ -1029,10 +1028,10 @@ pub const Shape = struct {
} }
/// Parses a struct literal into a list of options for each axes. /// Parses a struct literal into a list of options for each axes.
pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) std.BoundedArray(T, MAX_RANK) { pub fn parseAxesOptions(self: Shape, T: type, options: anytype, default: T) stdx.BoundedArray(T, MAX_RANK) {
const V = @TypeOf(options); const V = @TypeOf(options);
var res: std.BoundedArray(T, MAX_RANK) = .{}; var res: stdx.BoundedArray(T, MAX_RANK) = .{};
if (comptime stdx.meta.isSliceOf(V, T)) { if (comptime stdx.meta.isSliceOf(V, T)) {
stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len }); stdx.debug.assert(options.len == self.rank(), "expects exactly {} options in slice, for {} got {}", .{ self.rank(), self, options.len });
for (options) |d| { for (options) |d| {
@ -1084,7 +1083,7 @@ pub const Shape = struct {
for (0..other.rank()) |ax| { for (0..other.rank()) |ax| {
if (other.tag(ax) != Shape.TagUnknown) { if (other.tag(ax) != Shape.TagUnknown) {
if (self.hasTag(other.tag(ax))) |batching_ax| { if (self.hasTag(other.tag(ax))) |batching_ax| {
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other }); stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({f}, {f})", .{ self, other });
batching_axes += 1; batching_axes += 1;
} }
} }

View File

@ -53,13 +53,12 @@ pub const Tensor = struct {
pub fn format( pub fn format(
self: Tensor, self: Tensor,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
_ = options; // TODO(0.15.0) handle format
const bare_fmt = fmt.len == 1 and fmt[0] == '_'; // const bare_fmt = fmt.len == 1 and fmt[0] == '_';
try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape}); const bare_fmt = false;
try writer.print(if (bare_fmt) "{f}" else "Tensor({f})", .{self._shape});
} }
/// Returns the shape of a Tensor. /// Returns the shape of a Tensor.
@ -99,7 +98,7 @@ pub const Tensor = struct {
if (builtin.mode == .Debug) { if (builtin.mode == .Debug) {
// Check that the MLIR value actually have the same shape. // Check that the MLIR value actually have the same shape.
const other = fromMlirValue(val); const other = fromMlirValue(val);
stdx.debug.internalAssert(sh.eql(other._shape), "Created a {} from Mlir value but expected {}", .{ other._shape, res._shape }); stdx.debug.internalAssert(sh.eql(other._shape), "Created a {f} from Mlir value but expected {f}", .{ other._shape, res._shape });
} }
return res; return res;
@ -145,7 +144,7 @@ pub const Tensor = struct {
/// Returns the indices of each of the given axes. /// Returns the indices of each of the given axes.
/// ///
/// 'axis_' can be an integer or a tag. /// 'axis_' can be an integer or a tag.
pub fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) { pub fn axes(self: Tensor, axes_: anytype) stdx.BoundedArray(u3, Tensor.MAX_RANK) {
return self._shape.axes(axes_); return self._shape.axes(axes_);
} }
@ -260,7 +259,7 @@ pub const Tensor = struct {
/// For `reuseBuffer` to be effective, it needs to propagate all the way through the output. /// For `reuseBuffer` to be effective, it needs to propagate all the way through the output.
pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor { pub fn reuseBuffer(self: Tensor, origin: Tensor) Tensor {
// Note: check donation docs, this may be too permissive. // Note: check donation docs, this may be too permissive.
stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {} and {}", .{ self, origin }); stdx.debug.assert(self.byteSize() == origin.byteSize(), "Can't reuse buffers between tensors of different size: {f} and {f}", .{ self, origin });
// TODO: should we store all donations inside the context ? // TODO: should we store all donations inside the context ?
var res = self; var res = self;
@ -410,7 +409,7 @@ pub const Tensor = struct {
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts }); const loc = self.getContext().location(@src(), "triangularSolve({f}, {})", .{ self, opts });
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts); const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
return _result(self._shape, op.result(0)); return _result(self._shape, op.result(0));
} }
@ -435,7 +434,7 @@ pub const Tensor = struct {
/// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors. /// Returns a Tensor of complex number converted from a pair of real and imaginary Tensors.
pub fn complex(re: Tensor, im: Tensor) Tensor { pub fn complex(re: Tensor, im: Tensor) Tensor {
stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {} and {}", .{ re._shape, im._shape }); stdx.debug.assert(re._shape.eql(im._shape), "complex expects tensor shapes to match, got {f} and {f}", .{ re._shape, im._shape });
stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()}); stdx.debug.assert(re.dtype() == .f32 or re.dtype() == .f64, "complex expects tensors type to be f32 or f64, got {}", .{re.dtype()});
const loc = re.getContext().mlirCtx().location(@src()); const loc = re.getContext().mlirCtx().location(@src());
@ -523,7 +522,7 @@ pub const Tensor = struct {
}, },
}; };
const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts }); const loc = self.getContext().location(@src(), "fft({f},{})", .{ self, opts });
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts); const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
return _result(sh, op.result(0)); return _result(sh, op.result(0));
} }
@ -551,7 +550,7 @@ pub const Tensor = struct {
/// but it is not guaranteed to be deterministic between implementations. /// but it is not guaranteed to be deterministic between implementations.
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } { pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh}); const loc = ctx.location(@src(), "rand.bitGen({f})", .{sh});
const op = dialect.stablehlo.rng_bit_generator( const op = dialect.stablehlo.rng_bit_generator(
ctx.mlirCtx(), ctx.mlirCtx(),
self.algorithm, self.algorithm,
@ -589,7 +588,7 @@ pub const Tensor = struct {
16 => .u16, 16 => .u16,
32 => .u32, 32 => .u32,
64 => .u64, 64 => .u64,
else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {}", .{shape_}), else => stdx.debug.panic("uniform don't support non-byte aligned dtype. Got: {f}", .{shape_}),
}; };
const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype)); const rng, const bits = self.bitGenerator(shape_.withDtype(uint_dtype));
@ -674,7 +673,7 @@ pub const Tensor = struct {
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev }); const loc = ctx.location(@src(), "rand.normal({f}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev });
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean)); const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev)); const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims())); const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
@ -1046,7 +1045,7 @@ pub const Tensor = struct {
if (to == self.dtype()) { if (to == self.dtype()) {
return self; return self;
} }
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const loc = self.getContext().location(@src(), "convert({f},to={s})", .{ self, @tagName(to) });
const mlir_ctx = self.getContext().mlirCtx(); const mlir_ctx = self.getContext().mlirCtx();
const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to)); const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
@ -1160,7 +1159,7 @@ pub const Tensor = struct {
) Tensor { ) Tensor {
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
const Axes = std.BoundedArray(i64, MAX_RANK); const Axes = stdx.BoundedArray(i64, MAX_RANK);
var res_shape: Shape = .{ ._dtype = lhs.dtype() }; var res_shape: Shape = .{ ._dtype = lhs.dtype() };
// Validate batching axes // Validate batching axes
@ -1168,7 +1167,7 @@ pub const Tensor = struct {
var rhs_batching_axes: Axes = .{}; var rhs_batching_axes: Axes = .{};
for (batching_axes) |b_axes| { for (batching_axes) |b_axes| {
const l, const r = b_axes; const l, const r = b_axes;
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs });
var t = lhs._shape.tag(l); var t = lhs._shape.tag(l);
if (t == Shape.TagUnknown) t = rhs._shape.tag(r); if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
res_shape = res_shape.appendDim(lhs._shape.dim(l), t); res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
@ -1181,7 +1180,7 @@ pub const Tensor = struct {
var rhs_contracting_axes: Axes = .{}; var rhs_contracting_axes: Axes = .{};
for (contracting_axes) |c_axes| { for (contracting_axes) |c_axes| {
const l, const r = c_axes; const l, const r = c_axes;
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {} and {}", .{ l, r, lhs, rhs }); stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects contracting dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs });
lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l)); lhs_contracting_axes.appendAssumeCapacity(lhs._shape.axis(l));
rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r)); rhs_contracting_axes.appendAssumeCapacity(rhs._shape.axis(r));
} }
@ -1209,7 +1208,7 @@ pub const Tensor = struct {
} }
const mlir_ctx = lhs.getContext().mlirCtx(); const mlir_ctx = lhs.getContext().mlirCtx();
const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes }); const loc = lhs.getContext().location(@src(), "dot({f},{f},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes });
const op = dialect.stablehlo.dot_general( const op = dialect.stablehlo.dot_general(
mlir_ctx, mlir_ctx,
lhs.value(), lhs.value(),
@ -1406,7 +1405,7 @@ pub const Tensor = struct {
else else
toI64(axes__); toI64(axes__);
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] }); stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {f} and {any}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] });
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) { if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
return self; return self;
@ -1417,7 +1416,7 @@ pub const Tensor = struct {
return self.reshape(res_shape); return self.reshape(res_shape);
} }
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation }); const loc = self.getContext().location(@src(), "transpose({f}, {any})", .{ self, permutation });
const op = dialect.stablehlo.transpose( const op = dialect.stablehlo.transpose(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
@ -1505,7 +1504,7 @@ pub const Tensor = struct {
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ }); const loc = self.getContext().location(@src(), "flatten({f},{})", .{ self, axis_ });
const reshaped_val = dialect.stablehlo.reshape( const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
@ -1684,7 +1683,7 @@ pub const Tensor = struct {
const res_shape = shape0.insertTag(axis_, 1, tag); const res_shape = shape0.insertTag(axis_, 1, tag);
for (tensors[1..]) |tensor| { for (tensors[1..]) |tensor| {
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape }); stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {f} and {f}", .{ shape0, tensor._shape });
} }
var reshaped: [32]Tensor = undefined; var reshaped: [32]Tensor = undefined;
@ -1859,7 +1858,7 @@ pub const Tensor = struct {
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
const res_shape = sh.withDtype(dt); const res_shape = sh.withDtype(dt);
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a }); const loc = ctx.location(@src(), "iota({f}, {})", .{ res_shape, a });
const mlir_ctx = ctx.mlirCtx(); const mlir_ctx = ctx.mlirCtx();
var op = dialect.stablehlo.iota( var op = dialect.stablehlo.iota(
@ -1931,7 +1930,7 @@ pub const Tensor = struct {
pub fn constant(dimz: anytype, val: Data) Tensor { pub fn constant(dimz: anytype, val: Data) Tensor {
const sh = Shape.init(dimz, val.dtype()); const sh = Shape.init(dimz, val.dtype());
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val }); const loc = CompilationContext.current().location(@src(), "dims={f}, value={}", .{ sh, val });
var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type| var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc) dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
@ -1951,7 +1950,7 @@ pub const Tensor = struct {
pub fn constantTensor(val: HostBuffer) Tensor { pub fn constantTensor(val: HostBuffer) Tensor {
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const loc = ctx.location(@src()); const loc = ctx.location(@src());
const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {f}", .{val.shape()});
const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc); const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
return _result(val.shape(), constant_op.result(0)); return _result(val.shape(), constant_op.result(0));
} }
@ -1975,10 +1974,10 @@ pub const Tensor = struct {
/// you will lose the tags. /// you will lose the tags.
/// To avoid use favorise `.broad(shape)` when working with tagged tensors. /// To avoid use favorise `.broad(shape)` when working with tagged tensors.
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor { pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ }); stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({f}, {f}, {any})", .{ self, output_shape, axes_ });
for (0.., axes_) |self_ax, other_ax| { for (0.., axes_) |self_ax, other_ax| {
const d = self.dim(self_ax); const d = self.dim(self_ax);
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({f}, {f}, {any}), error on self axis {d} mapping to other axis {d}", .{ self, output_shape, axes_, self_ax, other_ax });
} }
const res_shape = output_shape.withDtype(self.dtype()); const res_shape = output_shape.withDtype(self.dtype());
@ -1989,7 +1988,7 @@ pub const Tensor = struct {
} }
const ctx = self.getContext(); const ctx = self.getContext();
const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape); const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const loc = ctx.location(@src(), "broadcast({f}, {f}, axes={any})", .{ self, res_shape, axes_ });
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
return _result(res_shape, broadcast_op.result(0)); return _result(res_shape, broadcast_op.result(0));
@ -1997,7 +1996,7 @@ pub const Tensor = struct {
/// Broadcasts a Tensor to the given shape, adding axes at the beginning. /// Broadcasts a Tensor to the given shape, adding axes at the beginning.
pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor { pub fn broadcastLeft(self: Tensor, output_shape: Shape) Tensor {
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastLeft expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() });
const a = output_shape.rank() - self.rank(); const a = output_shape.rank() - self.rank();
if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) { if (self.rank() == output_shape.rank() and std.mem.eql(i64, self.dims(), output_shape.dims())) {
@ -2009,7 +2008,7 @@ pub const Tensor = struct {
/// Broadcasts a Tensor to the given shape, adding axes at the end. /// Broadcasts a Tensor to the given shape, adding axes at the end.
pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor { pub fn broadcastRight(self: Tensor, output_shape: Shape) Tensor {
stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {} and {}", .{ self.rank(), output_shape.rank() }); stdx.debug.assert(self.rank() <= output_shape.rank(), "broadcastRight expects tensor rank to be less than output tensor rank, got {d} and {d}", .{ self.rank(), output_shape.rank() });
if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) { if (self.rank() == output_shape.rank() and self._shape.eql(output_shape)) {
return self; return self;
@ -2022,7 +2021,7 @@ pub const Tensor = struct {
pub fn broad(self: Tensor, other: Shape) Tensor { pub fn broad(self: Tensor, other: Shape) Tensor {
// TODO: broad is too restrictive because sometime you only want to specify one specific axis // TODO: broad is too restrictive because sometime you only want to specify one specific axis
// Note: if you code below, make sure to update Shape.canBroadcastTo. // Note: if you code below, make sure to update Shape.canBroadcastTo.
stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {} to {}", .{ self, other }); stdx.debug.assert(self._shape.canBroadcastTo(other), "Can't broadcast {f} to {f}", .{ self, other });
// Already the right shape // Already the right shape
if (std.mem.eql(i64, self.dims(), other.dims())) return self; if (std.mem.eql(i64, self.dims(), other.dims())) return self;
@ -2036,7 +2035,7 @@ pub const Tensor = struct {
} }
// check that each axis of self maps to an axis of other // check that each axis of self maps to an axis of other
var axes_: std.BoundedArray(i64, MAX_RANK) = .{}; var axes_: stdx.BoundedArray(i64, MAX_RANK) = .{};
for (self._shape.tags()) |t| { for (self._shape.tags()) |t| {
axes_.appendAssumeCapacity(@intCast(other.axis(t))); axes_.appendAssumeCapacity(@intCast(other.axis(t)));
} }
@ -2047,14 +2046,14 @@ pub const Tensor = struct {
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor { pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
const output_shape = self._shape.reshape(output_shape_); const output_shape = self._shape.reshape(output_shape_);
const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape); const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape}); const loc = self.getContext().location(@src(), "reshape({f})", .{output_shape});
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc); const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
return _result(output_shape, reshape_value.result(0)); return _result(output_shape, reshape_value.result(0));
} }
/// Converts the given 1 element Tensor into a 0-rank Tensor. /// Converts the given 1 element Tensor into a 0-rank Tensor.
pub fn asScalar(self: Tensor) Tensor { pub fn asScalar(self: Tensor) Tensor {
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self}); stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self});
return self.reshape(.{}); return self.reshape(.{});
} }
@ -2177,12 +2176,12 @@ pub const Tensor = struct {
stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{}); stdx.debug.assert(coord_axes_.len > 0, "gatherValues expects 1 or more axes to operate one, received none. Example: `x.gatherValues(.a, indices, .{{}})`", .{});
for (coord_axes_.constSlice(), 0..) |a, i| { for (coord_axes_.constSlice(), 0..) |a, i| {
if (i > 0) { if (i > 0) {
stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {}", .{ coord_axes, self }); stdx.debug.assert(a == coord_axes_.get(i - 1) + 1, "gatherValues expects 'coord_axes' to be sequential. But {any} aren't sequential in {f}", .{ coord_axes, self });
} }
} }
const AxisKind = enum { batching, offset, collapsed, indices }; const AxisKind = enum { batching, offset, collapsed, indices };
var self_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; var self_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
var indices_batch_axes: Shape.DimsArray = .{}; var indices_batch_axes: Shape.DimsArray = .{};
for (self._shape.tags(), 0..self.rank()) |t, self_ax| { for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax)); const maybe_coord_ax = std.mem.indexOfScalar(u3, coord_axes_.constSlice(), @intCast(self_ax));
@ -2191,7 +2190,7 @@ pub const Tensor = struct {
// Note: tags are required for batching. // Note: tags are required for batching.
self_kind.appendAssumeCapacity(.batching); self_kind.appendAssumeCapacity(.batching);
indices_batch_axes.appendAssumeCapacity(id_ax); indices_batch_axes.appendAssumeCapacity(id_ax);
stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={any}', in 'coord_axes_={any}' and in 'indices={}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices }); stdx.debug.assert(maybe_coord_ax == null, "gatherValues expects axes to appear at most twice. Axis {s} has been found both in 'self={f}', in 'coord_axes_={any}' and in 'indices={f}'", .{ self._shape._tags.get(self_ax), self, coord_axes, indices });
} else if (maybe_coord_ax) |_| { } else if (maybe_coord_ax) |_| {
// for gatherValues we collapsed all gathered axes // for gatherValues we collapsed all gathered axes
// (contrary to gatherSlices where we collapse none) // (contrary to gatherSlices where we collapse none)
@ -2208,13 +2207,13 @@ pub const Tensor = struct {
indices.rank() indices.rank()
else blk: { else blk: {
const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const ax = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {}", .{ coord_axes, coord_axes_.len, indices }); stdx.debug.assert(indices.dim(ax) == coord_axes_.len, "gatherValues with axes={any}, expects indices to be of shape [..., {}], got: {f}", .{ coord_axes, coord_axes_.len, indices });
break :blk ax; break :blk ax;
}; };
// compute res shape // compute res shape
var res_shape = Shape.init(.{}, self.dtype()); var res_shape = Shape.init(.{}, self.dtype());
var res_kind: std.BoundedArray(AxisKind, MAX_RANK) = .{}; var res_kind: stdx.BoundedArray(AxisKind, MAX_RANK) = .{};
for (self_kind.constSlice(), 0..) |kind, ax_usize| { for (self_kind.constSlice(), 0..) |kind, ax_usize| {
const ax: u3 = @intCast(ax_usize); const ax: u3 = @intCast(ax_usize);
if (ax == coord_axes_.get(0)) { if (ax == coord_axes_.get(0)) {
@ -2275,7 +2274,7 @@ pub const Tensor = struct {
); );
const mlir_shape = fromMlirValue(gather_op.result(0)).shape(); const mlir_shape = fromMlirValue(gather_op.result(0)).shape();
stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={}, indices={}. You should transpose one or the other.", .{ self, indices }); stdx.debug.assert(mlir_shape.eql(res_shape), "gatherValues expects that batching indices appear in the same order in 'self' and 'indices', got: self={f}, indices={f}. You should transpose one or the other.", .{ self, indices });
return _result(res_shape, gather_op.result(0)); return _result(res_shape, gather_op.result(0));
} }
@ -2347,29 +2346,29 @@ pub const Tensor = struct {
/// and gatherSlices can copy data by group of C'*D elements. /// and gatherSlices can copy data by group of C'*D elements.
pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor { pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32); const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32);
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices }); // scoped_log.debug("gatherSlice({}, {f}, {})", .{ self, slice_shape, indices });
const tagged_api = slice_shape.isFullyTagged(); const tagged_api = slice_shape.isFullyTagged();
if (tagged_api) { if (tagged_api) {
for (slice_shape.tags()) |t| { for (slice_shape.tags()) |t| {
stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {}", .{ t, self }); stdx.debug.assert(self._shape.hasTag(t) != null, "gatherSlices expects `slices_shape` to only use tags from `self`. But {s} wasn't found in {f}", .{ t, self });
} }
} else { } else {
// For untagged api, we require all slices to be specified. // For untagged api, we require all slices to be specified.
// Note: we could relax this and right align the slice. // Note: we could relax this and right align the slice.
stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({}, slice={_}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape }); stdx.debug.assert(slice_shape.rank() == self.rank(), "gatherSlices expects `slice_shape.rank()` to match `self.rank()`. Got: gatherSlices({f}, slice={f}). To avoid specifying all axes in `slice_shape`, you can use tags.", .{ self, slice_shape });
} }
const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1); const index_coord_axis = indices._shape.hasTag(.coord) orelse indices._shape.axis(-1);
stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({}, slice={_}, indices) expects 'indices' to be a tensor [..., {}], got {}", .{ self, slice_shape, slice_shape.rank(), indices }); stdx.debug.assert(indices.dim(index_coord_axis) == slice_shape.rank(), "gatherSlices({f}, slice={f}, indices) expects 'indices' to be a tensor [..., {}], got {f}", .{ self, slice_shape, slice_shape.rank(), indices });
// Compute result shape // Compute result shape
var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype()); var res_shape = indices._shape.remove(index_coord_axis).withDtype(self.dtype());
var slice_dims = self._shape._dims; var slice_dims = self._shape._dims;
var self_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{}; var self_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
var indices_batch_axes: std.BoundedArray(i64, MAX_RANK) = .{}; var indices_batch_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
var start_index_map: std.BoundedArray(i64, MAX_RANK) = .{}; var start_index_map: stdx.BoundedArray(i64, MAX_RANK) = .{};
var self_offset_axes: std.BoundedArray(i64, MAX_RANK) = .{}; var self_offset_axes: stdx.BoundedArray(i64, MAX_RANK) = .{};
for (self._shape.tags(), 0..self.rank()) |t, self_ax| { for (self._shape.tags(), 0..self.rank()) |t, self_ax| {
const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax); const maybe_slice_ax: ?u3 = if (tagged_api) slice_shape.hasTag(t) else @intCast(self_ax);
@ -2379,12 +2378,12 @@ pub const Tensor = struct {
self_batch_axes.appendAssumeCapacity(@intCast(self_ax)); self_batch_axes.appendAssumeCapacity(@intCast(self_ax));
indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t)); indices_batch_axes.appendAssumeCapacity(indices._shape.axis(t));
slice_dims.set(self_ax, 1); slice_dims.set(self_ax, 1);
stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={_}` and `indices={}`", .{ t, slice_shape, indices }); stdx.debug.assert(slice_shape.hasTag(t) == null, "gatherSlices expect axes to be either batches or slices axes. Axis {s} has been found both in `slices={f}` and `indices={f}`", .{ t, slice_shape, indices });
} else if (maybe_slice_ax) |slice_ax| { } else if (maybe_slice_ax) |slice_ax| {
// Specified axes contains the start offset of the slices, // Specified axes contains the start offset of the slices,
// and are collected in `start_index_map`. // and are collected in `start_index_map`.
const slice_dim = slice_shape.dim(slice_ax); const slice_dim = slice_shape.dim(slice_ax);
stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {} > {}.", .{ t, slice_shape, self._shape }); stdx.debug.assert(slice_dim <= self._shape.dim(self_ax), "gatherSlices expects `slice_shape` to be smaller than `self.shape()`. On axis {s}, got {f} > {f}.", .{ t, slice_shape, self._shape });
slice_dims.set(self_ax, slice_dim); slice_dims.set(self_ax, slice_dim);
res_shape = res_shape.appendDim(slice_dim, t); res_shape = res_shape.appendDim(slice_dim, t);
start_index_map.appendAssumeCapacity(@intCast(self_ax)); start_index_map.appendAssumeCapacity(@intCast(self_ax));
@ -2396,7 +2395,7 @@ pub const Tensor = struct {
} }
} }
const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices }); const loc = self.getContext().location(@src(), "gatherSlices({f}, slice_shape={f}, idx={f})", .{ self, slice_shape, indices });
const gather_op = dialect.stablehlo.gather( const gather_op = dialect.stablehlo.gather(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
@ -3172,7 +3171,7 @@ pub const Tensor = struct {
/// Note: this doesn't support tagging, if you have tags, /// Note: this doesn't support tagging, if you have tags,
/// you should use `dynamicSlice` directly. /// you should use `dynamicSlice` directly.
pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor { pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start}); stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {f}", .{slice_.start});
const a = self.axis(axis_); const a = self.axis(axis_);
const new_shape = self._shape.set(a, slice_.len); const new_shape = self._shape.set(a, slice_.len);
@ -3226,17 +3225,17 @@ pub const Tensor = struct {
const offset = slice_.start; const offset = slice_.start;
const len = slice_.len; const len = slice_.len;
if (slices_tags.len == 0) { if (slices_tags.len == 0) {
stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {} and {}", .{ self.rank(), slices.len }); stdx.debug.assert(self.rank() == slices.len, "dynamicSlice expects tensor rank and 'slices_' length to be equal, got {d} and {d}", .{ self.rank(), slices.len });
offset_values[i] = offset.value(); offset_values[i] = offset.value();
res_shape._dims.set(i, len); res_shape._dims.set(i, len);
stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {} and {} for index {}", .{ len, self.dim(i), i }); stdx.debug.assert(len <= self.dim(i), "dynamicSlice expects slices 'len' to be less than or equal to their corresponding dimension in input tensor, got {d} and {d} for index {d}", .{ len, self.dim(i), i });
} else { } else {
const t = slices_tags.get(i); const t = slices_tags.get(i);
const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {})", .{ t, self._shape }); const a = res_shape.hasTag(t) orelse stdx.debug.panic("dynamicSlice expects input tensor to have tags used in 'slices_' but {s} is missing (input shape is {f})", .{ t, self._shape });
stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {} and {} for axis {s}", .{ len, self.dim(a), t }); stdx.debug.assert(len <= self.dim(a), "dynamicSlice expects slices 'len' to be less than their corresponding dimension in input tensor, got {d} and {d} for axis {s}", .{ len, self.dim(a), t });
offset_values[a] = offset.value(); offset_values[a] = offset.value();
res_shape._dims.set(a, len); res_shape._dims.set(a, len);
@ -3304,14 +3303,14 @@ pub const Tensor = struct {
if (tagged_api) { if (tagged_api) {
// Check that all update tags are known. // Check that all update tags are known.
for (update._shape._tags.constSlice()) |t| { for (update._shape._tags.constSlice()) |t| {
stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {})", .{ t, self._shape }); stdx.debug.assert(self._shape.hasTag(t) != null, "dynamicUpdateSlice expects 'update_' tensor tags to be a subset of input tensor tags but {s} is missing (input shape is {f})", .{ t, self._shape });
} }
var update_shape = self._shape; var update_shape = self._shape;
var prev_ax: i8 = -1; var prev_ax: i8 = -1;
for (self._shape.tags(), 0..) |t, self_ax| { for (self._shape.tags(), 0..) |t, self_ax| {
if (update._shape.hasTag(t)) |up_ax| { if (update._shape.hasTag(t)) |up_ax| {
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_, self }); stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {f} and {f}. (hint: you need to explicitly transpose 'update_')", .{ update_, self });
update_shape._dims.set(self_ax, update.dim(up_ax)); update_shape._dims.set(self_ax, update.dim(up_ax));
prev_ax = up_ax; prev_ax = up_ax;
@ -3322,7 +3321,7 @@ pub const Tensor = struct {
update = update.reshape(update_shape); update = update.reshape(update_shape);
} }
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {}", .{ self, update }); stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {f} and {f}", .{ self, update });
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| { for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
const t = self._shape.debugTag(ax); const t = self._shape.debugTag(ax);
@ -3350,7 +3349,7 @@ pub const Tensor = struct {
// This is only allowed when using tagged sliced. // This is only allowed when using tagged sliced.
offset_values = .{zero} ** MAX_RANK; offset_values = .{zero} ** MAX_RANK;
for (offset.constSlice(), offset_tags.constSlice()) |start, t| { for (offset.constSlice(), offset_tags.constSlice()) |start, t| {
const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {})", .{ t, self._shape }); const a = self._shape.hasTag(t) orelse stdx.debug.panic("dynamicUpdateSlice expects input tensor to have tags used in 'offset_' but {s} is missing (input shape is {f})", .{ t, self._shape });
offset_values[a] = start.value(); offset_values[a] = start.value();
} }
} }
@ -3469,12 +3468,12 @@ pub const Tensor = struct {
/// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors. /// Returns a Tensor containing the element-wise result of the given 'cmp' comparison between the two input Tensors.
pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor { pub fn cmp(self: Tensor, direction: dialect.stablehlo.ComparisonDirection.Direction, other: Tensor) Tensor {
stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.dtype() == other.dtype(), "cmp expects input tensors to be of the same type, got {t} and {t}", .{ self.dtype(), other.dtype() });
if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other); if (self.rank() == 0 and other.rank() != 0) return self.broadcast(other._shape, &.{}).cmp(direction, other);
if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{})); if (self.rank() != 0 and other.rank() == 0) return self.cmp(direction, other.broadcast(self._shape, &.{}));
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {f} and {f}", .{ self._shape, other._shape });
const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)}); const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)});
const op = dialect.stablehlo.compare( const op = dialect.stablehlo.compare(
@ -3492,7 +3491,7 @@ pub const Tensor = struct {
/// For each vector in the input tensor, /// For each vector in the input tensor,
/// creates a diagonal-matrix where diagonal values are set to the vector values. /// creates a diagonal-matrix where diagonal values are set to the vector values.
pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor { pub fn toDiagonal(self: Tensor, axis_: anytype, new_tags: [2]EnumLiteral) Tensor {
stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {} rank, got {}", .{ MAX_RANK - 1, self }); stdx.debug.assert(self.rank() < MAX_RANK - 1, "toDiagonal expects input up to {d} rank, got {f}", .{ MAX_RANK - 1, self });
const a = self.axis(axis_); const a = self.axis(axis_);
const d = self.dim(a); const d = self.dim(a);
var res_shape = self._shape; var res_shape = self._shape;
@ -3622,8 +3621,8 @@ pub const Tensor = struct {
return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape())); return bool_tensor.select(on_true, on_false.broad(bool_tensor.shape()));
} }
stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_true._shape }); stdx.debug.assert(bool_tensor._shape.eqlDims(on_true._shape), "select expects input tensor and 'on_true' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_true._shape });
stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {} and {}", .{ bool_tensor._shape, on_false._shape }); stdx.debug.assert(bool_tensor._shape.eqlDims(on_false._shape), "select expects input tensor and 'on_false' tensor dimensions to match, got {f} and {f}", .{ bool_tensor._shape, on_false._shape });
const loc = bool_tensor.getContext().mlirCtx().location(@src()); const loc = bool_tensor.getContext().mlirCtx().location(@src());
const op = dialect.stablehlo.select( const op = dialect.stablehlo.select(
@ -3762,7 +3761,7 @@ pub const Tensor = struct {
/// ///
/// - res[a, b, c, d] == (A[a], B[b], C[c], D[d]) /// - res[a, b, c, d] == (A[a], B[b], C[c], D[d])
pub fn cartesianProductStacked(vectors: []const Tensor) Tensor { pub fn cartesianProductStacked(vectors: []const Tensor) Tensor {
var out = std.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable; var out = stdx.BoundedArray(Tensor, Tensor.MAX_RANK).init(vectors.len) catch unreachable;
_cartesianProduct(vectors, out.slice()); _cartesianProduct(vectors, out.slice());
return Tensor.stack(out.constSlice(), .last, .coord); return Tensor.stack(out.constSlice(), .last, .coord);
@ -3801,7 +3800,7 @@ pub const Tensor = struct {
) fn (Tensor, Tensor) Tensor { ) fn (Tensor, Tensor) Tensor {
return struct { return struct {
pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor { pub fn binaryOpHelper(self: Tensor, other: Tensor) Tensor {
stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {} and {}", .{ op_name, self, other }); stdx.debug.assert(self.dtype() == other.dtype(), "{s} expects tensor to be of same type, got {f} and {f}", .{ op_name, self, other });
if (self.rank() == 0 and other.rank() != 0) { if (self.rank() == 0 and other.rank() != 0) {
return binaryOpHelper(self.broad(other._shape), other); return binaryOpHelper(self.broad(other._shape), other);
@ -3811,10 +3810,10 @@ pub const Tensor = struct {
return binaryOpHelper(self, other.broad(self._shape)); return binaryOpHelper(self, other.broad(self._shape));
} }
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {f} and {f}", .{ op_name, self._shape, other._shape });
const ctx = self.getContext(); const ctx = self.getContext();
const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other }); const location = ctx.location(src, "{s}({f}, {f})", .{ op_name, self, other });
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location }); const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
return _result(self._shape, ret.result(0)); return _result(self._shape, ret.result(0));
} }
@ -3837,7 +3836,7 @@ pub const Tensor = struct {
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
const host_buffer = inputs[0]; const host_buffer = inputs[0];
std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() });
// This is true because of the operand aliases. // This is true because of the operand aliases.
// Since the result is already pointing to the input we don't need to modify the buffer. // Since the result is already pointing to the input we don't need to modify the buffer.
std.debug.assert(host_buffer._data == outputs[0]._data); std.debug.assert(host_buffer._data == outputs[0]._data);
@ -4060,8 +4059,8 @@ test shapesOf {
} }
} }
pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), value: T) std.BoundedArray(i64, Tensor.MAX_RANK) { pub fn _collectAxes(T: type, bounded_array: stdx.BoundedArray(T, Tensor.MAX_RANK), value: T) stdx.BoundedArray(i64, Tensor.MAX_RANK) {
var res: std.BoundedArray(i64, Tensor.MAX_RANK) = .{}; var res: stdx.BoundedArray(i64, Tensor.MAX_RANK) = .{};
for (bounded_array.constSlice(), 0..) |v, ax| { for (bounded_array.constSlice(), 0..) |v, ax| {
if (v == value) { if (v == value) {
res.appendAssumeCapacity(@intCast(ax)); res.appendAssumeCapacity(@intCast(ax));
@ -4070,12 +4069,12 @@ pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK)
return res; return res;
} }
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } { fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, stdx.BoundedArray(u3, Tensor.MAX_RANK) } {
const AxesT = @TypeOf(axes_); const AxesT = @TypeOf(axes_);
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int; const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int;
const coord_axes = if (axes_is_scalar) const coord_axes = if (axes_is_scalar)
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable stdx.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
else else
self.axes(axes_); self.axes(axes_);
@ -4099,7 +4098,7 @@ inline fn toI64(values: anytype) []i64 {
} }
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool { fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{}; var perm: stdx.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
// Don't rewrite on invalid inputs. // Don't rewrite on invalid inputs.
if (permutation.len > x.rank()) return false; if (permutation.len > x.rank()) return false;
for (permutation) |ax| { for (permutation) |ax| {

View File

@ -31,7 +31,7 @@ pub fn asyncMain() !void {
.root_name = "Test", .root_name = "Test",
.estimated_total_items = test_fn_list.len, .estimated_total_items = test_fn_list.len,
}); });
const have_tty = std.io.getStdErr().isTty(); const have_tty = std.fs.File.stderr().isTty();
var args = std.process.args(); var args = std.process.args();
// Skip executable path // Skip executable path

View File

@ -51,10 +51,10 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
if (should_free_left) left.deinit(allocator); if (should_free_left) left.deinit(allocator);
if (should_free_right) right.deinit(allocator); if (should_free_right) right.deinit(allocator);
} }
errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() }); errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() });
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) { if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() }); log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
return error.TestUnexpectedResult; return error.TestUnexpectedResult;
} }
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) { if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
@ -89,7 +89,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
const right_data = right.items(R); const right_data = right.items(R);
for (left_data, right_data, 0..) |l, r, i| { for (left_data, right_data, 0..) |l, r, i| {
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) { if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
log.err("left.data != right_data.\n < {d:.3} \n > {d:.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] }); log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
return error.TestUnexpectedResult; return error.TestUnexpectedResult;
} }
} }
@ -108,7 +108,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpectedEqual}!void { pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpectedEqual}!void {
if (expected.eqlWithTags(actual)) return; if (expected.eqlWithTags(actual)) return;
std.debug.print("Expected {}, got {}", .{ expected, actual }); std.debug.print("Expected {f}, got {f}", .{ expected, actual });
return error.TestExpectedEqual; return error.TestExpectedEqual;
} }

View File

@ -9,6 +9,7 @@ zig_library(
deps = [ deps = [
"//async", "//async",
"//ffi:zig", "//ffi:zig",
"//stdx",
"//zml/tokenizer/hftokenizers", "//zml/tokenizer/hftokenizers",
"//zml/tokenizer/sentencepiece", "//zml/tokenizer/sentencepiece",
], ],

View File

@ -27,5 +27,6 @@ zig_library(
deps = [ deps = [
":hftokenizers_cc", ":hftokenizers_cc",
"//ffi:zig", "//ffi:zig",
"//stdx",
], ],
) )

View File

@ -1,6 +1,8 @@
const std = @import("std"); const std = @import("std");
const c = @import("c"); const c = @import("c");
const ffi = @import("ffi"); const ffi = @import("ffi");
const stdx = @import("stdx");
pub const Encoder = struct { pub const Encoder = struct {
inner: *HFTokenizer, inner: *HFTokenizer,
@ -33,8 +35,8 @@ pub const Encoder = struct {
}; };
pub const Decoder = struct { pub const Decoder = struct {
const StringBuffer = std.BoundedArray(u8, 128); const StringBuffer = stdx.BoundedArray(u8, 128);
const TokensIdsBuffer = std.BoundedArray(u32, 4); const TokensIdsBuffer = stdx.BoundedArray(u32, 4);
inner: *HFTokenizer, inner: *HFTokenizer,
current_string: ?[]const u8 = null, current_string: ?[]const u8 = null,

View File

@ -2,10 +2,11 @@
//! Disclaimer this is not a very robust implementation: //! Disclaimer this is not a very robust implementation:
//! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization. //! In particular the normalization is pretty minimalist, only works with ascii, and don't do unicode normalization.
//! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer. //! Mostly used for testing models that don't have an official HF/sentencepiece tokenizer.
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const testing = std.testing; const testing = std.testing;
const builtin = @import("builtin");
const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/tokenizer"); const log = std.log.scoped(.@"zml/tokenizer");
@ -87,12 +88,11 @@ pub const Tokenizer = struct {
} }
/// Reads a new word directly into the tokenizer arena. /// Reads a new word directly into the tokenizer arena.
pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: anytype) !void { pub fn readTokenInto(self: *Tokenizer, score: f32, len: usize, tok_reader: *std.Io.Reader) !void {
const arena = self.arena_state.allocator(); const arena = self.arena_state.allocator();
const token = try arena.alloc(u8, len); const token = try arena.alloc(u8, len);
const n = try tok_reader.readAll(token); try tok_reader.readSliceAll(token);
std.debug.assert(n == len);
return self.addOwnedToken(score, token); return self.addOwnedToken(score, token);
} }
@ -190,9 +190,9 @@ pub const Tokenizer = struct {
if (options.debug) { if (options.debug) {
var _debug_buf: [256]u8 = undefined; var _debug_buf: [256]u8 = undefined;
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
var debug_progress = std.ArrayList(u8).init(_debug_alloc.allocator()); var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator());
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
log.debug("tokens: {d} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
} }
var best_score: f32 = -1e10; var best_score: f32 = -1e10;
var best_token: u32 = 0; var best_token: u32 = 0;
@ -312,7 +312,7 @@ pub const Tokenizer = struct {
/// Note that if the tokenizer allows sub-unicode bytes, it's possible /// Note that if the tokenizer allows sub-unicode bytes, it's possible
/// the output is not valid utf8. /// the output is not valid utf8.
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
var output = std.ArrayList(u8).init(allocator); var output = std.array_list.Managed(u8).init(allocator);
errdefer output.deinit(); errdefer output.deinit();
try self.decodeWithOpts(&output, input, .{}); try self.decodeWithOpts(&output, input, .{});
@ -321,7 +321,7 @@ pub const Tokenizer = struct {
pub fn decodeWithOpts( pub fn decodeWithOpts(
self: *const Tokenizer, self: *const Tokenizer,
output: *std.ArrayList(u8), output: *std.array_list.Managed(u8),
input: []const u32, input: []const u32,
opts: struct { sep: []const u8 = "" }, opts: struct { sep: []const u8 = "" },
) error{OutOfMemory}!void { ) error{OutOfMemory}!void {
@ -363,7 +363,8 @@ pub const Tokenizer = struct {
// First lookup the byte fallback entry. // First lookup the byte fallback entry.
// Note: we assume upper case, but we could try both upper and lower case if needed. // Note: we assume upper case, but we could try both upper and lower case if needed.
_ = std.fmt.bufPrintIntToSlice(byte_fallback_buf[3..5], c, 16, .upper, .{ .fill = '0', .width = 2 }); var writer: std.Io.Writer = .fixed(byte_fallback_buf[3..5]);
try writer.printInt(c, 16, .upper, .{ .fill = '0', .width = 2 });
const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse { const entry = tokenizer.token_lookup.getEntry(&byte_fallback_buf) orelse {
log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf}); log.err("Tokenizer has \"byte_fallback\" = true, but doesn't contains the byte fallback token {s}", .{byte_fallback_buf});
return error.InvalidInput; return error.InvalidInput;
@ -443,8 +444,8 @@ pub const Encoder = struct {
}; };
pub const Decoder = struct { pub const Decoder = struct {
const StringBuffer = std.BoundedArray(u8, 128); const StringBuffer = stdx.BoundedArray(u8, 128);
const TokensIdsBuffer = std.BoundedArray(u32, 4); const TokensIdsBuffer = stdx.BoundedArray(u32, 4);
inner: *Tokenizer, inner: *Tokenizer,
arena: std.heap.ArenaAllocator, arena: std.heap.ArenaAllocator,
@ -571,7 +572,7 @@ test CharTokenIterator {
{ {
tokenizer.byte_fallback = false; tokenizer.byte_fallback = false;
var it: CharTokenIterator = .{ .input = "ζL" }; var it: CharTokenIterator = .{ .input = "ζL" };
var res: std.BoundedArray(u32, 8) = .{}; var res: stdx.BoundedArray(u32, 8) = .{};
while (try it.nextCodepointToken(&tokenizer)) |token| { while (try it.nextCodepointToken(&tokenizer)) |token| {
res.appendAssumeCapacity(token); res.appendAssumeCapacity(token);
} }
@ -582,7 +583,7 @@ test CharTokenIterator {
{ {
tokenizer.byte_fallback = true; tokenizer.byte_fallback = true;
var it: CharTokenIterator = .{ .input = "ζL" }; var it: CharTokenIterator = .{ .input = "ζL" };
var res: std.BoundedArray(u32, 8) = .{}; var res: stdx.BoundedArray(u32, 8) = .{};
while (try it.nextCodepointToken(&tokenizer)) |token| { while (try it.nextCodepointToken(&tokenizer)) |token| {
res.appendAssumeCapacity(token); res.appendAssumeCapacity(token);
} }
@ -596,7 +597,7 @@ pub const Normalizer = struct {
/// Space token used by sentencepiece derived tokenizer. /// Space token used by sentencepiece derived tokenizer.
pub const sentencepiece_space = ""; // \xe2\x96\x81 pub const sentencepiece_space = ""; // \xe2\x96\x81
_whitespace: std.BoundedArray(u8, 8) = .{}, _whitespace: stdx.BoundedArray(u8, 8) = .{},
flags: packed struct { flags: packed struct {
remove_extra_whitespaces: bool, remove_extra_whitespaces: bool,
@ -610,7 +611,7 @@ pub const Normalizer = struct {
split_on_punct_ascii: bool, split_on_punct_ascii: bool,
}, },
pub fn init(flags: std.meta.FieldType(Normalizer, .flags), escaped_whitespace: ?[]const u8) Normalizer { pub fn init(flags: @FieldType(Normalizer, "flags"), escaped_whitespace: ?[]const u8) Normalizer {
var res: Normalizer = .{ .flags = flags }; var res: Normalizer = .{ .flags = flags };
if (escaped_whitespace) |escaped| { if (escaped_whitespace) |escaped| {
res._whitespace.appendSliceAssumeCapacity(escaped); res._whitespace.appendSliceAssumeCapacity(escaped);
@ -622,7 +623,7 @@ pub const Normalizer = struct {
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
} }
fn addSlice(data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void {
try normalized.appendSlice(data); try normalized.appendSlice(data);
for (data) |_| try normalized_to_origin.append(consumed); for (data) |_| try normalized_to_origin.append(consumed);
} }
@ -672,9 +673,9 @@ pub const Normalizer = struct {
// Pre-allocate outputs // Pre-allocate outputs
const space = self.escapedSpace() orelse " "; const space = self.escapedSpace() orelse " ";
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
var normalized = try std.ArrayList(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
errdefer normalized.deinit(); errdefer normalized.deinit();
var normalized_to_origin = try std.ArrayList(usize).initCapacity(allocator, normalized.capacity); var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity);
errdefer normalized_to_origin.deinit(); errdefer normalized_to_origin.deinit();
// If the spec asks for it, add a whitespace at the beginning. // If the spec asks for it, add a whitespace at the beginning.
@ -965,7 +966,7 @@ test Normalizer {
/// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm, /// This implementation precompupte a mapping between bytes encoded with GPT2 algorithm,
/// into utf8 bytes, and do lookups at runtime. /// into utf8 bytes, and do lookups at runtime.
pub const Gpt2TextDecoder = struct { pub const Gpt2TextDecoder = struct {
const Code = std.BoundedArray(u8, 2); const Code = stdx.BoundedArray(u8, 2);
// TODO: benchmark this is more efficient than doing the conversion at runtime. // TODO: benchmark this is more efficient than doing the conversion at runtime.
code_to_byte: std.AutoArrayHashMap(Code, u8), code_to_byte: std.AutoArrayHashMap(Code, u8),
@ -982,7 +983,7 @@ pub const Gpt2TextDecoder = struct {
var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init var code: Code = .{ .buffer = .{ 0, 0 }, .len = 0 }; // 0-init
const i: u8 = @intCast(index); const i: u8 = @intCast(index);
if (isPrintableByte(i)) { if (isPrintableByte(i)) {
if (std.ascii.isASCII(i)) { if (std.ascii.isAscii(i)) {
code.appendAssumeCapacity(i); code.appendAssumeCapacity(i);
} else { } else {
const codepoint: u21 = @as(u21, @intCast(i)); const codepoint: u21 = @as(u21, @intCast(i));
@ -1005,7 +1006,7 @@ pub const Gpt2TextDecoder = struct {
/// Transform bytes representing text under the gpt2 encoding, /// Transform bytes representing text under the gpt2 encoding,
/// and write to the `unicode` buffer utf-8 bytes. /// and write to the `unicode` buffer utf-8 bytes.
pub fn decode(self: Gpt2TextDecoder, unicode: *std.ArrayList(u8), bytes: []const u8) ![]const u8 { pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 {
const start = unicode.items.len; const start = unicode.items.len;
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes }; var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
while (it.nextCodepointSlice()) |codepoint| { while (it.nextCodepointSlice()) |codepoint| {
@ -1029,7 +1030,7 @@ test Gpt2TextDecoder {
var decoder = try Gpt2TextDecoder.init(testing.allocator); var decoder = try Gpt2TextDecoder.init(testing.allocator);
defer decoder.deinit(); defer decoder.deinit();
var out = std.ArrayList(u8).init(testing.allocator); var out = std.array_list.Managed(u8).init(testing.allocator);
defer out.deinit(); defer out.deinit();
// Ascii is not changed. // Ascii is not changed.
@ -1076,7 +1077,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
// Buffer containing all concatenated tokens. // Buffer containing all concatenated tokens.
// Reserve a big chunk, to avoid grow event, but release over-allocated memory. // Reserve a big chunk, to avoid grow event, but release over-allocated memory.
var all_tokens = try std.ArrayList(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
const original_alloc = all_tokens.items.ptr; const original_alloc = all_tokens.items.ptr;
// A re-alloc event here means we have invalidated all slices inside the tokenizer. // A re-alloc event here means we have invalidated all slices inside the tokenizer.
// If this is too annoying we could switch to a custom type instead of slices. // If this is too annoying we could switch to a custom type instead of slices.
@ -1164,7 +1165,7 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
} }
/// Returns a copy of the given string, stored inside the given ArrayList. /// Returns a copy of the given string, stored inside the given ArrayList.
fn dup(buffer: *std.ArrayList(u8), str: []const u8) ![]const u8 { fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 {
const n = buffer.items.len; const n = buffer.items.len;
try buffer.appendSlice(str); try buffer.appendSlice(str);
return buffer.items[n..]; return buffer.items[n..];
@ -1175,7 +1176,7 @@ fn objectGet(
object: std.json.ObjectMap, object: std.json.ObjectMap,
comptime kind: std.meta.FieldEnum(std.json.Value), comptime kind: std.meta.FieldEnum(std.json.Value),
key: []const u8, key: []const u8,
) ?std.meta.FieldType(std.json.Value, kind) { ) ?@FieldType(std.json.Value, @tagName(kind)) {
const val = object.get(key) orelse return null; const val = object.get(key) orelse return null;
if (val != kind) return null; if (val != kind) return null;
return @field(val, @tagName(kind)); return @field(val, @tagName(kind));
@ -1184,10 +1185,11 @@ fn objectGet(
pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer { pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u8, vocab_size: u32) !Tokenizer {
const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{}); const tokenizer_file = try std.fs.cwd().openFile(tokenizer_path, .{});
defer tokenizer_file.close(); defer tokenizer_file.close();
var tok_reader = std.io.bufferedReader(tokenizer_file.reader()); var read_buff: [4096]u8 = undefined;
const r = tok_reader.reader(); var tok_reader = tokenizer_file.reader(&read_buff);
const r: *std.Io.Reader = &tok_reader.interface;
const max_token_len = try r.readInt(u32, .little); const max_token_len = try readValueLE(u32, r);
const special_tokens: Tokenizer.SpecialTokens = .{ const special_tokens: Tokenizer.SpecialTokens = .{
.unk = 0, .unk = 0,
.bos = 1, .bos = 1,
@ -1195,11 +1197,11 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u
}; };
var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true); var tokenizer = try Tokenizer.init(allocator, vocab_size, max_token_len, null, special_tokens, true);
var i: u32 = 0; var i: u32 = 0;
while (readToken(&tokenizer, &r)) : (i += 1) { while (readToken(&tokenizer, r)) : (i += 1) {
// Pass // Pass
} else |_| { } else |_| {
if (i < vocab_size) { if (i < vocab_size) {
log.info("Read {d} words out of {?d}", .{ i, vocab_size }); log.info("Read {d} words out of {d}", .{ i, vocab_size });
} }
tokenizer.vocab_size = i; tokenizer.vocab_size = i;
} }
@ -1207,8 +1209,14 @@ pub fn fromTinyLlamaFile(allocator: std.mem.Allocator, tokenizer_path: []const u
return tokenizer; return tokenizer;
} }
fn readToken(tokenizer: *Tokenizer, tok_reader: anytype) !void { fn readToken(tokenizer: *Tokenizer, tok_reader: *std.Io.Reader) !void {
const score: f32 = @bitCast(try tok_reader.readInt(u32, .little)); const score: f32 = try readValueLE(f32, tok_reader);
const len: usize = @intCast(try tok_reader.readInt(u32, .little)); const len: usize = try readValueLE(u32, tok_reader);
try tokenizer.readTokenInto(score, len, tok_reader); try tokenizer.readTokenInto(score, len, tok_reader);
} }
fn readValueLE(T: type, reader: *std.Io.Reader) !T {
var res: [1]T = undefined;
try reader.readSliceEndian(T, &res, .little);
return res[0];
}

View File

@ -1,10 +1,11 @@
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.@"//zml/tokenizer");
const asynk = @import("async"); const asynk = @import("async");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml_tokenizer = @import("zml/tokenizer"); const zml_tokenizer = @import("zml/tokenizer");
const log = std.log.scoped(.@"//zml/tokenizer");
const Flags = struct { const Flags = struct {
tokenizer: []const u8, tokenizer: []const u8,
prompt: []const u8, prompt: []const u8,
@ -35,7 +36,7 @@ pub fn asyncMain() !void {
const prompt_tok = try encoder.encode(args.prompt); const prompt_tok = try encoder.encode(args.prompt);
log.info("Input: {s}\nOutput: {d}", .{ args.prompt, prompt_tok }); log.info("Input: {s}\nOutput: {any}", .{ args.prompt, prompt_tok });
var errors: u8 = 0; var errors: u8 = 0;
{ {
@ -47,14 +48,14 @@ pub fn asyncMain() !void {
} }
if (args.expected.len > 0) { if (args.expected.len > 0) {
var expected = try std.ArrayList(u32).initCapacity(allocator, args.prompt.len); var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len);
var it = std.mem.splitSequence(u8, args.expected, ","); var it = std.mem.splitSequence(u8, args.expected, ",");
while (it.next()) |int_token| { while (it.next()) |int_token| {
const tok = try std.fmt.parseInt(u32, int_token, 10); const tok = try std.fmt.parseInt(u32, int_token, 10);
try expected.append(tok); try expected.append(tok);
} }
if (!std.mem.eql(u32, expected.items, prompt_tok)) { if (!std.mem.eql(u32, expected.items, prompt_tok)) {
log.err("Doesn't match expected: {d}", .{expected.items}); log.err("Doesn't match expected: {any}", .{expected.items});
errors += 1; errors += 1;
} }
} }

View File

@ -19,5 +19,6 @@ zig_library(
deps = [ deps = [
":sentencepiece_swig", ":sentencepiece_swig",
"//ffi:zig", "//ffi:zig",
"//stdx",
], ],
) )

View File

@ -1,6 +1,8 @@
const std = @import("std"); const std = @import("std");
const c = @import("c"); const c = @import("c");
const ffi = @import("ffi"); const ffi = @import("ffi");
const stdx = @import("stdx");
const StringToTokenRatio = 3; const StringToTokenRatio = 3;
@ -81,7 +83,7 @@ pub const Encoder = struct {
pub const Decoder = struct { pub const Decoder = struct {
const StringBufferSize = 64; const StringBufferSize = 64;
const StringBuffer = std.BoundedArray(u8, StringBufferSize); const StringBuffer = stdx.BoundedArray(u8, StringBufferSize);
const TokenIdsBufferSize = 4; const TokenIdsBufferSize = 4;
inner: *SentencePieceProcessor, inner: *SentencePieceProcessor,

View File

@ -14,7 +14,7 @@ const Tensor = zml.Tensor;
/// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}`
/// * `matmul(.{10}, .{10}) -> .{}` /// * `matmul(.{10}, .{10}) -> .{}`
pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor { pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({}, {}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs }); stdx.debug.assert(lhs.rank() >= 1 and rhs.rank() >= 1, "Can't matmul({f}, {f}) ! The two tensors need to have at least rank 1.", .{ lhs, rhs });
const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }}; const contracting = [_][2]i8{.{ -1, if (rhs.rank() >= 2) rhs.rank() - 2 else 0 }};
if (lhs.rank() == 1 or rhs.rank() <= 2) { if (lhs.rank() == 1 or rhs.rank() <= 2) {
@ -22,7 +22,7 @@ pub fn matmul(lhs: Tensor, rhs: Tensor) Tensor {
return lhs.dotGeneral(rhs, &contracting, &.{}); return lhs.dotGeneral(rhs, &contracting, &.{});
} }
stdx.debug.assert(lhs.rank() == 2, "Can't matmul({}, {}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs }); stdx.debug.assert(lhs.rank() == 2, "Can't matmul({f}, {f}) ! One of the two tensors need to have a rank less than 2.", .{ lhs, rhs });
// Pytorch treats the extra dimensions of rhs has batching dimensions, // Pytorch treats the extra dimensions of rhs has batching dimensions,
// and implicitly broadcast lhs along those. // and implicitly broadcast lhs along those.
@ -91,7 +91,7 @@ pub fn unsqueeze(
self: Tensor, self: Tensor,
axis_: anytype, axis_: anytype,
) Tensor { ) Tensor {
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self}); stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {f}, it's already at max rank.", .{self});
const a = switch (@typeInfo(@TypeOf(axis_))) { const a = switch (@typeInfo(@TypeOf(axis_))) {
.int, .comptime_int => if (axis_ < 0) .int, .comptime_int => if (axis_ < 0)
@as(i8, self.rank()) + 1 + axis_ @as(i8, self.rank()) + 1 + axis_
@ -125,9 +125,9 @@ test unsqueeze {
/// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle /// ref: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#pixelshuffle
pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor { pub fn pixelShuffle(tensor: Tensor, upscale_factor: u32) Tensor {
const shape = tensor.shape(); const shape = tensor.shape();
stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor}); stdx.debug.assert(shape.hasTags(.{ .c, .w, .h }), "pixelShuffle({f}) is invalide. Missing tags {{.c, .w, .h}}", .{tensor});
stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor }); stdx.debug.assert(@mod(shape.dim(.c), upscale_factor * upscale_factor) == 0, "pixelShuffle({f}) is invalide. Number of channels {}, isn't divisible by upscale factor {}**2", .{ tensor, shape.dim(.c), upscale_factor });
const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor }); const s = tensor.splitAxis(.c, .{ .c = -1, .upscale_h = upscale_factor, .upscale_w = upscale_factor });
const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w }); const perm = s.shape().contiguousPerm(.{ .h, .upscale_h, .w, .upscale_w });