2024-10-28 11:21:46 +00:00
const std = @import ( " std " ) ;
const asynk = @import ( " async " ) ;
2023-06-21 14:45:14 +00:00
const c = @import ( " c " ) ;
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
pub const safetensors = @import ( " aio/safetensors.zig " ) ;
2025-08-07 15:09:27 +00:00
pub const torch = @import ( " aio/torch.zig " ) ;
2023-01-02 14:28:25 +00:00
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
2024-10-28 11:21:46 +00:00
const posix = @import ( " posix.zig " ) ;
const zml = @import ( " zml.zig " ) ;
2023-01-02 14:28:25 +00:00
2024-10-28 11:21:46 +00:00
pub const log = std . log . scoped ( . @ " zml/aio " ) ;
2023-01-27 14:35:11 +00:00
test {
std . testing . refAllDecls ( @This ( ) ) ;
std . testing . refAllDecls ( safetensors ) ;
2025-08-07 15:09:27 +00:00
std . testing . refAllDecls ( torch ) ;
2023-01-27 14:35:11 +00:00
}
2024-10-28 11:21:46 +00:00
// TODO error set for weight loading
2023-01-02 14:28:25 +00:00
/// Detects the format of the model file (base on filename) and open it.
pub fn detectFormatAndOpen ( allocator : std . mem . Allocator , model_path : [ ] const u8 ) ! BufferStore {
return if ( std . mem . endsWith ( u8 , model_path , " .safetensors " ) )
try safetensors . open ( allocator , model_path )
else if ( std . mem . endsWith ( u8 , model_path , " .safetensors.index.json " ) )
try safetensors . open ( allocator , model_path )
2025-08-07 15:09:27 +00:00
else if ( std . mem . endsWith ( u8 , model_path , " .pt " ) )
try torch . open ( allocator , model_path )
// else if (std.mem.endsWith(u8, model_path, ".gguf"))
// try gguf.open(allocator, model_path)
// else if (std.mem.endsWith(u8, model_path, ".tinyllama"))
// try tinyllama.open(allocator, model_path)
2023-01-02 14:28:25 +00:00
else {
std . debug . panic ( " File extension not recognized: {s} " , . { model_path } ) ;
} ;
}
/// Creates a Model struct with tensor shapes read from the given BufferStore.
/// The result can be used to pass to `compileModel`.
///
/// * The `Tensor` field `Model.a.b` will be populated with a `Tensor`
/// whose shape is read from the "a.b" tensor.
/// * If `Model` contains a list of layers, then the field:
/// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor.
2025-08-28 14:39:21 +00:00
pub fn populateModel ( comptime Model : type , allocator : std . mem . Allocator , store : BufferStore ) ! Model {
return populateModelWithPrefix ( Model , allocator , store , " " ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a Model struct with tensor shapes read from the given TensorStore,
/// using a given prefix.
/// The result can be used to pass to `compileWithModel`.
///
/// * The `Tensor` field `Model.a.b` will be populated with a `Tensor`
/// whose shape is read from the "prefix.a.b" tensor.
/// * If `Model` contains a list of layers, then the field:
/// `Model.layers[2].a.b` will be populated from the "prefix.layers.2.a.b" tensor.
pub fn populateModelWithPrefix ( comptime Model : type , allocator : std . mem . Allocator , store : BufferStore , prefix : [ ] const u8 ) ! Model {
var model : Model = undefined ;
var prefix_builder : PrefixBuilder = . { } ;
try prefix_builder . push ( allocator , prefix ) ;
defer prefix_builder . deinit ( allocator ) ;
2025-08-28 14:39:21 +00:00
const ok = _populateStruct ( allocator , & prefix_builder , store , & model , true ) catch | err | {
2023-01-02 14:28:25 +00:00
std . debug . panic ( " Can't populate model of type {s}: {s} " , . { @typeName ( type ) , @errorName ( err ) } ) ;
} ;
if ( ! ok ) return error . TensorNotFound ;
return model ;
}
/// A struct containing all the buffers and metadata found in a model file.
pub const BufferStore = struct {
pub const Buffers = std . StringArrayHashMapUnmanaged ( HostBuffer ) ;
2023-04-07 16:45:58 +00:00
pub const Metadatas = std . StringArrayHashMapUnmanaged ( Metadata ) ;
2025-08-28 14:39:21 +00:00
var _unique_store_id : std . atomic . Value ( u64 ) = . init ( 0 ) ;
const _store_id_range : u64 = 1024 * 1024 * 1024 ;
2023-01-02 14:28:25 +00:00
arena : std . heap . ArenaAllocator ,
files : [ ] MemoryMappedFile = & . { } ,
buffers : Buffers = . { } ,
2023-04-07 16:45:58 +00:00
_metadata : Metadatas = . { } ,
2025-08-28 14:39:21 +00:00
_unique_id : u64 ,
2023-01-02 14:28:25 +00:00
2025-08-28 14:39:21 +00:00
/// Create an empty BufferStore.
/// Takes owneship of the given files.
pub fn init ( allocator : std . mem . Allocator ) BufferStore {
return . {
2023-04-20 15:43:18 +00:00
. arena = std . heap . ArenaAllocator . init ( allocator ) ,
2025-08-28 14:39:21 +00:00
. _unique_id = _unique_store_id . fetchAdd ( _store_id_range , . monotonic ) ,
2023-04-20 15:43:18 +00:00
} ;
2025-08-28 14:39:21 +00:00
}
/// Create an empty BufferStore.
/// Takes owneship of the given files.
pub fn initWithFiles ( allocator : std . mem . Allocator , files : [ ] const MemoryMappedFile ) error { OutOfMemory } ! BufferStore {
var self : BufferStore = . init ( allocator ) ;
2023-04-20 15:43:18 +00:00
self . files = try self . arena . allocator ( ) . dupe ( MemoryMappedFile , files ) ;
return self ;
}
2023-01-02 14:28:25 +00:00
pub fn deinit ( self : BufferStore ) void {
2024-01-16 14:13:45 +00:00
for ( self . files ) | * file | {
file . deinit ( ) ;
}
2023-01-02 14:28:25 +00:00
self . arena . deinit ( ) ;
}
pub fn get ( self : BufferStore , key : [ ] const u8 ) ? HostBuffer {
return self . buffers . get ( key ) ;
}
2025-08-28 14:39:21 +00:00
pub fn loadBufferById ( self : BufferStore , x : zml . Tensor , platform : zml . Platform ) ! zml . Buffer {
var host_buffer : zml . HostBuffer = switch ( x . _id ) {
. buffer_id = > | id | hb : {
if ( id < self . _unique_id or self . _unique_id + _store_id_range < = id ) {
@panic ( " `store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object. " ) ;
}
break : hb self . buffers . values ( ) [ id - self . _unique_id ] ;
} ,
else = > @panic ( " `store.loadBufferById()` only works on Tensor created by `store.getTensor()` " ) ,
} ;
// Use the sharding information stored in the tensor.
host_buffer . _shape = x . shape ( ) ;
return try host_buffer . toDevice ( platform ) ;
}
/// Creates a bufferized version of a model from the given BufferStore.
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadModelById ( self : BufferStore , Model : type , allocator : std . mem . Allocator , model : Model , platform : zml . Platform ) ! zml . Bufferized ( Model ) {
const Ctx = struct {
platform : * const zml . Platform ,
store : * const BufferStore ,
pub fn cb ( ctx : @This ( ) , x : zml . Tensor ) zml . Buffer {
return ctx . store . loadBufferById ( x , ctx . platform . * ) catch @panic ( " Failed to load buffer to device " ) ;
}
} ;
var res : zml . Bufferized ( Model ) = undefined ;
try zml . meta . mapAlloc ( Ctx . cb , allocator , . { . platform = & platform , . store = & self } , model , & res ) ;
return res ;
}
pub fn getTensor ( self : BufferStore , key : [ ] const u8 ) zml . Tensor {
return self . getTensorOrNull ( key ) orelse {
log . err ( " Tensor not found: {s} " , . { key } ) ;
self . findSimilarBufferKeys ( std . heap . smp_allocator , key ) ;
@panic ( " Tensor not found " ) ;
} ;
}
pub fn getTensorOrNull ( self : BufferStore , key : [ ] const u8 ) ? zml . Tensor {
return if ( self . buffers . getIndex ( key ) ) | entry_idx |
. {
. _shape = self . buffers . values ( ) [ entry_idx ] . shape ( ) ,
. _id = . { . buffer_id = self . _unique_id + entry_idx } ,
. _donation = . input_buffer ,
}
else
return null ;
}
2023-01-02 14:28:25 +00:00
/// Count layers starting with the given prefix.
pub fn countLayers ( self : BufferStore , prefix : [ ] const u8 ) usize {
// Note: This is kinda inefficient
const digit_start_index = prefix . len + 1 ;
var it = self . buffers . iterator ( ) ;
var maybe_max_index : ? usize = null ;
while ( it . next ( ) ) | entry | {
if ( ! std . mem . startsWith ( u8 , entry . key_ptr . * , prefix ) ) continue ;
const next_dot_index = std . mem . indexOfScalarPos ( u8 , entry . key_ptr . * , digit_start_index , '.' ) orelse entry . key_ptr . len ;
const index = std . fmt . parseInt ( usize , entry . key_ptr . * [ digit_start_index . . next_dot_index ] , 10 ) catch continue ;
if ( maybe_max_index ) | * max_index | {
max_index . * = @max ( max_index . * , index ) ;
} else {
maybe_max_index = index ;
}
}
return if ( maybe_max_index ) | index | index + 1 else 0 ;
}
2023-04-07 16:45:58 +00:00
pub fn metadata ( self : BufferStore , key : [ ] const u8 , comptime tag : std . meta . FieldEnum ( Metadata ) ) ? std . meta . FieldType ( Metadata , tag ) {
2023-01-02 14:28:25 +00:00
const wrapped_value = self . _metadata . get ( key ) orelse return null ;
if ( wrapped_value ! = tag ) {
zml . log . err ( " Tried to interpret metadata '{s}' as {}, but was of type {} " , . { key , tag , wrapped_value } ) ;
@panic ( " invalid metadata type " ) ;
}
return @field ( wrapped_value , @tagName ( tag ) ) ;
}
2023-04-07 16:45:58 +00:00
pub fn metadataSlice ( self : BufferStore , key : [ ] const u8 , comptime tag : Metadata . ItemType ) ? [ ] const tag . toZigType ( ) {
2023-01-02 14:28:25 +00:00
const wrapped_value = self . _metadata . get ( key ) orelse return null ;
2023-04-07 16:45:58 +00:00
const true_tag = std . meta . stringToEnum ( std . meta . FieldEnum ( Metadata ) , @tagName ( tag ) ) . ? ;
if ( wrapped_value = = true_tag ) {
return @field ( wrapped_value , " array_ " + + @tagName ( tag ) ) ;
}
return null ;
}
2025-08-28 14:39:21 +00:00
/// Assists in debuggigng `BufferNotFound` error
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
pub fn findSimilarBufferKeys ( store : BufferStore , tmp_alloc : std . mem . Allocator , original_key : [ ] const u8 ) void {
const suffixes = [ _ ] [ ] const u8 { " " , " .weight " , " .bias " } ;
var shown_keys = std . StringHashMap ( void ) . init ( tmp_alloc ) ;
defer shown_keys . deinit ( ) ;
// remove suffix .weight and .bias
var base_key = original_key ;
for ( suffixes ) | suffix | {
if ( std . mem . endsWith ( u8 , original_key , suffix ) ) {
base_key = original_key [ 0 . . original_key . len - suffix . len ] ;
break ;
}
}
// first test: look for exact matches
var matches : usize = 0 ;
var it = store . buffers . iterator ( ) ;
while ( it . next ( ) ) | entry | {
const key = entry . key_ptr . * ;
if ( std . mem . startsWith ( u8 , key , base_key ) ) {
if ( matches = = 0 ) log . warn ( " Similar buffers found: " , . { } ) ;
if ( ! shown_keys . contains ( key ) ) {
log . warn ( " - {s}: {f} " , . { key , entry . value_ptr . * . shape ( ) } ) ;
shown_keys . put ( key , { } ) catch continue ;
matches + = 1 ;
}
}
}
// second test: progressive partial matches
if ( matches = = 0 ) {
var components = std . mem . splitScalar ( u8 , base_key , '.' ) ;
while ( components . next ( ) ) | component | {
matches = 0 ;
it = store . buffers . iterator ( ) ;
while ( it . next ( ) ) | entry | {
const key = entry . key_ptr . * ;
if ( std . mem . indexOf ( u8 , key , component ) ! = null and ! shown_keys . contains ( key ) ) {
if ( matches = = 0 ) log . warn ( " Partial matches for '{s}': " , . { component } ) ;
log . warn ( " - {s}: {f} " , . { key , entry . value_ptr . * . shape ( ) } ) ;
shown_keys . put ( key , { } ) catch continue ;
matches + = 1 ;
if ( matches > = 5 ) break ;
}
}
if ( matches > 0 ) break ;
}
}
}
2023-04-07 16:45:58 +00:00
} ;
2023-01-02 14:28:25 +00:00
2023-04-07 16:45:58 +00:00
pub const Metadata = union ( enum ) {
null : void ,
int : i64 ,
float : f64 ,
bool : bool ,
string : [ ] const u8 ,
array_bool : [ ] const bool ,
array_int : [ ] const i64 ,
array_float : [ ] const f64 ,
array_string : [ ] const [ ] const u8 ,
pub const ItemType = enum {
int ,
float ,
bool ,
string ,
pub fn toZigType ( comptime kind : ItemType ) type {
return switch ( kind ) {
. int = > i64 ,
. float = > f64 ,
. bool = > bool ,
. string = > [ ] const u8 ,
} ;
}
} ;
pub fn wrap ( x : anytype ) Metadata {
return switch ( @TypeOf ( x ) ) {
inline u8 , i8 , u16 , i16 , u32 , i32 , u64 , i64 = > . { . int = @intCast ( x ) } ,
inline f16 , f32 , f64 = > . { . float = @floatCast ( x ) } ,
bool = > . { . bool = x } ,
[ ] const u8 = > . { . string = x } ,
else = > @panic ( " Unsupported type for zml.aio.Value: " + + @typeName ( @TypeOf ( x ) ) ) ,
} ;
}
pub fn copySlice ( allocator : std . mem . Allocator , any_slice : anytype ) ! Metadata {
return switch ( @TypeOf ( any_slice [ 0 ] ) ) {
inline u8 , i8 , u16 , i16 , u32 , i32 , u64 , i64 = > {
const res = try allocator . alloc ( i64 , any_slice . len ) ;
for ( res , any_slice ) | * r , val | r . * = @intCast ( val ) ;
return . { . array_int = res } ;
} ,
inline f16 , f32 , f64 = > {
const res = try allocator . alloc ( f64 , any_slice . len ) ;
for ( res , any_slice ) | * r , val | r . * = @floatCast ( val ) ;
return . { . array_float = res } ;
} ,
bool = > . { . array_bool = try allocator . dupe ( bool , any_slice ) } ,
[ ] const u8 = > . { . array_string = try allocator . dupe ( [ ] const u8 , @alignCast ( any_slice ) ) } ,
else = > @panic ( " Unsupported type for zml.aio.Value: " + + @typeName ( @TypeOf ( any_slice ) ) ) ,
} ;
}
pub fn format (
self : Metadata ,
comptime fmt : [ ] const u8 ,
options : std . fmt . FormatOptions ,
writer : anytype ,
) ! void {
_ = fmt ;
_ = options ;
switch ( self ) {
. null = > _ = try writer . write ( " null " ) ,
inline . bool , . array_bool = > | b | try writer . print ( " {any} " , . { b } ) ,
inline else = > | v | try writer . print ( " {d} " , . { v } ) ,
2023-01-02 14:28:25 +00:00
}
}
} ;
/// A file containing contiguous/non-contiguous buffers, that can be read with mmap
/// (assumes contiguous if `strides` is `null`).
/// This struct is meant to be wrapped into a format specific struct, like io.gguf.File.
pub const MemoryMappedFile = struct {
/// underlying file handle
file : asynk . File ,
2024-07-02 14:19:04 +00:00
data : [ ] align ( std . heap . page_size_min ) const u8 ,
2023-01-02 14:28:25 +00:00
data_offset : u64 = 0 ,
pub fn init ( file : asynk . File ) ! MemoryMappedFile {
const data_len : usize = ( try file . stat ( ) ) . size ;
2023-05-04 14:44:12 +00:00
const data_ = try asynk . callBlocking ( std . posix . mmap , . {
2023-01-02 14:28:25 +00:00
null ,
data_len ,
std . posix . PROT . READ ,
2024-07-02 14:19:04 +00:00
std . posix . system . MAP { . TYPE = . PRIVATE } ,
2023-05-25 16:02:11 +00:00
file . handle ( ) ,
2023-01-02 14:28:25 +00:00
0 ,
} ) ;
2023-06-21 14:45:14 +00:00
try asynk . callBlocking ( posix . madvise , . {
data_ . ptr ,
@as ( usize , @intCast ( data_ . len ) ) ,
@as ( u32 , @intCast ( c . MADV_SEQUENTIAL ) ) ,
} ) ;
2023-01-02 14:28:25 +00:00
return . {
. file = file ,
. data = data_ ,
} ;
}
2023-04-20 15:43:18 +00:00
pub fn mappedSlice ( self : MemoryMappedFile , start : usize , len : usize ) [ ] const u8 {
2023-01-02 14:28:25 +00:00
return self . data [ self . data_offset + start . . ] [ 0 . . len ] ;
}
pub fn deinit ( self : * MemoryMappedFile ) void {
std . posix . munmap ( self . data ) ;
self . file . close ( ) catch @panic ( " failed to close file " ) ;
self . * = undefined ;
}
} ;
/// Helper handling prefix building.
///
/// This allows to easily push/pop prefixes and handles the generation of the string with the correct format.
2025-08-28 14:39:21 +00:00
pub const PrefixBuilder = struct {
2023-01-02 14:28:25 +00:00
/// Stores the computed prefix.
2025-08-28 14:39:21 +00:00
data : std . ArrayList ( u8 ) = . { } ,
2023-01-02 14:28:25 +00:00
/// Stack storing the size of the intermediary prefix.
2025-08-28 14:39:21 +00:00
subprefixes : std . ArrayList ( u32 ) = . { } ,
pub fn initCapacity ( allocator : std . mem . Allocator , capacity : usize ) ! PrefixBuilder {
return . {
. data = try . initCapacity ( allocator , capacity ) ,
. subprefixes = try . initCapacity ( allocator , @divFloor ( capacity , 4 ) ) ,
} ;
}
2023-01-02 14:28:25 +00:00
pub fn deinit ( self : * PrefixBuilder , allocator : std . mem . Allocator ) void {
self . data . deinit ( allocator ) ;
self . subprefixes . deinit ( allocator ) ;
}
2025-08-28 14:39:21 +00:00
pub fn items ( self : PrefixBuilder ) [ ] const u8 {
return self . data . items ;
}
pub fn concat ( self : * PrefixBuilder , prefix : [ ] const u8 ) [ ] const u8 {
self . push ( stdx . noalloc , prefix ) catch unreachable ;
const res = self . items ( ) ;
self . pop ( ) ;
return res ;
}
2023-01-02 14:28:25 +00:00
pub fn push ( self : * PrefixBuilder , allocator : std . mem . Allocator , prefix : [ ] const u8 ) ! void {
const old_len : u32 = @intCast ( self . data . items . len ) ;
try self . subprefixes . append ( allocator , old_len ) ;
errdefer _ = self . subprefixes . pop ( ) ;
if ( old_len = = 0 ) {
try self . data . appendSlice ( allocator , prefix ) ;
} else {
try self . data . ensureUnusedCapacity ( allocator , prefix . len + 1 ) ;
self . data . appendAssumeCapacity ( '.' ) ;
self . data . appendSliceAssumeCapacity ( prefix ) ;
}
}
pub fn pushDigit ( self : * PrefixBuilder , allocator : std . mem . Allocator , idx : usize ) ! void {
const old_len : u32 = @intCast ( self . data . items . len ) ;
try self . subprefixes . append ( allocator , old_len ) ;
errdefer _ = self . subprefixes . pop ( ) ;
try self . data . ensureUnusedCapacity ( allocator , 16 ) ;
if ( old_len > 0 ) {
self . data . appendAssumeCapacity ( '.' ) ;
}
try self . data . writer ( allocator ) . print ( " {d} " , . { idx } ) ;
}
pub fn pop ( self : * PrefixBuilder ) void {
2024-07-02 14:19:04 +00:00
const last_prefix_len = self . subprefixes . pop ( ) orelse unreachable ;
2023-01-02 14:28:25 +00:00
self . data . shrinkRetainingCapacity ( last_prefix_len ) ;
}
2025-08-28 14:39:21 +00:00
pub fn checkpoint ( self : PrefixBuilder ) [ 2 ] usize {
return . { self . data . items . len , self . subprefixes . items . len } ;
}
pub fn restore ( self : * PrefixBuilder , ckpt : [ 2 ] usize ) void {
self . data . items . len , self . subprefixes . items . len = ckpt ;
}
2023-01-02 14:28:25 +00:00
} ;
fn _populateStruct (
allocator : std . mem . Allocator ,
prefix_builder : * PrefixBuilder ,
2025-08-28 14:39:21 +00:00
store : BufferStore ,
2023-01-02 14:28:25 +00:00
obj : anytype ,
required : bool ,
) ! bool {
const err_msg = " _populateStruct must be called with a pointer to type. Received " ;
const type_info , const T = switch ( @typeInfo ( @TypeOf ( obj ) ) ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | switch ( ptr_info . size ) {
. one = > . { @typeInfo ( ptr_info . child ) , ptr_info . child } ,
2023-01-02 14:28:25 +00:00
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ,
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ;
const prefix = prefix_builder . data . items ;
if ( T = = zml . Tensor ) {
2025-08-28 14:39:21 +00:00
return if ( store . getTensorOrNull ( prefix ) ) | tensor | {
obj . * = tensor ;
2023-01-02 14:28:25 +00:00
return true ;
} else {
if ( required ) {
2025-08-28 14:39:21 +00:00
log . err ( " Tensor not found: {s} ({d}) " , . { prefix , store . buffers . count ( ) } ) ;
2023-01-02 14:28:25 +00:00
}
return false ;
} ;
}
2023-03-31 14:23:45 +00:00
return switch ( type_info ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | {
if ( ptr_info . size = = . slice ) {
2023-01-02 14:28:25 +00:00
obj . * = & . { } ;
2025-08-28 14:39:21 +00:00
const len = store . countLayers ( prefix ) ;
2023-01-02 14:28:25 +00:00
if ( len > 0 ) {
obj . * = try allocator . alloc ( ptr_info . child , len ) ;
for ( obj . * , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
2025-08-28 14:39:21 +00:00
const found = try _populateStruct ( allocator , prefix_builder , store , value , required ) ;
2023-01-02 14:28:25 +00:00
if ( ! found ) {
2023-04-07 16:45:58 +00:00
log . err ( " Not able to load {s} as {s} " , . { prefix_builder . data . items , @typeName ( ptr_info . child ) } ) ;
2023-01-02 14:28:25 +00:00
return false ;
}
}
} else if ( required ) {
log . warn ( " No layer found at {s} " , . { prefix } ) ;
}
return true ;
} else {
2023-04-07 16:45:58 +00:00
log . err ( " {s} - {s}: {s} type not supported " , . { @src ( ) . fn_name , prefix , @typeName ( T ) } ) ;
2023-01-02 14:28:25 +00:00
return false ;
}
} ,
2024-07-02 14:19:04 +00:00
. array = > | arr_info | {
2023-06-19 15:29:29 +00:00
for ( obj , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
2025-08-28 14:39:21 +00:00
const found = try _populateStruct ( allocator , prefix_builder , store , value , required ) ;
2023-06-19 15:29:29 +00:00
if ( ! found ) {
log . err ( " Not able to load {s} as {s} " , . { prefix_builder . data . items , @typeName ( arr_info . child ) } ) ;
return false ;
}
}
return true ;
} ,
2024-07-02 14:19:04 +00:00
. @ " struct " = > | struct_info | {
2023-01-02 14:28:25 +00:00
var partial_struct = false ;
inline for ( struct_info . fields ) | field | {
2023-03-31 14:23:45 +00:00
if ( field . is_comptime or @sizeOf ( field . type ) = = 0 ) continue ;
2023-01-02 14:28:25 +00:00
try prefix_builder . push ( allocator , field . name ) ;
defer prefix_builder . pop ( ) ;
var has_default = false ;
2024-07-02 14:19:04 +00:00
if ( field . default_value_ptr ) | _ | has_default = true ;
2025-08-28 14:39:21 +00:00
if ( zml . meta . Contains ( field . type , zml . Tensor ) ) {
const field_found = try _populateStruct ( allocator , prefix_builder , store , & @field ( obj , field . name ) , required and ! has_default ) ;
partial_struct = partial_struct or field_found ;
if ( ! field_found ) {
if ( field . default_value_ptr ) | v | {
@field ( obj , field . name ) = @as ( * const field . type , @ptrCast ( @alignCast ( v ) ) ) . * ;
} else {
if ( partial_struct ) {
log . warn ( " Incomplete struct '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored. " , . { prefix , @typeName ( T ) , field . name } ) ;
obj . * = undefined ;
return false ;
}
2023-01-02 14:28:25 +00:00
return false ;
}
return false ;
}
}
}
return true ;
} ,
2024-07-02 14:19:04 +00:00
. optional = > | opt_info | {
2023-01-02 14:28:25 +00:00
obj . * = @as ( opt_info . child , undefined ) ;
2025-08-28 14:39:21 +00:00
const found = try _populateStruct ( allocator , prefix_builder , store , & ( obj . * . ? ) , false ) ;
2023-01-02 14:28:25 +00:00
if ( ! found ) obj . * = null ;
return true ;
} ,
2024-07-02 14:19:04 +00:00
. int = > {
2023-01-02 14:28:25 +00:00
obj . * = undefined ;
return true ;
} ,
2024-07-02 14:19:04 +00:00
. float = > {
2024-10-28 11:21:46 +00:00
obj . * = std . math . nan ( @TypeOf ( obj . * ) ) ;
2023-01-02 14:28:25 +00:00
return true ;
} ,
2024-07-02 14:19:04 +00:00
. void = > true ,
. @ " union " = > true ,
2024-09-02 14:11:47 +00:00
. bool = > {
obj . * = undefined ;
return true ;
} ,
2023-01-02 14:28:25 +00:00
else = > if ( required ) {
2023-04-07 16:45:58 +00:00
log . err ( " {s}: {s} type not supported " , . { prefix , @typeName ( T ) } ) ;
2023-01-02 14:28:25 +00:00
return error . UnsupportedMetadataType ;
} else return false ,
2023-03-31 14:23:45 +00:00
} ;
}
test populateModel {
const Model = struct {
a : zml . Tensor ,
b : struct { a : zml . Tensor , b : u32 } ,
c : [ ] zml . Tensor ,
d : [ ] struct { a : zml . Tensor , b : u32 } ,
e : struct { zml . Tensor , u32 , struct { a : u32 , b : zml . Tensor , c : void } } ,
f : ? zml . Tensor ,
g : ? zml . Tensor ,
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
fn _newHostBuffer ( n : u32 ) zml . HostBuffer {
2024-10-28 11:21:46 +00:00
return . { . _shape = zml . Shape . init ( . { n } , . f16 ) , . _strides = undefined , . _data = undefined } ;
2023-03-31 14:23:45 +00:00
}
} ;
var arena_state = std . heap . ArenaAllocator . init ( std . testing . allocator ) ;
defer arena_state . deinit ( ) ;
2025-08-28 14:39:21 +00:00
var store : BufferStore = . init ( arena_state . allocator ( ) ) ;
2023-03-31 14:23:45 +00:00
try store . buffers . ensureUnusedCapacity ( arena_state . allocator ( ) , 16 ) ;
store . buffers . putAssumeCapacity ( " a " , Model . _newHostBuffer ( 10 ) ) ;
store . buffers . putAssumeCapacity ( " b.a " , Model . _newHostBuffer ( 20 ) ) ;
store . buffers . putAssumeCapacity ( " c.0 " , Model . _newHostBuffer ( 30 ) ) ;
store . buffers . putAssumeCapacity ( " c.1 " , Model . _newHostBuffer ( 31 ) ) ;
store . buffers . putAssumeCapacity ( " c.2 " , Model . _newHostBuffer ( 32 ) ) ;
store . buffers . putAssumeCapacity ( " d.0.a " , Model . _newHostBuffer ( 40 ) ) ;
store . buffers . putAssumeCapacity ( " d.1.a " , Model . _newHostBuffer ( 41 ) ) ;
store . buffers . putAssumeCapacity ( " d.2.a " , Model . _newHostBuffer ( 42 ) ) ;
store . buffers . putAssumeCapacity ( " e.0 " , Model . _newHostBuffer ( 50 ) ) ;
store . buffers . putAssumeCapacity ( " e.2.b " , Model . _newHostBuffer ( 51 ) ) ;
store . buffers . putAssumeCapacity ( " f " , Model . _newHostBuffer ( 60 ) ) ;
// no entry for g.
store . buffers . putAssumeCapacity ( " unused_entry " , Model . _newHostBuffer ( 1000 ) ) ;
const model = try populateModel ( Model , arena_state . allocator ( ) , store ) ;
try std . testing . expectEqual ( 10 , model . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 20 , model . b . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 3 , model . c . len ) ;
try std . testing . expectEqual ( 30 , model . c [ 0 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 31 , model . c [ 1 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 32 , model . c [ 2 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 3 , model . d . len ) ;
try std . testing . expectEqual ( 40 , model . d [ 0 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 41 , model . d [ 1 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 42 , model . d [ 2 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 50 , model . e [ 0 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 51 , model . e [ 2 ] . b . dim ( 0 ) ) ;
try std . testing . expectEqual ( 60 , model . f . ? . dim ( 0 ) ) ;
try std . testing . expectEqual ( null , model . g ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a bufferized version of a Model from the given BufferStore. For details about
/// bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
///
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
pub fn loadBuffers (
comptime Model : type ,
2024-10-28 11:21:46 +00:00
init_args : if ( @hasDecl ( Model , " init " ) ) stdx . meta . Tail ( stdx . meta . FnArgs ( Model . init ) ) else void ,
2025-08-28 14:39:21 +00:00
store : BufferStore ,
2023-01-02 14:28:25 +00:00
allocator : std . mem . Allocator ,
platform : zml . Platform ,
2025-02-12 13:18:27 +00:00
) ! zml . Bufferized ( Model ) {
2025-08-28 14:39:21 +00:00
return loadBuffersWithPrefix ( Model , init_args , store , allocator , platform , " " ) ;
2025-02-12 13:18:27 +00:00
}
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
///
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
pub fn loadBuffersWithPrefix (
comptime Model : type ,
2025-02-13 09:48:13 +00:00
init_args : if ( @hasDecl ( Model , " init " ) ) stdx . meta . Tail ( stdx . meta . FnArgs ( Model . init ) ) else void ,
2025-08-28 14:39:21 +00:00
store : BufferStore ,
2025-02-12 13:18:27 +00:00
allocator : std . mem . Allocator ,
platform : zml . Platform ,
prefix : [ ] const u8 ,
2023-01-02 14:28:25 +00:00
) ! zml . Bufferized ( Model ) {
var arena_state = std . heap . ArenaAllocator . init ( allocator ) ;
defer arena_state . deinit ( ) ;
const arena = arena_state . allocator ( ) ;
2025-02-12 13:18:27 +00:00
// Get model structure with tensor shapes from the buffer store with prefix
2025-08-28 14:39:21 +00:00
var model : Model = try zml . aio . populateModelWithPrefix ( Model , arena , store , prefix ) ;
2023-01-02 14:28:25 +00:00
// If the Model has a "init" function, call it with the given parameters.
if ( @hasDecl ( Model , " init " ) ) {
@call ( . auto , Model . init , . { & model } + + init_args ) ;
}
2025-08-28 14:39:21 +00:00
return loadModelBuffersWithPrefix ( Model , model , store , allocator , platform , prefix ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a bufferized version of a Model from the given BufferStore. For details about
/// bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadModelBuffers (
comptime Model : type ,
model : Model ,
2025-08-28 14:39:21 +00:00
store : BufferStore ,
2023-01-02 14:28:25 +00:00
allocator : std . mem . Allocator ,
platform : zml . Platform ,
) ! zml . Bufferized ( Model ) {
2025-08-28 14:39:21 +00:00
return try loadModelBuffersWithPrefix ( Model , model , store , allocator , platform , " " ) ;
2023-01-02 14:28:25 +00:00
}
2023-04-07 16:45:58 +00:00
2023-02-24 17:33:14 +00:00
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
2023-01-02 14:28:25 +00:00
pub fn loadModelBuffersWithPrefix (
comptime Model : type ,
model : Model ,
2025-08-28 14:39:21 +00:00
store : BufferStore ,
2023-01-02 14:28:25 +00:00
allocator : std . mem . Allocator ,
platform : zml . Platform ,
prefix : [ ] const u8 ,
) ! zml . Bufferized ( Model ) {
// Allocate the bufferized version.
2023-02-24 17:33:14 +00:00
// We copy the shape, and let visitStructAndLoadBuffer write the other fields.
2023-01-02 14:28:25 +00:00
// to write them just afterward.
var res : zml . Bufferized ( Model ) = undefined ;
try zml . meta . mapAlloc ( struct {
2023-02-24 17:33:14 +00:00
pub fn initBuffer ( _ : void , tensor : zml . Tensor ) zml . Buffer {
return . { . _shape = tensor . shape ( ) , . _api = undefined , . _shards = undefined } ;
2023-01-02 14:28:25 +00:00
}
} . initBuffer , allocator , { } , model , & res ) ;
var prefix_builder : PrefixBuilder = . { } ;
try prefix_builder . push ( allocator , prefix ) ;
defer prefix_builder . deinit ( allocator ) ;
2025-08-28 14:39:21 +00:00
try visitStructAndLoadBuffer ( allocator , & prefix_builder , store , & res , platform ) ;
2023-01-02 14:28:25 +00:00
return res ;
}
/// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the
/// Buffer found.
pub fn unloadBuffers ( model : anytype ) void {
zml . meta . visit ( ( struct {
fn cb ( _ : void , buffer : * zml . Buffer ) void {
buffer . deinit ( ) ;
}
} ) . cb , { } , model ) ;
}
2024-05-02 17:10:11 +00:00
/// deinit all buffers in the given struct
pub fn awaitAll ( buffers : anytype ) ! void {
2024-12-25 17:14:44 +00:00
zml . meta . visit ( ( struct {
fn cb ( _ : void , buffer : * zml . Buffer ) void {
buffer . * = buffer . awaitt ( ) catch unreachable ;
}
} ) . cb , { } , buffers ) ;
2024-05-02 17:10:11 +00:00
}
2025-08-28 14:39:21 +00:00
fn visitStructAndLoadBuffer ( allocator : std . mem . Allocator , prefix_builder : * PrefixBuilder , store : BufferStore , obj : anytype , platform : zml . Platform ) ! void {
2023-01-02 14:28:25 +00:00
const err_msg = " visitStructAndLoadBuffer must be called with a pointer to type. Received " ;
const type_info , const T = switch ( @typeInfo ( @TypeOf ( obj ) ) ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | switch ( ptr_info . size ) {
. one = > . { @typeInfo ( ptr_info . child ) , ptr_info . child } ,
2023-01-02 14:28:25 +00:00
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ,
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ;
const prefix = prefix_builder . data . items ;
if ( T = = zml . Buffer ) {
2025-08-28 14:39:21 +00:00
return if ( store . get ( prefix ) ) | host_buffer | {
2023-02-24 17:33:14 +00:00
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer ;
2025-07-28 13:54:28 +00:00
log . debug ( " Loading buffer {s} ({f}) " , . { prefix , obj . _shape } ) ;
stdx . debug . assert ( host_buffer . shape ( ) . eql ( obj . _shape ) , " loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {f} and {f} for tensor {s} " , . { obj . _shape , host_buffer , prefix } ) ;
2023-02-24 17:33:14 +00:00
buf_with_metadata . _shape = obj . _shape ;
2024-12-25 17:14:44 +00:00
obj . * = try zml . Buffer . from ( platform , buf_with_metadata , . { } ) ;
2023-01-02 14:28:25 +00:00
} else {
2024-06-14 15:27:06 +00:00
log . err ( " Buffer not found: {s} " , . { prefix } ) ;
2025-08-28 14:39:21 +00:00
store . findSimilarBufferKeys ( allocator , prefix ) ;
2024-06-14 15:27:06 +00:00
2023-01-02 14:28:25 +00:00
return error . BufferNotFound ;
} ;
2023-06-19 15:29:29 +00:00
} else if ( T = = zml . Shape ) return ;
2023-01-02 14:28:25 +00:00
switch ( type_info ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | {
if ( ptr_info . size = = . slice ) {
2023-01-02 14:28:25 +00:00
for ( obj . * , 0 . . ) | * value , i | {
2023-02-24 17:33:14 +00:00
try prefix_builder . pushDigit ( allocator , i ) ;
2023-01-02 14:28:25 +00:00
defer prefix_builder . pop ( ) ;
2025-08-28 14:39:21 +00:00
try visitStructAndLoadBuffer ( allocator , prefix_builder , store , value , platform ) ;
2023-01-02 14:28:25 +00:00
}
2023-07-17 09:10:27 +00:00
} else stdx . debug . compileError ( " type not supported by visitStructAndLoadBuffer: {} " , . { T } ) ;
2023-01-02 14:28:25 +00:00
} ,
2024-07-02 14:19:04 +00:00
. array = > {
2023-06-19 15:29:29 +00:00
for ( obj , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
2025-08-28 14:39:21 +00:00
try visitStructAndLoadBuffer ( allocator , prefix_builder , store , value , platform ) ;
2023-06-19 15:29:29 +00:00
}
} ,
2024-07-02 14:19:04 +00:00
. @ " struct " = > | struct_info | {
2023-01-02 14:28:25 +00:00
inline for ( struct_info . fields ) | field | {
2023-03-31 14:23:45 +00:00
if ( field . is_comptime or @sizeOf ( field . type ) = = 0 ) continue ;
2023-01-02 14:28:25 +00:00
try prefix_builder . push ( allocator , field . name ) ;
defer prefix_builder . pop ( ) ;
2025-08-28 14:39:21 +00:00
try visitStructAndLoadBuffer ( allocator , prefix_builder , store , & @field ( obj , field . name ) , platform ) ;
2023-01-02 14:28:25 +00:00
}
} ,
2024-07-02 14:19:04 +00:00
. optional = > {
2023-02-24 17:33:14 +00:00
if ( obj . * ) | * obj_val | {
2025-08-28 14:39:21 +00:00
try visitStructAndLoadBuffer ( allocator , prefix_builder , store , obj_val , platform ) ;
2023-01-02 14:28:25 +00:00
}
} ,
else = > { } ,
}
}