Implement buffer‑ID based loading by moving tensor ID handling into BufferStore, fix zml.call tag hashing, and expose CPU device count.

This commit is contained in:
Tarry Singh 2025-08-28 14:39:21 +00:00
parent 6e7617918d
commit 7913c00d70
13 changed files with 272 additions and 131 deletions

View File

@ -1,5 +1,6 @@
const std = @import("std"); const std = @import("std");
/// Properly format a slice of numbers.
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) { pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
return .{ .slice = any_slice }; return .{ .slice = any_slice };
} }
@ -115,3 +116,57 @@ pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatBool, values, spec, writer); return try formatSliceCustom(formatBool, values, spec, writer);
} }
/// Format a struct using `format` method of subfields when possible.
pub fn any(any_val: anytype) FmtAny(@TypeOf(any_val)) {
return .{ .data = any_val };
}
fn FmtAny(Data: type) type {
return struct {
data: Data,
pub inline fn format(self: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
try printValue(writer, .{}, self.data, std.options.fmt_max_depth);
}
};
}
/// Fix up of std.io.Writer.printValue that uses `format` method of subfields when possible
fn printValue(
w: *std.io.Writer,
options: std.fmt.Options,
value: anytype,
max_depth: usize,
) std.io.Writer.Error!void {
const T = @TypeOf(value);
if (std.meta.hasMethod(T, "format")) {
return try value.format(w);
}
if (std.meta.hasMethod(T, "formatNumber")) {
return try value.formatNumber(w, options.toNumber(.decimal, .lower));
}
if (max_depth == 0) {
try w.writeAll(".{ ... }");
return;
}
switch (@typeInfo(T)) {
.@"struct" => |info| {
try w.writeAll(".{ ");
inline for (info.fields, 0..) |f, i| {
if (i > 0) try w.writeAll(", ");
if (!info.is_tuple) {
try w.writeByte('.');
try w.writeAll(f.name);
try w.writeAll(" = ");
}
try printValue(w, options, @field(value, f.name), max_depth - 1);
}
try w.writeAll(" }");
},
inline else => try w.printValue("any", options, value, max_depth - 1),
}
}

View File

