diff --git a/stdx/fmt.zig b/stdx/fmt.zig index 145db76..03bf7a3 100644 --- a/stdx/fmt.zig +++ b/stdx/fmt.zig @@ -1,5 +1,6 @@ const std = @import("std"); +/// Properly format a slice of numbers. pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) { return .{ .slice = any_slice }; } @@ -115,3 +116,57 @@ pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void { return try formatSliceCustom(formatBool, values, spec, writer); } + +/// Format a struct using `format` method of subfields when possible. +pub fn any(any_val: anytype) FmtAny(@TypeOf(any_val)) { + return .{ .data = any_val }; +} + +fn FmtAny(Data: type) type { + return struct { + data: Data, + + pub inline fn format(self: @This(), writer: *std.io.Writer) std.io.Writer.Error!void { + try printValue(writer, .{}, self.data, std.options.fmt_max_depth); + } + }; +} + +/// Fix up of std.io.Writer.printValue that uses `format` method of subfields when possible +fn printValue( + w: *std.io.Writer, + options: std.fmt.Options, + value: anytype, + max_depth: usize, +) std.io.Writer.Error!void { + const T = @TypeOf(value); + if (std.meta.hasMethod(T, "format")) { + return try value.format(w); + } + if (std.meta.hasMethod(T, "formatNumber")) { + return try value.formatNumber(w, options.toNumber(.decimal, .lower)); + } + + if (max_depth == 0) { + try w.writeAll(".{ ... }"); + return; + } + + switch (@typeInfo(T)) { + .@"struct" => |info| { + try w.writeAll(".{ "); + inline for (info.fields, 0..) |f, i| { + if (i > 0) try w.writeAll(", "); + + if (!info.is_tuple) { + try w.writeByte('.'); + try w.writeAll(f.name); + try w.writeAll(" = "); + } + try printValue(w, options, @field(value, f.name), max_depth - 1); + } + try w.writeAll(" }"); + }, + inline else => try w.printValue("any", options, value, max_depth - 1), + } +} diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 9bfa532..4698349 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,3 +1,6 @@ +const std = @import("std"); +const builtin = @import("builtin"); + pub const BoundedArray = @import("bounded_array.zig").BoundedArray; pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned; pub const debug = @import("debug.zig"); @@ -11,7 +14,6 @@ pub const queue = @import("queue.zig"); pub const time = @import("time.zig"); test { - const std = @import("std"); std.testing.refAllDecls(@This()); } @@ -20,3 +22,5 @@ pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T { var storage: [max_len]T = undefined; return storage[0..len]; } + +pub const noalloc: std.mem.Allocator = if (builtin.mode == .ReleaseFast) undefined else std.testing.failing_allocator; diff --git a/zml/aio.zig b/zml/aio.zig index 1b926a9..a778a62 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -43,8 +43,8 @@ pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) /// whose shape is read from the "a.b" tensor. /// * If `Model` contains a list of layers, then the field: /// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor. -pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, buffer_store: BufferStore) !Model { - return populateModelWithPrefix(Model, allocator, buffer_store, ""); +pub fn populateModel(comptime Model: type, allocator: std.mem.Allocator, store: BufferStore) !Model { + return populateModelWithPrefix(Model, allocator, store, ""); } /// Creates a Model struct with tensor shapes read from the given TensorStore, @@ -62,8 +62,7 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato try prefix_builder.push(allocator, prefix); defer prefix_builder.deinit(allocator); - const unique_id = zml.Tensor._reserveIdRange(@intCast(store.buffers.count())); - const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| { + const ok = _populateStruct(allocator, &prefix_builder, store, &model, true) catch |err| { std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) }); }; if (!ok) return error.TensorNotFound; @@ -74,17 +73,28 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato pub const BufferStore = struct { pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer); pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata); + var _unique_store_id: std.atomic.Value(u64) = .init(0); + const _store_id_range: u64 = 1024 * 1024 * 1024; arena: std.heap.ArenaAllocator, files: []MemoryMappedFile = &.{}, buffers: Buffers = .{}, _metadata: Metadatas = .{}, + _unique_id: u64, - /// Create an empty BufferStore. Takes owneship of the given files. - pub fn init(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore { - var self: zml.aio.BufferStore = .{ + /// Create an empty BufferStore. + /// Takes owneship of the given files. + pub fn init(allocator: std.mem.Allocator) BufferStore { + return .{ .arena = std.heap.ArenaAllocator.init(allocator), + ._unique_id = _unique_store_id.fetchAdd(_store_id_range, .monotonic), }; + } + + /// Create an empty BufferStore. + /// Takes owneship of the given files. + pub fn initWithFiles(allocator: std.mem.Allocator, files: []const MemoryMappedFile) error{OutOfMemory}!BufferStore { + var self: BufferStore = .init(allocator); self.files = try self.arena.allocator().dupe(MemoryMappedFile, files); return self; } @@ -100,6 +110,61 @@ pub const BufferStore = struct { return self.buffers.get(key); } + pub fn loadBufferById(self: BufferStore, x: zml.Tensor, platform: zml.Platform) !zml.Buffer { + var host_buffer: zml.HostBuffer = switch (x._id) { + .buffer_id => |id| hb: { + if (id < self._unique_id or self._unique_id + _store_id_range <= id) { + @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object."); + } + break :hb self.buffers.values()[id - self._unique_id]; + }, + else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"), + }; + + // Use the sharding information stored in the tensor. + host_buffer._shape = x.shape(); + return try host_buffer.toDevice(platform); + } + + /// Creates a bufferized version of a model from the given BufferStore. + /// + /// This will represent the weights of the model, loaded on a specific platform. + /// It can be used with a `module.Exe` (a compiled version of the same Model), to make a + /// `module.ExeWithWeights` ready to be called. + pub fn loadModelById(self: BufferStore, Model: type, allocator: std.mem.Allocator, model: Model, platform: zml.Platform) !zml.Bufferized(Model) { + const Ctx = struct { + platform: *const zml.Platform, + store: *const BufferStore, + + pub fn cb(ctx: @This(), x: zml.Tensor) zml.Buffer { + return ctx.store.loadBufferById(x, ctx.platform.*) catch @panic("Failed to load buffer to device"); + } + }; + + var res: zml.Bufferized(Model) = undefined; + try zml.meta.mapAlloc(Ctx.cb, allocator, .{ .platform = &platform, .store = &self }, model, &res); + return res; + } + + pub fn getTensor(self: BufferStore, key: []const u8) zml.Tensor { + return self.getTensorOrNull(key) orelse { + log.err("Tensor not found: {s}", .{key}); + self.findSimilarBufferKeys(std.heap.smp_allocator, key); + @panic("Tensor not found"); + }; + } + + pub fn getTensorOrNull(self: BufferStore, key: []const u8) ?zml.Tensor { + return if (self.buffers.getIndex(key)) |entry_idx| + .{ + ._shape = self.buffers.values()[entry_idx].shape(), + ._id = .{ .buffer_id = self._unique_id + entry_idx }, + ._donation = .input_buffer, + } + else + return null; + } + /// Count layers starting with the given prefix. pub fn countLayers(self: BufferStore, prefix: []const u8) usize { // Note: This is kinda inefficient @@ -140,6 +205,58 @@ pub const BufferStore = struct { return null; } + + /// Assists in debuggigng `BufferNotFound` error + /// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos) + pub fn findSimilarBufferKeys(store: BufferStore, tmp_alloc: std.mem.Allocator, original_key: []const u8) void { + const suffixes = [_][]const u8{ "", ".weight", ".bias" }; + var shown_keys = std.StringHashMap(void).init(tmp_alloc); + defer shown_keys.deinit(); + + // remove suffix .weight and .bias + var base_key = original_key; + for (suffixes) |suffix| { + if (std.mem.endsWith(u8, original_key, suffix)) { + base_key = original_key[0 .. original_key.len - suffix.len]; + break; + } + } + + // first test: look for exact matches + var matches: usize = 0; + var it = store.buffers.iterator(); + while (it.next()) |entry| { + const key = entry.key_ptr.*; + if (std.mem.startsWith(u8, key, base_key)) { + if (matches == 0) log.warn("Similar buffers found:", .{}); + if (!shown_keys.contains(key)) { + log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); + shown_keys.put(key, {}) catch continue; + matches += 1; + } + } + } + + // second test: progressive partial matches + if (matches == 0) { + var components = std.mem.splitScalar(u8, base_key, '.'); + while (components.next()) |component| { + matches = 0; + it = store.buffers.iterator(); + while (it.next()) |entry| { + const key = entry.key_ptr.*; + if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) { + if (matches == 0) log.warn("Partial matches for '{s}':", .{component}); + log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); + shown_keys.put(key, {}) catch continue; + matches += 1; + if (matches >= 5) break; + } + } + if (matches > 0) break; + } + } + } }; pub const Metadata = union(enum) { @@ -260,17 +377,35 @@ pub const MemoryMappedFile = struct { /// Helper handling prefix building. /// /// This allows to easily push/pop prefixes and handles the generation of the string with the correct format. -const PrefixBuilder = struct { +pub const PrefixBuilder = struct { /// Stores the computed prefix. - data: std.ArrayListUnmanaged(u8) = .{}, + data: std.ArrayList(u8) = .{}, /// Stack storing the size of the intermediary prefix. - subprefixes: std.ArrayListUnmanaged(u32) = .{}, + subprefixes: std.ArrayList(u32) = .{}, + + pub fn initCapacity(allocator: std.mem.Allocator, capacity: usize) !PrefixBuilder { + return .{ + .data = try .initCapacity(allocator, capacity), + .subprefixes = try .initCapacity(allocator, @divFloor(capacity, 4)), + }; + } pub fn deinit(self: *PrefixBuilder, allocator: std.mem.Allocator) void { self.data.deinit(allocator); self.subprefixes.deinit(allocator); } + pub fn items(self: PrefixBuilder) []const u8 { + return self.data.items; + } + + pub fn concat(self: *PrefixBuilder, prefix: []const u8) []const u8 { + self.push(stdx.noalloc, prefix) catch unreachable; + const res = self.items(); + self.pop(); + return res; + } + pub fn push(self: *PrefixBuilder, allocator: std.mem.Allocator, prefix: []const u8) !void { const old_len: u32 = @intCast(self.data.items.len); try self.subprefixes.append(allocator, old_len); @@ -301,13 +436,20 @@ const PrefixBuilder = struct { const last_prefix_len = self.subprefixes.pop() orelse unreachable; self.data.shrinkRetainingCapacity(last_prefix_len); } + + pub fn checkpoint(self: PrefixBuilder) [2]usize { + return .{ self.data.items.len, self.subprefixes.items.len }; + } + + pub fn restore(self: *PrefixBuilder, ckpt: [2]usize) void { + self.data.items.len, self.subprefixes.items.len = ckpt; + } }; fn _populateStruct( allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, - unique_id: u64, - buffer_store: BufferStore, + store: BufferStore, obj: anytype, required: bool, ) !bool { @@ -322,17 +464,12 @@ fn _populateStruct( const prefix = prefix_builder.data.items; if (T == zml.Tensor) { - return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| { - const buffer = buffer_store.get(prefix).?; - obj.* = zml.Tensor{ - ._shape = buffer.shape(), - ._id = .{ .buffer_id = unique_id + entry_idx }, - ._donation = .input_buffer, - }; + return if (store.getTensorOrNull(prefix)) |tensor| { + obj.* = tensor; return true; } else { if (required) { - log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() }); + log.err("Tensor not found: {s} ({d})", .{ prefix, store.buffers.count() }); } return false; }; @@ -343,14 +480,14 @@ fn _populateStruct( if (ptr_info.size == .slice) { obj.* = &.{}; - const len = buffer_store.countLayers(prefix); + const len = store.countLayers(prefix); if (len > 0) { obj.* = try allocator.alloc(ptr_info.child, len); for (obj.*, 0..) |*value, i| { try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); - const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); + const found = try _populateStruct(allocator, prefix_builder, store, value, required); if (!found) { log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) }); return false; @@ -369,7 +506,7 @@ fn _populateStruct( for (obj, 0..) |*value, i| { try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); - const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); + const found = try _populateStruct(allocator, prefix_builder, store, value, required); if (!found) { log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(arr_info.child) }); return false; @@ -386,15 +523,20 @@ fn _populateStruct( var has_default = false; if (field.default_value_ptr) |_| has_default = true; - const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default); - partial_struct = partial_struct or field_found; - if (!field_found) { - if (field.default_value_ptr) |v| { - @field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*; - } else { - if (partial_struct) { - log.warn("Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name }); - obj.* = undefined; + + if (zml.meta.Contains(field.type, zml.Tensor)) { + const field_found = try _populateStruct(allocator, prefix_builder, store, &@field(obj, field.name), required and !has_default); + partial_struct = partial_struct or field_found; + if (!field_found) { + if (field.default_value_ptr) |v| { + @field(obj, field.name) = @as(*const field.type, @ptrCast(@alignCast(v))).*; + } else { + if (partial_struct) { + log.warn("Incomplete struct '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored.", .{ prefix, @typeName(T), field.name }); + obj.* = undefined; + return false; + } + return false; } @@ -406,7 +548,7 @@ fn _populateStruct( }, .optional => |opt_info| { obj.* = @as(opt_info.child, undefined); - const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false); + const found = try _populateStruct(allocator, prefix_builder, store, &(obj.*.?), false); if (!found) obj.* = null; return true; }, @@ -449,7 +591,7 @@ test populateModel { var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); defer arena_state.deinit(); - var store: BufferStore = .{ .arena = arena_state }; + var store: BufferStore = .init(arena_state.allocator()); try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16); store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10)); store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20)); @@ -494,11 +636,11 @@ test populateModel { pub fn loadBuffers( comptime Model: type, init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void, - buffer_store: BufferStore, + store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, ) !zml.Bufferized(Model) { - return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, ""); + return loadBuffersWithPrefix(Model, init_args, store, allocator, platform, ""); } /// Creates a bufferized version of a Model from the given BufferStore with a specified prefix. @@ -512,7 +654,7 @@ pub fn loadBuffers( pub fn loadBuffersWithPrefix( comptime Model: type, init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void, - buffer_store: BufferStore, + store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, prefix: []const u8, @@ -522,14 +664,14 @@ pub fn loadBuffersWithPrefix( const arena = arena_state.allocator(); // Get model structure with tensor shapes from the buffer store with prefix - var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix); + var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, store, prefix); // If the Model has a "init" function, call it with the given parameters. if (@hasDecl(Model, "init")) { @call(.auto, Model.init, .{&model} ++ init_args); } - return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix); + return loadModelBuffersWithPrefix(Model, model, store, allocator, platform, prefix); } /// Creates a bufferized version of a Model from the given BufferStore. For details about @@ -541,11 +683,11 @@ pub fn loadBuffersWithPrefix( pub fn loadModelBuffers( comptime Model: type, model: Model, - buffer_store: BufferStore, + store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, ) !zml.Bufferized(Model) { - return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); + return try loadModelBuffersWithPrefix(Model, model, store, allocator, platform, ""); } /// Creates a bufferized version of a Model from the given BufferStore and the given prefix. @@ -557,7 +699,7 @@ pub fn loadModelBuffers( pub fn loadModelBuffersWithPrefix( comptime Model: type, model: Model, - buffer_store: BufferStore, + store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, prefix: []const u8, @@ -576,7 +718,7 @@ pub fn loadModelBuffersWithPrefix( try prefix_builder.push(allocator, prefix); defer prefix_builder.deinit(allocator); - try visitStructAndLoadBuffer(allocator, &prefix_builder, buffer_store, &res, platform); + try visitStructAndLoadBuffer(allocator, &prefix_builder, store, &res, platform); return res; } @@ -590,58 +732,6 @@ pub fn unloadBuffers(model: anytype) void { }).cb, {}, model); } -/// Assists in debuggigng `BufferNotFound` error -/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos) -fn findSimilarBufferKeys(original_key: []const u8, store: BufferStore, temp_allocator: std.mem.Allocator) void { - const suffixes = [_][]const u8{ "", ".weight", ".bias" }; - var shown_keys = std.StringHashMap(void).init(temp_allocator); - defer shown_keys.deinit(); - - // remove suffix .weight and .bias - var base_key = original_key; - for (suffixes) |suffix| { - if (std.mem.endsWith(u8, original_key, suffix)) { - base_key = original_key[0 .. original_key.len - suffix.len]; - break; - } - } - - // first test: look for exact matches - var matches: usize = 0; - var it = store.buffers.iterator(); - while (it.next()) |entry| { - const key = entry.key_ptr.*; - if (std.mem.startsWith(u8, key, base_key)) { - if (matches == 0) log.warn("Similar buffers found:", .{}); - if (!shown_keys.contains(key)) { - log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); - shown_keys.put(key, {}) catch continue; - matches += 1; - } - } - } - - // second test: progressive partial matches - if (matches == 0) { - var components = std.mem.splitScalar(u8, base_key, '.'); - while (components.next()) |component| { - matches = 0; - it = store.buffers.iterator(); - while (it.next()) |entry| { - const key = entry.key_ptr.*; - if (std.mem.indexOf(u8, key, component) != null and !shown_keys.contains(key)) { - if (matches == 0) log.warn("Partial matches for '{s}':", .{component}); - log.warn(" - {s}: {f}", .{ key, entry.value_ptr.*.shape() }); - shown_keys.put(key, {}) catch continue; - matches += 1; - if (matches >= 5) break; - } - } - if (matches > 0) break; - } - } -} - /// deinit all buffers in the given struct pub fn awaitAll(buffers: anytype) !void { zml.meta.visit((struct { @@ -651,7 +741,7 @@ pub fn awaitAll(buffers: anytype) !void { }).cb, {}, buffers); } -fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void { +fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, store: BufferStore, obj: anytype, platform: zml.Platform) !void { const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received "; const type_info, const T = switch (@typeInfo(@TypeOf(obj))) { .pointer => |ptr_info| switch (ptr_info.size) { @@ -663,7 +753,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi const prefix = prefix_builder.data.items; if (T == zml.Buffer) { - return if (buffer_store.get(prefix)) |host_buffer| { + return if (store.get(prefix)) |host_buffer| { // obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us. var buf_with_metadata = host_buffer; log.debug("Loading buffer {s} ({f})", .{ prefix, obj._shape }); @@ -673,7 +763,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi } else { log.err("Buffer not found: {s}", .{prefix}); - findSimilarBufferKeys(prefix, buffer_store, allocator); + store.findSimilarBufferKeys(allocator, prefix); return error.BufferNotFound; }; @@ -686,7 +776,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); - try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform); + try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform); } } else stdx.debug.compileError("type not supported by visitStructAndLoadBuffer: {}", .{T}); }, @@ -694,7 +784,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi for (obj, 0..) |*value, i| { try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); - try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, value, platform); + try visitStructAndLoadBuffer(allocator, prefix_builder, store, value, platform); } }, @@ -704,12 +794,12 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi try prefix_builder.push(allocator, field.name); defer prefix_builder.pop(); - try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform); + try visitStructAndLoadBuffer(allocator, prefix_builder, store, &@field(obj, field.name), platform); } }, .optional => { if (obj.*) |*obj_val| { - try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, obj_val, platform); + try visitStructAndLoadBuffer(allocator, prefix_builder, store, obj_val, platform); } }, else => {}, diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index 08952b8..49ba5fc 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -13,9 +13,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8); const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { - var res: zml.aio.BufferStore = .{ - .arena = std.heap.ArenaAllocator.init(allocator), - }; + var res: zml.aio.BufferStore = .init(allocator); errdefer res.arena.deinit(); const arena = res.arena.allocator(); diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index 8114c90..c51319d 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -25,9 +25,7 @@ const TinyLlamaConfig = extern struct { /// For convenience we use the same layer names /// than the one used by the Llama-3.1 models. pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.BufferStore { - var res: zml.aio.BufferStore = .{ - .arena = std.heap.ArenaAllocator.init(allocator), - }; + var res: zml.aio.BufferStore = .init(allocator); errdefer res.arena.deinit(); const arena = res.arena.allocator(); diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index 79487f6..3f1511d 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -36,7 +36,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore const py_values = try eval.evaluate(tmp_alloc, ops, true); // file ownership is transferred to the BufferStore - var res = try zml.aio.BufferStore.init(allocator, &.{torch_file.mmap_file}); + var res = try zml.aio.BufferStore.initWithFiles(allocator, &.{mmap_file}); try torch_file.parseModel(py_values, &res); return res; } diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig index 0307d51..fe8811f 100644 --- a/zml/aio/torch/file.zig +++ b/zml/aio/torch/file.zig @@ -480,7 +480,7 @@ test "Read pickle (zipped)" { // torch.save({ "model": model, "tensor": tensor}, "simple.pt") const file = try asynk.File.open("zml/aio/torch/simple.pt", .{ .mode = .read_only }); const mmap_file = try zml.aio.MemoryMappedFile.init(file); - var store = try zml.aio.BufferStore.init(testing.allocator, &.{mmap_file}); + var store = try zml.aio.BufferStore.initWithFiles(testing.allocator, &.{mmap_file}); defer store.deinit(); { diff --git a/zml/exe.zig b/zml/exe.zig index d4c426b..dffc16c 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -77,7 +77,7 @@ pub fn compileModel( ) !FnExe(func) { const ModelT = ModuleSignature(func).ModelT; const name = @typeName(ModelT) ++ ".forward"; - log.info("Compiling {s} with {}", .{ name, args_shapes }); + log.info("Compiling {s} with {f}", .{ name, stdx.fmt.any(args_shapes) }); var context = try CompilationContext.init(allocator, name, platform); defer context.deinit(); @@ -209,9 +209,9 @@ pub const BaseExe = struct { var execute_context: ?*pjrt.ExecuteContext = null; if (platform.pjrt_api.ffi()) |ffi| { - log.info("Created context execution {*} for {*}", .{ execute_context, exe }); execute_context = try platform.pjrt_api.createExecuteContext(); try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?); + // log.info("Created context execution {*} for {*}", .{ execute_context, exe }); } return .{ diff --git a/zml/meta.zig b/zml/meta.zig index ee2fbc6..855318a 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -141,6 +141,7 @@ test MapType { /// Any `To` struct inside `from` will be copied over to the target. pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void { // TODO: handle tuple to slice conversion + // TODO: handle error bubbling up const From = FnParam(cb, 1); const To = stdx.meta.FnResult(cb); const FromStruct = @TypeOf(from); diff --git a/zml/module.zig b/zml/module.zig index 6263d3f..1b63ac7 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1129,7 +1129,7 @@ pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void { hash(hasher, shape._dtype, .Shallow); hash(hasher, shape._sharding_info, .Shallow); for (shape.tags()) |tag| { - hash(hasher, @intFromPtr(tag), .Shallow); + hashArray(hasher, std.mem.span(tag), .Shallow); } } diff --git a/zml/platform.zig b/zml/platform.zig index 0bc792e..6e13d7f 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -82,9 +82,7 @@ pub const Platform = struct { }; const _CreateOptions = struct { - // XLA CPU client doesn't read options - // https://github.com/openxla/xla/blob/42496a28c374bd35f493cc5dbde74805407245dc/xla/pjrt/c/pjrt_c_api_cpu_internal.cc#L33-L46 - cpu: struct {} = .{}, + cpu: Cpu = .{ .device_count = 4 }, // bump memory fraction from XLA defaults of 75% to 90%. // Even on a 8GB GPU it should leave enough space for the Cuda driver @@ -94,6 +92,14 @@ const _CreateOptions = struct { tpu: struct {} = .{}, neuron: struct {} = .{}, + pub const Cpu = struct { + device_count: u32, + + fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { + values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count))); + } + }; + pub const Cuda = struct { allocator: Allocator = .{ .bfc = .{} }, // TODO support all of https://github.com/openxla/xla/blob/3d31c48c719d331d432132b3e0c2c5ce52650675/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L76-L86 @@ -118,7 +124,7 @@ const _CreateOptions = struct { }; }; - pub fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { + fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { switch (self.allocator) { .platform => { values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); @@ -142,6 +148,7 @@ const _CreateOptions = struct { var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out); values.shrinkRetainingCapacity(0); switch (target) { + .cpu => self.cpu.writeNamedValues(&values), .cuda => self.cuda.writeNamedValues(&values), inline else => |t| { stdx.debug.assertComptime(@hasField(_CreateOptions, @tagName(t)), "zml.platform.CreateOptions doesn't list target {s}", .{@tagName(t)}); diff --git a/zml/tensor.zig b/zml/tensor.zig index 63ab44d..9d55ebf 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -269,18 +269,6 @@ pub const Tensor = struct { return res; } - var _global_tensor_counter: u64 = 0; - - /// Internal use - pub fn _reserveIdRange(len: u32) u64 { - return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst); - } - - /// Internal use - pub fn setUniqueId(self: *Tensor) void { - self._id = .{ .buffer_id = _reserveIdRange(1) }; - } - /// Returns a Tensor containing the absolute value of each element of the input Tensor. pub fn abs(self: Tensor) Tensor { const loc = self.getContext().mlirCtx().location(@src()); diff --git a/zml/testing.zig b/zml/testing.zig index 5fec596..59778ea 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -298,7 +298,7 @@ test testLayer { }; // create a buffer store containing the activations: - var activations = try zml.aio.BufferStore.init(std.testing.allocator, &.{}); + var activations = zml.aio.BufferStore.init(std.testing.allocator); defer activations.deinit(); { const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });