diff --git a/examples/benchmark/main.zig b/examples/benchmark/main.zig index ee032fc..0a8aedb 100644 --- a/examples/benchmark/main.zig +++ b/examples/benchmark/main.zig @@ -5,16 +5,12 @@ const flags = @import("tigerbeetle/flags"); // set log level to debug to print the generated IR pub const std_options = .{ - .log_level = .debug, + .log_level = .warn, }; -/// Model definition -const Benchmark = struct { - pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor { - _ = self; - return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m}); - } -}; +pub fn benchmark(a: zml.Tensor, b: zml.Tensor) zml.Tensor { + return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m}); +} pub fn main() !void { try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain); @@ -34,11 +30,6 @@ pub fn asyncMain() !void { defer _ = gpa.deinit(); const allocator = gpa.allocator(); - // Arena allocator for BufferStore etc. - var arena_state = std.heap.ArenaAllocator.init(allocator); - defer arena_state.deinit(); - const arena = arena_state.allocator(); - var context = try zml.Context.init(); defer context.deinit(); @@ -60,18 +51,15 @@ pub fn asyncMain() !void { // Start compiling. // The shape of the input tensor, we have to pass in manually. timer.reset(); - var compilation = try asynk.asyncc(zml.module.compileModel, .{ allocator, Benchmark.forward, Benchmark{}, .{ a_shape, b_shape }, platform }); + var compilation = try asynk.asyncc(zml.compileFn, .{ allocator, benchmark, .{ a_shape, b_shape }, platform }); // Wait for compilation to finish - const compiled = try compilation.awaitt(); + const executable = try compilation.awaitt(); + defer executable.deinit(); const compilation_elapsed = timer.lap() / std.time.ns_per_ms; std.debug.print("-" ** 160 ++ "\n\n", .{}); std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed}); - // pass the model weights to the compiled module to create an executable module - var executable = try compiled.prepare(arena, .{}); - defer executable.deinit(); - var rng = std.Random.DefaultPrng.init(0); const random = rng.random(); diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 38ae79c..1442a7c 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -32,8 +32,8 @@ pub const std_options = .{ pub fn generateText( llama: LlamaLM, - mod_prefill: zml.module.ExeWithWeights(LlamaLM.forward), - mod: zml.module.ExeWithWeights(LlamaLM.forward), + mod_prefill: zml.ModuleExe(LlamaLM.forward), + mod: zml.ModuleExe(LlamaLM.forward), tokenizer: zml.tokenizer.Tokenizer, allocator: std.mem.Allocator, seed: u128, @@ -221,9 +221,9 @@ pub fn asyncMain() !void { defer zml.aio.unloadBuffers(&llama_weights); log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms}); - var llama_module_prefill = try (try fut_mod_prefill.awaitt()).prepare(allocator, llama_weights); + var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights); defer llama_module_prefill.deinit(); - var llama_module = try (try fut_mod.awaitt()).prepare(allocator, llama_weights); + var llama_module = (try fut_mod.awaitt()).prepare(llama_weights); defer llama_module.deinit(); log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms}); diff --git a/examples/mnist/mnist.zig b/examples/mnist/mnist.zig index f15a86e..98fab69 100644 --- a/examples/mnist/mnist.zig +++ b/examples/mnist/mnist.zig @@ -86,7 +86,7 @@ pub fn asyncMain() !void { const compiled_mnist = try compilation.awaitt(); log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms}); - var mnist = try compiled_mnist.prepare(allocator, model_weights); + const mnist = compiled_mnist.prepare(model_weights); defer mnist.deinit(); log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms}); diff --git a/examples/simple_layer/main.zig b/examples/simple_layer/main.zig index ec88aaf..475f10b 100644 --- a/examples/simple_layer/main.zig +++ b/examples/simple_layer/main.zig @@ -74,7 +74,7 @@ pub fn asyncMain() !void { const compiled = try compilation.awaitt(); // pass the model weights to the compiled module to create an executable module - var executable = try compiled.prepare(arena, model_weights); + var executable = compiled.prepare(model_weights); defer executable.deinit(); // prepare an input buffer