@ -1,3 +1,6 @@
const std = @import("std");
const builtin = @import("builtin");
pub const BoundedArray = @import("bounded_array.zig").BoundedArray; pub const BoundedArray = @import("bounded_array.zig").BoundedArray;
pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned; pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned;
pub const debug = @import("debug.zig"); pub const debug = @import("debug.zig");
@ -11,7 +14,6 @@ pub const queue = @import("queue.zig");
pub const time = @import("time.zig"); pub const time = @import("time.zig");
test { test {
const std = @import("std");
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
@ -20,3 +22,5 @@ pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T {
var storage: [max_len]T = undefined; var storage: [max_len]T = undefined;
return storage[0..len]; return storage[0..len];
} }
pub const noalloc: std.mem.Allocator = if (builtin.mode == .ReleaseFast) undefined else std.testing.failing_allocator;

View File

@ -43,8 +43,8 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
/// whose shape is read from the "a.b" tensor. /// whose shape is read from the "a.b" tensor.
/// * If `Model` contains a list of layers, then the field: /// * If `Model` contains a list of layers, then the field:
/// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor. /// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor.
pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, buffer_store: BufferStore) !Model { pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, store: BufferStore) !Model {
return populateModelWithPrefix(Model, allocator, buffer_store, ""); return populateModelWithPrefix(Model, allocator, store, "");
} }
/// Creates a Model struct with tensor shapes read from the given TensorStore, /// Creates a Model struct with tensor shapes read from the given TensorStore,
@ -62,8 +62,7 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
try prefix_builder.push(allocator, prefix); try prefix_builder.push(allocator, prefix);
defer prefix_builder.deinit(allocator); defer prefix_builder.deinit(allocator);
const unique_id = zml.Tensor._reserveIdRange(@intCast(store.buffers.count())); const ok = _populateStruct(allocator, &prefix_builder, store, &model, true) catch |err| {
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) }); std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
}; };
if (!ok) return error.TensorNotFound; if (!ok) return error.TensorNotFound;
@ -74,17 +73,28 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
pub const BufferStore = struct { pub const BufferStore = struct {
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer); pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata); pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
var _unique_store_id: std.atomic.Value(u64) = .init(0);
const _store_id_range: u64 = 1024 * 1024 * 1024;
arena: std.heap.ArenaAllocator, arena: std.heap.ArenaAllocator,
files: []MemoryMappedFile = &.{}, files: []MemoryMappedFile = &.{},
buffers: Buffers = .{}, buffers: Buffers = .{},
_metadata: Metadatas = .{}, _metadata: Metadatas = .{},
_unique_id: u64,
/// Create an empty BufferStore. Takes owneship of the given files. /// Create an empty BufferStore.
pub fn init(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore { /// Takes owneship of the given files.
var self: zml.aio.BufferStore = .{ pub fn init(allocator: std.mem.Allocator) BufferStore {
return .{
.arena = std.heap.ArenaAllocator.init(allocator), .arena = std.heap.ArenaAllocator.init(allocator),
._unique_id = _unique_store_id.fetchAdd(_store_id_range, .monotonic),
}; };
}
/// Create an empty BufferStore.
/// Takes owneship of the given files.
pub fn initWithFiles(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore {
var self: BufferStore = .init(allocator);
self.files = try self.arena.allocator().dupe(MemoryMappedFile, files); self.files = try self.arena.allocator().dupe(MemoryMappedFile, files);
return self; return self;
} }
@ -100,6 +110,61 @@ pub const BufferStore = struct {
return self.buffers.get(key); return self.buffers.get(key);
} }
pub fn loadBufferById(self: BufferStore, x: zml.Tensor, platform: zml.Platform) !zml.Buffer {
var host_buffer: zml.HostBuffer = switch (x._id) {
.buffer_id => |id| hb: {
if (id < self._unique_id or self._unique_id + _store_id_range <= id) {
@panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object.");
}
break :hb self.buffers.values()[id - self._unique_id];
},
else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"),
};
// Use the sharding information stored in the tensor.
host_buffer._shape = x.shape();
return try host_buffer.toDevice(platform);
}
/// Creates a bufferized version of a model from the given BufferStore.
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadModelById(self: BufferStore, Model: type, allocator: std.mem.Allocator, model: Model, platform: zml.Platform) !zml.Bufferized(Model) {
const Ctx = struct {
platform: *const zml.Platform,
store: *const BufferStore,
pub fn cb(ctx: @This(), x: zml.Tensor) zml.Buffer {
return ctx.store.loadBufferById(x, ctx.platform.*) catch @panic("Failed to load buffer to device");
}
};
var res: zml.Bufferized(Model) = undefined;
try zml.meta.mapAlloc(Ctx.cb, allocator, .{ .platform = &platform, .store = &self }, model, &res);
return res;
}
pub fn getTensor(self: BufferStore, key: []const u8) zml.Tensor {
return self.getTensorOrNull(key) orelse {
log.err("Tensor not found: {s}", .{key});
self.findSimilarBufferKeys(std.heap.smp_allocator, key);
@panic("Tensor not found");
};
}
pub fn getTensorOrNull(self: BufferStore, key: []const u8) ?zml.Tensor {
return if (self.buffers.getIndex(key)) |entry_idx|
.{
._shape = self.buffers.values()[entry_idx].shape(),
._id = .{ .buffer_id = self._unique_id + entry_idx },
._donation = .input_buffer,
}
else
return null;
}
/// Count layers starting with the given prefix. /// Count layers starting with the given prefix.
pub fn countLayers(self: BufferStore, prefix: []const u8) usize { pub fn countLayers(self: BufferStore, prefix: []const u8) usize {
// Note: This is kinda inefficient // Note: This is kinda inefficient
@ -140,6 +205,58 @@ pub const BufferStore = struct {
return null; return null;
} }
/// Assists in debuggigng `BufferNotFound` error
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
pub fn findSimilarBufferKeys(store: BufferStore, tmp_alloc: std.mem.Allocator, original_key: []const u8) void {
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
var shown_keys = std.StringHashMap(void).init(tmp_alloc);
defer shown_keys.deinit();
// remove suffix .weight and .bias
var base_key = original_key;
for (suffixes) |suffix| {
if (std.mem.endsWith(u8, original_key, suffix)) {
base_key = original_key[0 .. original_key.len - suffix.len];
break;
}
}
// first test: look for exact matches
var matches: usize = 0;
var it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.startsWith(u8, key, base_key)) {
if (matches == 0) log.warn("Similar buffers found:", .{});
if (!shown_keys.contains(key)) {
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
}
}
}
// second test: progressive partial matches
if (matches == 0) {
var components = std.mem.splitScalar(u8, base_key, '.');
while (components.next()) |component| {
matches = 0;
it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
if (matches >= 5) break;
}
}
if (matches > 0) break;
}
}
}
}; };
pub const Metadata = union(enum) { pub const Metadata = union(enum) {
@ -260,17 +377,35 @@ pub const MemoryMappedFile = struct {
/// Helper handling prefix building. /// Helper handling prefix building.
/// ///
/// This allows to easily push/pop prefixes and handles the generation of the string with the correct format. /// This allows to easily push/pop prefixes and handles the generation of the string with the correct format.
const PrefixBuilder = struct { pub const PrefixBuilder = struct {
/// Stores the computed prefix. /// Stores the computed prefix.
data: std.ArrayListUnmanaged(u8) = .{}, data: std.ArrayList(u8) = .{},
/// Stack storing the size of the intermediary prefix. /// Stack storing the size of the intermediary prefix.
subprefixes: std.ArrayListUnmanaged(u32) = .{}, subprefixes: std.ArrayList(u32) = .{},
pub fn initCapacity(allocator: std.mem.Allocator, capacity: usize) !PrefixBuilder {
return .{
.data = try .initCapacity(allocator, capacity),
.subprefixes = try .initCapacity(allocator, @divFloor(capacity, 4)),
};
}
pub fn deinit(self: *PrefixBuilder, allocator: std.mem.Allocator) void { pub fn deinit(self: *PrefixBuilder, allocator: std.mem.Allocator) void {
self.data.deinit(allocator); self.data.deinit(allocator);
self.subprefixes.deinit(allocator); self.subprefixes.deinit(allocator);
} }
pub fn items(self: PrefixBuilder) []const u8 {
return self.data.items;
}
pub fn concat(self: *PrefixBuilder, prefix: []const u8) []const u8 {
self.push(stdx.noalloc, prefix) catch unreachable;
const res = self.items();
self.pop();
return res;
}
pub fn push(self: *PrefixBuilder, allocator: std.mem.Allocator, prefix: []const u8) !void { pub fn push(self: *PrefixBuilder, allocator: std.mem.Allocator, prefix: []const u8) !void {
const old_len: u32 = @intCast(self.data.items.len); const old_len: u32 = @intCast(self.data.items.len);
try self.subprefixes.append(allocator, old_len); try self.subprefixes.append(allocator, old_len);
@ -301,13 +436,20 @@ const PrefixBuilder = struct {
const last_prefix_len = self.subprefixes.pop() orelse unreachable; const last_prefix_len = self.subprefixes.pop() orelse unreachable;
self.data.shrinkRetainingCapacity(last_prefix_len); self.data.shrinkRetainingCapacity(last_prefix_len);
} }
pub fn checkpoint(self: PrefixBuilder) [2]usize {
return .{ self.data.items.len, self.subprefixes.items.len };
}
pub fn restore(self: *PrefixBuilder, ckpt: [2]usize) void {
self.data.items.len, self.subprefixes.items.len = ckpt;
}
}; };
fn _populateStruct( fn _populateStruct(
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
prefix_builder: *PrefixBuilder, prefix_builder: *PrefixBuilder,
unique_id: u64, store: BufferStore,
buffer_store: BufferStore,
obj: anytype, obj: anytype,
required: bool, required: bool,
) !bool { ) !bool {
@ -322,17 +464,12 @@ fn _populateStruct(
const prefix = prefix_builder.data.items; const prefix = prefix_builder.data.items;
if (T == zml.Tensor) { if (T == zml.Tensor) {
return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| { return if (store.getTensorOrNull(prefix)) |tensor| {
const buffer = buffer_store.get(prefix).?; obj.* = tensor;
obj.* = zml.Tensor{
._shape = buffer.shape(),
._id = .{ .buffer_id = unique_id + entry_idx },
._donation = .input_buffer,
};
return true; return true;
} else { } else {
if (required) { if (required) {
log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() }); log.err("Tensor not found: {s} ({d})", .{ prefix, store.buffers.count() });
} }
return false; return false;
}; };
@ -343,14 +480,14 @@ fn _populateStruct(
if (ptr_info.size == .slice) { if (ptr_info.size == .slice) {
obj.* = &.{}; obj.* = &.{};
const len = buffer_store.countLayers(prefix); const len = store.countLayers(prefix);
if (len > 0) { if (len > 0) {
obj.* = try allocator.alloc(ptr_info.child, len); obj.* = try allocator.alloc(ptr_info.child, len);
for (obj.*, 0..) |*value, i| { for (obj.*, 0..) |*value, i| {
try prefix_builder.pushDigit(allocator, i); try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop(); defer prefix_builder.pop();
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); const found = try _populateStruct(allocator, prefix_builder, store, value, required);
if (!found) { if (!found) {
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) }); log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
return false; return false;
@ -369,7 +506,7 @@ fn _populateStruct(
for (obj, 0..) |*value, i| { for (obj, 0..) |*value, i| {
try prefix_builder.pushDigit(allocator, i); try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop(); defer prefix_builder.pop();
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); const found = try _populateStruct(allocator, prefix_builder, store, value, required);
if (!found) { if (!found) {
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) }); log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) });
return false; return false;
@ -386,15 +523,20 @@ fn _populateStruct(
var has_default = false; var has_default = false;
if (field.default_value_ptr) |_| has_default = true; if (field.default_value_ptr) |_| has_default = true;
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
partial_struct = partial_struct or field_found; if (zml.meta.Contains(field.type, zml.Tensor)) {
if (!field_found) { const field_found = try _populateStruct(allocator, prefix_builder, store, &@field(obj, field.name), required and !has_default);
if (field.default_value_ptr) |v| { partial_struct = partial_struct or field_found;
@field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*; if (!field_found) {
} else { if (field.default_value_ptr) |v| {
if (partial_struct) { @field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name }); } else {
obj.* = undefined; if (partial_struct) {
log.warn("Incomplete struct '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });
obj.* = undefined;
return false;
}
return false; return false;
} }
@ -406,7 +548,7 @@ fn _populateStruct(
}, },
.optional => |opt_info| { .optional => |opt_info| {
obj.* = @as(opt_info.child, undefined); obj.* = @as(opt_info.child, undefined);
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false); const found = try _populateStruct(allocator, prefix_builder, store, &(obj.*.?), false);
if (!found) obj.* = null; if (!found) obj.* = null;
return true; return true;
}, },
@ -449,7 +591,7 @@ test populateModel {
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena_state.deinit(); defer arena_state.deinit();
var store: BufferStore = .{ .arena = arena_state }; var store: BufferStore = .init(arena_state.allocator());
try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16); try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16);
store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10)); store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10));
store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20)); store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20));
@ -494,11 +636,11 @@ test populateModel {
pub fn loadBuffers( pub fn loadBuffers(
comptime Model: type, comptime Model: type,
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void, init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
buffer_store: BufferStore, store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
) !zml.Bufferized(Model) { ) !zml.Bufferized(Model) {
return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, ""); return loadBuffersWithPrefix(Model, init_args, store, allocator, platform, "");
} }
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix. /// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
@ -512,7 +654,7 @@ pub fn loadBuffers(
pub fn loadBuffersWithPrefix( pub fn loadBuffersWithPrefix(
comptime Model: type, comptime Model: type,
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void, init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
buffer_store: BufferStore, store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
prefix: []const u8, prefix: []const u8,
@ -522,14 +664,14 @@ pub fn loadBuffersWithPrefix(
const arena = arena_state.allocator(); const arena = arena_state.allocator();
// Get model structure with tensor shapes from the buffer store with prefix // Get model structure with tensor shapes from the buffer store with prefix
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix); var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, store, prefix);
// If the Model has a "init" function, call it with the given parameters. // If the Model has a "init" function, call it with the given parameters.
if (@hasDecl(Model, "init")) { if (@hasDecl(Model, "init")) {
@call(.auto, Model.init, .{&model} ++ init_args); @call(.auto, Model.init, .{&model} ++ init_args);
} }
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix); return loadModelBuffersWithPrefix(Model, model, store, allocator, platform, prefix);
} }
/// Creates a bufferized version of a Model from the given BufferStore. For details about /// Creates a bufferized version of a Model from the given BufferStore. For details about
@ -541,11 +683,11 @@ pub fn loadBuffersWithPrefix(
pub fn loadModelBuffers( pub fn loadModelBuffers(
comptime Model: type, comptime Model: type,
model: Model, model: Model,
buffer_store: BufferStore, store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
) !zml.Bufferized(Model) { ) !zml.Bufferized(Model) {
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); return try loadModelBuffersWithPrefix(Model, model, store, allocator, platform, "");
} }
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix. /// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
@ -557,7 +699,7 @@ pub fn loadModelBuffers(
pub fn loadModelBuffersWithPrefix( pub fn loadModelBuffersWithPrefix(
comptime Model: type, comptime Model: type,
model: Model, model: Model,
buffer_store: BufferStore, store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
prefix: []const u8, prefix: []const u8,
@ -576,7 +718,7 @@ pub fn loadModelBuffersWithPrefix(
try prefix_builder.push(allocator, prefix); try prefix_builder.push(allocator, prefix);
defer prefix_builder.deinit(allocator); defer prefix_builder.deinit(allocator);
try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform); try visitStructAndLoadBuffer(allocator, &prefix_builder, store, &res, platform);
return res; return res;
} }
@ -590,58 +732,6 @@ pub fn unloadBuffers(model: anytype) void {
}).cb, {}, model); }).cb, {}, model);
} }
/// Assists in debuggigng `BufferNotFound` error
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void {
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
var shown_keys = std.StringHashMap(void).init(temp_allocator);
defer shown_keys.deinit();
// remove suffix .weight and .bias
var base_key = original_key;
for (suffixes) |suffix| {
if (std.mem.endsWith(u8, original_key, suffix)) {
base_key = original_key[0 .. original_key.len - suffix.len];
break;
}
}
// first test: look for exact matches
var matches: usize = 0;
var it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.startsWith(u8, key, base_key)) {
if (matches == 0) log.warn("Similar buffers found:", .{});
if (!shown_keys.contains(key)) {
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
}
}
}
// second test: progressive partial matches
if (matches == 0) {
var components = std.mem.splitScalar(u8, base_key, '.');
while (components.next()) |component| {
matches = 0;
it = store.buffers.iterator();
while (it.next()) |entry| {
const key = entry.key_ptr.*;
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
shown_keys.put(key, {}) catch continue;
matches += 1;
if (matches >= 5) break;
}
}
if (matches > 0) break;
}
}
}
/// deinit all buffers in the given struct /// deinit all buffers in the given struct
pub fn awaitAll(buffers: anytype) !void { pub fn awaitAll(buffers: anytype) !void {
zml.meta.visit((struct { zml.meta.visit((struct {
@ -651,7 +741,7 @@ pub fn awaitAll(buffers: anytype) !void {
}).cb, {}, buffers); }).cb, {}, buffers);
} }
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void { fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, store: BufferStore, obj: anytype, platform: zml.Platform) !void {
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received "; const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) { const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
.pointer => |ptr_info| switch (ptr_info.size) { .pointer => |ptr_info| switch (ptr_info.size) {
@ -663,7 +753,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
const prefix = prefix_builder.data.items; const prefix = prefix_builder.data.items;
if (T == zml.Buffer) { if (T == zml.Buffer) {
return if (buffer_store.get(prefix)) |host_buffer| { return if (store.get(prefix)) |host_buffer| {
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer; var buf_with_metadata = host_buffer;
log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape }); log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape });
@ -673,7 +763,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
} else { } else {
log.err("Buffer not found: {s}", .{prefix}); log.err("Buffer not found: {s}", .{prefix});
findSimilarBufferKeys(prefix, buffer_store, allocator); store.findSimilarBufferKeys(allocator, prefix);
return error.BufferNotFound; return error.BufferNotFound;
}; };
@ -686,7 +776,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
try prefix_builder.pushDigit(allocator, i); try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop(); defer prefix_builder.pop();
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform); try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform);
} }
} else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T}); } else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
}, },
@ -694,7 +784,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
for (obj, 0..) |*value, i| { for (obj, 0..) |*value, i| {
try prefix_builder.pushDigit(allocator, i); try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop(); defer prefix_builder.pop();
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform); try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform);
} }
}, },
@ -704,12 +794,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
try prefix_builder.push(allocator, field.name); try prefix_builder.push(allocator, field.name);
defer prefix_builder.pop(); defer prefix_builder.pop();
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform); try visitStructAndLoadBuffer(allocator, prefix_builder, store, &@field(obj, field.name), platform);
} }
}, },
.optional => { .optional => {
if (obj.*) |*obj_val| { if (obj.*) |*obj_val| {
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform); try visitStructAndLoadBuffer(allocator, prefix_builder, store, obj_val, platform);
} }
}, },
else => {}, else => {},

View File

@ -13,9 +13,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.@"zml/io"); const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{ var res: zml.aio.BufferStore = .init(allocator);
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit(); errdefer res.arena.deinit();
const arena = res.arena.allocator(); const arena = res.arena.allocator();

View File

@ -25,9 +25,7 @@ const TinyLlamaConfig = extern struct {
/// For convenience we use the same layer names /// For convenience we use the same layer names
/// than the one used by the Llama-3.1 models. /// than the one used by the Llama-3.1 models.
pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.BufferStore {
var res: zml.aio.BufferStore = .{ var res: zml.aio.BufferStore = .init(allocator);
.arena = std.heap.ArenaAllocator.init(allocator),
};
errdefer res.arena.deinit(); errdefer res.arena.deinit();
const arena = res.arena.allocator(); const arena = res.arena.allocator();

View File

@ -36,7 +36,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const py_values = try eval.evaluate(tmp_alloc, ops, true); const py_values = try eval.evaluate(tmp_alloc, ops, true);
// file ownership is transferred to the BufferStore // file ownership is transferred to the BufferStore
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file}); var res = try zml.aio.BufferStore.initWithFiles(allocator, &.{mmap_file});
try torch_file.parseModel(py_values, &res); try torch_file.parseModel(py_values, &res);
return res; return res;
} }

