Implement buffer‑ID based loading by moving tensor ID handling into BufferStore, fix zml.call tag hashing, and expose CPU device count.
This commit is contained in:
parent
6e7617918d
commit
7913c00d70
55
stdx/fmt.zig
55
stdx/fmt.zig
@ -1,5 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
|
/// Properly format a slice of numbers.
|
||||||
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
|
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
|
||||||
return .{ .slice = any_slice };
|
return .{ .slice = any_slice };
|
||||||
}
|
}
|
||||||
@ -115,3 +116,57 @@ pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io
|
|||||||
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatBool, values, spec, writer);
|
return try formatSliceCustom(formatBool, values, spec, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Format a struct using `format` method of subfields when possible.
|
||||||
|
pub fn any(any_val: anytype) FmtAny(@TypeOf(any_val)) {
|
||||||
|
return .{ .data = any_val };
|
||||||
|
}
|
||||||
|
|
||||||
|
fn FmtAny(Data: type) type {
|
||||||
|
return struct {
|
||||||
|
data: Data,
|
||||||
|
|
||||||
|
pub inline fn format(self: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
|
||||||
|
try printValue(writer, .{}, self.data, std.options.fmt_max_depth);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fix up of std.io.Writer.printValue that uses `format` method of subfields when possible
|
||||||
|
fn printValue(
|
||||||
|
w: *std.io.Writer,
|
||||||
|
options: std.fmt.Options,
|
||||||
|
value: anytype,
|
||||||
|
max_depth: usize,
|
||||||
|
) std.io.Writer.Error!void {
|
||||||
|
const T = @TypeOf(value);
|
||||||
|
if (std.meta.hasMethod(T, "format")) {
|
||||||
|
return try value.format(w);
|
||||||
|
}
|
||||||
|
if (std.meta.hasMethod(T, "formatNumber")) {
|
||||||
|
return try value.formatNumber(w, options.toNumber(.decimal, .lower));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (max_depth == 0) {
|
||||||
|
try w.writeAll(".{ ... }");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (@typeInfo(T)) {
|
||||||
|
.@"struct" => |info| {
|
||||||
|
try w.writeAll(".{ ");
|
||||||
|
inline for (info.fields, 0..) |f, i| {
|
||||||
|
if (i > 0) try w.writeAll(", ");
|
||||||
|
|
||||||
|
if (!info.is_tuple) {
|
||||||
|
try w.writeByte('.');
|
||||||
|
try w.writeAll(f.name);
|
||||||
|
try w.writeAll(" = ");
|
||||||
|
}
|
||||||
|
try printValue(w, options, @field(value, f.name), max_depth - 1);
|
||||||
|
}
|
||||||
|
try w.writeAll(" }");
|
||||||
|
},
|
||||||
|
inline else => try w.printValue("any", options, value, max_depth - 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
pub const BoundedArray = @import("bounded_array.zig").BoundedArray;
|
pub const BoundedArray = @import("bounded_array.zig").BoundedArray;
|
||||||
pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned;
|
pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned;
|
||||||
pub const debug = @import("debug.zig");
|
pub const debug = @import("debug.zig");
|
||||||
@ -11,7 +14,6 @@ pub const queue = @import("queue.zig");
|
|||||||
pub const time = @import("time.zig");
|
pub const time = @import("time.zig");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
const std = @import("std");
|
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,3 +22,5 @@ pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T {
|
|||||||
var storage: [max_len]T = undefined;
|
var storage: [max_len]T = undefined;
|
||||||
return storage[0..len];
|
return storage[0..len];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const noalloc: std.mem.Allocator = if (builtin.mode == .ReleaseFast) undefined else std.testing.failing_allocator;
|
||||||
|
|||||||
294
zml/aio.zig
294
zml/aio.zig
@ -43,8 +43,8 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8)
|
|||||||
/// whose shape is read from the "a.b" tensor.
|
/// whose shape is read from the "a.b" tensor.
|
||||||
/// * If `Model` contains a list of layers, then the field:
|
/// * If `Model` contains a list of layers, then the field:
|
||||||
/// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor.
|
/// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor.
|
||||||
pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, buffer_store: BufferStore) !Model {
|
pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, store: BufferStore) !Model {
|
||||||
return populateModelWithPrefix(Model, allocator, buffer_store, "");
|
return populateModelWithPrefix(Model, allocator, store, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Model struct with tensor shapes read from the given TensorStore,
|
/// Creates a Model struct with tensor shapes read from the given TensorStore,
|
||||||
@ -62,8 +62,7 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
|
|||||||
try prefix_builder.push(allocator, prefix);
|
try prefix_builder.push(allocator, prefix);
|
||||||
defer prefix_builder.deinit(allocator);
|
defer prefix_builder.deinit(allocator);
|
||||||
|
|
||||||
const unique_id = zml.Tensor._reserveIdRange(@intCast(store.buffers.count()));
|
const ok = _populateStruct(allocator, &prefix_builder, store, &model, true) catch |err| {
|
||||||
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
|
|
||||||
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
|
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
|
||||||
};
|
};
|
||||||
if (!ok) return error.TensorNotFound;
|
if (!ok) return error.TensorNotFound;
|
||||||
@ -74,17 +73,28 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
|
|||||||
pub const BufferStore = struct {
|
pub const BufferStore = struct {
|
||||||
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
|
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
|
||||||
pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
|
pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
|
||||||
|
var _unique_store_id: std.atomic.Value(u64) = .init(0);
|
||||||
|
const _store_id_range: u64 = 1024 * 1024 * 1024;
|
||||||
|
|
||||||
arena: std.heap.ArenaAllocator,
|
arena: std.heap.ArenaAllocator,
|
||||||
files: []MemoryMappedFile = &.{},
|
files: []MemoryMappedFile = &.{},
|
||||||
buffers: Buffers = .{},
|
buffers: Buffers = .{},
|
||||||
_metadata: Metadatas = .{},
|
_metadata: Metadatas = .{},
|
||||||
|
_unique_id: u64,
|
||||||
|
|
||||||
/// Create an empty BufferStore. Takes owneship of the given files.
|
/// Create an empty BufferStore.
|
||||||
pub fn init(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore {
|
/// Takes owneship of the given files.
|
||||||
var self: zml.aio.BufferStore = .{
|
pub fn init(allocator: std.mem.Allocator) BufferStore {
|
||||||
|
return .{
|
||||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||||
|
._unique_id = _unique_store_id.fetchAdd(_store_id_range, .monotonic),
|
||||||
};
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an empty BufferStore.
|
||||||
|
/// Takes owneship of the given files.
|
||||||
|
pub fn initWithFiles(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore {
|
||||||
|
var self: BufferStore = .init(allocator);
|
||||||
self.files = try self.arena.allocator().dupe(MemoryMappedFile, files);
|
self.files = try self.arena.allocator().dupe(MemoryMappedFile, files);
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
@ -100,6 +110,61 @@ pub const BufferStore = struct {
|
|||||||
return self.buffers.get(key);
|
return self.buffers.get(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn loadBufferById(self: BufferStore, x: zml.Tensor, platform: zml.Platform) !zml.Buffer {
|
||||||
|
var host_buffer: zml.HostBuffer = switch (x._id) {
|
||||||
|
.buffer_id => |id| hb: {
|
||||||
|
if (id < self._unique_id or self._unique_id + _store_id_range <= id) {
|
||||||
|
@panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object.");
|
||||||
|
}
|
||||||
|
break :hb self.buffers.values()[id - self._unique_id];
|
||||||
|
},
|
||||||
|
else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use the sharding information stored in the tensor.
|
||||||
|
host_buffer._shape = x.shape();
|
||||||
|
return try host_buffer.toDevice(platform);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a bufferized version of a model from the given BufferStore.
|
||||||
|
///
|
||||||
|
/// This will represent the weights of the model, loaded on a specific platform.
|
||||||
|
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
|
||||||
|
/// `module.ExeWithWeights` ready to be called.
|
||||||
|
pub fn loadModelById(self: BufferStore, Model: type, allocator: std.mem.Allocator, model: Model, platform: zml.Platform) !zml.Bufferized(Model) {
|
||||||
|
const Ctx = struct {
|
||||||
|
platform: *const zml.Platform,
|
||||||
|
store: *const BufferStore,
|
||||||
|
|
||||||
|
pub fn cb(ctx: @This(), x: zml.Tensor) zml.Buffer {
|
||||||
|
return ctx.store.loadBufferById(x, ctx.platform.*) catch @panic("Failed to load buffer to device");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var res: zml.Bufferized(Model) = undefined;
|
||||||
|
try zml.meta.mapAlloc(Ctx.cb, allocator, .{ .platform = &platform, .store = &self }, model, &res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getTensor(self: BufferStore, key: []const u8) zml.Tensor {
|
||||||
|
return self.getTensorOrNull(key) orelse {
|
||||||
|
log.err("Tensor not found: {s}", .{key});
|
||||||
|
self.findSimilarBufferKeys(std.heap.smp_allocator, key);
|
||||||
|
@panic("Tensor not found");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getTensorOrNull(self: BufferStore, key: []const u8) ?zml.Tensor {
|
||||||
|
return if (self.buffers.getIndex(key)) |entry_idx|
|
||||||
|
.{
|
||||||
|
._shape = self.buffers.values()[entry_idx].shape(),
|
||||||
|
._id = .{ .buffer_id = self._unique_id + entry_idx },
|
||||||
|
._donation = .input_buffer,
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/// Count layers starting with the given prefix.
|
/// Count layers starting with the given prefix.
|
||||||
pub fn countLayers(self: BufferStore, prefix: []const u8) usize {
|
pub fn countLayers(self: BufferStore, prefix: []const u8) usize {
|
||||||
// Note: This is kinda inefficient
|
// Note: This is kinda inefficient
|
||||||
@ -140,6 +205,58 @@ pub const BufferStore = struct {
|
|||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Assists in debuggigng `BufferNotFound` error
|
||||||
|
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
|
||||||
|
pub fn findSimilarBufferKeys(store: BufferStore, tmp_alloc: std.mem.Allocator, original_key: []const u8) void {
|
||||||
|
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
|
||||||
|
var shown_keys = std.StringHashMap(void).init(tmp_alloc);
|
||||||
|
defer shown_keys.deinit();
|
||||||
|
|
||||||
|
// remove suffix .weight and .bias
|
||||||
|
var base_key = original_key;
|
||||||
|
for (suffixes) |suffix| {
|
||||||
|
if (std.mem.endsWith(u8, original_key, suffix)) {
|
||||||
|
base_key = original_key[0 .. original_key.len - suffix.len];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// first test: look for exact matches
|
||||||
|
var matches: usize = 0;
|
||||||
|
var it = store.buffers.iterator();
|
||||||
|
while (it.next()) |entry| {
|
||||||
|
const key = entry.key_ptr.*;
|
||||||
|
if (std.mem.startsWith(u8, key, base_key)) {
|
||||||
|
if (matches == 0) log.warn("Similar buffers found:", .{});
|
||||||
|
if (!shown_keys.contains(key)) {
|
||||||
|
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
||||||
|
shown_keys.put(key, {}) catch continue;
|
||||||
|
matches += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// second test: progressive partial matches
|
||||||
|
if (matches == 0) {
|
||||||
|
var components = std.mem.splitScalar(u8, base_key, '.');
|
||||||
|
while (components.next()) |component| {
|
||||||
|
matches = 0;
|
||||||
|
it = store.buffers.iterator();
|
||||||
|
while (it.next()) |entry| {
|
||||||
|
const key = entry.key_ptr.*;
|
||||||
|
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
||||||
|
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
||||||
|
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
||||||
|
shown_keys.put(key, {}) catch continue;
|
||||||
|
matches += 1;
|
||||||
|
if (matches >= 5) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (matches > 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Metadata = union(enum) {
|
pub const Metadata = union(enum) {
|
||||||
@ -260,17 +377,35 @@ pub const MemoryMappedFile = struct {
|
|||||||
/// Helper handling prefix building.
|
/// Helper handling prefix building.
|
||||||
///
|
///
|
||||||
/// This allows to easily push/pop prefixes and handles the generation of the string with the correct format.
|
/// This allows to easily push/pop prefixes and handles the generation of the string with the correct format.
|
||||||
const PrefixBuilder = struct {
|
pub const PrefixBuilder = struct {
|
||||||
/// Stores the computed prefix.
|
/// Stores the computed prefix.
|
||||||
data: std.ArrayListUnmanaged(u8) = .{},
|
data: std.ArrayList(u8) = .{},
|
||||||
/// Stack storing the size of the intermediary prefix.
|
/// Stack storing the size of the intermediary prefix.
|
||||||
subprefixes: std.ArrayListUnmanaged(u32) = .{},
|
subprefixes: std.ArrayList(u32) = .{},
|
||||||
|
|
||||||
|
pub fn initCapacity(allocator: std.mem.Allocator, capacity: usize) !PrefixBuilder {
|
||||||
|
return .{
|
||||||
|
.data = try .initCapacity(allocator, capacity),
|
||||||
|
.subprefixes = try .initCapacity(allocator, @divFloor(capacity, 4)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn deinit(self: *PrefixBuilder, allocator: std.mem.Allocator) void {
|
pub fn deinit(self: *PrefixBuilder, allocator: std.mem.Allocator) void {
|
||||||
self.data.deinit(allocator);
|
self.data.deinit(allocator);
|
||||||
self.subprefixes.deinit(allocator);
|
self.subprefixes.deinit(allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn items(self: PrefixBuilder) []const u8 {
|
||||||
|
return self.data.items;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn concat(self: *PrefixBuilder, prefix: []const u8) []const u8 {
|
||||||
|
self.push(stdx.noalloc, prefix) catch unreachable;
|
||||||
|
const res = self.items();
|
||||||
|
self.pop();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn push(self: *PrefixBuilder, allocator: std.mem.Allocator, prefix: []const u8) !void {
|
pub fn push(self: *PrefixBuilder, allocator: std.mem.Allocator, prefix: []const u8) !void {
|
||||||
const old_len: u32 = @intCast(self.data.items.len);
|
const old_len: u32 = @intCast(self.data.items.len);
|
||||||
try self.subprefixes.append(allocator, old_len);
|
try self.subprefixes.append(allocator, old_len);
|
||||||
@ -301,13 +436,20 @@ const PrefixBuilder = struct {
|
|||||||
const last_prefix_len = self.subprefixes.pop() orelse unreachable;
|
const last_prefix_len = self.subprefixes.pop() orelse unreachable;
|
||||||
self.data.shrinkRetainingCapacity(last_prefix_len);
|
self.data.shrinkRetainingCapacity(last_prefix_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn checkpoint(self: PrefixBuilder) [2]usize {
|
||||||
|
return .{ self.data.items.len, self.subprefixes.items.len };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn restore(self: *PrefixBuilder, ckpt: [2]usize) void {
|
||||||
|
self.data.items.len, self.subprefixes.items.len = ckpt;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fn _populateStruct(
|
fn _populateStruct(
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
prefix_builder: *PrefixBuilder,
|
prefix_builder: *PrefixBuilder,
|
||||||
unique_id: u64,
|
store: BufferStore,
|
||||||
buffer_store: BufferStore,
|
|
||||||
obj: anytype,
|
obj: anytype,
|
||||||
required: bool,
|
required: bool,
|
||||||
) !bool {
|
) !bool {
|
||||||
@ -322,17 +464,12 @@ fn _populateStruct(
|
|||||||
|
|
||||||
const prefix = prefix_builder.data.items;
|
const prefix = prefix_builder.data.items;
|
||||||
if (T == zml.Tensor) {
|
if (T == zml.Tensor) {
|
||||||
return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| {
|
return if (store.getTensorOrNull(prefix)) |tensor| {
|
||||||
const buffer = buffer_store.get(prefix).?;
|
obj.* = tensor;
|
||||||
obj.* = zml.Tensor{
|
|
||||||
._shape = buffer.shape(),
|
|
||||||
._id = .{ .buffer_id = unique_id + entry_idx },
|
|
||||||
._donation = .input_buffer,
|
|
||||||
};
|
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
if (required) {
|
if (required) {
|
||||||
log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
|
log.err("Tensor not found: {s} ({d})", .{ prefix, store.buffers.count() });
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@ -343,14 +480,14 @@ fn _populateStruct(
|
|||||||
if (ptr_info.size == .slice) {
|
if (ptr_info.size == .slice) {
|
||||||
obj.* = &.{};
|
obj.* = &.{};
|
||||||
|
|
||||||
const len = buffer_store.countLayers(prefix);
|
const len = store.countLayers(prefix);
|
||||||
if (len > 0) {
|
if (len > 0) {
|
||||||
obj.* = try allocator.alloc(ptr_info.child, len);
|
obj.* = try allocator.alloc(ptr_info.child, len);
|
||||||
|
|
||||||
for (obj.*, 0..) |*value, i| {
|
for (obj.*, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
const found = try _populateStruct(allocator, prefix_builder, store, value, required);
|
||||||
if (!found) {
|
if (!found) {
|
||||||
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
|
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
|
||||||
return false;
|
return false;
|
||||||
@ -369,7 +506,7 @@ fn _populateStruct(
|
|||||||
for (obj, 0..) |*value, i| {
|
for (obj, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
const found = try _populateStruct(allocator, prefix_builder, store, value, required);
|
||||||
if (!found) {
|
if (!found) {
|
||||||
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) });
|
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) });
|
||||||
return false;
|
return false;
|
||||||
@ -386,15 +523,20 @@ fn _populateStruct(
|
|||||||
|
|
||||||
var has_default = false;
|
var has_default = false;
|
||||||
if (field.default_value_ptr) |_| has_default = true;
|
if (field.default_value_ptr) |_| has_default = true;
|
||||||
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
|
|
||||||
partial_struct = partial_struct or field_found;
|
if (zml.meta.Contains(field.type, zml.Tensor)) {
|
||||||
if (!field_found) {
|
const field_found = try _populateStruct(allocator, prefix_builder, store, &@field(obj, field.name), required and !has_default);
|
||||||
if (field.default_value_ptr) |v| {
|
partial_struct = partial_struct or field_found;
|
||||||
@field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
|
if (!field_found) {
|
||||||
} else {
|
if (field.default_value_ptr) |v| {
|
||||||
if (partial_struct) {
|
@field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*;
|
||||||
log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });
|
} else {
|
||||||
obj.* = undefined;
|
if (partial_struct) {
|
||||||
|
log.warn("Incomplete struct '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name });
|
||||||
|
obj.* = undefined;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -406,7 +548,7 @@ fn _populateStruct(
|
|||||||
},
|
},
|
||||||
.optional => |opt_info| {
|
.optional => |opt_info| {
|
||||||
obj.* = @as(opt_info.child, undefined);
|
obj.* = @as(opt_info.child, undefined);
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
|
const found = try _populateStruct(allocator, prefix_builder, store, &(obj.*.?), false);
|
||||||
if (!found) obj.* = null;
|
if (!found) obj.* = null;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
@ -449,7 +591,7 @@ test populateModel {
|
|||||||
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
var store: BufferStore = .{ .arena = arena_state };
|
var store: BufferStore = .init(arena_state.allocator());
|
||||||
try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16);
|
try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16);
|
||||||
store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10));
|
store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10));
|
||||||
store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20));
|
store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20));
|
||||||
@ -494,11 +636,11 @@ test populateModel {
|
|||||||
pub fn loadBuffers(
|
pub fn loadBuffers(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
||||||
buffer_store: BufferStore,
|
store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
) !zml.Bufferized(Model) {
|
) !zml.Bufferized(Model) {
|
||||||
return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, "");
|
return loadBuffersWithPrefix(Model, init_args, store, allocator, platform, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
|
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
|
||||||
@ -512,7 +654,7 @@ pub fn loadBuffers(
|
|||||||
pub fn loadBuffersWithPrefix(
|
pub fn loadBuffersWithPrefix(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
||||||
buffer_store: BufferStore,
|
store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
prefix: []const u8,
|
prefix: []const u8,
|
||||||
@ -522,14 +664,14 @@ pub fn loadBuffersWithPrefix(
|
|||||||
const arena = arena_state.allocator();
|
const arena = arena_state.allocator();
|
||||||
|
|
||||||
// Get model structure with tensor shapes from the buffer store with prefix
|
// Get model structure with tensor shapes from the buffer store with prefix
|
||||||
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix);
|
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, store, prefix);
|
||||||
|
|
||||||
// If the Model has a "init" function, call it with the given parameters.
|
// If the Model has a "init" function, call it with the given parameters.
|
||||||
if (@hasDecl(Model, "init")) {
|
if (@hasDecl(Model, "init")) {
|
||||||
@call(.auto, Model.init, .{&model} ++ init_args);
|
@call(.auto, Model.init, .{&model} ++ init_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix);
|
return loadModelBuffersWithPrefix(Model, model, store, allocator, platform, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
||||||
@ -541,11 +683,11 @@ pub fn loadBuffersWithPrefix(
|
|||||||
pub fn loadModelBuffers(
|
pub fn loadModelBuffers(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
model: Model,
|
model: Model,
|
||||||
buffer_store: BufferStore,
|
store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
) !zml.Bufferized(Model) {
|
) !zml.Bufferized(Model) {
|
||||||
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
return try loadModelBuffersWithPrefix(Model, model, store, allocator, platform, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
|
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
|
||||||
@ -557,7 +699,7 @@ pub fn loadModelBuffers(
|
|||||||
pub fn loadModelBuffersWithPrefix(
|
pub fn loadModelBuffersWithPrefix(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
model: Model,
|
model: Model,
|
||||||
buffer_store: BufferStore,
|
store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
prefix: []const u8,
|
prefix: []const u8,
|
||||||
@ -576,7 +718,7 @@ pub fn loadModelBuffersWithPrefix(
|
|||||||
try prefix_builder.push(allocator, prefix);
|
try prefix_builder.push(allocator, prefix);
|
||||||
defer prefix_builder.deinit(allocator);
|
defer prefix_builder.deinit(allocator);
|
||||||
|
|
||||||
try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform);
|
try visitStructAndLoadBuffer(allocator, &prefix_builder, store, &res, platform);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -590,58 +732,6 @@ pub fn unloadBuffers(model: anytype) void {
|
|||||||
}).cb, {}, model);
|
}).cb, {}, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Assists in debuggigng `BufferNotFound` error
|
|
||||||
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
|
|
||||||
fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void {
|
|
||||||
const suffixes = [_][]const u8{ "", ".weight", ".bias" };
|
|
||||||
var shown_keys = std.StringHashMap(void).init(temp_allocator);
|
|
||||||
defer shown_keys.deinit();
|
|
||||||
|
|
||||||
// remove suffix .weight and .bias
|
|
||||||
var base_key = original_key;
|
|
||||||
for (suffixes) |suffix| {
|
|
||||||
if (std.mem.endsWith(u8, original_key, suffix)) {
|
|
||||||
base_key = original_key[0 .. original_key.len - suffix.len];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// first test: look for exact matches
|
|
||||||
var matches: usize = 0;
|
|
||||||
var it = store.buffers.iterator();
|
|
||||||
while (it.next()) |entry| {
|
|
||||||
const key = entry.key_ptr.*;
|
|
||||||
if (std.mem.startsWith(u8, key, base_key)) {
|
|
||||||
if (matches == 0) log.warn("Similar buffers found:", .{});
|
|
||||||
if (!shown_keys.contains(key)) {
|
|
||||||
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
|
||||||
shown_keys.put(key, {}) catch continue;
|
|
||||||
matches += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// second test: progressive partial matches
|
|
||||||
if (matches == 0) {
|
|
||||||
var components = std.mem.splitScalar(u8, base_key, '.');
|
|
||||||
while (components.next()) |component| {
|
|
||||||
matches = 0;
|
|
||||||
it = store.buffers.iterator();
|
|
||||||
while (it.next()) |entry| {
|
|
||||||
const key = entry.key_ptr.*;
|
|
||||||
if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) {
|
|
||||||
if (matches == 0) log.warn("Partial matches for '{s}':", .{component});
|
|
||||||
log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() });
|
|
||||||
shown_keys.put(key, {}) catch continue;
|
|
||||||
matches += 1;
|
|
||||||
if (matches >= 5) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (matches > 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// deinit all buffers in the given struct
|
/// deinit all buffers in the given struct
|
||||||
pub fn awaitAll(buffers: anytype) !void {
|
pub fn awaitAll(buffers: anytype) !void {
|
||||||
zml.meta.visit((struct {
|
zml.meta.visit((struct {
|
||||||
@ -651,7 +741,7 @@ pub fn awaitAll(buffers: anytype) !void {
|
|||||||
}).cb, {}, buffers);
|
}).cb, {}, buffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
||||||
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
||||||
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
||||||
.pointer => |ptr_info| switch (ptr_info.size) {
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
@ -663,7 +753,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
|
|
||||||
const prefix = prefix_builder.data.items;
|
const prefix = prefix_builder.data.items;
|
||||||
if (T == zml.Buffer) {
|
if (T == zml.Buffer) {
|
||||||
return if (buffer_store.get(prefix)) |host_buffer| {
|
return if (store.get(prefix)) |host_buffer| {
|
||||||
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
|
||||||
var buf_with_metadata = host_buffer;
|
var buf_with_metadata = host_buffer;
|
||||||
log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape });
|
log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape });
|
||||||
@ -673,7 +763,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
} else {
|
} else {
|
||||||
log.err("Buffer not found: {s}", .{prefix});
|
log.err("Buffer not found: {s}", .{prefix});
|
||||||
|
|
||||||
findSimilarBufferKeys(prefix, buffer_store, allocator);
|
store.findSimilarBufferKeys(allocator, prefix);
|
||||||
|
|
||||||
return error.BufferNotFound;
|
return error.BufferNotFound;
|
||||||
};
|
};
|
||||||
@ -686,7 +776,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform);
|
||||||
}
|
}
|
||||||
} else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
|
} else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T});
|
||||||
},
|
},
|
||||||
@ -694,7 +784,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
for (obj, 0..) |*value, i| {
|
for (obj, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -704,12 +794,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, store, &@field(obj, field.name), platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.optional => {
|
.optional => {
|
||||||
if (obj.*) |*obj_val| {
|
if (obj.*) |*obj_val| {
|
||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, store, obj_val, platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
|
|||||||
@ -13,9 +13,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
|
|||||||
const log = std.log.scoped(.@"zml/io");
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
var res: zml.aio.BufferStore = .{
|
var res: zml.aio.BufferStore = .init(allocator);
|
||||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
|
||||||
};
|
|
||||||
errdefer res.arena.deinit();
|
errdefer res.arena.deinit();
|
||||||
const arena = res.arena.allocator();
|
const arena = res.arena.allocator();
|
||||||
|
|
||||||
|
|||||||
@ -25,9 +25,7 @@ const TinyLlamaConfig = extern struct {
|
|||||||
/// For convenience we use the same layer names
|
/// For convenience we use the same layer names
|
||||||
/// than the one used by the Llama-3.1 models.
|
/// than the one used by the Llama-3.1 models.
|
||||||
pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.BufferStore {
|
||||||
var res: zml.aio.BufferStore = .{
|
var res: zml.aio.BufferStore = .init(allocator);
|
||||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
|
||||||
};
|
|
||||||
errdefer res.arena.deinit();
|
errdefer res.arena.deinit();
|
||||||
const arena = res.arena.allocator();
|
const arena = res.arena.allocator();
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
const py_values = try eval.evaluate(tmp_alloc, ops, true);
|
||||||
|
|
||||||
// file ownership is transferred to the BufferStore
|
// file ownership is transferred to the BufferStore
|
||||||
var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file});
|
var res = try zml.aio.BufferStore.initWithFiles(allocator, &.{mmap_file});
|
||||||
try torch_file.parseModel(py_values, &res);
|
try torch_file.parseModel(py_values, &res);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -480,7 +480,7 @@ test "Read pickle (zipped)" {
|
|||||||
// torch.save({ "model": model, "tensor": tensor}, "simple.pt")
|
// torch.save({ "model": model, "tensor": tensor}, "simple.pt")
|
||||||
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
|
const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only });
|
||||||
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
|
const mmap_file = try zml.aio.MemoryMappedFile.init(file);
|
||||||
var store = try zml.aio.BufferStore.init(testing.allocator, &.{mmap_file});
|
var store = try zml.aio.BufferStore.initWithFiles(testing.allocator, &.{mmap_file});
|
||||||
defer store.deinit();
|
defer store.deinit();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|||||||
@ -77,7 +77,7 @@ pub fn compileModel(
|
|||||||
) !FnExe(func) {
|
) !FnExe(func) {
|
||||||
const ModelT = ModuleSignature(func).ModelT;
|
const ModelT = ModuleSignature(func).ModelT;
|
||||||
const name = @typeName(ModelT) ++ ".forward";
|
const name = @typeName(ModelT) ++ ".forward";
|
||||||
log.info("Compiling {s} with {}", .{ name, args_shapes });
|
log.info("Compiling {s} with {f}", .{ name, stdx.fmt.any(args_shapes) });
|
||||||
|
|
||||||
var context = try CompilationContext.init(allocator, name, platform);
|
var context = try CompilationContext.init(allocator, name, platform);
|
||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
@ -209,9 +209,9 @@ pub const BaseExe = struct {
|
|||||||
|
|
||||||
var execute_context: ?*pjrt.ExecuteContext = null;
|
var execute_context: ?*pjrt.ExecuteContext = null;
|
||||||
if (platform.pjrt_api.ffi()) |ffi| {
|
if (platform.pjrt_api.ffi()) |ffi| {
|
||||||
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
|
|
||||||
execute_context = try platform.pjrt_api.createExecuteContext();
|
execute_context = try platform.pjrt_api.createExecuteContext();
|
||||||
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
|
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
|
||||||
|
// log.info("Created context execution {*} for {*}", .{ execute_context, exe });
|
||||||
}
|
}
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
|
|||||||
@ -141,6 +141,7 @@ test MapType {
|
|||||||
/// Any `To` struct inside `from` will be copied over to the target.
|
/// Any `To` struct inside `from` will be copied over to the target.
|
||||||
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
|
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
|
||||||
// TODO: handle tuple to slice conversion
|
// TODO: handle tuple to slice conversion
|
||||||
|
// TODO: handle error bubbling up
|
||||||
const From = FnParam(cb, 1);
|
const From = FnParam(cb, 1);
|
||||||
const To = stdx.meta.FnResult(cb);
|
const To = stdx.meta.FnResult(cb);
|
||||||
const FromStruct = @TypeOf(from);
|
const FromStruct = @TypeOf(from);
|
||||||
|
|||||||
@ -1129,7 +1129,7 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void {
|
|||||||
hash(hasher, shape._dtype, .Shallow);
|
hash(hasher, shape._dtype, .Shallow);
|
||||||
hash(hasher, shape._sharding_info, .Shallow);
|
hash(hasher, shape._sharding_info, .Shallow);
|
||||||
for (shape.tags()) |tag| {
|
for (shape.tags()) |tag| {
|
||||||
hash(hasher, @intFromPtr(tag), .Shallow);
|
hashArray(hasher, std.mem.span(tag), .Shallow);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -82,9 +82,7 @@ pub const Platform = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const _CreateOptions = struct {
|
const _CreateOptions = struct {
|
||||||
// XLA CPU client doesn't read options
|
cpu: Cpu = .{ .device_count = 4 },
|
||||||
// https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46
|
|
||||||
cpu: struct {} = .{},
|
|
||||||
|
|
||||||
// bump memory fraction from XLA defaults of 75% to 90%.
|
// bump memory fraction from XLA defaults of 75% to 90%.
|
||||||
// Even on a 8GB GPU it should leave enough space for the Cuda driver
|
// Even on a 8GB GPU it should leave enough space for the Cuda driver
|
||||||
@ -94,6 +92,14 @@ const _CreateOptions = struct {
|
|||||||
tpu: struct {} = .{},
|
tpu: struct {} = .{},
|
||||||
neuron: struct {} = .{},
|
neuron: struct {} = .{},
|
||||||
|
|
||||||
|
pub const Cpu = struct {
|
||||||
|
device_count: u32,
|
||||||
|
|
||||||
|
fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
||||||
|
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub const Cuda = struct {
|
pub const Cuda = struct {
|
||||||
allocator: Allocator = .{ .bfc = .{} },
|
allocator: Allocator = .{ .bfc = .{} },
|
||||||
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
|
// TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86
|
||||||
@ -118,7 +124,7 @@ const _CreateOptions = struct {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
||||||
switch (self.allocator) {
|
switch (self.allocator) {
|
||||||
.platform => {
|
.platform => {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
||||||
@ -142,6 +148,7 @@ const _CreateOptions = struct {
|
|||||||
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
|
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
|
||||||
values.shrinkRetainingCapacity(0);
|
values.shrinkRetainingCapacity(0);
|
||||||
switch (target) {
|
switch (target) {
|
||||||
|
.cpu => self.cpu.writeNamedValues(&values),
|
||||||
.cuda => self.cuda.writeNamedValues(&values),
|
.cuda => self.cuda.writeNamedValues(&values),
|
||||||
inline else => |t| {
|
inline else => |t| {
|
||||||
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});
|
stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)});
|
||||||
|
|||||||
@ -269,18 +269,6 @@ pub const Tensor = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
var _global_tensor_counter: u64 = 0;
|
|
||||||
|
|
||||||
/// Internal use
|
|
||||||
pub fn _reserveIdRange(len: u32) u64 {
|
|
||||||
return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Internal use
|
|
||||||
pub fn setUniqueId(self: *Tensor) void {
|
|
||||||
self._id = .{ .buffer_id = _reserveIdRange(1) };
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
|
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
|
||||||
pub fn abs(self: Tensor) Tensor {
|
pub fn abs(self: Tensor) Tensor {
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
|
|||||||
@ -298,7 +298,7 @@ test testLayer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// create a buffer store containing the activations:
|
// create a buffer store containing the activations:
|
||||||
var activations = try zml.aio.BufferStore.init(std.testing.allocator, &.{});
|
var activations = zml.aio.BufferStore.init(std.testing.allocator);
|
||||||
defer activations.deinit();
|
defer activations.deinit();
|
||||||
{
|
{
|
||||||
const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });
|
const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user