2024-10-28 11:21:46 +00:00
const std = @import ( " std " ) ;
2023-06-21 14:45:14 +00:00
const builtin = @import ( " builtin " ) ;
2024-10-28 11:21:46 +00:00
const asynk = @import ( " async " ) ;
2023-06-21 14:45:14 +00:00
const c = @import ( " c " ) ;
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
pub const gguf = @import ( " aio/gguf.zig " ) ;
pub const nemo = @import ( " aio/nemo.zig " ) ;
pub const safetensors = @import ( " aio/safetensors.zig " ) ;
2024-04-05 15:07:29 +00:00
pub const tinyllama = @import ( " aio/tinyllama.zig " ) ;
2023-01-02 14:28:25 +00:00
pub const torch = @import ( " aio/torch.zig " ) ;
pub const yaml = @import ( " aio/yaml.zig " ) ;
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
2024-10-28 11:21:46 +00:00
const posix = @import ( " posix.zig " ) ;
const zml = @import ( " zml.zig " ) ;
2023-01-02 14:28:25 +00:00
2024-10-28 11:21:46 +00:00
pub const log = std . log . scoped ( . @ " zml/aio " ) ;
2023-01-27 14:35:11 +00:00
test {
std . testing . refAllDecls ( @This ( ) ) ;
std . testing . refAllDecls ( gguf ) ;
2023-01-31 11:58:58 +00:00
std . testing . refAllDecls ( nemo ) ;
2023-01-27 14:35:11 +00:00
std . testing . refAllDecls ( safetensors ) ;
std . testing . refAllDecls ( torch ) ;
2023-01-31 11:58:58 +00:00
std . testing . refAllDecls ( yaml ) ;
2023-01-27 14:35:11 +00:00
}
2024-10-28 11:21:46 +00:00
// TODO error set for weight loading
2023-01-02 14:28:25 +00:00
/// Detects the format of the model file (base on filename) and open it.
pub fn detectFormatAndOpen ( allocator : std . mem . Allocator , model_path : [ ] const u8 ) ! BufferStore {
return if ( std . mem . endsWith ( u8 , model_path , " .safetensors " ) )
try safetensors . open ( allocator , model_path )
else if ( std . mem . endsWith ( u8 , model_path , " .safetensors.index.json " ) )
try safetensors . open ( allocator , model_path )
else if ( std . mem . endsWith ( u8 , model_path , " .gguf " ) )
try gguf . open ( allocator , model_path )
else if ( std . mem . endsWith ( u8 , model_path , " .pt " ) )
try torch . open ( allocator , model_path )
2024-04-05 15:07:29 +00:00
else if ( std . mem . endsWith ( u8 , model_path , " .tinyllama " ) )
try tinyllama . open ( allocator , model_path )
2023-01-02 14:28:25 +00:00
else {
std . debug . panic ( " File extension not recognized: {s} " , . { model_path } ) ;
} ;
}
/// Creates a Model struct with tensor shapes read from the given BufferStore.
/// The result can be used to pass to `compileModel`.
///
/// * The `Tensor` field `Model.a.b` will be populated with a `Tensor`
/// whose shape is read from the "a.b" tensor.
/// * If `Model` contains a list of layers, then the field:
/// `Model.layers[2].a.b` will be populated from the "layers.2.a.b" tensor.
pub fn populateModel ( comptime Model : type , allocator : std . mem . Allocator , buffer_store : BufferStore ) ! Model {
return populateModelWithPrefix ( Model , allocator , buffer_store , " " ) ;
}
/// Creates a Model struct with tensor shapes read from the given TensorStore,
/// using a given prefix.
/// The result can be used to pass to `compileWithModel`.
///
/// * The `Tensor` field `Model.a.b` will be populated with a `Tensor`
/// whose shape is read from the "prefix.a.b" tensor.
/// * If `Model` contains a list of layers, then the field:
/// `Model.layers[2].a.b` will be populated from the "prefix.layers.2.a.b" tensor.
pub fn populateModelWithPrefix ( comptime Model : type , allocator : std . mem . Allocator , store : BufferStore , prefix : [ ] const u8 ) ! Model {
var model : Model = undefined ;
var prefix_builder : PrefixBuilder = . { } ;
try prefix_builder . push ( allocator , prefix ) ;
defer prefix_builder . deinit ( allocator ) ;
2023-11-16 15:11:23 +00:00
const unique_id = zml . Tensor . _reserveIdRange ( @intCast ( store . buffers . count ( ) ) ) ;
2023-04-07 16:45:58 +00:00
const ok = _populateStruct ( allocator , & prefix_builder , unique_id , store , & model , true ) catch | err | {
2023-01-02 14:28:25 +00:00
std . debug . panic ( " Can't populate model of type {s}: {s} " , . { @typeName ( type ) , @errorName ( err ) } ) ;
} ;
if ( ! ok ) return error . TensorNotFound ;
return model ;
}
/// A struct containing all the buffers and metadata found in a model file.
pub const BufferStore = struct {
pub const Buffers = std . StringArrayHashMapUnmanaged ( HostBuffer ) ;
2023-04-07 16:45:58 +00:00
pub const Metadatas = std . StringArrayHashMapUnmanaged ( Metadata ) ;
2023-01-02 14:28:25 +00:00
arena : std . heap . ArenaAllocator ,
files : [ ] MemoryMappedFile = & . { } ,
buffers : Buffers = . { } ,
2023-04-07 16:45:58 +00:00
_metadata : Metadatas = . { } ,
2023-01-02 14:28:25 +00:00
2023-04-20 15:43:18 +00:00
/// Create an empty BufferStore. Takes owneship of the given files.
pub fn init ( allocator : std . mem . Allocator , files : [ ] const MemoryMappedFile ) error { OutOfMemory } ! BufferStore {
var self : zml . aio . BufferStore = . {
. arena = std . heap . ArenaAllocator . init ( allocator ) ,
} ;
self . files = try self . arena . allocator ( ) . dupe ( MemoryMappedFile , files ) ;
return self ;
}
2023-01-02 14:28:25 +00:00
pub fn deinit ( self : BufferStore ) void {
2024-01-16 14:13:45 +00:00
for ( self . files ) | * file | {
file . deinit ( ) ;
}
2023-01-02 14:28:25 +00:00
self . arena . deinit ( ) ;
}
pub fn get ( self : BufferStore , key : [ ] const u8 ) ? HostBuffer {
return self . buffers . get ( key ) ;
}
/// Count layers starting with the given prefix.
pub fn countLayers ( self : BufferStore , prefix : [ ] const u8 ) usize {
// Note: This is kinda inefficient
const digit_start_index = prefix . len + 1 ;
var it = self . buffers . iterator ( ) ;
var maybe_max_index : ? usize = null ;
while ( it . next ( ) ) | entry | {
if ( ! std . mem . startsWith ( u8 , entry . key_ptr . * , prefix ) ) continue ;
const next_dot_index = std . mem . indexOfScalarPos ( u8 , entry . key_ptr . * , digit_start_index , '.' ) orelse entry . key_ptr . len ;
const index = std . fmt . parseInt ( usize , entry . key_ptr . * [ digit_start_index . . next_dot_index ] , 10 ) catch continue ;
if ( maybe_max_index ) | * max_index | {
max_index . * = @max ( max_index . * , index ) ;
} else {
maybe_max_index = index ;
}
}
return if ( maybe_max_index ) | index | index + 1 else 0 ;
}
2023-04-07 16:45:58 +00:00
pub fn metadata ( self : BufferStore , key : [ ] const u8 , comptime tag : std . meta . FieldEnum ( Metadata ) ) ? std . meta . FieldType ( Metadata , tag ) {
2023-01-02 14:28:25 +00:00
const wrapped_value = self . _metadata . get ( key ) orelse return null ;
if ( wrapped_value ! = tag ) {
zml . log . err ( " Tried to interpret metadata '{s}' as {}, but was of type {} " , . { key , tag , wrapped_value } ) ;
@panic ( " invalid metadata type " ) ;
}
return @field ( wrapped_value , @tagName ( tag ) ) ;
}
2023-04-07 16:45:58 +00:00
pub fn metadataSlice ( self : BufferStore , key : [ ] const u8 , comptime tag : Metadata . ItemType ) ? [ ] const tag . toZigType ( ) {
2023-01-02 14:28:25 +00:00
const wrapped_value = self . _metadata . get ( key ) orelse return null ;
2023-04-07 16:45:58 +00:00
const true_tag = std . meta . stringToEnum ( std . meta . FieldEnum ( Metadata ) , @tagName ( tag ) ) . ? ;
if ( wrapped_value = = true_tag ) {
return @field ( wrapped_value , " array_ " + + @tagName ( tag ) ) ;
}
return null ;
}
} ;
2023-01-02 14:28:25 +00:00
2023-04-07 16:45:58 +00:00
pub const Metadata = union ( enum ) {
null : void ,
int : i64 ,
float : f64 ,
bool : bool ,
string : [ ] const u8 ,
array_bool : [ ] const bool ,
array_int : [ ] const i64 ,
array_float : [ ] const f64 ,
array_string : [ ] const [ ] const u8 ,
pub const ItemType = enum {
int ,
float ,
bool ,
string ,
pub fn toZigType ( comptime kind : ItemType ) type {
return switch ( kind ) {
. int = > i64 ,
. float = > f64 ,
. bool = > bool ,
. string = > [ ] const u8 ,
} ;
}
} ;
pub fn wrap ( x : anytype ) Metadata {
return switch ( @TypeOf ( x ) ) {
inline u8 , i8 , u16 , i16 , u32 , i32 , u64 , i64 = > . { . int = @intCast ( x ) } ,
inline f16 , f32 , f64 = > . { . float = @floatCast ( x ) } ,
bool = > . { . bool = x } ,
[ ] const u8 = > . { . string = x } ,
else = > @panic ( " Unsupported type for zml.aio.Value: " + + @typeName ( @TypeOf ( x ) ) ) ,
} ;
}
pub fn copySlice ( allocator : std . mem . Allocator , any_slice : anytype ) ! Metadata {
return switch ( @TypeOf ( any_slice [ 0 ] ) ) {
inline u8 , i8 , u16 , i16 , u32 , i32 , u64 , i64 = > {
const res = try allocator . alloc ( i64 , any_slice . len ) ;
for ( res , any_slice ) | * r , val | r . * = @intCast ( val ) ;
return . { . array_int = res } ;
} ,
inline f16 , f32 , f64 = > {
const res = try allocator . alloc ( f64 , any_slice . len ) ;
for ( res , any_slice ) | * r , val | r . * = @floatCast ( val ) ;
return . { . array_float = res } ;
} ,
bool = > . { . array_bool = try allocator . dupe ( bool , any_slice ) } ,
[ ] const u8 = > . { . array_string = try allocator . dupe ( [ ] const u8 , @alignCast ( any_slice ) ) } ,
else = > @panic ( " Unsupported type for zml.aio.Value: " + + @typeName ( @TypeOf ( any_slice ) ) ) ,
} ;
}
pub fn format (
self : Metadata ,
comptime fmt : [ ] const u8 ,
options : std . fmt . FormatOptions ,
writer : anytype ,
) ! void {
_ = fmt ;
_ = options ;
switch ( self ) {
. null = > _ = try writer . write ( " null " ) ,
inline . bool , . array_bool = > | b | try writer . print ( " {any} " , . { b } ) ,
inline else = > | v | try writer . print ( " {d} " , . { v } ) ,
2023-01-02 14:28:25 +00:00
}
}
} ;
/// A file containing contiguous/non-contiguous buffers, that can be read with mmap
/// (assumes contiguous if `strides` is `null`).
/// This struct is meant to be wrapped into a format specific struct, like io.gguf.File.
pub const MemoryMappedFile = struct {
/// underlying file handle
file : asynk . File ,
2024-07-02 14:19:04 +00:00
data : [ ] align ( std . heap . page_size_min ) const u8 ,
2023-01-02 14:28:25 +00:00
data_offset : u64 = 0 ,
pub fn init ( file : asynk . File ) ! MemoryMappedFile {
const data_len : usize = ( try file . stat ( ) ) . size ;
2023-05-04 14:44:12 +00:00
const data_ = try asynk . callBlocking ( std . posix . mmap , . {
2023-01-02 14:28:25 +00:00
null ,
data_len ,
std . posix . PROT . READ ,
2024-07-02 14:19:04 +00:00
std . posix . system . MAP { . TYPE = . PRIVATE } ,
2023-05-25 16:02:11 +00:00
file . handle ( ) ,
2023-01-02 14:28:25 +00:00
0 ,
} ) ;
2023-06-21 14:45:14 +00:00
try asynk . callBlocking ( posix . madvise , . {
data_ . ptr ,
@as ( usize , @intCast ( data_ . len ) ) ,
@as ( u32 , @intCast ( c . MADV_SEQUENTIAL ) ) ,
} ) ;
2023-01-02 14:28:25 +00:00
return . {
. file = file ,
. data = data_ ,
} ;
}
2023-04-20 15:43:18 +00:00
pub fn mappedSlice ( self : MemoryMappedFile , start : usize , len : usize ) [ ] const u8 {
2023-01-02 14:28:25 +00:00
return self . data [ self . data_offset + start . . ] [ 0 . . len ] ;
}
pub fn deinit ( self : * MemoryMappedFile ) void {
std . posix . munmap ( self . data ) ;
self . file . close ( ) catch @panic ( " failed to close file " ) ;
self . * = undefined ;
}
} ;
/// Helper handling prefix building.
///
/// This allows to easily push/pop prefixes and handles the generation of the string with the correct format.
const PrefixBuilder = struct {
/// Stores the computed prefix.
data : std . ArrayListUnmanaged ( u8 ) = . { } ,
/// Stack storing the size of the intermediary prefix.
subprefixes : std . ArrayListUnmanaged ( u32 ) = . { } ,
pub fn deinit ( self : * PrefixBuilder , allocator : std . mem . Allocator ) void {
self . data . deinit ( allocator ) ;
self . subprefixes . deinit ( allocator ) ;
}
pub fn push ( self : * PrefixBuilder , allocator : std . mem . Allocator , prefix : [ ] const u8 ) ! void {
const old_len : u32 = @intCast ( self . data . items . len ) ;
try self . subprefixes . append ( allocator , old_len ) ;
errdefer _ = self . subprefixes . pop ( ) ;
if ( old_len = = 0 ) {
try self . data . appendSlice ( allocator , prefix ) ;
} else {
try self . data . ensureUnusedCapacity ( allocator , prefix . len + 1 ) ;
self . data . appendAssumeCapacity ( '.' ) ;
self . data . appendSliceAssumeCapacity ( prefix ) ;
}
}
pub fn pushDigit ( self : * PrefixBuilder , allocator : std . mem . Allocator , idx : usize ) ! void {
const old_len : u32 = @intCast ( self . data . items . len ) ;
try self . subprefixes . append ( allocator , old_len ) ;
errdefer _ = self . subprefixes . pop ( ) ;
try self . data . ensureUnusedCapacity ( allocator , 16 ) ;
if ( old_len > 0 ) {
self . data . appendAssumeCapacity ( '.' ) ;
}
try self . data . writer ( allocator ) . print ( " {d} " , . { idx } ) ;
}
pub fn pop ( self : * PrefixBuilder ) void {
2024-07-02 14:19:04 +00:00
const last_prefix_len = self . subprefixes . pop ( ) orelse unreachable ;
2023-01-02 14:28:25 +00:00
self . data . shrinkRetainingCapacity ( last_prefix_len ) ;
}
} ;
fn _populateStruct (
allocator : std . mem . Allocator ,
prefix_builder : * PrefixBuilder ,
2023-04-07 16:45:58 +00:00
unique_id : u64 ,
2023-01-02 14:28:25 +00:00
buffer_store : BufferStore ,
obj : anytype ,
required : bool ,
) ! bool {
const err_msg = " _populateStruct must be called with a pointer to type. Received " ;
const type_info , const T = switch ( @typeInfo ( @TypeOf ( obj ) ) ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | switch ( ptr_info . size ) {
. one = > . { @typeInfo ( ptr_info . child ) , ptr_info . child } ,
2023-01-02 14:28:25 +00:00
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ,
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ;
const prefix = prefix_builder . data . items ;
if ( T = = zml . Tensor ) {
2023-04-07 16:45:58 +00:00
return if ( buffer_store . buffers . getIndex ( prefix ) ) | entry_idx | {
const buffer = buffer_store . get ( prefix ) . ? ;
2023-01-02 14:28:25 +00:00
obj . * = zml . Tensor {
. _shape = buffer . shape ( ) ,
2023-04-07 16:45:58 +00:00
. _id = . { . buffer_id = unique_id + entry_idx } ,
2023-01-02 14:28:25 +00:00
. _donation = . input_buffer ,
} ;
return true ;
} else {
if ( required ) {
2023-04-07 16:45:58 +00:00
log . err ( " Tensor not found: {s} ({d}) " , . { prefix , buffer_store . buffers . count ( ) } ) ;
2023-01-02 14:28:25 +00:00
}
return false ;
} ;
}
2023-03-31 14:23:45 +00:00
return switch ( type_info ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | {
if ( ptr_info . size = = . slice ) {
2023-01-02 14:28:25 +00:00
obj . * = & . { } ;
const len = buffer_store . countLayers ( prefix ) ;
if ( len > 0 ) {
obj . * = try allocator . alloc ( ptr_info . child , len ) ;
for ( obj . * , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
const found = try _populateStruct ( allocator , prefix_builder , unique_id , buffer_store , value , required ) ;
if ( ! found ) {
2023-04-07 16:45:58 +00:00
log . err ( " Not able to load {s} as {s} " , . { prefix_builder . data . items , @typeName ( ptr_info . child ) } ) ;
2023-01-02 14:28:25 +00:00
return false ;
}
}
} else if ( required ) {
log . warn ( " No layer found at {s} " , . { prefix } ) ;
}
return true ;
} else {
2023-04-07 16:45:58 +00:00
log . err ( " {s} - {s}: {s} type not supported " , . { @src ( ) . fn_name , prefix , @typeName ( T ) } ) ;
2023-01-02 14:28:25 +00:00
return false ;
}
} ,
2024-07-02 14:19:04 +00:00
. array = > | arr_info | {
2023-06-19 15:29:29 +00:00
for ( obj , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
const found = try _populateStruct ( allocator , prefix_builder , unique_id , buffer_store , value , required ) ;
if ( ! found ) {
log . err ( " Not able to load {s} as {s} " , . { prefix_builder . data . items , @typeName ( arr_info . child ) } ) ;
return false ;
}
}
return true ;
} ,
2024-07-02 14:19:04 +00:00
. @ " struct " = > | struct_info | {
2023-01-02 14:28:25 +00:00
var partial_struct = false ;
inline for ( struct_info . fields ) | field | {
2023-03-31 14:23:45 +00:00
if ( field . is_comptime or @sizeOf ( field . type ) = = 0 ) continue ;
2023-01-02 14:28:25 +00:00
try prefix_builder . push ( allocator , field . name ) ;
defer prefix_builder . pop ( ) ;
var has_default = false ;
2024-07-02 14:19:04 +00:00
if ( field . default_value_ptr ) | _ | has_default = true ;
2023-01-02 14:28:25 +00:00
const field_found = try _populateStruct ( allocator , prefix_builder , unique_id , buffer_store , & @field ( obj , field . name ) , required and ! has_default ) ;
partial_struct = partial_struct or field_found ;
if ( ! field_found ) {
2024-07-02 14:19:04 +00:00
if ( field . default_value_ptr ) | v | {
2023-01-02 14:28:25 +00:00
@field ( obj , field . name ) = @as ( * const field . type , @alignCast ( @ptrCast ( v ) ) ) . * ;
} else {
if ( partial_struct ) {
log . warn ( " Incomplete metadata '{0s}': {1s}. Missing field: '{2s}'. '{0s}' will be ignored. " , . { prefix , @typeName ( T ) , field . name } ) ;
obj . * = undefined ;
return false ;
}
return false ;
}
}
}
return true ;
} ,
2024-07-02 14:19:04 +00:00
. optional = > | opt_info | {
2023-01-02 14:28:25 +00:00
obj . * = @as ( opt_info . child , undefined ) ;
const found = try _populateStruct ( allocator , prefix_builder , unique_id , buffer_store , & ( obj . * . ? ) , false ) ;
if ( ! found ) obj . * = null ;
return true ;
} ,
2024-07-02 14:19:04 +00:00
. int = > {
2023-01-02 14:28:25 +00:00
obj . * = undefined ;
return true ;
} ,
2024-07-02 14:19:04 +00:00
. float = > {
2024-10-28 11:21:46 +00:00
obj . * = std . math . nan ( @TypeOf ( obj . * ) ) ;
2023-01-02 14:28:25 +00:00
return true ;
} ,
2024-07-02 14:19:04 +00:00
. void = > true ,
. @ " union " = > true ,
2024-09-02 14:11:47 +00:00
. bool = > {
obj . * = undefined ;
return true ;
} ,
2023-01-02 14:28:25 +00:00
else = > if ( required ) {
2023-04-07 16:45:58 +00:00
log . err ( " {s}: {s} type not supported " , . { prefix , @typeName ( T ) } ) ;
2023-01-02 14:28:25 +00:00
return error . UnsupportedMetadataType ;
} else return false ,
2023-03-31 14:23:45 +00:00
} ;
}
test populateModel {
const Model = struct {
a : zml . Tensor ,
b : struct { a : zml . Tensor , b : u32 } ,
c : [ ] zml . Tensor ,
d : [ ] struct { a : zml . Tensor , b : u32 } ,
e : struct { zml . Tensor , u32 , struct { a : u32 , b : zml . Tensor , c : void } } ,
f : ? zml . Tensor ,
g : ? zml . Tensor ,
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
fn _newHostBuffer ( n : u32 ) zml . HostBuffer {
2024-10-28 11:21:46 +00:00
return . { . _shape = zml . Shape . init ( . { n } , . f16 ) , . _strides = undefined , . _data = undefined } ;
2023-03-31 14:23:45 +00:00
}
} ;
var arena_state = std . heap . ArenaAllocator . init ( std . testing . allocator ) ;
defer arena_state . deinit ( ) ;
var store : BufferStore = . { . arena = arena_state } ;
try store . buffers . ensureUnusedCapacity ( arena_state . allocator ( ) , 16 ) ;
store . buffers . putAssumeCapacity ( " a " , Model . _newHostBuffer ( 10 ) ) ;
store . buffers . putAssumeCapacity ( " b.a " , Model . _newHostBuffer ( 20 ) ) ;
store . buffers . putAssumeCapacity ( " c.0 " , Model . _newHostBuffer ( 30 ) ) ;
store . buffers . putAssumeCapacity ( " c.1 " , Model . _newHostBuffer ( 31 ) ) ;
store . buffers . putAssumeCapacity ( " c.2 " , Model . _newHostBuffer ( 32 ) ) ;
store . buffers . putAssumeCapacity ( " d.0.a " , Model . _newHostBuffer ( 40 ) ) ;
store . buffers . putAssumeCapacity ( " d.1.a " , Model . _newHostBuffer ( 41 ) ) ;
store . buffers . putAssumeCapacity ( " d.2.a " , Model . _newHostBuffer ( 42 ) ) ;
store . buffers . putAssumeCapacity ( " e.0 " , Model . _newHostBuffer ( 50 ) ) ;
store . buffers . putAssumeCapacity ( " e.2.b " , Model . _newHostBuffer ( 51 ) ) ;
store . buffers . putAssumeCapacity ( " f " , Model . _newHostBuffer ( 60 ) ) ;
// no entry for g.
store . buffers . putAssumeCapacity ( " unused_entry " , Model . _newHostBuffer ( 1000 ) ) ;
const model = try populateModel ( Model , arena_state . allocator ( ) , store ) ;
try std . testing . expectEqual ( 10 , model . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 20 , model . b . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 3 , model . c . len ) ;
try std . testing . expectEqual ( 30 , model . c [ 0 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 31 , model . c [ 1 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 32 , model . c [ 2 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 3 , model . d . len ) ;
try std . testing . expectEqual ( 40 , model . d [ 0 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 41 , model . d [ 1 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 42 , model . d [ 2 ] . a . dim ( 0 ) ) ;
try std . testing . expectEqual ( 50 , model . e [ 0 ] . dim ( 0 ) ) ;
try std . testing . expectEqual ( 51 , model . e [ 2 ] . b . dim ( 0 ) ) ;
try std . testing . expectEqual ( 60 , model . f . ? . dim ( 0 ) ) ;
try std . testing . expectEqual ( null , model . g ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a bufferized version of a Model from the given BufferStore. For details about
/// bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
///
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
pub fn loadBuffers (
comptime Model : type ,
2024-10-28 11:21:46 +00:00
init_args : if ( @hasDecl ( Model , " init " ) ) stdx . meta . Tail ( stdx . meta . FnArgs ( Model . init ) ) else void ,
2023-01-02 14:28:25 +00:00
buffer_store : BufferStore ,
allocator : std . mem . Allocator ,
platform : zml . Platform ,
) ! zml . Bufferized ( Model ) {
var arena_state = std . heap . ArenaAllocator . init ( allocator ) ;
defer arena_state . deinit ( ) ;
const arena = arena_state . allocator ( ) ;
var model : Model = try zml . aio . populateModel ( Model , arena , buffer_store ) ;
// If the Model has a "init" function, call it with the given parameters.
if ( @hasDecl ( Model , " init " ) ) {
@call ( . auto , Model . init , . { & model } + + init_args ) ;
}
2023-04-07 16:45:58 +00:00
return loadModelBuffersWithPrefix ( Model , model , buffer_store , allocator , platform , " " ) ;
2023-01-02 14:28:25 +00:00
}
/// Creates a bufferized version of a Model from the given BufferStore. For details about
/// bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
pub fn loadModelBuffers (
comptime Model : type ,
model : Model ,
buffer_store : BufferStore ,
allocator : std . mem . Allocator ,
platform : zml . Platform ,
) ! zml . Bufferized ( Model ) {
return try loadModelBuffersWithPrefix ( Model , model , buffer_store , allocator , platform , " " ) ;
}
2023-04-07 16:45:58 +00:00
2023-02-24 17:33:14 +00:00
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
2023-01-02 14:28:25 +00:00
pub fn loadModelBuffersWithPrefix (
comptime Model : type ,
model : Model ,
buffer_store : BufferStore ,
allocator : std . mem . Allocator ,
platform : zml . Platform ,
prefix : [ ] const u8 ,
) ! zml . Bufferized ( Model ) {
// Allocate the bufferized version.
2023-02-24 17:33:14 +00:00
// We copy the shape, and let visitStructAndLoadBuffer write the other fields.
2023-01-02 14:28:25 +00:00
// to write them just afterward.
var res : zml . Bufferized ( Model ) = undefined ;
try zml . meta . mapAlloc ( struct {
2023-02-24 17:33:14 +00:00
pub fn initBuffer ( _ : void , tensor : zml . Tensor ) zml . Buffer {
return . { . _shape = tensor . shape ( ) , . _api = undefined , . _shards = undefined } ;
2023-01-02 14:28:25 +00:00
}
} . initBuffer , allocator , { } , model , & res ) ;
var prefix_builder : PrefixBuilder = . { } ;
try prefix_builder . push ( allocator , prefix ) ;
defer prefix_builder . deinit ( allocator ) ;
try visitStructAndLoadBuffer ( allocator , & prefix_builder , buffer_store , & res , platform ) ;
return res ;
}
/// Takes a bufferized version of a `model`, ie a mirror struct of the `model`, and deinit all the
/// Buffer found.
pub fn unloadBuffers ( model : anytype ) void {
zml . meta . visit ( ( struct {
fn cb ( _ : void , buffer : * zml . Buffer ) void {
buffer . deinit ( ) ;
}
} ) . cb , { } , model ) ;
}
2024-06-14 15:27:06 +00:00
/// Assists in debuggigng `BufferNotFound` error
/// This is useful when a buffer key is not found and you want to identify possible alternatives (or typos)
fn findSimilarBufferKeys ( original_key : [ ] const u8 , store : BufferStore , temp_allocator : std . mem . Allocator ) void {
const suffixes = [ _ ] [ ] const u8 { " " , " .weight " , " .bias " } ;
var shown_keys = std . StringHashMap ( void ) . init ( temp_allocator ) ;
defer shown_keys . deinit ( ) ;
// remove suffix .weight and .bias
var base_key = original_key ;
for ( suffixes ) | suffix | {
if ( std . mem . endsWith ( u8 , original_key , suffix ) ) {
base_key = original_key [ 0 . . original_key . len - suffix . len ] ;
break ;
}
}
// first test: look for exact matches
var matches : usize = 0 ;
var it = store . buffers . iterator ( ) ;
while ( it . next ( ) ) | entry | {
const key = entry . key_ptr . * ;
if ( std . mem . startsWith ( u8 , key , base_key ) ) {
if ( matches = = 0 ) log . warn ( " Similar buffers found: " , . { } ) ;
if ( ! shown_keys . contains ( key ) ) {
log . warn ( " - {s}: {} " , . { key , entry . value_ptr . * . shape ( ) } ) ;
shown_keys . put ( key , { } ) catch continue ;
matches + = 1 ;
}
}
}
// second test: progressive partial matches
if ( matches = = 0 ) {
var components = std . mem . splitScalar ( u8 , base_key , '.' ) ;
while ( components . next ( ) ) | component | {
matches = 0 ;
it = store . buffers . iterator ( ) ;
while ( it . next ( ) ) | entry | {
const key = entry . key_ptr . * ;
if ( std . mem . indexOf ( u8 , key , component ) ! = null and ! shown_keys . contains ( key ) ) {
if ( matches = = 0 ) log . warn ( " Partial matches for '{s}': " , . { component } ) ;
log . warn ( " - {s}: {} " , . { key , entry . value_ptr . * . shape ( ) } ) ;
shown_keys . put ( key , { } ) catch continue ;
matches + = 1 ;
if ( matches > = 5 ) break ;
}
}
if ( matches > 0 ) break ;
}
}
}
2024-05-02 17:10:11 +00:00
/// deinit all buffers in the given struct
pub fn awaitAll ( buffers : anytype ) ! void {
2024-12-25 17:14:44 +00:00
zml . meta . visit ( ( struct {
fn cb ( _ : void , buffer : * zml . Buffer ) void {
buffer . * = buffer . awaitt ( ) catch unreachable ;
}
} ) . cb , { } , buffers ) ;
2024-05-02 17:10:11 +00:00
}
2023-01-02 14:28:25 +00:00
fn visitStructAndLoadBuffer ( allocator : std . mem . Allocator , prefix_builder : * PrefixBuilder , buffer_store : BufferStore , obj : anytype , platform : zml . Platform ) ! void {
const err_msg = " visitStructAndLoadBuffer must be called with a pointer to type. Received " ;
const type_info , const T = switch ( @typeInfo ( @TypeOf ( obj ) ) ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | switch ( ptr_info . size ) {
. one = > . { @typeInfo ( ptr_info . child ) , ptr_info . child } ,
2023-01-02 14:28:25 +00:00
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ,
else = > @compileError ( err_msg + + @typeName ( @TypeOf ( obj ) ) ) ,
} ;
const prefix = prefix_builder . data . items ;
if ( T = = zml . Buffer ) {
return if ( buffer_store . get ( prefix ) ) | host_buffer | {
2023-02-24 17:33:14 +00:00
// obj._shape has been set inside `loadModelBuffersWithPrefix`, before calling us.
var buf_with_metadata = host_buffer ;
2023-04-20 15:43:18 +00:00
log . debug ( " Loading buffer {s} ({}) " , . { prefix , obj . _shape } ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( host_buffer . shape ( ) . eql ( obj . _shape ) , " loadModelBuffers expects to find the same shapes in the model and in the buffer store, got {} and {} for tensor {s} " , . { obj . _shape , host_buffer , prefix } ) ;
2023-02-24 17:33:14 +00:00
buf_with_metadata . _shape = obj . _shape ;
2024-12-25 17:14:44 +00:00
obj . * = try zml . Buffer . from ( platform , buf_with_metadata , . { } ) ;
2023-01-02 14:28:25 +00:00
} else {
2024-06-14 15:27:06 +00:00
log . err ( " Buffer not found: {s} " , . { prefix } ) ;
findSimilarBufferKeys ( prefix , buffer_store , allocator ) ;
2023-01-02 14:28:25 +00:00
return error . BufferNotFound ;
} ;
2023-06-19 15:29:29 +00:00
} else if ( T = = zml . Shape ) return ;
2023-01-02 14:28:25 +00:00
switch ( type_info ) {
2024-07-02 14:19:04 +00:00
. pointer = > | ptr_info | {
if ( ptr_info . size = = . slice ) {
2023-01-02 14:28:25 +00:00
for ( obj . * , 0 . . ) | * value , i | {
2023-02-24 17:33:14 +00:00
try prefix_builder . pushDigit ( allocator , i ) ;
2023-01-02 14:28:25 +00:00
defer prefix_builder . pop ( ) ;
try visitStructAndLoadBuffer ( allocator , prefix_builder , buffer_store , value , platform ) ;
}
2023-07-17 09:10:27 +00:00
} else stdx . debug . compileError ( " type not supported by visitStructAndLoadBuffer: {} " , . { T } ) ;
2023-01-02 14:28:25 +00:00
} ,
2024-07-02 14:19:04 +00:00
. array = > {
2023-06-19 15:29:29 +00:00
for ( obj , 0 . . ) | * value , i | {
try prefix_builder . pushDigit ( allocator , i ) ;
defer prefix_builder . pop ( ) ;
try visitStructAndLoadBuffer ( allocator , prefix_builder , buffer_store , value , platform ) ;
}
} ,
2024-07-02 14:19:04 +00:00
. @ " struct " = > | struct_info | {
2023-01-02 14:28:25 +00:00
inline for ( struct_info . fields ) | field | {
2023-03-31 14:23:45 +00:00
if ( field . is_comptime or @sizeOf ( field . type ) = = 0 ) continue ;
2023-01-02 14:28:25 +00:00
try prefix_builder . push ( allocator , field . name ) ;
defer prefix_builder . pop ( ) ;
try visitStructAndLoadBuffer ( allocator , prefix_builder , buffer_store , & @field ( obj , field . name ) , platform ) ;
}
} ,
2024-07-02 14:19:04 +00:00
. optional = > {
2023-02-24 17:33:14 +00:00
if ( obj . * ) | * obj_val | {
try visitStructAndLoadBuffer ( allocator , prefix_builder , buffer_store , obj_val , platform ) ;
2023-01-02 14:28:25 +00:00
}
} ,
else = > { } ,
}
}