runtimes/rocm: implement zmlxrocm in Zig

Also, sandbox `amdgpu.ids` and restore safetensors json parsing.
This commit is contained in:
Tarry Singh 2025-07-07 16:48:07 +00:00
parent a512b9c8a2
commit c488b634fc
17 changed files with 280 additions and 1597 deletions

View File

@ -54,7 +54,7 @@ def cc_import(
patched_name = "{}.patchelf".format(name)
patchelf(
name = patched_name,
shared_library = shared_library,
src = shared_library,
soname = soname,
add_needed = add_needed,
remove_needed = remove_needed,

View File

@ -1,5 +1,5 @@
def _patchelf_impl(ctx):
output_name = ctx.file.shared_library.basename
output_name = ctx.file.src.basename
if ctx.attr.soname:
output_name = ctx.attr.soname
output = ctx.actions.declare_file("{}/{}".format(ctx.attr.name, output_name))
@ -43,9 +43,9 @@ def _patchelf_impl(ctx):
ctx.actions.write(renamed_syms, "")
ctx.actions.run_shell(
inputs = [ctx.file.shared_library, renamed_syms],
inputs = [ctx.file.src, renamed_syms],
outputs = [output],
arguments = [ctx.executable._patchelf.path, ctx.file.shared_library.path, output.path],
arguments = [ctx.executable._patchelf.path, ctx.file.src.path, output.path],
command = "\n".join(commands),
tools = [ctx.executable._patchelf],
)
@ -59,7 +59,7 @@ def _patchelf_impl(ctx):
patchelf = rule(
implementation = _patchelf_impl,
attrs = {
"shared_library": attr.label(allow_single_file = True, mandatory = True),
"src": attr.label(allow_single_file = True, mandatory = True),
"soname": attr.string(),
"add_needed": attr.string_list(),
"remove_needed": attr.string_list(),

View File

@ -1,63 +0,0 @@
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
load("@rules_cc//cc:cc_test.bzl", "cc_test")
load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary")
def zig_cc_binary(
name,
copts = [],
args = None,
env = None,
data = [],
deps = [],
tags = [],
visibility = None,
**kwargs):
zig_binary(
name = "{}_lib".format(name),
kind = BINARY_KIND.static_lib,
copts = copts + ["-lc"],
deps = deps,
visibility = visibility,
**kwargs
)
cc_binary(
name = name,
args = args,
env = env,
data = data,
deps = [":{}_lib".format(name)],
tags = tags,
visibility = visibility,
)
def zig_cc_test(
name,
copts = [],
env = None,
data = [],
deps = [],
test_runner = None,
tags = [],
visibility = None,
**kwargs):
zig_binary(
name = "{}_test_lib".format(name),
kind = BINARY_KIND.test_lib,
test_runner = test_runner,
tags = tags,
copts = copts + ["-lc"],
deps = deps + [
"@rules_zig//zig/lib:libc",
],
visibility = visibility,
**kwargs
)
cc_test(
name = name,
env = env,
data = data,
deps = [":{}_test_lib".format(name)],
tags = tags,
visibility = visibility,
linkstatic = True,
)

View File

@ -1,9 +0,0 @@
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Dialect/Arith.h>
#include <mlir-c/Dialect/Func.h>
#include <mlir-c/Dialect/Math.h>
#include <mlir-c/Dialect/SCF.h>
#include <mlir-c/IR.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Transforms.h>

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ cc_shared_library(
patchelf(
name = "libpjrt_cuda.patchelf",
shared_library = "libpjrt_cuda.so",
src = "libpjrt_cuda.so",
add_needed = [
"libzmlxcuda.so.0",
],

View File

@ -31,7 +31,7 @@ _NEURON_PACKAGES = {
),
packages.patchelf(
name = "libnrt.patchelf",
shared_library = "lib/libnrt.so.1",
src = "lib/libnrt.so.1",
set_rpath = '$ORIGIN',
add_needed = [
# readelf -d ./opt/aws/neuron/libl/libncfw.so
@ -43,7 +43,7 @@ _NEURON_PACKAGES = {
),
packages.patchelf(
name = "libncfw.patchelf",
shared_library = "lib/libncfw.so",
src = "lib/libncfw.so",
soname = "libncfw.so.2",
),
]),

View File

@ -1,13 +1,14 @@
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("@rules_zig//zig:defs.bzl", "zig_library", "zig_shared_library")
cc_library(
name = "zmlxrocm_lib",
srcs = ["zmlxrocm.c"],
linkopts = [
"-lc",
"-ldl",
],
zig_shared_library(
name = "zmlxrocm",
main = "zmlxrocm.zig",
# Use Clang's compiler-rt, but disable stack checking
# to avoid requiring on the _zig_probe_stack symbol.
copts = ["-fno-stack-check"],
shared_lib_name = "libzmlxrocm.so.0",
deps = ["//stdx"],
visibility = ["@libpjrt_rocm//:__subpackages__"],
)
@ -51,6 +52,6 @@ zig_library(
filegroup(
name = "layers",
srcs = ["@libpjrt_rocm//:amdgpu_ids_layer"],
srcs = [],
visibility = ["//visibility:public"],
)

View File

@ -20,26 +20,12 @@ config_setting(
flag_values = {":hipblaslt": "True"},
)
cc_shared_library(
name = "zmlxrocm_so",
shared_lib_name = "lib/libzmlxrocm.so.0",
deps = ["@zml//runtimes/rocm:zmlxrocm_lib"],
)
patchelf(
name = "libpjrt_rocm.patchelf",
shared_library = "libpjrt_rocm.so",
name = "libpjrt_rocm_so",
src = "libpjrt_rocm.so",
add_needed = [
"libzmlxrocm.so.0",
# So that RPATH is taken into account.
"librocblas.so.4",
"libMIOpen.so.1",
] + select({
"_hipblaslt": [
"libhipblaslt.so.0",
],
"//conditions:default": [],
}),
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
},
@ -49,49 +35,52 @@ patchelf(
copy_to_directory(
name = "sandbox",
srcs = [
":zmlxrocm_so",
":libpjrt_rocm.patchelf",
":libpjrt_rocm_so",
"@comgr//:amd_comgr",
"@hip-runtime-amd//:amdhip_patched",
"@hip-runtime-amd//:amdhip",
"@hip-runtime-amd//:hiprtc",
"@hipblaslt//:hipblaslt",
"@hipfft",
"@hipsolver",
"@hsa-amd-aqlprofile//:hsa-amd-aqlprofile",
"@hsa-rocr//:hsa-runtime",
"@libdrm-amdgpu-amdgpu1",
"@libdrm-amdgpu-common//:amdgpu_ids",
"@libdrm2-amdgpu",
"@libelf1",
"@libnuma1",
"@libtinfo6",
"@libzstd1",
"@miopen-hip//:MIOpen",
"@rccl",
"@rocblas//:rocblas",
"@rocblas//:runfiles",
"@rocfft",
"@rocm-core",
"@rocm-device-libs//:runfiles",
"@rocm-smi-lib//:rocm_smi",
"@rocprofiler-register",
"@rocfft",
"@rocsolver",
"@roctracer",
"@roctracer//:roctx",
"@libelf1",
"@libdrm2-amdgpu",
"@libnuma1",
"@libzstd1",
"@libdrm-amdgpu-amdgpu1",
"@libtinfo6",
"@zlib1g",
"@zml//runtimes/rocm:zmlxrocm",
] + select({
":_hipblaslt": ["@hipblaslt//:runfiles"],
"//conditions:default": [],
}),
replace_prefixes = {
"libpjrt_rocm.patchelf": "lib",
"lib/x86_64-linux-gnu": "lib",
"usr/lib/x86_64-linux-gnu": "lib",
"libelf1": "lib",
"amdhip": "lib",
"hipblaslt": "lib",
"rocblas": "lib",
"opt/amdgpu/lib/x86_64-linux-gnu": "lib",
"lib/x86_64-linux-gnu": "lib",
"libdrm-amdgpu-amdgpu1": "lib",
"amdhip_patched": "lib",
"libelf1": "lib",
"libpjrt_rocm_so": "lib",
"opt/amdgpu/lib/x86_64-linux-gnu": "lib",
"opt/amdgpu/share": "share",
"rocblas": "lib",
"runtimes/rocm": "lib",
"usr/lib/x86_64-linux-gnu": "lib",
},
add_directory_to_runfiles = True,
include_external_repositories = ["**"],

View File

@ -1277,7 +1277,13 @@
},
{
"arch": "amd64",
"dependencies": [],
"dependencies": [
{
"key": "rocm-core_6.4.1.60401-83_22.04_amd64",
"name": "rocm-core",
"version": "6.4.1.60401-83~22.04"
}
],
"key": "roctracer_4.1.60401.60401-83_22.04_amd64",
"name": "roctracer",
"sha256": "58cead537cf07c8a8770bfe28346c3b3c92cc4297b51e307c9032b04434b187c",
@ -4153,6 +4159,49 @@
"https://repo.radeon.com/rocm/apt/6.4.1/pool/main/r/rocfft/rocfft_1.0.32.60401-83~22.04_amd64.deb"
],
"version": "1.0.32.60401-83~22.04"
},
{
"arch": "amd64",
"dependencies": [
{
"key": "hipblaslt_0.12.1.60401-83_22.04_amd64",
"name": "hipblaslt",
"version": "0.12.1.60401-83~22.04"
},
{
"key": "rocm-core_6.4.1.60401-83_22.04_amd64",
"name": "rocm-core",
"version": "6.4.1.60401-83~22.04"
},
{
"key": "roctracer_4.1.60401.60401-83_22.04_amd64",
"name": "roctracer",
"version": "4.1.60401.60401-83~22.04"
},
{
"key": "hipblas-common-dev_1.0.0.60401-83_22.04_amd64",
"name": "hipblas-common-dev",
"version": "1.0.0.60401-83~22.04"
}
],
"key": "hipblaslt-dev_0.12.1.60401-83_22.04_amd64",
"name": "hipblaslt-dev",
"sha256": "46eb2285c76d246b162eb54cc7f9e5cb7bcdd0aa83d57ecaea440e57260f2f4a",
"urls": [
"https://repo.radeon.com/rocm/apt/6.4.1/pool/main/h/hipblaslt-dev/hipblaslt-dev_0.12.1.60401-83~22.04_amd64.deb"
],
"version": "0.12.1.60401-83~22.04"
},
{
"arch": "amd64",
"dependencies": [],
"key": "hipblas-common-dev_1.0.0.60401-83_22.04_amd64",
"name": "hipblas-common-dev",
"sha256": "5df3e4a8a1959cbf94106f7bf87d7fb71bf06e726cde00c6092ef29bbd8156f0",
"urls": [
"https://repo.radeon.com/rocm/apt/6.4.1/pool/main/h/hipblas-common-dev/hipblas-common-dev_1.0.0.60401-83~22.04_amd64.deb"
],
"version": "1.0.0.60401-83~22.04"
}
],
"version": 1

View File

@ -35,7 +35,7 @@ packages:
- "rocsolver"
- "hipsolver"
- "hipfft"
# - "roctracer"
- "roctracer"
- "hipblaslt"
# - "hipblaslt-dev"
- "hipblaslt-dev"
- "hip-runtime-amd"

View File

@ -14,7 +14,7 @@ _UBUNTU_PACKAGES = {
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.patchelf(
name = "libelf1",
shared_library = "usr/lib/x86_64-linux-gnu/libelf.so.1",
src = "usr/lib/x86_64-linux-gnu/libelf.so.1",
set_rpath = "$ORIGIN",
),
]),
@ -25,8 +25,12 @@ _UBUNTU_PACKAGES = {
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.patchelf(
name = "libdrm-amdgpu-amdgpu1",
shared_library = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1",
src = "opt/amdgpu/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1",
add_needed = ["libzmlxrocm.so.0"],
set_rpath = "$ORIGIN",
rename_dynamic_symbols = {
"fopen64": "zmlxrocm_fopen64",
},
),
]),
"libtinfo6": packages.filegroup(name = "libtinfo6", srcs = ["lib/x86_64-linux-gnu/libtinfo.so.6"]),
@ -38,14 +42,7 @@ _ROCM_PACKAGES = {
"rocm-smi-lib": packages.filegroup(name = "rocm_smi", srcs = ["lib/librocm_smi64.so.7"]),
"hsa-rocr": packages.filegroup(name = "hsa-runtime", srcs = ["lib/libhsa-runtime64.so.1"]),
"hsa-amd-aqlprofile": packages.filegroup(name = "hsa-amd-aqlprofile", srcs = ["lib/libhsa-amd-aqlprofile64.so.1"]),
"comgr": "\n".join([
packages.filegroup(
name = "amd_comgr",
srcs = [
"lib/libamd_comgr.so.3",
],
),
]),
"comgr": packages.filegroup(name = "amd_comgr", srcs = ["lib/libamd_comgr.so.3"]),
"rocprofiler-register": packages.filegroup(name = "rocprofiler-register", srcs = ["lib/librocprofiler-register.so.0"]),
"miopen-hip": "\n".join([
packages.filegroup(name = "MIOpen", srcs = ["lib/libMIOpen.so.1"]),
@ -59,7 +56,7 @@ _ROCM_PACKAGES = {
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
packages.patchelf(
name = "rocblas",
shared_library = "lib/librocblas.so.4",
src = "lib/librocblas.so.4",
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
@ -90,7 +87,7 @@ _ROCM_PACKAGES = {
packages.load_("@zml//runtimes/rocm:gfx.bzl", "bytecode_select"),
packages.patchelf(
name = "hipblaslt",
shared_library = "lib/libhipblaslt.so.0",
src = "lib/libhipblaslt.so.0",
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",
@ -116,10 +113,9 @@ _ROCM_PACKAGES = {
"hipfft": packages.filegroup(name = "hipfft", srcs = ["lib/libhipfft.so.0"]),
"hip-runtime-amd": "\n".join([
packages.load_("@zml//bazel:patchelf.bzl", "patchelf"),
packages.filegroup(name = "amdhip", srcs = ["lib/libamdhip64.so.6"]),
packages.patchelf(
name = "amdhip_patched",
shared_library = ":amdhip",
name = "amdhip",
src = "lib/libamdhip64.so.6",
add_needed = ["libzmlxrocm.so.0"],
rename_dynamic_symbols = {
"dlopen": "zmlxrocm_dlopen",

View File

@ -1,5 +1,5 @@
const builtin = @import("builtin");
const std = @import("std");
const builtin = @import("builtin");
const asynk = @import("async");
const bazel_builtin = @import("bazel_builtin");
@ -10,20 +10,6 @@ const stdx = @import("stdx");
const log = std.log.scoped(.@"zml/runtime/rocm");
const ROCmEnvEntry = struct {
name: [:0]const u8,
rpath: []const u8,
dirname: bool,
mandatory: bool,
};
const rocm_env_entries: []const ROCmEnvEntry = &.{
.{ .name = "HIPBLASLT_EXT_OP_LIBRARY_PATH", .rpath = "/lib/hipblaslt/library/hipblasltExtOpLibrary.dat", .dirname = false, .mandatory = false },
.{ .name = "HIPBLASLT_TENSILE_LIBPATH", .rpath = "/lib/hipblaslt/library/TensileManifest.txt", .dirname = true, .mandatory = false },
.{ .name = "ROCBLAS_TENSILE_LIBPATH", .rpath = "/lib/rocblas/library/TensileManifest.txt", .dirname = true, .mandatory = true },
.{ .name = "ROCM_PATH", .rpath = "/", .dirname = false, .mandatory = true },
};
pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_ROCM");
}
@ -35,23 +21,9 @@ fn hasRocmDevices() bool {
return true;
}
fn setupRocmEnv(allocator: std.mem.Allocator, rocm_data_dir: []const u8) !void {
for (rocm_env_entries) |entry| {
var real_path: []const u8 = std.fmt.allocPrintZ(allocator, "{s}/{s}", .{ rocm_data_dir, entry.rpath }) catch null orelse {
if (entry.mandatory) {
stdx.debug.panic("Unable to find {s} in {s}\n", .{ entry.name, bazel_builtin.current_repository });
}
continue;
};
if (entry.dirname) {
real_path = std.fs.path.dirname(real_path) orelse {
stdx.debug.panic("Unable to dirname on {s}", .{real_path});
};
}
_ = c.setenv(entry.name, try allocator.dupeZ(u8, real_path), 1);
}
fn setupRocmEnv(rocm_data_dir: []const u8) !void {
var buf: [std.fs.max_path_bytes]u8 = undefined;
_ = c.setenv("ROCM_PATH", try stdx.fs.path.bufJoinZ(&buf, &.{rocm_data_dir}), 1); // must be zero terminated
}
pub fn load() !*const pjrt.Api {
@ -81,7 +53,7 @@ pub fn load() !*const pjrt.Api {
return error.FileNotFound;
};
try setupRocmEnv(arena.allocator(), sandbox_path);
try setupRocmEnv(sandbox_path);
var lib_path_buf: [std.fs.max_path_bytes]u8 = undefined;
const lib_path = try stdx.fs.path.bufJoinZ(&lib_path_buf, &.{ sandbox_path, "lib", "libpjrt_rocm.so" });

View File

@ -1,52 +0,0 @@
#include <dlfcn.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>
void *zmlxrocm_dlopen(const char *filename, int flags) __attribute__((visibility("default")))
{
if (filename != NULL)
{
char *replacements[] = {
"librocm-core.so",
"librocm-core.so.1",
"librocm_smi64.so",
"librocm_smi64.so.7",
"libhsa-runtime64.so",
"libhsa-runtime64.so.1",
"libhsa-amd-aqlprofile64.so",
"libhsa-amd-aqlprofile64.so.1",
"libamd_comgr.so",
"libamd_comgr.so.3",
"librocprofiler-register.so",
"librocprofiler-register.so.0",
"libMIOpen.so",
"libMIOpen.so.1",
"librccl.so",
"librccl.so.1",
"librocblas.so",
"librocblas.so.4",
"libroctracer64.so",
"libroctracer64.so.4",
"libroctx64.so",
"libroctx64.so.4",
"libhipblaslt.so",
"libhipblaslt.so.0",
"libamdhip64.so",
"libamdhip64.so.6",
"libhiprtc.so",
"libhiprtc.so.6",
NULL,
NULL,
};
for (int i = 0; replacements[i] != NULL; i += 2)
{
if (strcmp(filename, replacements[i]) == 0)
{
filename = replacements[i + 1];
break;
}
}
}
return dlopen(filename, flags);
}

View File

@ -0,0 +1,50 @@
const std = @import("std");
const stdx = @import("stdx");
pub export fn zmlxrocm_dlopen(filename: [*c]const u8, flags: c_int) ?*anyopaque {
const replacements: std.StaticStringMap([:0]const u8) = .initComptime(.{
.{ "librocm-core.so", "librocm-core.so.1" },
.{ "librocm_smi64.so", "librocm_smi64.so.7" },
.{ "libhsa-runtime64.so", "libhsa-runtime64.so.1" },
.{ "libhsa-amd-aqlprofile64.so", "libhsa-amd-aqlprofile64.so.1" },
.{ "libamd_comgr.so", "libamd_comgr.so.3" },
.{ "librocprofiler-register.so", "librocprofiler-register.so.0" },
.{ "libMIOpen.so", "libMIOpen.so.1" },
.{ "librccl.so", "librccl.so.1" },
.{ "librocblas.so", "librocblas.so.4" },
.{ "libroctracer64.so", "libroctracer64.so.4" },
.{ "libroctx64.so", "libroctx64.so.4" },
.{ "libhipblaslt.so", "libhipblaslt.so.0" },
.{ "libamdhip64.so", "libamdhip64.so.6" },
.{ "libhiprtc.so", "libhiprtc.so.6" },
});
var buf: [std.fs.max_path_bytes]u8 = undefined;
const new_filename: [*c]const u8 = if (filename) |f| blk: {
const replacement = replacements.get(std.fs.path.basename(std.mem.span(f))) orelse break :blk f;
break :blk stdx.fs.path.bufJoinZ(&buf, &.{
stdx.fs.selfSharedObjectDirPath(),
replacement,
}) catch unreachable;
} else null;
return std.c.dlopen(new_filename, @bitCast(flags));
}
pub export fn zmlxrocm_fopen64(pathname: [*c]const u8, mode: [*c]const u8) ?*std.c.FILE {
const replacements: std.StaticStringMap([]const u8) = .initComptime(.{
.{ "/opt/amdgpu/share/libdrm/amdgpu.ids", "../share/libdrm/amdgpu.ids" },
});
var buf: [std.fs.max_path_bytes]u8 = undefined;
const new_pathname: [*c]const u8 = blk: {
const replacement = replacements.get(std.mem.span(pathname)) orelse break :blk pathname;
break :blk stdx.fs.path.bufJoinZ(&buf, &.{
stdx.fs.selfSharedObjectDirPath(),
replacement,
}) catch unreachable;
};
return std.c.fopen64(new_pathname, mode);
}

107
zml/aio/json.zig Normal file
View File

@ -0,0 +1,107 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("../zml.zig");
const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
const file = try std.fs.cwd().openFile(path, .{});
defer file.close();
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit();
const arena = res.arena.allocator();
const json_data = try file.reader().readAllAlloc(arena, (try file.metadata()).size());
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
var it = metadata.object.iterator();
while (it.next()) |entry| {
var prefix_buf: [1024]u8 = undefined;
try parseMetadata(allocator, &res, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*);
}
return res;
}
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
const metadata = &store._metadata;
const key = prefix.items;
return switch (val) {
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .null),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
.array => |v| {
if (v.items.len == 0) return;
return if (validSlice(v)) |item_type| {
const data: zml.aio.Metadata = switch (item_type) {
.bool => blk: {
const values = try allocator.alloc(bool, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.bool;
break :blk .{ .array_bool = values };
},
.integer => blk: {
const values = try allocator.alloc(i64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.integer;
break :blk .{ .array_int = values };
},
.float => blk: {
const values = try allocator.alloc(f64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.float;
break :blk .{ .array_float = values };
},
inline .string, .number_string => |tag| blk: {
const values = try allocator.alloc([]const u8, v.items.len);
for (v.items, 0..) |item, i| {
values[i] = @field(item, @tagName(tag));
}
break :blk .{ .array_string = values };
},
.null, .array, .object => unreachable,
};
try metadata.put(allocator, try allocator.dupe(u8, key), data);
} else {
for (v.items, 0..) |item, i| {
var new_prefix = prefix;
if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.');
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try parseMetadata(allocator, store, new_prefix, item);
}
};
},
.object => |v| {
var obj_iter = v.iterator();
while (obj_iter.next()) |entry| {
var new_prefix = prefix;
if (prefix.items.len > 0)
new_prefix.appendAssumeCapacity('.');
new_prefix.appendSliceAssumeCapacity(entry.key_ptr.*);
try parseMetadata(allocator, store, new_prefix, entry.value_ptr.*);
}
},
};
}
/// We can only create a Zig slice out of json array, if all values
/// in the array have the same type.
fn validSlice(v: std.json.Array) ?std.meta.Tag(std.json.Value) {
if (v.items.len == 0) return null;
const item_type: std.meta.Tag(std.json.Value) = v.items[0];
switch (item_type) {
.null, .array, .object => return null,
else => {},
}
for (v.items[1..]) |item| {
if (item != item_type)
return null;
}
return item_type;
}

View File

@ -1,13 +1,12 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
const asynk = @import("async");
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
const std = @import("std");
const zml = @import("../zml.zig");
const json = @import("json.zig");
const HostBuffer = zml.HostBuffer;
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
const StringBuilder = std.ArrayListUnmanaged(u8);
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
@ -56,6 +55,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename });
try loadFile(allocator, store, files, full_filename);
}
if (index.object.get("__metadata__")) |metadata| {
var prefix_buf: [1024]u8 = undefined;
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata);
}
}
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
@ -85,6 +89,11 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.Array
var it = metadata.object.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.eql(u8, key, "__metadata__")) {
var prefix_buf: [1024]u8 = undefined;
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*);
continue;
}
const val = entry.value_ptr.*;
const shape_field = val.object.get("shape").?.array;
if (shape_field.items.len > zml.Shape.MAX_RANK) {