2023-01-03 10:21:07 +00:00
const std = @import ( " std " ) ;
2024-10-25 10:20:04 +00:00
const asynk = @import ( " async " ) ;
2023-06-27 14:23:22 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-03 10:21:07 +00:00
const zml = @import ( " zml " ) ;
pub fn main ( ) ! void {
2023-06-27 14:23:22 +00:00
try asynk . AsyncThread . main ( std . heap . c_allocator , asyncMain ) ;
2023-01-03 10:21:07 +00:00
}
pub fn asyncMain ( ) ! void {
// Short lived allocations
var gpa = std . heap . GeneralPurposeAllocator ( . { } ) { } ;
defer _ = gpa . deinit ( ) ;
const allocator = gpa . allocator ( ) ;
var args = std . process . args ( ) ;
// Skip executable path
_ = args . next ( ) . ? ;
const file = if ( args . next ( ) ) | path | blk : {
std . debug . print ( " File path: {s} \n " , . { path } ) ;
break : blk path ;
} else {
std . debug . print ( " Missing file path argument \n " , . { } ) ;
2024-10-14 11:27:41 +00:00
std . debug . print ( " Try: bazel run --config=release //loader:safetensors -- /path/to/mymodel.safetensors or /path/to/model.safetensors.index.json \n " , . { } ) ;
2023-01-03 10:21:07 +00:00
std . process . exit ( 0 ) ;
} ;
var buffer_store = try zml . aio . safetensors . open ( allocator , file ) ;
defer buffer_store . deinit ( ) ;
var context = try zml . Context . init ( ) ;
defer context . deinit ( ) ;
2023-11-09 12:31:11 +00:00
const platform = context . autoPlatform ( . { } ) ;
2023-06-27 14:23:22 +00:00
context . printAvailablePlatforms ( platform ) ;
2023-01-03 10:21:07 +00:00
var buffers = try gpa . allocator ( ) . alloc ( zml . Buffer , buffer_store . buffers . count ( ) ) ;
defer {
2024-10-25 10:20:04 +00:00
// Note we don't pass an allocator to buf.deinit() cause its allocated on the device.
for ( buffers ) | * buf | buf . deinit ( ) ;
2023-01-03 10:21:07 +00:00
gpa . allocator ( ) . free ( buffers ) ;
}
var total_bytes : usize = 0 ;
var timer = try std . time . Timer . start ( ) ;
var it = buffer_store . buffers . iterator ( ) ;
var i : usize = 0 ;
std . debug . print ( " \n Start to read {d} buffers from store.. \n " , . { buffer_store . buffers . count ( ) } ) ;
while ( it . next ( ) ) | entry | : ( i + = 1 ) {
const host_buffer = entry . value_ptr . * ;
2024-10-25 10:20:04 +00:00
total_bytes + = host_buffer . shape ( ) . byteSize ( ) ;
2023-03-22 14:52:33 +00:00
std . debug . print ( " Buffer: {s} ({any} / {any}) \n " , . { entry . key_ptr . * , i + 1 , buffer_store . buffers . count ( ) } ) ;
2023-01-03 10:21:07 +00:00
buffers [ i ] = try zml . Buffer . from ( platform , host_buffer ) ;
}
const stop = timer . read ( ) ;
2023-07-04 13:40:05 +00:00
const time_in_s = stdx . math . divFloat ( f64 , stop , std . time . ns_per_s ) ;
const mbs = stdx . math . divFloat ( f64 , total_bytes , 1024 * 1024 ) ;
2023-01-03 10:21:07 +00:00
std . debug . print ( " \n Loading speed: {d:.2} MB/s \n \n " , . { mbs / time_in_s } ) ;
}