View File

@ -480,7 +480,7 @@ test "Read pickle (zipped)" {
// torch.save({ "model": model, "tensor": tensor}, "simple.pt") // torch.save({ "model": model, "tensor": tensor}, "simple.pt")
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only }); const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
const mmap_file = try zml.aio.MemoryMappedFile.init(file); const mmap_file = try zml.aio.MemoryMappedFile.init(file);
var store = try zml.aio.BufferStore.init(testing.allocator, &.{mmap_file}); var store = try zml.aio.BufferStore.initWithFiles(testing.allocator, &.{mmap_file});
defer store.deinit(); defer store.deinit();
{ {

View File

@ -77,7 +77,7 @@ pub fn compileModel(
) !FnExe(func) { ) !FnExe(func) {
const ModelT = ModuleSignature(func).ModelT; const ModelT = ModuleSignature(func).ModelT;
const name = @typeName(ModelT) ++ ".forward"; const name = @typeName(ModelT) ++ ".forward";
log.info("Compiling {s} with {}", .{ name, args_shapes }); log.info("Compiling {s} with {f}", .{ name, stdx.fmt.any(args_shapes) });
var context = try CompilationContext.init(allocator, name, platform); var context = try CompilationContext.init(allocator, name, platform);
defer context.deinit(); defer context.deinit();
@ -209,9 +209,9 @@ pub const BaseExe = struct {
var execute_context: ?*pjrt.ExecuteContext = null; var execute_context: ?*pjrt.ExecuteContext = null;
if (platform.pjrt_api.ffi()) |ffi| { if (platform.pjrt_api.ffi()) |ffi| {
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
execute_context = try platform.pjrt_api.createExecuteContext(); execute_context = try platform.pjrt_api.createExecuteContext();
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?); try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
// log.info("Created context execution {*} for {*}", .{ execute_context, exe });
} }
return .{ return .{

View File

@ -141,6 +141,7 @@ test MapType {
/// Any `To` struct inside `from` will be copied over to the target. /// Any `To` struct inside `from` will be copied over to the target.
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void { pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
// TODO: handle tuple to slice conversion // TODO: handle tuple to slice conversion
// TODO: handle error bubbling up
const From = FnParam(cb, 1); const From = FnParam(cb, 1);
const To = stdx.meta.FnResult(cb); const To = stdx.meta.FnResult(cb);
const FromStruct = @TypeOf(from); const FromStruct = @TypeOf(from);

View File

@ -1129,7 +1129,7 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void {
hash(hasher, shape._dtype, .Shallow); hash(hasher, shape._dtype, .Shallow);
hash(hasher, shape._sharding_info, .Shallow); hash(hasher, shape._sharding_info, .Shallow);
for (shape.tags()) |tag| { for (shape.tags()) |tag| {
hash(hasher, @intFromPtr(tag), .Shallow); hashArray(hasher, std.mem.span(tag), .Shallow);
} }
} }

View File

@ -82,9 +82,7 @@ pub const Platform = struct {
}; };
const _CreateOptions = struct { const _CreateOptions = struct {
// XLA CPU client doesn't read options cpu: Cpu = .{ .device_count = 4 },
// https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46
cpu: struct {} = .{},
// bump memory fraction from XLA defaults of 75% to 90%. // bump memory fraction from XLA defaults of 75% to 90%.
// Even on a 8GB GPU it should leave enough space for the Cuda driver // Even on a 8GB GPU it should leave enough space for the Cuda driver
@ -94,6 +92,14 @@ const _CreateOptions = struct {
tpu: struct {} = .{}, tpu: struct {} = .{},
neuron: struct {} = .{}, neuron: struct {} = .{},
pub const Cpu = struct {
device_count: u32,
fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
}
};
pub const Cuda = struct { pub const Cuda = struct {
allocator: Allocator = .{ .bfc = .{} }, allocator: Allocator = .{ .bfc = .{} },
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86 // TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
@ -118,7 +124,7 @@ const _CreateOptions = struct {
}; };
}; };
pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
switch (self.allocator) { switch (self.allocator) {
.platform => { .platform => {
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
@ -142,6 +148,7 @@ const _CreateOptions = struct {
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out); var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
values.shrinkRetainingCapacity(0); values.shrinkRetainingCapacity(0);
switch (target) { switch (target) {
.cpu => self.cpu.writeNamedValues(&values),
.cuda => self.cuda.writeNamedValues(&values), .cuda => self.cuda.writeNamedValues(&values),
inline else => |t| { inline else => |t| {
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)}); stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});

View File

@ -269,18 +269,6 @@ pub const Tensor = struct {
return res; return res;
} }
var _global_tensor_counter: u64 = 0;
/// Internal use
pub fn _reserveIdRange(len: u32) u64 {
return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst);
}
/// Internal use
pub fn setUniqueId(self: *Tensor) void {
self._id = .{ .buffer_id = _reserveIdRange(1) };
}
/// Returns a Tensor containing the absolute value of each element of the input Tensor. /// Returns a Tensor containing the absolute value of each element of the input Tensor.
pub fn abs(self: Tensor) Tensor { pub fn abs(self: Tensor) Tensor {
const loc = self.getContext().mlirCtx().location(@src()); const loc = self.getContext().mlirCtx().location(@src());

View File

@ -298,7 +298,7 @@ test testLayer {
}; };
// create a buffer store containing the activations: // create a buffer store containing the activations:
var activations = try zml.aio.BufferStore.init(std.testing.allocator, &.{}); var activations = zml.aio.BufferStore.init(std.testing.allocator);
defer activations.deinit(); defer activations.deinit();
{ {
const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 }); const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });