Radix/examples/loader/main.zig

68 lines
2.1 KiB
Zig

const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const asynk = @import("async");
const asyncc = asynk.asyncc;
pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
// Short lived allocations
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
var args = std.process.args();
// Skip executable path
_ = args.next().?;
const file = if (args.next()) |path| blk: {
std.debug.print("File path: {s}\n", .{path});
break :blk path;
} else {
std.debug.print("Missing file path argument\n", .{});
std.debug.print("Try: bazel run -c opt //loader:safetensors -- /path/to/mymodel.safetensors or /path/to/model.safetensors.index.json \n", .{});
std.process.exit(0);
};
var buffer_store = try zml.aio.safetensors.open(allocator, file);
defer buffer_store.deinit();
var context = try zml.Context.init();
defer context.deinit();
const platform = context.autoPlatform();
context.printAvailablePlatforms(platform);
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());
defer {
for (buffers) |*buf| {
buf.deinit();
}
gpa.allocator().free(buffers);
}
var total_bytes: usize = 0;
var timer = try std.time.Timer.start();
var it = buffer_store.buffers.iterator();
var i: usize = 0;
std.debug.print("\nStart to read {d} buffers from store..\n", .{buffer_store.buffers.count()});
while (it.next()) |entry| : (i += 1) {
const host_buffer = entry.value_ptr.*;
total_bytes += host_buffer.data.len;
std.debug.print("Buffer: {s} ({any} / {any})\n", .{ entry.key_ptr.*, i + 1, buffer_store.buffers.count() });
buffers[i] = try zml.Buffer.from(platform, host_buffer);
}
const stop = timer.read();
const time_in_s = stdx.math.divFloat(f64, stop, std.time.ns_per_s);
const mbs = stdx.math.divFloat(f64, total_bytes, 1024 * 1024);
std.debug.print("\nLoading speed: {d:.2} MB/s\n\n", .{mbs / time_in_s});
}