Update core Zig modules (async, mlir, pjrt, stdx) and third‑party Bazel definitions for the Zig 0.14.0 release.
This commit is contained in:
parent
16cc8c6658
commit
30f6be0e2f
@ -9,7 +9,7 @@ bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
|
|||||||
bazel_dep(name = "aspect_rules_py", version = "1.3.1")
|
bazel_dep(name = "aspect_rules_py", version = "1.3.1")
|
||||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||||
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.1")
|
bazel_dep(name = "hermetic_cc_toolchain", version = "3.1.1")
|
||||||
bazel_dep(name = "libxev", version = "20250222.0-07bcffa")
|
bazel_dep(name = "libxev", version = "20250313.0-5773f46")
|
||||||
bazel_dep(name = "llvm-raw", version = "20250217.0-0e779ad")
|
bazel_dep(name = "llvm-raw", version = "20250217.0-0e779ad")
|
||||||
bazel_dep(name = "patchelf", version = "0.18.0")
|
bazel_dep(name = "patchelf", version = "0.18.0")
|
||||||
bazel_dep(name = "pcre2", version = "10.43")
|
bazel_dep(name = "pcre2", version = "10.43")
|
||||||
@ -27,7 +27,7 @@ bazel_dep(name = "stablehlo", version = "20250217.0-4598975")
|
|||||||
bazel_dep(name = "toolchains_protoc", version = "0.3.7")
|
bazel_dep(name = "toolchains_protoc", version = "0.3.7")
|
||||||
bazel_dep(name = "with_cfg.bzl", version = "0.8.0")
|
bazel_dep(name = "with_cfg.bzl", version = "0.8.0")
|
||||||
bazel_dep(name = "xla", version = "20250204.1-6789523")
|
bazel_dep(name = "xla", version = "20250204.1-6789523")
|
||||||
bazel_dep(name = "zig-protobuf", version = "20240722.0-c644d11")
|
bazel_dep(name = "zig-protobuf", version = "20250213.0-5304067")
|
||||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||||
|
|
||||||
bazel_dep(name = "buildifier_prebuilt", version = "7.3.1", dev_dependency = True)
|
bazel_dep(name = "buildifier_prebuilt", version = "7.3.1", dev_dependency = True)
|
||||||
@ -47,9 +47,10 @@ register_toolchains("@toolchains_protoc_hub//:all")
|
|||||||
|
|
||||||
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
||||||
zig.index(file = "//bazel:zig_index.json")
|
zig.index(file = "//bazel:zig_index.json")
|
||||||
zig.toolchain(zig_version = "0.14.0-dev.363+c3faae6bf")
|
zig.toolchain(zig_version = "0.14.0")
|
||||||
zig.mirrors(urls = [
|
zig.mirrors(urls = [
|
||||||
"https://mirror.zml.ai/zig",
|
"https://mirror.zml.ai/zig",
|
||||||
|
"https://ziglang.org/builds/",
|
||||||
])
|
])
|
||||||
use_repo(zig, "zig_toolchains")
|
use_repo(zig, "zig_toolchains")
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,7 @@ pub fn sleep(exec: *Executor, ms: u64) !void {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn waitForCompletionOutsideCoro(exec: *Executor, c: *xev.Completion) !void {
|
pub fn waitForCompletionOutsideCoro(exec: *Executor, c: *xev.Completion) !void {
|
||||||
@setCold(true);
|
@branchHint(.unlikely);
|
||||||
while (c.state() != .dead) {
|
while (c.state() != .dead) {
|
||||||
try exec.tick();
|
try exec.tick();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,73 +1,154 @@
|
|||||||
{
|
{
|
||||||
"master": {
|
"master": {
|
||||||
"version": "0.14.0-dev.363+c3faae6bf",
|
"version": "0.15.0-dev.56+d0911786c",
|
||||||
"date": "2024-07-18",
|
"date": "2025-03-12",
|
||||||
"docs": "https://ziglang.org/documentation/master/",
|
"docs": "https://ziglang.org/documentation/master/",
|
||||||
"stdDocs": "https://ziglang.org/documentation/master/std/",
|
"stdDocs": "https://ziglang.org/documentation/master/std/",
|
||||||
"src": {
|
"src": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "55b48780575ee86668fe9e8119abcc0831d3ce93f5f848b7a9b610155c1a865e",
|
"shasum": "5f10f2763bc19ad540844821891b17981ac40d205d34f93679af732d90079c19",
|
||||||
"size": "17272356"
|
"size": "17775980"
|
||||||
},
|
},
|
||||||
"bootstrap": {
|
"bootstrap": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-bootstrap-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-bootstrap-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "65cd278494293ff953561194607c4b7e9b96908aa903c7a8421f4e8014a6345b",
|
"shasum": "f4cb749d63db2a11f3d087ce7607ae1a26dfc1049a554147d7a60f7327f17ef8",
|
||||||
"size": "46507328"
|
"size": "48043388"
|
||||||
},
|
},
|
||||||
"x86_64-macos": {
|
"x86_64-macos": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-macos-x86_64-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-macos-x86_64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "256b09afa6a4e0cd5c4f8497ef3625ba9b01de2b75f57b728337bc1de4681c9c",
|
"shasum": "a737bf40b6b4627833c2346f4d1ab63c387e16e70c535cec421029efbf792826",
|
||||||
"size": "48937384"
|
"size": "51066200"
|
||||||
},
|
},
|
||||||
"aarch64-macos": {
|
"aarch64-macos": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-macos-aarch64-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-macos-aarch64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "cd9f563150b1adb7306912b5acff9b00e39ef283075a42b95186f39bda656862",
|
"shasum": "ef8f0429fa663c55807a60c3931fddc971276dd4570ca794a81c20c6cabfb56d",
|
||||||
"size": "44960052"
|
"size": "45933112"
|
||||||
},
|
},
|
||||||
"x86_64-linux": {
|
"x86_64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-x86_64-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-x86_64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "98ce531beaac0e683713ec1843023b8aa81a318686472ff13f2c075f0362bf0a",
|
"shasum": "54ef448d32520ca10641f18c4e0a4393f762461d1e351ff075683c391951628d",
|
||||||
"size": "47164832"
|
"size": "49113412"
|
||||||
},
|
},
|
||||||
"aarch64-linux": {
|
"aarch64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-aarch64-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-aarch64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "81e1c06740c017ad8aa3df451c544da17a2e23440c1e695954b8c4b612243af0",
|
"shasum": "55234d068a5a60851c39052431037762fb3447af691751f826c6faf5ab7d0850",
|
||||||
"size": "43190732"
|
"size": "44950392"
|
||||||
},
|
},
|
||||||
"armv7a-linux": {
|
"armv7a-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-armv7a-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-armv7a-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "cda6c3f2b51355c3f117814c13962c9fd65d9ba02f5981d45ea83c135b9019b6",
|
"shasum": "eae58573a8a9c1744782d3c73c930ee97de98f506c8a814eeb27c69b3bd7412c",
|
||||||
"size": "44096460"
|
"size": "46124004"
|
||||||
},
|
},
|
||||||
"riscv64-linux": {
|
"riscv64-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-riscv64-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-riscv64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "aa488f1763ff65a910c8f6bdedc5ff8b9a05225d8b29a1446e3c0e0cb7bff683",
|
"shasum": "66112eecb04e26dbed26eec0f39fc3609814ad75744f0e46aaa70d8367a6bea3",
|
||||||
"size": "45637416"
|
"size": "48093348"
|
||||||
},
|
},
|
||||||
"powerpc64le-linux": {
|
"powerpc64le-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-powerpc64le-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-powerpc64le-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "8e9f715b53edf8d8b8c0d9dcc173a1faafdb3b659cebe4435d19d2b2ef24eda9",
|
"shasum": "2986e5244781677abf4dd8e90c5e049f8f59e6a759984214a9b6763645d49241",
|
||||||
"size": "46654844"
|
"size": "48754240"
|
||||||
},
|
},
|
||||||
"x86-linux": {
|
"x86-linux": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-linux-x86-0.14.0-dev.363+c3faae6bf.tar.xz",
|
"tarball": "https://ziglang.org/builds/zig-linux-x86-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
"shasum": "4633ce74826903cf019f390956726c434a5b032f06931aa085e0b22194b55bb5",
|
"shasum": "b01d947a285d6504539b05747616545c45c3cf5653f3f6926d6a1a141e387283",
|
||||||
"size": "52144844"
|
"size": "51633580"
|
||||||
|
},
|
||||||
|
"loongarch64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/builds/zig-linux-loongarch64-0.15.0-dev.56+d0911786c.tar.xz",
|
||||||
|
"shasum": "7e09d716da24da8b0098af72dacf3bb1ee35de1c09478b2fef06e74db5a5da4a",
|
||||||
|
"size": "45856620"
|
||||||
},
|
},
|
||||||
"x86_64-windows": {
|
"x86_64-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-windows-x86_64-0.14.0-dev.363+c3faae6bf.zip",
|
"tarball": "https://ziglang.org/builds/zig-windows-x86_64-0.15.0-dev.56+d0911786c.zip",
|
||||||
"shasum": "44a26238f1757723f54e9b5d4d08b508be8ebdfad9442218d0f2a3ae61f032a5",
|
"shasum": "05c71d9a820a883589fc34e2e82d49a7ce1263b5957d58ae83ab9f3de02aae14",
|
||||||
"size": "79857623"
|
"size": "82776510"
|
||||||
},
|
},
|
||||||
"aarch64-windows": {
|
"aarch64-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-windows-aarch64-0.14.0-dev.363+c3faae6bf.zip",
|
"tarball": "https://ziglang.org/builds/zig-windows-aarch64-0.15.0-dev.56+d0911786c.zip",
|
||||||
"shasum": "8214699c48b5753d127f9632fbb78432fdfa1dd5715ff73106469b339bb641ef",
|
"shasum": "f620259d96ab0d60725ce86dedfe11b1c061acdff1d5f4e4cae5806d9d9477a2",
|
||||||
"size": "75829556"
|
"size": "78667620"
|
||||||
},
|
},
|
||||||
"x86-windows": {
|
"x86-windows": {
|
||||||
"tarball": "https://ziglang.org/builds/zig-windows-x86-0.14.0-dev.363+c3faae6bf.zip",
|
"tarball": "https://ziglang.org/builds/zig-windows-x86-0.15.0-dev.56+d0911786c.zip",
|
||||||
"shasum": "86be63cb2017f23371da631980b27ab6aa4a43b5d4ed96fd08f1311584d779a9",
|
"shasum": "1592b059d91296b04c66b08a36951ced1582bfd37bc0cf6ab998c0231d0765a8",
|
||||||
"size": "83972311"
|
"size": "84520518"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"0.14.0": {
|
||||||
|
"date": "2025-03-05",
|
||||||
|
"docs": "https://ziglang.org/documentation/0.14.0/",
|
||||||
|
"stdDocs": "https://ziglang.org/documentation/0.14.0/std/",
|
||||||
|
"notes": "https://ziglang.org/download/0.14.0/release-notes.html",
|
||||||
|
"src": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-0.14.0.tar.xz",
|
||||||
|
"shasum": "c76638c03eb204c4432ae092f6fa07c208567e110fbd4d862d131a7332584046",
|
||||||
|
"size": "17772188"
|
||||||
|
},
|
||||||
|
"bootstrap": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-bootstrap-0.14.0.tar.xz",
|
||||||
|
"shasum": "bf3fcb22be0b83f4791748adb567d3304779d66d7bf9b1bd557ef6c2e0232807",
|
||||||
|
"size": "48029040"
|
||||||
|
},
|
||||||
|
"x86_64-macos": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-macos-x86_64-0.14.0.tar.xz",
|
||||||
|
"shasum": "685816166f21f0b8d6fc7aa6a36e91396dcd82ca6556dfbe3e329deffc01fec3",
|
||||||
|
"size": "51039964"
|
||||||
|
},
|
||||||
|
"aarch64-macos": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-macos-aarch64-0.14.0.tar.xz",
|
||||||
|
"shasum": "b71e4b7c4b4be9953657877f7f9e6f7ee89114c716da7c070f4a238220e95d7e",
|
||||||
|
"size": "45902412"
|
||||||
|
},
|
||||||
|
"x86_64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-x86_64-0.14.0.tar.xz",
|
||||||
|
"shasum": "473ec26806133cf4d1918caf1a410f8403a13d979726a9045b421b685031a982",
|
||||||
|
"size": "49091960"
|
||||||
|
},
|
||||||
|
"aarch64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-aarch64-0.14.0.tar.xz",
|
||||||
|
"shasum": "ab64e3ea277f6fc5f3d723dcd95d9ce1ab282c8ed0f431b4de880d30df891e4f",
|
||||||
|
"size": "44922728"
|
||||||
|
},
|
||||||
|
"armv7a-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-armv7a-0.14.0.tar.xz",
|
||||||
|
"shasum": "a67dbfa9bdf769228ec994f2098698c619f930883ca5ef638f50eee2d7788d10",
|
||||||
|
"size": "46112980"
|
||||||
|
},
|
||||||
|
"riscv64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-riscv64-0.14.0.tar.xz",
|
||||||
|
"shasum": "a2b14d3de326d3fd095548ef38bf5a67b15dadd62fbcc90836d63cc4355f8ef7",
|
||||||
|
"size": "48069188"
|
||||||
|
},
|
||||||
|
"powerpc64le-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-powerpc64le-0.14.0.tar.xz",
|
||||||
|
"shasum": "3eabd60876ebc2748de8eb57b4b8cfa78861ba9bf7c6dd83f4e3e1d271d7c45e",
|
||||||
|
"size": "48707620"
|
||||||
|
},
|
||||||
|
"x86-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-x86-0.14.0.tar.xz",
|
||||||
|
"shasum": "55d1ba21de5109686ffa675b9cc1dd66930093c202995a637ce3e397816e4c08",
|
||||||
|
"size": "51621460"
|
||||||
|
},
|
||||||
|
"loongarch64-linux": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-linux-loongarch64-0.14.0.tar.xz",
|
||||||
|
"shasum": "31a2f07df55f8f528b92d540db9aae6c0b38643c34dc1ac33a0111d855e996ae",
|
||||||
|
"size": "45821860"
|
||||||
|
},
|
||||||
|
"x86_64-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-windows-x86_64-0.14.0.zip",
|
||||||
|
"shasum": "f53e5f9011ba20bbc3e0e6d0a9441b31eb227a97bac0e7d24172f1b8b27b4371",
|
||||||
|
"size": "82219809"
|
||||||
|
},
|
||||||
|
"aarch64-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-windows-aarch64-0.14.0.zip",
|
||||||
|
"shasum": "03e984383ebb8f85293557cfa9f48ee8698e7c400239570c9ff1aef3bffaf046",
|
||||||
|
"size": "78113283"
|
||||||
|
},
|
||||||
|
"x86-windows": {
|
||||||
|
"tarball": "https://ziglang.org/download/0.14.0/zig-windows-x86-0.14.0.zip",
|
||||||
|
"shasum": "1a867d808cf4fa9184358395d94441390b6b24ee8d00d356ca11ea7cbfd3a4ec",
|
||||||
|
"size": "83970029"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"0.13.0": {
|
"0.13.0": {
|
||||||
|
|||||||
@ -50,7 +50,7 @@ pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.C) void;
|
|||||||
|
|
||||||
fn MlirHelpersMethods(OuterT: type) type {
|
fn MlirHelpersMethods(OuterT: type) type {
|
||||||
switch (@typeInfo(OuterT)) {
|
switch (@typeInfo(OuterT)) {
|
||||||
.Struct => |info| {
|
.@"struct" => |info| {
|
||||||
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
|
if (info.fields.len != 1) @compileError("Mlir wrapper type can only wrap one Mlir value. Received: " ++ @typeName(OuterT));
|
||||||
},
|
},
|
||||||
else => @compileError("MlirHelpersMethods is only available on an Mlir wrapper struct. Received: " ++ @typeName(OuterT)),
|
else => @compileError("MlirHelpersMethods is only available on an Mlir wrapper struct. Received: " ++ @typeName(OuterT)),
|
||||||
|
|||||||
@ -101,10 +101,10 @@ pub const Api = struct {
|
|||||||
|
|
||||||
fn CallFnArgType(comptime func: Funcs) type {
|
fn CallFnArgType(comptime func: Funcs) type {
|
||||||
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func));
|
const fti = @typeInfo(std.meta.FieldType(c.PJRT_Api, func));
|
||||||
const fn_ptr = @typeInfo(fti.Optional.child);
|
const fn_ptr = @typeInfo(fti.optional.child);
|
||||||
const fn_type_info = @typeInfo(fn_ptr.Pointer.child);
|
const fn_type_info = @typeInfo(fn_ptr.pointer.child);
|
||||||
const arg_array_type_info = @typeInfo(fn_type_info.Fn.params[0].type.?);
|
const arg_array_type_info = @typeInfo(fn_type_info.@"fn".params[0].type.?);
|
||||||
return arg_array_type_info.Pointer.child;
|
return arg_array_type_info.pointer.child;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn call(self: *const Api, comptime method: Funcs, arg: CallFnArgType(method)) ApiError!@TypeOf(arg) {
|
inline fn call(self: *const Api, comptime method: Funcs, arg: CallFnArgType(method)) ApiError!@TypeOf(arg) {
|
||||||
@ -681,8 +681,8 @@ pub const BufferType = enum(c.PJRT_Buffer_Type) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
pub const MemoryLayoutType = enum(c.PJRT_Buffer_MemoryLayout_Type) {
|
||||||
Tiled = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
tiled = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
||||||
Strides = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
strides = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const MemoryLayout = union(MemoryLayoutType) {
|
pub const MemoryLayout = union(MemoryLayoutType) {
|
||||||
@ -698,12 +698,12 @@ pub const MemoryLayout = union(MemoryLayoutType) {
|
|||||||
byte_strides: []const i64,
|
byte_strides: []const i64,
|
||||||
};
|
};
|
||||||
|
|
||||||
Tiled: Tiled,
|
tiled: Tiled,
|
||||||
Strides: Strides,
|
strides: Strides,
|
||||||
|
|
||||||
fn toCStruct(self: MemoryLayout) c.PJRT_Buffer_MemoryLayout {
|
fn toCStruct(self: MemoryLayout) c.PJRT_Buffer_MemoryLayout {
|
||||||
return pjrtStruct(switch (self) {
|
return pjrtStruct(switch (self) {
|
||||||
.Tiled => |v| c.PJRT_Buffer_MemoryLayout{
|
.tiled => |v| c.PJRT_Buffer_MemoryLayout{
|
||||||
.type = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
.type = c.PJRT_Buffer_MemoryLayout_Type_Tiled,
|
||||||
.unnamed_0 = .{
|
.unnamed_0 = .{
|
||||||
.tiled = c.PJRT_Buffer_MemoryLayout_Tiled{
|
.tiled = c.PJRT_Buffer_MemoryLayout_Tiled{
|
||||||
@ -715,7 +715,7 @@ pub const MemoryLayout = union(MemoryLayoutType) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
.Strides => |v| c.PJRT_Buffer_MemoryLayout{
|
.strides => |v| c.PJRT_Buffer_MemoryLayout{
|
||||||
.type = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
.type = c.PJRT_Buffer_MemoryLayout_Type_Strides,
|
||||||
.unnamed_0 = .{
|
.unnamed_0 = .{
|
||||||
.strides = c.PJRT_Buffer_MemoryLayout_Strides{
|
.strides = c.PJRT_Buffer_MemoryLayout_Strides{
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
|
const debug = @import("debug.zig");
|
||||||
|
|
||||||
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
||||||
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
|
pub fn fatal(comptime fmt_string: []const u8, args: anytype) noreturn {
|
||||||
@ -78,8 +79,8 @@ pub fn parse(args: *std.process.ArgIterator, comptime CliArgs: type) CliArgs {
|
|||||||
assert(args.skip()); // Discard executable name.
|
assert(args.skip()); // Discard executable name.
|
||||||
|
|
||||||
return switch (@typeInfo(CliArgs)) {
|
return switch (@typeInfo(CliArgs)) {
|
||||||
.Union => parse_commands(args, CliArgs),
|
.@"union" => parse_commands(args, CliArgs),
|
||||||
.Struct => parse_flags(args, CliArgs),
|
.@"struct" => parse_flags(args, CliArgs),
|
||||||
else => unreachable,
|
else => unreachable,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -127,7 +128,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(@typeInfo(Flags) == .Struct);
|
assert(@typeInfo(Flags) == .@"struct");
|
||||||
|
|
||||||
comptime var fields: [std.meta.fields(Flags).len]std.builtin.Type.StructField = undefined;
|
comptime var fields: [std.meta.fields(Flags).len]std.builtin.Type.StructField = undefined;
|
||||||
comptime var field_count = 0;
|
comptime var field_count = 0;
|
||||||
@ -136,7 +137,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
|
|
||||||
comptime for (std.meta.fields(Flags)) |field| {
|
comptime for (std.meta.fields(Flags)) |field| {
|
||||||
if (std.mem.eql(u8, field.name, "positional")) {
|
if (std.mem.eql(u8, field.name, "positional")) {
|
||||||
assert(@typeInfo(field.type) == .Struct);
|
assert(@typeInfo(field.type) == .@"struct");
|
||||||
positional_fields = std.meta.fields(field.type);
|
positional_fields = std.meta.fields(field.type);
|
||||||
var optional_tail = false;
|
var optional_tail = false;
|
||||||
for (positional_fields) |positional_field| {
|
for (positional_fields) |positional_field| {
|
||||||
@ -146,7 +147,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
optional_tail = true;
|
optional_tail = true;
|
||||||
}
|
}
|
||||||
switch (@typeInfo(positional_field.type)) {
|
switch (@typeInfo(positional_field.type)) {
|
||||||
.Optional => |optional| {
|
.optional => |optional| {
|
||||||
// optional flags should have a default
|
// optional flags should have a default
|
||||||
assert(default_value(positional_field) != null);
|
assert(default_value(positional_field) != null);
|
||||||
assert(default_value(positional_field).? == null);
|
assert(default_value(positional_field).? == null);
|
||||||
@ -162,16 +163,13 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
field_count += 1;
|
field_count += 1;
|
||||||
|
|
||||||
switch (@typeInfo(field.type)) {
|
switch (@typeInfo(field.type)) {
|
||||||
.Bool => {
|
.bool => {
|
||||||
// boolean flags should have a default
|
// boolean flags should have a default
|
||||||
assert(default_value(field) != null);
|
debug.assertComptime(default_value(field) != null and default_value(field).? == false, "boolean flag --{s} should default to false", .{field.name});
|
||||||
assert(default_value(field).? == false);
|
|
||||||
},
|
},
|
||||||
.Optional => |optional| {
|
.optional => |optional| {
|
||||||
// optional flags should have a default
|
// optional flags should have a default
|
||||||
assert(default_value(field) != null);
|
debug.assertComptime(default_value(field) != null and default_value(field).? == null, "optional flag --{s} should have a null default value", .{field.name});
|
||||||
assert(default_value(field).? == null);
|
|
||||||
|
|
||||||
assert_valid_value_type(optional.child);
|
assert_valid_value_type(optional.child);
|
||||||
},
|
},
|
||||||
else => {
|
else => {
|
||||||
@ -182,15 +180,15 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
};
|
};
|
||||||
|
|
||||||
var result: Flags = undefined;
|
var result: Flags = undefined;
|
||||||
// Would use std.enums.EnumFieldStruct(Flags, u32, 0) here but Flags is a Struct not an Enum.
|
// Would use std.enums.EnumFieldStruct(Flags, u32, 0) here but Flags is a struct not an Enum.
|
||||||
var counts = comptime blk: {
|
var counts = comptime blk: {
|
||||||
var count_fields = std.meta.fields(Flags)[0..std.meta.fields(Flags).len].*;
|
var count_fields = std.meta.fields(Flags)[0..std.meta.fields(Flags).len].*;
|
||||||
for (&count_fields) |*field| {
|
for (&count_fields) |*field| {
|
||||||
field.type = u32;
|
field.type = u32;
|
||||||
field.alignment = @alignOf(u32);
|
field.alignment = @alignOf(u32);
|
||||||
field.default_value = @ptrCast(&@as(u32, 0));
|
field.default_value_ptr = @ptrCast(&@as(u32, 0));
|
||||||
}
|
}
|
||||||
break :blk @Type(.{ .Struct = .{
|
break :blk @Type(.{ .@"struct" = .{
|
||||||
.layout = .auto,
|
.layout = .auto,
|
||||||
.fields = &count_fields,
|
.fields = &count_fields,
|
||||||
.decls = &.{},
|
.decls = &.{},
|
||||||
@ -288,10 +286,10 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
|
|
||||||
fn assert_valid_value_type(comptime T: type) void {
|
fn assert_valid_value_type(comptime T: type) void {
|
||||||
comptime {
|
comptime {
|
||||||
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .Int) return;
|
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int) return;
|
||||||
|
|
||||||
if (@typeInfo(T) == .Enum) {
|
if (@typeInfo(T) == .@"enum") {
|
||||||
const info = @typeInfo(T).Enum;
|
const info = @typeInfo(T).@"enum";
|
||||||
assert(info.is_exhaustive);
|
assert(info.is_exhaustive);
|
||||||
assert(info.fields.len >= 2);
|
assert(info.fields.len >= 2);
|
||||||
return;
|
return;
|
||||||
@ -343,14 +341,14 @@ fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
|||||||
assert(value.len > 0);
|
assert(value.len > 0);
|
||||||
|
|
||||||
const V = switch (@typeInfo(T)) {
|
const V = switch (@typeInfo(T)) {
|
||||||
.Optional => |optional| optional.child,
|
.optional => |optional| optional.child,
|
||||||
else => T,
|
else => T,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (V == []const u8 or V == [:0]const u8) return value;
|
if (V == []const u8 or V == [:0]const u8) return value;
|
||||||
if (V == ByteSize) return parse_value_size(flag, value);
|
if (V == ByteSize) return parse_value_size(flag, value);
|
||||||
if (@typeInfo(V) == .Int) return parse_value_int(V, flag, value);
|
if (@typeInfo(V) == .int) return parse_value_int(V, flag, value);
|
||||||
if (@typeInfo(V) == .Enum) return parse_value_enum(V, flag, value);
|
if (@typeInfo(V) == .@"enum") return parse_value_enum(V, flag, value);
|
||||||
comptime unreachable;
|
comptime unreachable;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -508,7 +506,7 @@ fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
|||||||
switch (err) {
|
switch (err) {
|
||||||
error.Overflow => fatal(
|
error.Overflow => fatal(
|
||||||
"{s}: value exceeds {d}-bit {s} integer: '{s}'",
|
"{s}: value exceeds {d}-bit {s} integer: '{s}'",
|
||||||
.{ flag, @typeInfo(T).Int.bits, @tagName(@typeInfo(T).Int.signedness), value },
|
.{ flag, @typeInfo(T).int.bits, @tagName(@typeInfo(T).int.signedness), value },
|
||||||
),
|
),
|
||||||
error.InvalidCharacter => fatal(
|
error.InvalidCharacter => fatal(
|
||||||
"{s}: expected an integer value, but found '{s}' (invalid digit)",
|
"{s}: expected an integer value, but found '{s}' (invalid digit)",
|
||||||
@ -520,7 +518,7 @@ fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
|||||||
|
|
||||||
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
||||||
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
comptime assert(@typeInfo(E).Enum.is_exhaustive);
|
comptime assert(@typeInfo(E).@"enum".is_exhaustive);
|
||||||
|
|
||||||
return std.meta.stringToEnum(E, value) orelse fatal(
|
return std.meta.stringToEnum(E, value) orelse fatal(
|
||||||
"{s}: expected one of {s}, but found '{s}'",
|
"{s}: expected one of {s}, but found '{s}'",
|
||||||
@ -564,7 +562,7 @@ pub fn flag_name(comptime field: std.builtin.Type.StructField) []const u8 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test flag_name {
|
test flag_name {
|
||||||
const field = @typeInfo(struct { statsd: bool }).Struct.fields[0];
|
const field = @typeInfo(struct { statsd: bool }).@"struct".fields[0];
|
||||||
try std.testing.expectEqualStrings(flag_name(field), "--statsd");
|
try std.testing.expectEqualStrings(flag_name(field), "--statsd");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -575,7 +573,7 @@ fn flag_name_positional(comptime field: std.builtin.Type.StructField) []const u8
|
|||||||
|
|
||||||
/// This is essentially `field.default_value`, but with a useful type instead of `?*anyopaque`.
|
/// This is essentially `field.default_value`, but with a useful type instead of `?*anyopaque`.
|
||||||
pub fn default_value(comptime field: std.builtin.Type.StructField) ?field.type {
|
pub fn default_value(comptime field: std.builtin.Type.StructField) ?field.type {
|
||||||
return if (field.default_value) |default_opaque|
|
return if (field.default_value_ptr) |default_opaque|
|
||||||
@as(*const field.type, @ptrCast(@alignCast(default_opaque))).*
|
@as(*const field.type, @ptrCast(@alignCast(default_opaque))).*
|
||||||
else
|
else
|
||||||
null;
|
null;
|
||||||
|
|||||||
@ -29,9 +29,9 @@ pub fn Union(comptime T: type) type {
|
|||||||
else => {},
|
else => {},
|
||||||
},
|
},
|
||||||
else => switch (@typeInfo(field.type)) {
|
else => switch (@typeInfo(field.type)) {
|
||||||
.Int => if (source == .integer) return .{ .value = @unionInit(T, field.name, @intCast(source.integer)) },
|
.int => if (source == .integer) return .{ .value = @unionInit(T, field.name, @intCast(source.integer)) },
|
||||||
.Float => if (source == .float) return .{ .value = @unionInit(T, field.name, @floatCast(source.float)) },
|
.float => if (source == .float) return .{ .value = @unionInit(T, field.name, @floatCast(source.float)) },
|
||||||
.Struct => if (source == .object) return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source.object, options)) },
|
.@"struct" => if (source == .object) return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source.object, options)) },
|
||||||
inline else => switch (source) {
|
inline else => switch (source) {
|
||||||
.number_string, .array => return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source, options)) },
|
.number_string, .array => return .{ .value = @unionInit(T, field.name, try std.json.innerParseFromValue(field.type, allocator, source, options)) },
|
||||||
else => {},
|
else => {},
|
||||||
|
|||||||
@ -4,14 +4,14 @@ pub inline fn divFloat(comptime T: type, numerator: anytype, denominator: anytyp
|
|||||||
|
|
||||||
pub inline fn floatCast(comptime T: type, x: anytype) T {
|
pub inline fn floatCast(comptime T: type, x: anytype) T {
|
||||||
return switch (@typeInfo(@TypeOf(x))) {
|
return switch (@typeInfo(@TypeOf(x))) {
|
||||||
.Float => @floatCast(x),
|
.float => @floatCast(x),
|
||||||
else => @floatFromInt(x),
|
else => @floatFromInt(x),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn intCast(comptime T: type, x: anytype) T {
|
pub inline fn intCast(comptime T: type, x: anytype) T {
|
||||||
return switch (@typeInfo(@TypeOf(x))) {
|
return switch (@typeInfo(@TypeOf(x))) {
|
||||||
.Int => @intCast(x),
|
.int => @intCast(x),
|
||||||
else => @intFromFloat(x),
|
else => @intFromFloat(x),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,21 +8,21 @@ pub const Signature = @import("signature.zig").Signature;
|
|||||||
|
|
||||||
pub fn isStruct(comptime T: type) bool {
|
pub fn isStruct(comptime T: type) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Struct => true,
|
.@"struct" => true,
|
||||||
else => false,
|
else => false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isTuple(comptime T: type) bool {
|
pub fn isTuple(comptime T: type) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Struct => |info| info.is_tuple,
|
.@"struct" => |info| info.is_tuple,
|
||||||
else => false,
|
else => false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
|
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Struct => |info| blk: {
|
.@"struct" => |info| blk: {
|
||||||
inline for (info.fields) |field| {
|
inline for (info.fields) |field| {
|
||||||
if (field.type != Elem) {
|
if (field.type != Elem) {
|
||||||
break :blk false;
|
break :blk false;
|
||||||
@ -36,7 +36,7 @@ pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
|
|||||||
|
|
||||||
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Struct => |info| blk: {
|
.@"struct" => |info| blk: {
|
||||||
inline for (info.fields) |field| {
|
inline for (info.fields) |field| {
|
||||||
if (f(field.type) == false) {
|
if (f(field.type) == false) {
|
||||||
break :blk false;
|
break :blk false;
|
||||||
@ -58,11 +58,11 @@ pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool
|
|||||||
|
|
||||||
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
|
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Pointer => |info| switch (info.size) {
|
.pointer => |info| switch (info.size) {
|
||||||
.Slice => info.child == Elem,
|
.slice => info.child == Elem,
|
||||||
.One => switch (@typeInfo(info.child)) {
|
.one => switch (@typeInfo(info.child)) {
|
||||||
// As Zig, convert pointer to Array as a slice.
|
// As Zig, convert pointer to Array as a slice.
|
||||||
.Array => |arr_info| arr_info.child == Elem,
|
.array => |arr_info| arr_info.child == Elem,
|
||||||
else => false,
|
else => false,
|
||||||
},
|
},
|
||||||
else => false,
|
else => false,
|
||||||
@ -73,14 +73,14 @@ pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
|
|||||||
|
|
||||||
pub fn isInteger(comptime T: type) bool {
|
pub fn isInteger(comptime T: type) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Int, .ComptimeInt => true,
|
.int, .comptime_int => true,
|
||||||
else => false,
|
else => false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Pointer => |info| info.size == .Slice and f(info.child),
|
.pointer => |info| info.size == .slice and f(info.child),
|
||||||
else => false,
|
else => false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -95,8 +95,8 @@ pub fn DeclEnum(comptime T: type) type {
|
|||||||
|
|
||||||
pub fn UnwrapPtr(comptime T: type) type {
|
pub fn UnwrapPtr(comptime T: type) type {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Pointer => |info| switch (info.size) {
|
.pointer => |info| switch (info.size) {
|
||||||
.One => info.child,
|
.one => info.child,
|
||||||
else => T,
|
else => T,
|
||||||
},
|
},
|
||||||
else => T,
|
else => T,
|
||||||
@ -106,11 +106,11 @@ pub fn UnwrapPtr(comptime T: type) type {
|
|||||||
pub fn asSlice(comptime T: type) type {
|
pub fn asSlice(comptime T: type) type {
|
||||||
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
|
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Pointer => |info| switch (info.size) {
|
.pointer => |info| switch (info.size) {
|
||||||
.Slice => info.child,
|
.slice => info.child,
|
||||||
.One => switch (@typeInfo(info.child)) {
|
.one => switch (@typeInfo(info.child)) {
|
||||||
// As Zig, convert pointer to Array as a slice.
|
// As Zig, convert pointer to Array as a slice.
|
||||||
.Array => |arr_info| arr_info.child,
|
.array => |arr_info| arr_info.child,
|
||||||
else => @compileError(err_msg),
|
else => @compileError(err_msg),
|
||||||
},
|
},
|
||||||
else => @compileError(err_msg),
|
else => @compileError(err_msg),
|
||||||
@ -137,7 +137,7 @@ pub fn TupleRangeX(comptime T: type, comptime start: usize, comptime end: usize)
|
|||||||
new_fields[j] = new_field;
|
new_fields[j] = new_field;
|
||||||
}
|
}
|
||||||
return @Type(.{
|
return @Type(.{
|
||||||
.Struct = .{
|
.@"struct" = .{
|
||||||
.is_tuple = true,
|
.is_tuple = true,
|
||||||
.layout = .auto,
|
.layout = .auto,
|
||||||
.decls = &.{},
|
.decls = &.{},
|
||||||
@ -147,26 +147,26 @@ pub fn TupleRangeX(comptime T: type, comptime start: usize, comptime end: usize)
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type {
|
pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type {
|
||||||
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse @compileError("anytype is not supported");
|
return @typeInfo(@TypeOf(func)).@"fn".params[n].type orelse @compileError("anytype is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn FnArgs(comptime func: anytype) type {
|
pub fn FnArgs(comptime func: anytype) type {
|
||||||
debug.assertComptime(!@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgs expects non generic function, got: {}", .{@TypeOf(func)});
|
debug.assertComptime(!@typeInfo(@TypeOf(func)).@"fn".is_generic, "FnArgs expects non generic function, got: {}", .{@TypeOf(func)});
|
||||||
return FnSignature(func, null).ArgsT;
|
return FnSignature(func, null).ArgsT;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn FnArgsWithHint(comptime func: anytype, ArgsT: type) type {
|
pub fn FnArgsWithHint(comptime func: anytype, ArgsT: type) type {
|
||||||
debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgsWithHint expects a generic function, got: {}", .{@TypeOf(func)});
|
debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".is_generic, "FnArgsWithHint expects a generic function, got: {}", .{@TypeOf(func)});
|
||||||
return FnSignature(func, ArgsT).ArgsT;
|
return FnSignature(func, ArgsT).ArgsT;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn FnResult(comptime func: anytype) type {
|
pub fn FnResult(comptime func: anytype) type {
|
||||||
return @typeInfo(@TypeOf(func)).Fn.return_type orelse @compileError("anytype is not supported");
|
return @typeInfo(@TypeOf(func)).@"fn".return_type orelse @compileError("anytype is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn Head(Tuple: type) type {
|
pub fn Head(Tuple: type) type {
|
||||||
return switch (@typeInfo(Tuple)) {
|
return switch (@typeInfo(Tuple)) {
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple");
|
if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple");
|
||||||
return struct_info.fields[0].type;
|
return struct_info.fields[0].type;
|
||||||
},
|
},
|
||||||
@ -176,7 +176,7 @@ pub fn Head(Tuple: type) type {
|
|||||||
|
|
||||||
pub fn Tail(Tuple: type) type {
|
pub fn Tail(Tuple: type) type {
|
||||||
return switch (@typeInfo(Tuple)) {
|
return switch (@typeInfo(Tuple)) {
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple");
|
if (struct_info.fields.len == 0) @compileError("Can't tail empty tuple");
|
||||||
var types: [struct_info.fields.len - 1]type = undefined;
|
var types: [struct_info.fields.len - 1]type = undefined;
|
||||||
for (struct_info.fields[1..], 0..) |field, i| types[i] = field.type;
|
for (struct_info.fields[1..], 0..) |field, i| types[i] = field.type;
|
||||||
|
|||||||
@ -3,12 +3,12 @@ const std = @import("std");
|
|||||||
const compileError = @import("debug.zig").compileError;
|
const compileError = @import("debug.zig").compileError;
|
||||||
|
|
||||||
pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
||||||
const params = @typeInfo(funcT).Fn.params;
|
const params = @typeInfo(funcT).@"fn".params;
|
||||||
if (params.len == 0) {
|
if (params.len == 0) {
|
||||||
return @TypeOf(.{});
|
return @TypeOf(.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (@typeInfo(funcT).Fn.is_generic == false) {
|
if (@typeInfo(funcT).@"fn".is_generic == false) {
|
||||||
return std.meta.ArgsTuple(funcT);
|
return std.meta.ArgsTuple(funcT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,14 +31,14 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
|||||||
break :blk num_buf[0..s :0];
|
break :blk num_buf[0..s :0];
|
||||||
},
|
},
|
||||||
.type = T,
|
.type = T,
|
||||||
.default_value = null,
|
.default_value_ptr = null,
|
||||||
.is_comptime = false,
|
.is_comptime = false,
|
||||||
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
|
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return @Type(.{
|
return @Type(.{
|
||||||
.Struct = .{
|
.@"struct" = .{
|
||||||
.is_tuple = true,
|
.is_tuple = true,
|
||||||
.layout = .auto,
|
.layout = .auto,
|
||||||
.decls = &.{},
|
.decls = &.{},
|
||||||
@ -67,11 +67,11 @@ pub fn FnSignature(comptime func: anytype, comptime argsT_: ?type) Signature {
|
|||||||
.ArgsT = argsT,
|
.ArgsT = argsT,
|
||||||
.ReturnT = return_type,
|
.ReturnT = return_type,
|
||||||
.ReturnPayloadT = switch (@typeInfo(return_type)) {
|
.ReturnPayloadT = switch (@typeInfo(return_type)) {
|
||||||
.ErrorUnion => |u| u.payload,
|
.error_union => |u| u.payload,
|
||||||
else => return_type,
|
else => return_type,
|
||||||
},
|
},
|
||||||
.ReturnErrorSet = switch (@typeInfo(return_type)) {
|
.ReturnErrorSet = switch (@typeInfo(return_type)) {
|
||||||
.ErrorUnion => |u| u.error_set,
|
.error_union => |u| u.error_set,
|
||||||
else => null,
|
else => null,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
7
third_party/modules/libxev/20250313.0-5773f46/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20250313.0-5773f46/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
module(
|
||||||
|
name = "libxev",
|
||||||
|
version = "20250313.0-5773f46",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||||
13
third_party/modules/libxev/20250313.0-5773f46/overlay/BUILD.bazel
vendored
Normal file
13
third_party/modules/libxev/20250313.0-5773f46/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "xev",
|
||||||
|
srcs = glob([
|
||||||
|
"src/*.zig",
|
||||||
|
"src/backend/*.zig",
|
||||||
|
"src/linux/*.zig",
|
||||||
|
"src/watcher/*.zig",
|
||||||
|
]),
|
||||||
|
main = "src/main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
7
third_party/modules/libxev/20250313.0-5773f46/overlay/MODULE.bazel
vendored
Normal file
7
third_party/modules/libxev/20250313.0-5773f46/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
module(
|
||||||
|
name = "libxev",
|
||||||
|
version = "20250313.0-5773f46",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||||
10
third_party/modules/libxev/20250313.0-5773f46/source.json
vendored
Normal file
10
third_party/modules/libxev/20250313.0-5773f46/source.json
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "libxev-5773f46de3937e848db3ff87577add8ef7e8222d",
|
||||||
|
"url": "https://github.com/mitchellh/libxev/archive/5773f46de3937e848db3ff87577add8ef7e8222d.tar.gz",
|
||||||
|
"integrity": "sha256-k9cyG0DPhEKwM+/dKiLFT9WeM+c2kCokN8VftOriSxI=",
|
||||||
|
"overlay": {
|
||||||
|
"MODULE.bazel": "",
|
||||||
|
"BUILD.bazel": ""
|
||||||
|
},
|
||||||
|
"patch_strip": 1
|
||||||
|
}
|
||||||
3
third_party/modules/libxev/metadata.json
vendored
3
third_party/modules/libxev/metadata.json
vendored
@ -18,7 +18,8 @@
|
|||||||
"20241208.1-db6a52b",
|
"20241208.1-db6a52b",
|
||||||
"20241208.2-db6a52b",
|
"20241208.2-db6a52b",
|
||||||
"20250124.0-31eed4e",
|
"20250124.0-31eed4e",
|
||||||
"20250222.0-07bcffa"
|
"20250222.0-07bcffa",
|
||||||
|
"20250313.0-5773f46",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
8
third_party/modules/zig-protobuf/20250213.0-5304067/MODULE.bazel
vendored
Normal file
8
third_party/modules/zig-protobuf/20250213.0-5304067/MODULE.bazel
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
module(
|
||||||
|
name = "zig-protobuf",
|
||||||
|
version = "20250213.0-5304067",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
32
third_party/modules/zig-protobuf/20250213.0-5304067/overlay/BUILD.bazel
vendored
Normal file
32
third_party/modules/zig-protobuf/20250213.0-5304067/overlay/BUILD.bazel
vendored
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
load("@rules_proto//proto:defs.bzl", "proto_lang_toolchain")
|
||||||
|
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary", "zig_library")
|
||||||
|
|
||||||
|
zig_library(
|
||||||
|
name = "protobuf",
|
||||||
|
import_name = "protobuf",
|
||||||
|
main = "src/protobuf.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
zig_binary(
|
||||||
|
name = "generator",
|
||||||
|
srcs = [
|
||||||
|
"bootstrapped-generator/FullName.zig",
|
||||||
|
"bootstrapped-generator/google/protobuf/compiler/plugin.pb.zig",
|
||||||
|
"bootstrapped-generator/google/protobuf/descriptor.pb.zig",
|
||||||
|
],
|
||||||
|
kind = BINARY_KIND.exe,
|
||||||
|
main = "bootstrapped-generator/main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":protobuf"],
|
||||||
|
)
|
||||||
|
|
||||||
|
proto_lang_toolchain(
|
||||||
|
name = "zig_toolchain",
|
||||||
|
command_line = "--zig_out=$(OUT)",
|
||||||
|
output_files = "multiple",
|
||||||
|
plugin = ":generator",
|
||||||
|
plugin_format_flag = "--plugin=protoc-gen-zig=%s",
|
||||||
|
runtime = ":protobuf",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
8
third_party/modules/zig-protobuf/20250213.0-5304067/overlay/MODULE.bazel
vendored
Normal file
8
third_party/modules/zig-protobuf/20250213.0-5304067/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
module(
|
||||||
|
name = "zig-protobuf",
|
||||||
|
version = "20250213.0-5304067",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "rules_zig", version = "20240904.0-010da15")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
9
third_party/modules/zig-protobuf/20250213.0-5304067/source.json
vendored
Normal file
9
third_party/modules/zig-protobuf/20250213.0-5304067/source.json
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "zig-protobuf-5304067205135532c0ad57f78243bfe86dc1ad3f",
|
||||||
|
"url": "https://github.com/gwenzek/zig-protobuf/archive/5304067205135532c0ad57f78243bfe86dc1ad3f.tar.gz",
|
||||||
|
"integrity": "sha256-cdby+J1CGNOvLRvdeASPNlB7y1q59LYcgjFOooBvIcA=",
|
||||||
|
"overlay": {
|
||||||
|
"MODULE.bazel": "",
|
||||||
|
"BUILD.bazel": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -11,7 +11,8 @@
|
|||||||
"gwenzek/zig-protobuf"
|
"gwenzek/zig-protobuf"
|
||||||
],
|
],
|
||||||
"versions": [
|
"versions": [
|
||||||
"20240722.0-c644d11"
|
"20240722.0-c644d11",
|
||||||
|
"20250213.0-5304067",
|
||||||
],
|
],
|
||||||
"yanked_versions": {}
|
"yanked_versions": {}
|
||||||
}
|
}
|
||||||
|
|||||||
6
third_party/zls/zls.bzl
vendored
6
third_party/zls/zls.bzl
vendored
@ -1,17 +1,17 @@
|
|||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
|
||||||
_VERSION = "0.13.0"
|
_VERSION = "0.14.0"
|
||||||
|
|
||||||
_ARCH = {
|
_ARCH = {
|
||||||
"x86_64-linux": struct(
|
"x86_64-linux": struct(
|
||||||
sha256 = "ec4c1b45caf88e2bcb9ebb16c670603cc596e4f621b96184dfbe837b39cd8410",
|
sha256 = "661f8d402ba3dc9b04b6e9bc3026495be7b838d2f18d148db2bd98bd699c1360",
|
||||||
exec_compatible_with = [
|
exec_compatible_with = [
|
||||||
"@platforms//os:linux",
|
"@platforms//os:linux",
|
||||||
"@platforms//cpu:x86_64",
|
"@platforms//cpu:x86_64",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
"aarch64-macos": struct(
|
"aarch64-macos": struct(
|
||||||
sha256 = "9848514524f5e5d33997ac280b7d92388407209d4b8d4be3866dc3cf30ca6ca8",
|
sha256 = "dfb627e1f9603583678f552d8035a12dce878215c0a507b32d6f1b9d074d6c4d",
|
||||||
exec_compatible_with = [
|
exec_compatible_with = [
|
||||||
"@platforms//os:macos",
|
"@platforms//os:macos",
|
||||||
"@platforms//cpu:aarch64",
|
"@platforms//cpu:aarch64",
|
||||||
|
|||||||
46
zml/aio.zig
46
zml/aio.zig
@ -227,7 +227,7 @@ pub const Metadata = union(enum) {
|
|||||||
pub const MemoryMappedFile = struct {
|
pub const MemoryMappedFile = struct {
|
||||||
/// underlying file handle
|
/// underlying file handle
|
||||||
file: asynk.File,
|
file: asynk.File,
|
||||||
data: []align(std.mem.page_size) const u8,
|
data: []align(std.heap.page_size_min) const u8,
|
||||||
data_offset: u64 = 0,
|
data_offset: u64 = 0,
|
||||||
|
|
||||||
pub fn init(file: asynk.File) !MemoryMappedFile {
|
pub fn init(file: asynk.File) !MemoryMappedFile {
|
||||||
@ -236,7 +236,7 @@ pub const MemoryMappedFile = struct {
|
|||||||
null,
|
null,
|
||||||
data_len,
|
data_len,
|
||||||
std.posix.PROT.READ,
|
std.posix.PROT.READ,
|
||||||
.{ .TYPE = .PRIVATE },
|
std.posix.system.MAP{ .TYPE = .PRIVATE },
|
||||||
file.handle(),
|
file.handle(),
|
||||||
0,
|
0,
|
||||||
});
|
});
|
||||||
@ -305,7 +305,7 @@ const PrefixBuilder = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn pop(self: *PrefixBuilder) void {
|
pub fn pop(self: *PrefixBuilder) void {
|
||||||
const last_prefix_len = self.subprefixes.popOrNull() orelse unreachable;
|
const last_prefix_len = self.subprefixes.pop() orelse unreachable;
|
||||||
self.data.shrinkRetainingCapacity(last_prefix_len);
|
self.data.shrinkRetainingCapacity(last_prefix_len);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -320,8 +320,8 @@ fn _populateStruct(
|
|||||||
) !bool {
|
) !bool {
|
||||||
const err_msg = "_populateStruct must be called with a pointer to type. Received ";
|
const err_msg = "_populateStruct must be called with a pointer to type. Received ";
|
||||||
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
||||||
.Pointer => |ptr_info| switch (ptr_info.size) {
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
.One => .{ @typeInfo(ptr_info.child), ptr_info.child },
|
.one => .{ @typeInfo(ptr_info.child), ptr_info.child },
|
||||||
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
||||||
},
|
},
|
||||||
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
||||||
@ -346,8 +346,8 @@ fn _populateStruct(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return switch (type_info) {
|
return switch (type_info) {
|
||||||
.Pointer => |ptr_info| {
|
.pointer => |ptr_info| {
|
||||||
if (ptr_info.size == .Slice) {
|
if (ptr_info.size == .slice) {
|
||||||
obj.* = &.{};
|
obj.* = &.{};
|
||||||
|
|
||||||
const len = buffer_store.countLayers(prefix);
|
const len = buffer_store.countLayers(prefix);
|
||||||
@ -372,7 +372,7 @@ fn _populateStruct(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Array => |arr_info| {
|
.array => |arr_info| {
|
||||||
for (obj, 0..) |*value, i| {
|
for (obj, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
@ -384,7 +384,7 @@ fn _populateStruct(
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
var partial_struct = false;
|
var partial_struct = false;
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||||
@ -392,11 +392,11 @@ fn _populateStruct(
|
|||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
var has_default = false;
|
var has_default = false;
|
||||||
if (field.default_value) |_| has_default = true;
|
if (field.default_value_ptr) |_| has_default = true;
|
||||||
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
|
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
|
||||||
partial_struct = partial_struct or field_found;
|
partial_struct = partial_struct or field_found;
|
||||||
if (!field_found) {
|
if (!field_found) {
|
||||||
if (field.default_value) |v| {
|
if (field.default_value_ptr) |v| {
|
||||||
@field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*;
|
@field(obj, field.name) = @as(*const field.type, @alignCast(@ptrCast(v))).*;
|
||||||
} else {
|
} else {
|
||||||
if (partial_struct) {
|
if (partial_struct) {
|
||||||
@ -411,22 +411,22 @@ fn _populateStruct(
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Optional => |opt_info| {
|
.optional => |opt_info| {
|
||||||
obj.* = @as(opt_info.child, undefined);
|
obj.* = @as(opt_info.child, undefined);
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
|
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
|
||||||
if (!found) obj.* = null;
|
if (!found) obj.* = null;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Int => {
|
.int => {
|
||||||
obj.* = undefined;
|
obj.* = undefined;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Float => {
|
.float => {
|
||||||
obj.* = undefined;
|
obj.* = undefined;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.Void => true,
|
.void => true,
|
||||||
.Union => true,
|
.@"union" => true,
|
||||||
else => if (required) {
|
else => if (required) {
|
||||||
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
||||||
return error.UnsupportedMetadataType;
|
return error.UnsupportedMetadataType;
|
||||||
@ -635,8 +635,8 @@ pub fn awaitAll(buffers: anytype) !void {
|
|||||||
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
||||||
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
||||||
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
||||||
.Pointer => |ptr_info| switch (ptr_info.size) {
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
.One => .{ @typeInfo(ptr_info.child), ptr_info.child },
|
.one => .{ @typeInfo(ptr_info.child), ptr_info.child },
|
||||||
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
||||||
},
|
},
|
||||||
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
else => @compileError(err_msg ++ @typeName(@TypeOf(obj))),
|
||||||
@ -661,8 +661,8 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
} else if (T == zml.Shape) return;
|
} else if (T == zml.Shape) return;
|
||||||
|
|
||||||
switch (type_info) {
|
switch (type_info) {
|
||||||
.Pointer => |ptr_info| {
|
.pointer => |ptr_info| {
|
||||||
if (ptr_info.size == .Slice) {
|
if (ptr_info.size == .slice) {
|
||||||
for (obj.*, 0..) |*value, i| {
|
for (obj.*, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
@ -671,7 +671,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
}
|
}
|
||||||
} else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
|
} else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
|
||||||
},
|
},
|
||||||
.Array => {
|
.array => {
|
||||||
for (obj, 0..) |*value, i| {
|
for (obj, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
@ -679,7 +679,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
@ -688,7 +688,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Optional => {
|
.optional => {
|
||||||
if (obj.*) |*obj_val| {
|
if (obj.*) |*obj_val| {
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,7 +45,7 @@ fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.G
|
|||||||
res.value_ptr.* = switch (entry.val) {
|
res.value_ptr.* = switch (entry.val) {
|
||||||
.array => |arr| switch (arr.child) {
|
.array => |arr| switch (arr.child) {
|
||||||
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
|
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
|
||||||
const T = std.meta.FieldType(core.GgufValue, tag);
|
const T = @FieldType(core.GgufValue, @tagName(tag));
|
||||||
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
|
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
|
||||||
},
|
},
|
||||||
else => blk: {
|
else => blk: {
|
||||||
|
|||||||
@ -84,7 +84,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
|
|||||||
buffer_file.data_offset = 8 + json_header_length;
|
buffer_file.data_offset = 8 + json_header_length;
|
||||||
|
|
||||||
try files.append(buffer_file);
|
try files.append(buffer_file);
|
||||||
errdefer _ = files.popOrNull();
|
errdefer _ = files.pop();
|
||||||
|
|
||||||
var it = metadata.object.iterator();
|
var it = metadata.object.iterator();
|
||||||
while (it.next()) |entry| {
|
while (it.next()) |entry| {
|
||||||
|
|||||||
@ -401,11 +401,11 @@ test evaluate {
|
|||||||
try std.testing.expectEqualDeep(expected, entries);
|
try std.testing.expectEqualDeep(expected, entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pop(values: *std.ArrayList(py.Any)) !py.Any {
|
pub fn pop(values: *std.ArrayList(py.Any)) error{StackUnderrun}!py.Any {
|
||||||
if (values.items.len == 0) {
|
if (values.items.len == 0) {
|
||||||
return error.StackUnderrun;
|
return error.StackUnderrun;
|
||||||
}
|
}
|
||||||
return values.pop();
|
return values.pop().?;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn popMark(values: *std.ArrayList(py.Any)) ![]py.Any {
|
fn popMark(values: *std.ArrayList(py.Any)) ![]py.Any {
|
||||||
|
|||||||
@ -1058,8 +1058,8 @@ fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes:
|
|||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).Int.bits, 8)]u8 {
|
fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8)]u8 {
|
||||||
var res: [@divExact(@typeInfo(T).Int.bits, 8)]u8 = undefined;
|
var res: [@divExact(@typeInfo(T).int.bits, 8)]u8 = undefined;
|
||||||
std.mem.writeInt(T, &res, value, .little);
|
std.mem.writeInt(T, &res, value, .little);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,7 +27,7 @@ const log = std.log.scoped(.zml);
|
|||||||
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
||||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||||
pub const Buffer = struct {
|
pub const Buffer = struct {
|
||||||
pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).Enum.tag_type) {
|
pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).@"enum".tag_type) {
|
||||||
host = @intFromEnum(pjrt.Memory.Kind.unpinned_host),
|
host = @intFromEnum(pjrt.Memory.Kind.unpinned_host),
|
||||||
host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host),
|
host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host),
|
||||||
device = @intFromEnum(pjrt.Memory.Kind.device),
|
device = @intFromEnum(pjrt.Memory.Kind.device),
|
||||||
@ -91,7 +91,7 @@ pub const Buffer = struct {
|
|||||||
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
|
const frame = try asynk.asyncc(pjrt.Client.bufferFromHostBuffer, .{
|
||||||
platform.pjrt_client,
|
platform.pjrt_client,
|
||||||
platform.pjrt_api,
|
platform.pjrt_api,
|
||||||
.{
|
pjrt.Client.BufferFromHostBufferArgs{
|
||||||
.data = buf.data,
|
.data = buf.data,
|
||||||
.buffer_type = buffer_type,
|
.buffer_type = buffer_type,
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
@ -239,7 +239,7 @@ pub const Buffer = struct {
|
|||||||
// TODO: exposes sharding in the API.
|
// TODO: exposes sharding in the API.
|
||||||
.device = platform.getDevices()[0],
|
.device = platform.getDevices()[0],
|
||||||
.layout = .{
|
.layout = .{
|
||||||
.Tiled = .{
|
.tiled = .{
|
||||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
||||||
.tile_dims = &.{},
|
.tile_dims = &.{},
|
||||||
.tile_dims_sizes = &.{},
|
.tile_dims_sizes = &.{},
|
||||||
@ -404,12 +404,12 @@ pub fn dtypeFromBufferType(pjrt_type: pjrt.BufferType) DataType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test bufferTypeFromDtype {
|
test bufferTypeFromDtype {
|
||||||
inline for (@typeInfo(DataType).Enum.fields) |field| {
|
inline for (@typeInfo(DataType).@"enum".fields) |field| {
|
||||||
const dt: DataType = @enumFromInt(field.value);
|
const dt: DataType = @enumFromInt(field.value);
|
||||||
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
|
try std.testing.expectEqual(dt, dtypeFromBufferType(bufferTypeFromDtype(dt)));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline for (@typeInfo(pjrt.BufferType).Enum.fields) |field| {
|
inline for (@typeInfo(pjrt.BufferType).@"enum".fields) |field| {
|
||||||
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
const dt: pjrt.BufferType = @enumFromInt(field.value);
|
||||||
if (dt == .INVALID) continue;
|
if (dt == .INVALID) continue;
|
||||||
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||||
|
|||||||
@ -97,15 +97,15 @@ pub const DataType = enum(u8) {
|
|||||||
|
|
||||||
pub fn fromSliceElementType(slice: anytype) DataType {
|
pub fn fromSliceElementType(slice: anytype) DataType {
|
||||||
const type_info = @typeInfo(@TypeOf(slice));
|
const type_info = @typeInfo(@TypeOf(slice));
|
||||||
if (type_info != .Pointer) {
|
if (type_info != .pointer) {
|
||||||
@compileError("`initFromSlice` expects a slice, got " ++ @tagName(type_info));
|
@compileError("`initFromSlice` expects a slice, got " ++ @tagName(type_info));
|
||||||
}
|
}
|
||||||
|
|
||||||
return switch (type_info.Pointer.size) {
|
return switch (type_info.pointer.size) {
|
||||||
.Slice, .C, .Many => DataType.fromZigType(type_info.Pointer.child),
|
.slice, .c, .many => DataType.fromZigType(type_info.pointer.child),
|
||||||
.One => b: {
|
.one => b: {
|
||||||
const child_type_info = @typeInfo(type_info.Pointer.child);
|
const child_type_info = @typeInfo(type_info.pointer.child);
|
||||||
break :b DataType.fromZigType(child_type_info.Array.child);
|
break :b DataType.fromZigType(child_type_info.array.child);
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -182,10 +182,10 @@ pub const DataType = enum(u8) {
|
|||||||
pub fn minValue(dtype: DataType) Data {
|
pub fn minValue(dtype: DataType) Data {
|
||||||
return switch (dtype) {
|
return switch (dtype) {
|
||||||
.bool => .{ .bool = false },
|
.bool => .{ .bool = false },
|
||||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).zero()),
|
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).zero()),
|
||||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).minusInf()),
|
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).minusInf()),
|
||||||
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), -std.math.inf(std.meta.FieldType(Data, tag))),
|
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), -std.math.inf(@FieldType(Data, @tagName(tag)))),
|
||||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.minInt(std.meta.FieldType(Data, tag))),
|
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.minInt(@FieldType(Data, @tagName(tag)))),
|
||||||
inline else => |tag| @panic("Unsupported type: " ++ @tagName(tag)),
|
inline else => |tag| @panic("Unsupported type: " ++ @tagName(tag)),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -194,9 +194,9 @@ pub const DataType = enum(u8) {
|
|||||||
return switch (dtype) {
|
return switch (dtype) {
|
||||||
.bool => .{ .bool = true },
|
.bool => .{ .bool = true },
|
||||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2fnuz => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||||
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).inf()),
|
inline .f8e5m2, .bf16 => |tag| @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).inf()),
|
||||||
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(std.meta.FieldType(Data, tag))),
|
inline .f16, .f32, .f64 => |tag| @unionInit(Data, @tagName(tag), std.math.inf(@FieldType(Data, @tagName(tag)))),
|
||||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(std.meta.FieldType(Data, tag))),
|
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| @unionInit(Data, @tagName(tag), std.math.maxInt(@FieldType(Data, @tagName(tag)))),
|
||||||
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
inline .c64, .c128 => |tag| @panic("DataType doesn't have a max value: " ++ @tagName(tag)),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -241,30 +241,30 @@ pub const Data = union(DataType) {
|
|||||||
|
|
||||||
return switch (dtype_) {
|
return switch (dtype_) {
|
||||||
.bool => switch (Ti) {
|
.bool => switch (Ti) {
|
||||||
.Bool => .{ .bool = value },
|
.bool => .{ .bool = value },
|
||||||
.ComptimeInt, .Int, .ComptimeFloat, .Float => .{ .bool = value != 0 },
|
.comptime_int, .int, .comptime_float, .float => .{ .bool = value != 0 },
|
||||||
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
|
else => @panic("Could not create Data of type bool from value of type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) {
|
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |tag| switch (Ti) {
|
||||||
.ComptimeInt, .Int => @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).fromF32(@floatFromInt(value))),
|
.comptime_int, .int => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatFromInt(value))),
|
||||||
.ComptimeFloat, .Float => @unionInit(Data, @tagName(tag), std.meta.FieldType(Data, tag).fromF32(@floatCast(value))),
|
.comptime_float, .float => @unionInit(Data, @tagName(tag), @FieldType(Data, @tagName(tag)).fromF32(@floatCast(value))),
|
||||||
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
|
else => @panic("Could not create Data of type bf16 from value of type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
inline .f16, .f32, .f64 => |tag| switch (Ti) {
|
inline .f16, .f32, .f64 => |tag| switch (Ti) {
|
||||||
.ComptimeInt, .Int => @unionInit(Data, @tagName(tag), @floatFromInt(value)),
|
.comptime_int, .int => @unionInit(Data, @tagName(tag), @floatFromInt(value)),
|
||||||
.ComptimeFloat, .Float => @unionInit(Data, @tagName(tag), @floatCast(value)),
|
.comptime_float, .float => @unionInit(Data, @tagName(tag), @floatCast(value)),
|
||||||
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
|
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
|
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |tag| switch (Ti) {
|
||||||
.ComptimeInt => blk: {
|
.comptime_int => blk: {
|
||||||
const OutT = std.meta.FieldType(Data, tag);
|
const OutT = @FieldType(Data, @tagName(tag));
|
||||||
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {
|
if (value >= std.math.minInt(OutT) and value <= std.math.maxInt(OutT)) {
|
||||||
break :blk @unionInit(Data, @tagName(tag), @intCast(value));
|
break :blk @unionInit(Data, @tagName(tag), @intCast(value));
|
||||||
} else {
|
} else {
|
||||||
@panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T));
|
@panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Int => @unionInit(Data, @tagName(tag), @intCast(value)),
|
.int => @unionInit(Data, @tagName(tag), @intCast(value)),
|
||||||
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
|
else => @panic("Could not create Data of type " ++ @tagName(tag) ++ " from value of type " ++ @typeName(T)),
|
||||||
},
|
},
|
||||||
.c64 => switch (T) {
|
.c64 => switch (T) {
|
||||||
@ -316,13 +316,13 @@ pub const Data = union(DataType) {
|
|||||||
pub fn as(self: Data, comptime T: type) T {
|
pub fn as(self: Data, comptime T: type) T {
|
||||||
// TODO allow more lossless conversions
|
// TODO allow more lossless conversions
|
||||||
switch (@typeInfo(T)) {
|
switch (@typeInfo(T)) {
|
||||||
.Bool => return self.bool,
|
.bool => return self.bool,
|
||||||
.Float => switch (self) {
|
.float => switch (self) {
|
||||||
inline .f16, .f32, .f64 => |v| return @floatCast(v),
|
inline .f16, .f32, .f64 => |v| return @floatCast(v),
|
||||||
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |v| return @floatCast(v.toF32()),
|
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz, .bf16 => |v| return @floatCast(v.toF32()),
|
||||||
else => {},
|
else => {},
|
||||||
},
|
},
|
||||||
.Int => switch (self) {
|
.int => switch (self) {
|
||||||
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |v| return @intCast(v),
|
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |v| return @intCast(v),
|
||||||
else => {},
|
else => {},
|
||||||
},
|
},
|
||||||
|
|||||||
@ -97,7 +97,7 @@ pub fn FnExe(comptime func: anytype) type {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn ModuleExe(comptime func: anytype) type {
|
pub fn ModuleExe(comptime func: anytype) type {
|
||||||
const AllArgs = stdx.meta.FnArgs(func);
|
const AllArgs = stdx.meta.FnArgs(func);
|
||||||
const len = @typeInfo(AllArgs).Struct.fields.len;
|
const len = @typeInfo(AllArgs).@"struct".fields.len;
|
||||||
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
||||||
return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func));
|
return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func));
|
||||||
}
|
}
|
||||||
@ -113,7 +113,7 @@ const Sign = struct {
|
|||||||
|
|
||||||
pub fn ModuleSignature(comptime func: anytype) Sign {
|
pub fn ModuleSignature(comptime func: anytype) Sign {
|
||||||
const AllArgsT = stdx.meta.FnArgs(func);
|
const AllArgsT = stdx.meta.FnArgs(func);
|
||||||
const len = @typeInfo(AllArgsT).Struct.fields.len;
|
const len = @typeInfo(AllArgsT).@"struct".fields.len;
|
||||||
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
|
|||||||
@ -129,12 +129,12 @@ fn ShapeStruct(comptime dims: anytype) type {
|
|||||||
struct_field.* = .{
|
struct_field.* = .{
|
||||||
.name = @tagName(axis),
|
.name = @tagName(axis),
|
||||||
.type = i64,
|
.type = i64,
|
||||||
.default_value = &default,
|
.default_value_ptr = &default,
|
||||||
.is_comptime = false,
|
.is_comptime = false,
|
||||||
.alignment = @alignOf(i64),
|
.alignment = @alignOf(i64),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return @Type(.{ .Struct = .{
|
return @Type(.{ .@"struct" = .{
|
||||||
.layout = .@"extern",
|
.layout = .@"extern",
|
||||||
.fields = &struct_fields,
|
.fields = &struct_fields,
|
||||||
.decls = &.{},
|
.decls = &.{},
|
||||||
|
|||||||
@ -19,7 +19,7 @@ pub const HostBuffer = struct {
|
|||||||
_strides: ?[Shape.MAX_RANK]i64 = null,
|
_strides: ?[Shape.MAX_RANK]i64 = null,
|
||||||
data: []const u8,
|
data: []const u8,
|
||||||
_memory: union(enum) {
|
_memory: union(enum) {
|
||||||
managed: u5,
|
managed: std.mem.Alignment,
|
||||||
unmanaged,
|
unmanaged,
|
||||||
} = .unmanaged,
|
} = .unmanaged,
|
||||||
|
|
||||||
@ -29,8 +29,8 @@ pub const HostBuffer = struct {
|
|||||||
pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer {
|
pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer {
|
||||||
return .{
|
return .{
|
||||||
._shape = sh,
|
._shape = sh,
|
||||||
.data = try allocator.alignedAlloc(u8, std.atomic.cache_line, sh.byteSize()),
|
.data = try allocator.alignedAlloc(u8, 64, sh.byteSize()),
|
||||||
._memory = .{ .managed = std.math.log2_int(u16, std.atomic.cache_line) },
|
._memory = .{ .managed = .@"64" },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,9 +174,10 @@ pub const HostBuffer = struct {
|
|||||||
return if (self._strides) |*strd| strd[0..self.rank()] else null;
|
return if (self._strides) |*strd| strd[0..self.rank()] else null;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn data(self: HostBuffer) []const u8 {
|
// TODO: rename .data into ._data and make it a [*]u8
|
||||||
return self.data;
|
// pub fn data(self: HostBuffer) []const u8 {
|
||||||
}
|
// return self.data;
|
||||||
|
// }
|
||||||
|
|
||||||
pub inline fn rank(self: HostBuffer) u4 {
|
pub inline fn rank(self: HostBuffer) u4 {
|
||||||
return self._shape.rank();
|
return self._shape.rank();
|
||||||
@ -336,7 +337,7 @@ pub const HostBuffer = struct {
|
|||||||
|
|
||||||
fn parseArrayInfo(T: type) Shape {
|
fn parseArrayInfo(T: type) Shape {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Array => |arr| {
|
.array => |arr| {
|
||||||
const s = parseArrayInfo(arr.child);
|
const s = parseArrayInfo(arr.child);
|
||||||
return s.insert(0, .{arr.len});
|
return s.insert(0, .{arr.len});
|
||||||
},
|
},
|
||||||
|
|||||||
104
zml/meta.zig
104
zml/meta.zig
@ -23,7 +23,7 @@ pub fn MapType(From: type, To: type) type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Struct => |struct_infos| {
|
.@"struct" => |struct_infos| {
|
||||||
const fields = struct_infos.fields;
|
const fields = struct_infos.fields;
|
||||||
var same: bool = true;
|
var same: bool = true;
|
||||||
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
|
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
|
||||||
@ -36,7 +36,7 @@ pub fn MapType(From: type, To: type) type {
|
|||||||
struct_field.* = .{
|
struct_field.* = .{
|
||||||
.name = field.name,
|
.name = field.name,
|
||||||
.type = R,
|
.type = R,
|
||||||
.default_value = null,
|
.default_value_ptr = null,
|
||||||
.is_comptime = field.is_comptime,
|
.is_comptime = field.is_comptime,
|
||||||
.alignment = @alignOf(R),
|
.alignment = @alignOf(R),
|
||||||
};
|
};
|
||||||
@ -45,7 +45,7 @@ pub fn MapType(From: type, To: type) type {
|
|||||||
// Generic handling of default value is complicated,
|
// Generic handling of default value is complicated,
|
||||||
// it would require to call the callback at comptime.
|
// it would require to call the callback at comptime.
|
||||||
if (R == ?To) {
|
if (R == ?To) {
|
||||||
struct_field.default_value = &@as(R, null);
|
struct_field.default_value_ptr = &@as(R, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -53,26 +53,26 @@ pub fn MapType(From: type, To: type) type {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (same) return T;
|
if (same) return T;
|
||||||
return @Type(.{ .Struct = .{
|
return @Type(.{ .@"struct" = .{
|
||||||
.layout = .auto,
|
.layout = .auto,
|
||||||
.fields = struct_fields[0..],
|
.fields = struct_fields[0..],
|
||||||
.decls = &.{},
|
.decls = &.{},
|
||||||
.is_tuple = struct_infos.is_tuple,
|
.is_tuple = struct_infos.is_tuple,
|
||||||
} });
|
} });
|
||||||
},
|
},
|
||||||
.Array => |arr_info| [arr_info.len]map(arr_info.child),
|
.array => |arr_info| [arr_info.len]map(arr_info.child),
|
||||||
.Pointer => |ptr_info| switch (ptr_info.size) {
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
.Slice => if (ptr_info.is_const)
|
.slice => if (ptr_info.is_const)
|
||||||
[]const map(ptr_info.child)
|
[]const map(ptr_info.child)
|
||||||
else
|
else
|
||||||
[]map(ptr_info.child),
|
[]map(ptr_info.child),
|
||||||
.One => if (ptr_info.is_const)
|
.one => if (ptr_info.is_const)
|
||||||
*const map(ptr_info.child)
|
*const map(ptr_info.child)
|
||||||
else
|
else
|
||||||
*map(ptr_info.child),
|
*map(ptr_info.child),
|
||||||
else => T,
|
else => T,
|
||||||
},
|
},
|
||||||
.Optional => |opt_info| ?map(opt_info.child),
|
.optional => |opt_info| ?map(opt_info.child),
|
||||||
else => T,
|
else => T,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -95,10 +95,10 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
const To = stdx.meta.FnResult(cb);
|
const To = stdx.meta.FnResult(cb);
|
||||||
const FromStruct = @TypeOf(from);
|
const FromStruct = @TypeOf(from);
|
||||||
const type_info_to_ptr = @typeInfo(@TypeOf(to));
|
const type_info_to_ptr = @typeInfo(@TypeOf(to));
|
||||||
if (type_info_to_ptr != .Pointer) {
|
if (type_info_to_ptr != .pointer) {
|
||||||
stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)});
|
stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)});
|
||||||
}
|
}
|
||||||
const ToStruct = type_info_to_ptr.Pointer.child;
|
const ToStruct = type_info_to_ptr.pointer.child;
|
||||||
const type_info_to = @typeInfo(ToStruct);
|
const type_info_to = @typeInfo(ToStruct);
|
||||||
|
|
||||||
if (FromStruct == From) {
|
if (FromStruct == From) {
|
||||||
@ -123,13 +123,13 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
if (@sizeOf(ToStruct) == 0) return;
|
if (@sizeOf(ToStruct) == 0) return;
|
||||||
|
|
||||||
switch (type_info_to) {
|
switch (type_info_to) {
|
||||||
.Struct => |info| inline for (info.fields) |field| {
|
.@"struct" => |info| inline for (info.fields) |field| {
|
||||||
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||||
const field_type_info = @typeInfo(field.type);
|
const field_type_info = @typeInfo(field.type);
|
||||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||||
switch (field_type_info) {
|
switch (field_type_info) {
|
||||||
// .Pointer => try convertType(From, To, allocator, @field(from, field.name), @field(to, field.name), Ctx, ctx, cb),
|
// .pointer => try convertType(From, To, allocator, @field(from, field.name), @field(to, field.name), Ctx, ctx, cb),
|
||||||
.Array, .Optional, .Union, .Struct, .Pointer => if (@hasField(FromStruct, field.name)) {
|
.array, .optional, .@"union", .@"struct", .pointer => if (@hasField(FromStruct, field.name)) {
|
||||||
try mapAlloc(
|
try mapAlloc(
|
||||||
cb,
|
cb,
|
||||||
allocator,
|
allocator,
|
||||||
@ -145,14 +145,14 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
else => @field(to, field.name) = @field(from, field.name),
|
else => @field(to, field.name) = @field(from, field.name),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Array => for (from, to) |f, *t| {
|
.array => for (from, to) |f, *t| {
|
||||||
try mapAlloc(cb, allocator, ctx, f, t);
|
try mapAlloc(cb, allocator, ctx, f, t);
|
||||||
},
|
},
|
||||||
.Pointer => |ptr_info| switch (ptr_info.size) {
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
.One => switch (type_info_to_ptr.Pointer.size) {
|
.one => switch (type_info_to_ptr.pointer.size) {
|
||||||
// pointer to array -> slice promotion
|
// pointer to array -> slice promotion
|
||||||
.Slice => {
|
.slice => {
|
||||||
const items = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len);
|
const items = try allocator.alloc(type_info_to_ptr.pointer.child, from.len);
|
||||||
for (from, items) |f, *t| {
|
for (from, items) |f, *t| {
|
||||||
try mapAlloc(cb, allocator, ctx, f, t);
|
try mapAlloc(cb, allocator, ctx, f, t);
|
||||||
}
|
}
|
||||||
@ -160,8 +160,8 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
},
|
},
|
||||||
else => try mapAlloc(cb, allocator, ctx, from.*, to.*),
|
else => try mapAlloc(cb, allocator, ctx, from.*, to.*),
|
||||||
},
|
},
|
||||||
.Slice => {
|
.slice => {
|
||||||
const items = try allocator.alloc(@typeInfo(ToStruct).Pointer.child, from.len);
|
const items = try allocator.alloc(@typeInfo(ToStruct).pointer.child, from.len);
|
||||||
for (from, items) |f, *t| {
|
for (from, items) |f, *t| {
|
||||||
try mapAlloc(cb, allocator, ctx, f, t);
|
try mapAlloc(cb, allocator, ctx, f, t);
|
||||||
}
|
}
|
||||||
@ -169,13 +169,13 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
},
|
},
|
||||||
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
||||||
},
|
},
|
||||||
.Optional => if (from) |f| {
|
.optional => if (from) |f| {
|
||||||
to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined);
|
to.* = @as(@typeInfo(type_info_to_ptr.pointer.child).optional.child, undefined);
|
||||||
try mapAlloc(cb, allocator, ctx, f, &(to.*.?));
|
try mapAlloc(cb, allocator, ctx, f, &(to.*.?));
|
||||||
} else {
|
} else {
|
||||||
to.* = null;
|
to.* = null;
|
||||||
},
|
},
|
||||||
.Int, .Float, .Enum, .Union => to.* = from,
|
.int, .float, .@"enum", .@"union" => to.* = from,
|
||||||
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,16 +241,16 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
const T = @TypeOf(v);
|
const T = @TypeOf(v);
|
||||||
const type_info_v = @typeInfo(T);
|
const type_info_v = @typeInfo(T);
|
||||||
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
||||||
.Pointer => |info| info.child,
|
.pointer => |info| info.child,
|
||||||
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (type_info_v != .Pointer) {
|
if (type_info_v != .pointer) {
|
||||||
const Callback = @TypeOf(cb);
|
const Callback = @TypeOf(cb);
|
||||||
stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T });
|
stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T });
|
||||||
}
|
}
|
||||||
const ptr_info = type_info_v.Pointer;
|
const ptr_info = type_info_v.pointer;
|
||||||
if (@typeInfo(ptr_info.child) == .Fn) return;
|
if (@typeInfo(ptr_info.child) == .@"fn") return;
|
||||||
if (ptr_info.child == anyopaque) return;
|
if (ptr_info.child == anyopaque) return;
|
||||||
// This is important, because with trivial types like void,
|
// This is important, because with trivial types like void,
|
||||||
// Zig sometimes decide to call `visit` at comptime, but can't do
|
// Zig sometimes decide to call `visit` at comptime, but can't do
|
||||||
@ -262,24 +262,24 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
// If we have a single pointer, two cases:
|
// If we have a single pointer, two cases:
|
||||||
// * It's a pointer to K, in which case we call the callback.
|
// * It's a pointer to K, in which case we call the callback.
|
||||||
// * It's a pointer to something else, in which case, we explore and recurse if needed.
|
// * It's a pointer to something else, in which case, we explore and recurse if needed.
|
||||||
.One => if (ptr_info.child == K) {
|
.one => if (ptr_info.child == K) {
|
||||||
cb(ctx, v);
|
cb(ctx, v);
|
||||||
} else if (ptr_info.child == ?K) {
|
} else if (ptr_info.child == ?K) {
|
||||||
if (v.*) |*val| cb(ctx, val);
|
if (v.*) |*val| cb(ctx, val);
|
||||||
} else switch (@typeInfo(ptr_info.child)) {
|
} else switch (@typeInfo(ptr_info.child)) {
|
||||||
.Struct => |s| inline for (s.fields) |field_info| {
|
.@"struct" => |s| inline for (s.fields) |field_info| {
|
||||||
if (field_info.is_comptime) continue;
|
if (field_info.is_comptime) continue;
|
||||||
const field_type_info = @typeInfo(field_info.type);
|
const field_type_info = @typeInfo(field_info.type);
|
||||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||||
switch (field_type_info) {
|
switch (field_type_info) {
|
||||||
.Pointer => visit(cb, ctx, @field(v, field_info.name)),
|
.pointer => visit(cb, ctx, @field(v, field_info.name)),
|
||||||
.Array, .Optional, .Union, .Struct => visit(cb, ctx, &@field(v, field_info.name)),
|
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field_info.name)),
|
||||||
else => {},
|
else => {},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
.array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
||||||
.Optional => if (v.* != null) visit(cb, ctx, &v.*.?),
|
.optional => if (v.* != null) visit(cb, ctx, &v.*.?),
|
||||||
.Union => switch (v.*) {
|
.@"union" => switch (v.*) {
|
||||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
@ -287,23 +287,23 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
// If we have a slice, two cases also:
|
// If we have a slice, two cases also:
|
||||||
// * It's a slice of K, in which case we call the callback for each element of the slice.
|
// * It's a slice of K, in which case we call the callback for each element of the slice.
|
||||||
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
|
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
|
||||||
.Slice => {
|
.slice => {
|
||||||
for (v) |*v_elem| {
|
for (v) |*v_elem| {
|
||||||
if (ptr_info.child == K) {
|
if (ptr_info.child == K) {
|
||||||
cb(ctx, v_elem);
|
cb(ctx, v_elem);
|
||||||
} else switch (@typeInfo(ptr_info.child)) {
|
} else switch (@typeInfo(ptr_info.child)) {
|
||||||
.Struct => |s| inline for (s.fields) |field_info| {
|
.@"struct" => |s| inline for (s.fields) |field_info| {
|
||||||
const field_type_info = @typeInfo(field_info.type);
|
const field_type_info = @typeInfo(field_info.type);
|
||||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||||
if (field_type_info == .Pointer) {
|
if (field_type_info == .pointer) {
|
||||||
visit(cb, ctx, @field(v_elem, field_info.name));
|
visit(cb, ctx, @field(v_elem, field_info.name));
|
||||||
} else {
|
} else {
|
||||||
visit(cb, ctx, &@field(v_elem, field_info.name));
|
visit(cb, ctx, &@field(v_elem, field_info.name));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
.array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
||||||
.Optional => if (v.* != null) visit(cb, ctx, &v.*.?),
|
.optional => if (v.* != null) visit(cb, ctx, &v.*.?),
|
||||||
.Union => switch (v_elem.*) {
|
.@"union" => switch (v_elem.*) {
|
||||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
@ -419,7 +419,7 @@ pub fn first(T: type, value: anytype) T {
|
|||||||
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
|
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
|
||||||
pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
|
pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
|
||||||
const sliceT = @typeInfo(FnParam(func, 0));
|
const sliceT = @typeInfo(FnParam(func, 0));
|
||||||
const T = sliceT.Pointer.child;
|
const T = sliceT.pointer.child;
|
||||||
const V = asSlice(@TypeOf(values));
|
const V = asSlice(@TypeOf(values));
|
||||||
if (V == T) {
|
if (V == T) {
|
||||||
return @call(.auto, func, .{values} ++ args);
|
return @call(.auto, func, .{values} ++ args);
|
||||||
@ -427,12 +427,12 @@ pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype
|
|||||||
// const fn_args
|
// const fn_args
|
||||||
|
|
||||||
return switch (@typeInfo(V)) {
|
return switch (@typeInfo(V)) {
|
||||||
.Pointer => stdx.debug.compileError("zip only accept by value arguments. Received: {}", .{V}),
|
.pointer => stdx.debug.compileError("zip only accept by value arguments. Received: {}", .{V}),
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
var out: V = values[0];
|
var out: V = values[0];
|
||||||
inline for (struct_info.fields) |f| {
|
inline for (struct_info.fields) |f| {
|
||||||
if (f.is_comptime) continue;
|
if (f.is_comptime) continue;
|
||||||
if (@typeInfo(f.type) == .Pointer) {
|
if (@typeInfo(f.type) == .pointer) {
|
||||||
stdx.debug.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: {}", .{V});
|
stdx.debug.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: {}", .{V});
|
||||||
}
|
}
|
||||||
var fields = try allocator.alloc(f.type, values.len);
|
var fields = try allocator.alloc(f.type, values.len);
|
||||||
@ -444,8 +444,8 @@ pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype
|
|||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
},
|
},
|
||||||
.Array => |arr_info| {
|
.array => |arr_info| {
|
||||||
if (@typeInfo(arr_info.child) == .Pointer) {
|
if (@typeInfo(arr_info.child) == .pointer) {
|
||||||
stdx.debug.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: {}", .{V});
|
stdx.debug.compileError("zip doesn't follow pointers and don't accept struct containing them. Received: {}", .{V});
|
||||||
}
|
}
|
||||||
var out: V = undefined;
|
var out: V = undefined;
|
||||||
@ -459,7 +459,7 @@ pub fn zip(comptime func: anytype, allocator: std.mem.Allocator, values: anytype
|
|||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
},
|
},
|
||||||
.Union, .Optional => stdx.debug.compileError("zip doesn't yet support {}", .{V}),
|
.@"union", .optional => stdx.debug.compileError("zip doesn't yet support {}", .{V}),
|
||||||
else => values[0],
|
else => values[0],
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -483,7 +483,7 @@ test zip {
|
|||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
||||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
func_ctx: _CollectCtx(func),
|
func_ctx: _CollectCtx(func),
|
||||||
out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT),
|
out: *std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT),
|
||||||
@ -505,7 +505,7 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(s
|
|||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
|
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
|
||||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
func_ctx: _CollectCtx(func),
|
func_ctx: _CollectCtx(func),
|
||||||
out: @TypeOf(out),
|
out: @TypeOf(out),
|
||||||
@ -525,12 +525,12 @@ pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out:
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn _CollectCtx(func: anytype) type {
|
fn _CollectCtx(func: anytype) type {
|
||||||
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
const params = @typeInfo(@TypeOf(func)).@"fn".params;
|
||||||
if (params.len == 1) return void;
|
if (params.len == 1) return void;
|
||||||
return params[0].type orelse @compileError("anytype not supported in collect");
|
return params[0].type orelse @compileError("anytype not supported in collect");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _CollectArg(func: anytype) type {
|
fn _CollectArg(func: anytype) type {
|
||||||
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
const params = @typeInfo(@TypeOf(func)).@"fn".params;
|
||||||
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -204,7 +204,7 @@ pub const CompilationContext = struct {
|
|||||||
var timer = std.time.Timer.start() catch null;
|
var timer = std.time.Timer.start() catch null;
|
||||||
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
||||||
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||||
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, func, &tensor_args, .{ .name = "main", .kind = .main } });
|
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, func, &tensor_args, CompilationContext.EmitMlirOpts{ .name = "main", .kind = .main } });
|
||||||
const module = self._module;
|
const module = self._module;
|
||||||
module.getBody().appendOperation(f.mlir_fn);
|
module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
@ -296,7 +296,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
pub fn closeBlock(self: *CompilationContext, block: Block) void {
|
pub fn closeBlock(self: *CompilationContext, block: Block) void {
|
||||||
const popped = self._blocks.pop();
|
const popped = self._blocks.pop();
|
||||||
std.debug.assert(block.block().eql(popped.block()));
|
std.debug.assert(block.block().eql(popped.?.block()));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pushBlock(self: *CompilationContext, block: Block) void {
|
fn pushBlock(self: *CompilationContext, block: Block) void {
|
||||||
@ -348,6 +348,11 @@ pub const CompilationContext = struct {
|
|||||||
return .{ block.block(), block_res };
|
return .{ block.block(), block_res };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const EmitMlirOpts = struct {
|
||||||
|
name: []const u8,
|
||||||
|
kind: MlirFn.Kind = .private,
|
||||||
|
};
|
||||||
|
|
||||||
/// Generate an MLIR function from a ZML function.
|
/// Generate an MLIR function from a ZML function.
|
||||||
/// The caller is responsible to have properly created the input
|
/// The caller is responsible to have properly created the input
|
||||||
/// tensors with unique tensor ids.
|
/// tensors with unique tensor ids.
|
||||||
@ -355,10 +360,7 @@ pub const CompilationContext = struct {
|
|||||||
self: *CompilationContext,
|
self: *CompilationContext,
|
||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
args: *const stdx.meta.FnArgs(func),
|
args: *const stdx.meta.FnArgs(func),
|
||||||
opts: struct {
|
opts: EmitMlirOpts,
|
||||||
name: []const u8,
|
|
||||||
kind: MlirFn.Kind = .private,
|
|
||||||
},
|
|
||||||
) error{OutOfMemory}!MlirFn {
|
) error{OutOfMemory}!MlirFn {
|
||||||
const frame = self._tracer.frameStart("emitMlir.emit");
|
const frame = self._tracer.frameStart("emitMlir.emit");
|
||||||
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
|
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
|
||||||
@ -944,9 +946,9 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
|
|
||||||
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
||||||
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
||||||
.Bool => .{ .value = .{ .bool_field = value } },
|
.bool => .{ .value = .{ .bool_field = value } },
|
||||||
.ComptimeInt, .Int => .{ .value = .{ .int_field = value } },
|
.comptime_int, .int => .{ .value = .{ .int_field = value } },
|
||||||
.ComptimeFloat, .Float => .{ .value = .{ .double_field = value } },
|
.comptime_float, .float => .{ .value = .{ .double_field = value } },
|
||||||
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
||||||
};
|
};
|
||||||
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
||||||
@ -1179,12 +1181,12 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch (@typeInfo(Key)) {
|
switch (@typeInfo(Key)) {
|
||||||
.NoReturn, .Opaque, .Undefined, .Null, .ComptimeFloat, .ComptimeInt, .Type, .EnumLiteral, .Frame, .Void => return,
|
.noreturn, .@"opaque", .undefined, .null, .comptime_float, .comptime_int, .type, .enum_literal, .frame, .void => return,
|
||||||
|
|
||||||
// Help the optimizer see that hashing an int is easy by inlining!
|
// Help the optimizer see that hashing an int is easy by inlining!
|
||||||
// TODO Check if the situation is better after #561 is resolved.
|
// TODO Check if the situation is better after #561 is resolved.
|
||||||
.Int => |int| switch (int.signedness) {
|
.int => |int| switch (int.signedness) {
|
||||||
.signed => hash(hasher, @as(@Type(.{ .Int = .{
|
.signed => hash(hasher, @as(@Type(.{ .int = .{
|
||||||
.bits = int.bits,
|
.bits = int.bits,
|
||||||
.signedness = .unsigned,
|
.signedness = .unsigned,
|
||||||
} }), @bitCast(key)), strat),
|
} }), @bitCast(key)), strat),
|
||||||
@ -1202,21 +1204,21 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
// Note: contrary to Zig we accept hashing floats.
|
// Note: contrary to Zig we accept hashing floats.
|
||||||
// Typically the float we are going to hash here are hyperparameters,
|
// Typically the float we are going to hash here are hyperparameters,
|
||||||
// and not the result of an operation, so bytes should be the same everytime.
|
// and not the result of an operation, so bytes should be the same everytime.
|
||||||
.Float => hasher.update(std.mem.asBytes(&key)),
|
.float => hasher.update(std.mem.asBytes(&key)),
|
||||||
.Bool => hash(hasher, @intFromBool(key), strat),
|
.bool => hash(hasher, @intFromBool(key), strat),
|
||||||
.Enum => hash(hasher, @intFromEnum(key), strat),
|
.@"enum" => hash(hasher, @intFromEnum(key), strat),
|
||||||
.ErrorSet => hash(hasher, @intFromError(key), strat),
|
.error_set => hash(hasher, @intFromError(key), strat),
|
||||||
.AnyFrame, .Fn => hash(hasher, @intFromPtr(key), strat),
|
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
|
||||||
.Pointer => |info| switch (info.size) {
|
.pointer => |info| switch (info.size) {
|
||||||
.One => switch (strat) {
|
.one => switch (strat) {
|
||||||
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
.Deep => hash(hasher, key.*, .Shallow),
|
.deep => hash(hasher, key.*, .Shallow),
|
||||||
.DeepRecursive => switch (@typeInfo(info.child)) {
|
.deeprecursive => switch (@typeInfo(info.child)) {
|
||||||
.Opaque, .Fn => hash(hasher, @intFromPtr(key), .Shallow),
|
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
else => hash(hasher, key.*, .DeepRecursive),
|
else => hash(hasher, key.*, .DeepRecursive),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
.Slice => {
|
.slice => {
|
||||||
switch (strat) {
|
switch (strat) {
|
||||||
.Shallow => hash(hasher, @intFromPtr(key.ptr), .Shallow),
|
.Shallow => hash(hasher, @intFromPtr(key.ptr), .Shallow),
|
||||||
.Deep => hashArray(hasher, key, .Shallow),
|
.Deep => hashArray(hasher, key, .Shallow),
|
||||||
@ -1224,21 +1226,21 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
}
|
}
|
||||||
hash(hasher, key.len, .Shallow);
|
hash(hasher, key.len, .Shallow);
|
||||||
},
|
},
|
||||||
.Many,
|
.many,
|
||||||
.C,
|
.c,
|
||||||
=> switch (strat) {
|
=> switch (strat) {
|
||||||
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
else => @compileError(
|
else => @compileError(
|
||||||
\\ unknown-length pointers and C pointers cannot be hashed deeply.
|
\\ unknown-length pointers and C pointers cannot be hashed deeply.
|
||||||
\\ Consider providing your own hash function.
|
\\ Consider providing your own hash function.
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
.Optional => if (key) |k| hash(hasher, k, strat),
|
.optional => if (key) |k| hash(hasher, k, strat),
|
||||||
|
|
||||||
.Array => hashArray(hasher, key, strat),
|
.array => hashArray(hasher, key, strat),
|
||||||
|
|
||||||
.Vector => |info| {
|
.vector => |info| {
|
||||||
if (std.meta.hasUniqueRepresentation(Key)) {
|
if (std.meta.hasUniqueRepresentation(Key)) {
|
||||||
hasher.update(std.mem.asBytes(&key));
|
hasher.update(std.mem.asBytes(&key));
|
||||||
} else {
|
} else {
|
||||||
@ -1249,7 +1251,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
.Struct => |info| {
|
.@"struct" => |info| {
|
||||||
inline for (info.fields) |field| {
|
inline for (info.fields) |field| {
|
||||||
// We reuse the hash of the previous field as the seed for the
|
// We reuse the hash of the previous field as the seed for the
|
||||||
// next one so that they're dependant.
|
// next one so that they're dependant.
|
||||||
@ -1257,7 +1259,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
.Union => |info| {
|
.@"union" => |info| {
|
||||||
if (info.tag_type) |tag_type| {
|
if (info.tag_type) |tag_type| {
|
||||||
const tag = std.meta.activeTag(key);
|
const tag = std.meta.activeTag(key);
|
||||||
hash(hasher, tag, strat);
|
hash(hasher, tag, strat);
|
||||||
@ -1275,7 +1277,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: HashStrategy
|
|||||||
} else @compileError("cannot hash untagged union type: " ++ @typeName(Key) ++ ", provide your own hash function");
|
} else @compileError("cannot hash untagged union type: " ++ @typeName(Key) ++ ", provide your own hash function");
|
||||||
},
|
},
|
||||||
|
|
||||||
.ErrorUnion => blk: {
|
.error_union => blk: {
|
||||||
const payload = key catch |err| {
|
const payload = key catch |err| {
|
||||||
hash(hasher, err, strat);
|
hash(hasher, err, strat);
|
||||||
break :blk;
|
break :blk;
|
||||||
|
|||||||
25
zml/nn.zig
25
zml/nn.zig
@ -1174,6 +1174,13 @@ pub const DynamicSamplingStrategy = struct {
|
|||||||
top_p: Tensor,
|
top_p: Tensor,
|
||||||
min_p: Tensor,
|
min_p: Tensor,
|
||||||
|
|
||||||
|
pub const Opts = struct {
|
||||||
|
top_k: u32,
|
||||||
|
temperature: f32 = 1.0,
|
||||||
|
top_p: f32 = 1.0,
|
||||||
|
min_p: f32 = 0.0,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) {
|
pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) {
|
||||||
const scalar_float = Shape.init(.{}, dtype);
|
const scalar_float = Shape.init(.{}, dtype);
|
||||||
const scalar_i32 = Shape.init(.{}, .i32);
|
const scalar_i32 = Shape.init(.{}, .i32);
|
||||||
@ -1189,19 +1196,14 @@ pub const DynamicSamplingStrategy = struct {
|
|||||||
pub fn makeBuffers(
|
pub fn makeBuffers(
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
dtype: zml.DataType,
|
dtype: zml.DataType,
|
||||||
args: struct {
|
opts: Opts,
|
||||||
top_k: u32,
|
|
||||||
temperature: f32 = 1.0,
|
|
||||||
top_p: f32 = 1.0,
|
|
||||||
min_p: f32 = 0.0,
|
|
||||||
},
|
|
||||||
) !zml.Bufferized(DynamicSamplingStrategy) {
|
) !zml.Bufferized(DynamicSamplingStrategy) {
|
||||||
return .{
|
return .{
|
||||||
.max_top_k = 0,
|
.max_top_k = 0,
|
||||||
.top_k = try zml.Buffer.scalar(platform, args.top_k, .i32),
|
.top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32),
|
||||||
.temperature = try zml.Buffer.scalar(platform, args.temperature, dtype),
|
.temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype),
|
||||||
.top_p = try zml.Buffer.scalar(platform, args.top_p, dtype),
|
.top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype),
|
||||||
.min_p = try zml.Buffer.scalar(platform, args.min_p, dtype),
|
.min_p = try zml.Buffer.scalar(platform, opts.min_p, dtype),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1274,7 +1276,8 @@ test sampleTokensDynamic {
|
|||||||
const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform);
|
const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform);
|
||||||
defer mod.deinit();
|
defer mod.deinit();
|
||||||
|
|
||||||
inline for (.{
|
const Args = struct { DynamicSamplingStrategy.Opts, [4]f32 };
|
||||||
|
inline for ([_]Args{
|
||||||
// top_k == logits.len -> just sort the input
|
// top_k == logits.len -> just sort the input
|
||||||
.{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } },
|
.{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } },
|
||||||
.{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
.{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||||
|
|||||||
@ -51,8 +51,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
||||||
const allocator = fba.allocator();
|
const allocator = fba.allocator();
|
||||||
|
|
||||||
const backend_config =
|
const backend_config = std.fmt.allocPrintZ(
|
||||||
std.fmt.allocPrintZ(
|
|
||||||
allocator,
|
allocator,
|
||||||
\\{{
|
\\{{
|
||||||
\\ "operation_queue_id":"0",
|
\\ "operation_queue_id":"0",
|
||||||
|
|||||||
18
zml/ops.zig
18
zml/ops.zig
@ -107,7 +107,7 @@ test "simple while" {
|
|||||||
|
|
||||||
const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
||||||
const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
||||||
const counter = .{
|
const counter: zml.Bufferized(CountInts) = .{
|
||||||
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
|
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
|
||||||
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
|
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
|
||||||
};
|
};
|
||||||
@ -301,7 +301,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
|||||||
// Reuse the first step Tensor.
|
// Reuse the first step Tensor.
|
||||||
// TODO: this is needed because of https://github.com/zml/zml/issues/97
|
// TODO: this is needed because of https://github.com/zml/zml/issues/97
|
||||||
// Normally I'd rather NOT reuse first_step to streamline the stablehlo IR.
|
// Normally I'd rather NOT reuse first_step to streamline the stablehlo IR.
|
||||||
return first_step.reshape(shape).pad(0, .{ ._0 = .{ .high = self.num_steps - 1 } });
|
return first_step.reshape(shape).pad(0, .{ ._0 = Tensor.Pad{ .high = self.num_steps - 1 } });
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor {
|
fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor {
|
||||||
@ -404,12 +404,12 @@ test "nested for" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn scanRow(x: Tensor, i: Tensor) Tensor {
|
pub fn scanRow(x: Tensor, i: Tensor) Tensor {
|
||||||
const row = x.dynamicSlice(.{.{ .start = i, .len = 1 }});
|
const row = x.dynamicSlice(.{Tensor.DynSlice{ .start = i, .len = 1 }});
|
||||||
return for_(OuterProd.scanCol, .{ .x = x, .x_row = row }, .{x.dim(0)});
|
return for_(OuterProd.scanCol, .{ .x = x, .x_row = row }, .{x.dim(0)});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scanCol(self: OuterProd, j: Tensor) Tensor {
|
pub fn scanCol(self: OuterProd, j: Tensor) Tensor {
|
||||||
const col = self.x.dynamicSlice(.{.{ .start = j, .len = 1 }});
|
const col = self.x.dynamicSlice(.{Tensor.DynSlice{ .start = j, .len = 1 }});
|
||||||
return self.x_row.mul(col);
|
return self.x_row.mul(col);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -663,10 +663,10 @@ pub fn fnInfo(comptime func: anytype) std.builtin.Type.Fn {
|
|||||||
}
|
}
|
||||||
const type_info = @typeInfo(@TypeOf(func));
|
const type_info = @typeInfo(@TypeOf(func));
|
||||||
const err_msg = "`func` must be a function and return one or more `Tensor`. Got: ";
|
const err_msg = "`func` must be a function and return one or more `Tensor`. Got: ";
|
||||||
if (type_info != .Fn or type_info.Fn.return_type == null) {
|
if (type_info != .@"fn" or type_info.@"fn".return_type == null) {
|
||||||
@compileError(err_msg ++ @typeName(@TypeOf(func)));
|
@compileError(err_msg ++ @typeName(@TypeOf(func)));
|
||||||
}
|
}
|
||||||
return type_info.Fn;
|
return type_info.@"fn";
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
|
fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
|
||||||
@ -731,13 +731,13 @@ pub fn staticCountTensors(comptime T: type) ?usize {
|
|||||||
if (T == Tensor) return 1;
|
if (T == Tensor) return 1;
|
||||||
|
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.Array => |array_info| array_info.len * (staticCountTensors(array_info.child) orelse return null),
|
.array => |array_info| array_info.len * (staticCountTensors(array_info.child) orelse return null),
|
||||||
.Pointer => |ptr_info| {
|
.pointer => |ptr_info| {
|
||||||
const n = staticCountTensors(ptr_info.child) orelse return null;
|
const n = staticCountTensors(ptr_info.child) orelse return null;
|
||||||
if (ptr_info.size != .One and n > 0) return null;
|
if (ptr_info.size != .One and n > 0) return null;
|
||||||
return n;
|
return n;
|
||||||
},
|
},
|
||||||
.Struct => |struct_info| {
|
.@"struct" => |struct_info| {
|
||||||
var count: usize = 0;
|
var count: usize = 0;
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
count += staticCountTensors(field.type) orelse return null;
|
count += staticCountTensors(field.type) orelse return null;
|
||||||
|
|||||||
@ -33,7 +33,7 @@ pub const Memory = pjrt.Memory;
|
|||||||
|
|
||||||
fn InnerMixin(comptime innerT: type) type {
|
fn InnerMixin(comptime innerT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).Pointer.is_const) *const innerT else *innerT {
|
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT {
|
||||||
return @ptrCast(self);
|
return @ptrCast(self);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
|
||||||
pub fn madvise(ptr: [*]align(std.mem.page_size) u8, length: usize, advice: u32) std.posix.MadviseError!void {
|
pub fn madvise(ptr: [*]align(std.heap.page_size_min) u8, length: usize, advice: u32) std.posix.MadviseError!void {
|
||||||
switch (std.posix.errno(c.madvise(ptr, @intCast(length), @intCast(advice)))) {
|
switch (std.posix.errno(c.madvise(ptr, @intCast(length), @intCast(advice)))) {
|
||||||
.SUCCESS => return,
|
.SUCCESS => return,
|
||||||
.ACCES => return error.AccessDenied,
|
.ACCES => return error.AccessDenied,
|
||||||
|
|||||||
@ -149,7 +149,7 @@ pub const Shape = struct {
|
|||||||
|
|
||||||
pub fn rank(self: Shape) u4 {
|
pub fn rank(self: Shape) u4 {
|
||||||
self.ensureDimsAndTagsAreSync();
|
self.ensureDimsAndTagsAreSync();
|
||||||
return self._dims.len;
|
return @intCast(self._dims.len);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dim(self: Shape, ax: anytype) i64 {
|
pub fn dim(self: Shape, ax: anytype) i64 {
|
||||||
|
|||||||
@ -3978,7 +3978,7 @@ pub fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK)
|
|||||||
|
|
||||||
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
fn _parseGatherCoord(self: Tensor, axes_: anytype) struct { bool, std.BoundedArray(u3, Tensor.MAX_RANK) } {
|
||||||
const AxesT = @TypeOf(axes_);
|
const AxesT = @TypeOf(axes_);
|
||||||
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .Int;
|
const axes_is_scalar = AxesT == EnumLiteral or AxesT == comptime_int or @typeInfo(AxesT) == .int;
|
||||||
|
|
||||||
const coord_axes = if (axes_is_scalar)
|
const coord_axes = if (axes_is_scalar)
|
||||||
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
std.BoundedArray(u3, Tensor.MAX_RANK).fromSlice(&.{self.axis(axes_)}) catch unreachable
|
||||||
|
|||||||
@ -11,7 +11,7 @@ const assert = std.debug.assert;
|
|||||||
// ref: https://github.com/ziglang/zig/issues/5738
|
// ref: https://github.com/ziglang/zig/issues/5738
|
||||||
const log_level: std.log.Level = .warn;
|
const log_level: std.log.Level = .warn;
|
||||||
|
|
||||||
pub const std_options = .{
|
pub const std_options: std.Options = .{
|
||||||
.log_level = log_level,
|
.log_level = log_level,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,7 @@ pub const Tokenizer = struct {
|
|||||||
|
|
||||||
try token_lookup.ensureTotalCapacity(arena, @intCast(vocab_size));
|
try token_lookup.ensureTotalCapacity(arena, @intCast(vocab_size));
|
||||||
|
|
||||||
const tokens: [][]const u8 = if (alloc_tokens) try arena.alloc([]u8, vocab_size) else &.{};
|
const tokens: [][]const u8 = if (alloc_tokens) try arena.alloc([]const u8, vocab_size) else &.{};
|
||||||
errdefer if (alloc_tokens) arena.free(tokens);
|
errdefer if (alloc_tokens) arena.free(tokens);
|
||||||
|
|
||||||
const scores: []f32 = if (alloc_tokens) try arena.alloc(f32, vocab_size) else &.{};
|
const scores: []f32 = if (alloc_tokens) try arena.alloc(f32, vocab_size) else &.{};
|
||||||
@ -91,9 +91,8 @@ pub const Tokenizer = struct {
|
|||||||
const arena = self.arena_state.allocator();
|
const arena = self.arena_state.allocator();
|
||||||
|
|
||||||
const token = try arena.alloc(u8, len);
|
const token = try arena.alloc(u8, len);
|
||||||
const n = try tok_reader.read(token);
|
const n = try tok_reader.readAll(token);
|
||||||
std.debug.assert(n == len);
|
std.debug.assert(n == len);
|
||||||
|
|
||||||
return self.addOwnedToken(score, token);
|
return self.addOwnedToken(score, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -95,7 +95,7 @@ pub fn unsqueeze(
|
|||||||
) Tensor {
|
) Tensor {
|
||||||
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK - 1, "Can't unsqueeze {}, it's already at max rank.", .{self});
|
||||||
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
const a = switch (@typeInfo(@TypeOf(axis_))) {
|
||||||
.Int, .ComptimeInt => if (axis_ < 0)
|
.int, .comptime_int => if (axis_ < 0)
|
||||||
@as(i8, self.rank()) + 1 + axis_
|
@as(i8, self.rank()) + 1 + axis_
|
||||||
else
|
else
|
||||||
self.axis(axis_),
|
self.axis(axis_),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user