Update example programs (benchmark, llama, mnist, simple_layer) to use the new Exe API and reflect BaseExe allocation changes.

This commit is contained in:
Foke Singh 2023-10-10 11:12:34 +00:00
parent 3bc6ad98be
commit 35395c13f8
4 changed files with 13 additions and 25 deletions

View File

@ -5,16 +5,12 @@ const flags = @import("tigerbeetle/flags");
// set log level to debug to print the generated IR
pub const std_options = .{
.log_level = .debug,
.log_level = .warn,
};
/// Model definition
const Benchmark = struct {
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
_ = self;
pub fn benchmark(a: zml.Tensor, b: zml.Tensor) zml.Tensor {
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
}
};
pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
@ -34,11 +30,6 @@ pub fn asyncMain() !void {
defer _ = gpa.deinit();
const allocator = gpa.allocator();
// Arena allocator for BufferStore etc.
var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit();
const arena = arena_state.allocator();
var context = try zml.Context.init();
defer context.deinit();
@ -60,18 +51,15 @@ pub fn asyncMain() !void {
// Start compiling.
// The shape of the input tensor, we have to pass in manually.
timer.reset();
var compilation = try asynk.asyncc(zml.module.compileModel, .{ allocator, Benchmark.forward, Benchmark{}, .{ a_shape, b_shape }, platform });
var compilation = try asynk.asyncc(zml.compileFn, .{ allocator, benchmark, .{ a_shape, b_shape }, platform });
// Wait for compilation to finish
const compiled = try compilation.awaitt();
const executable = try compilation.awaitt();
defer executable.deinit();
const compilation_elapsed = timer.lap() / std.time.ns_per_ms;
std.debug.print("-" ** 160 ++ "\n\n", .{});
std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed});
// pass the model weights to the compiled module to create an executable module
var executable = try compiled.prepare(arena, .{});
defer executable.deinit();
var rng = std.Random.DefaultPrng.init(0);
const random = rng.random();

View File

@ -32,8 +32,8 @@ pub const std_options = .{
pub fn generateText(
llama: LlamaLM,
mod_prefill: zml.module.ExeWithWeights(LlamaLM.forward),
mod: zml.module.ExeWithWeights(LlamaLM.forward),
mod_prefill: zml.ModuleExe(LlamaLM.forward),
mod: zml.ModuleExe(LlamaLM.forward),
tokenizer: zml.tokenizer.Tokenizer,
allocator: std.mem.Allocator,
seed: u128,
@ -221,9 +221,9 @@ pub fn asyncMain() !void {
defer zml.aio.unloadBuffers(&llama_weights);
log.info("\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
var llama_module_prefill = try (try fut_mod_prefill.awaitt()).prepare(allocator, llama_weights);
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
defer llama_module_prefill.deinit();
var llama_module = try (try fut_mod.awaitt()).prepare(allocator, llama_weights);
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
defer llama_module.deinit();
log.info("\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});

View File

@ -86,7 +86,7 @@ pub fn asyncMain() !void {
const compiled_mnist = try compilation.awaitt();
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
var mnist = try compiled_mnist.prepare(allocator, model_weights);
const mnist = compiled_mnist.prepare(model_weights);
defer mnist.deinit();
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});

View File

@ -74,7 +74,7 @@ pub fn asyncMain() !void {
const compiled = try compilation.awaitt();
// pass the model weights to the compiled module to create an executable module
var executable = try compiled.prepare(arena, model_weights);
var executable = compiled.prepare(model_weights);
defer executable.deinit();
// prepare an input buffer