Update example programs (benchmark, llama, mnist, simple_layer) to use the new Exe API and reflect BaseExe allocation changes.
This commit is contained in:
parent
3bc6ad98be
commit
35395c13f8
@ -5,16 +5,12 @@ const flags = @import("tigerbeetle/flags");
|
||||
|
||||
// set log level to debug to print the generated IR
|
||||
pub const std_options = .{
|
||||
.log_level = .debug,
|
||||
.log_level = .warn,
|
||||
};
|
||||
|
||||
/// Model definition
|
||||
const Benchmark = struct {
|
||||
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
||||
_ = self;
|
||||
pub fn benchmark(a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
||||
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
@ -34,11 +30,6 @@ pub fn asyncMain() !void {
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
|
||||
// Arena allocator for BufferStore etc.
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const arena = arena_state.allocator();
|
||||
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
@ -60,18 +51,15 @@ pub fn asyncMain() !void {
|
||||
// Start compiling.
|
||||
// The shape of the input tensor, we have to pass in manually.
|
||||
timer.reset();
|
||||
var compilation = try asynk.asyncc(zml.module.compileModel, .{ allocator, Benchmark.forward, Benchmark{}, .{ a_shape, b_shape }, platform });
|
||||
var compilation = try asynk.asyncc(zml.compileFn, .{ allocator, benchmark, .{ a_shape, b_shape }, platform });
|
||||
|
||||
// Wait for compilation to finish
|
||||
const compiled = try compilation.awaitt();
|
||||
const executable = try compilation.awaitt();
|
||||
defer executable.deinit();
|
||||
const compilation_elapsed = timer.lap() / std.time.ns_per_ms;
|
||||
std.debug.print("-" ** 160 ++ "\n\n", .{});
|
||||
std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed});
|
||||
|
||||
// pass the model weights to the compiled module to create an executable module
|
||||
var executable = try compiled.prepare(arena, .{});
|
||||
defer executable.deinit();
|
||||
|
||||
var rng = std.Random.DefaultPrng.init(0);
|
||||
const random = rng.random();
|
||||
|
||||
|
||||
@ -32,8 +32,8 @@ pub const std_options = .{
|
||||
|
||||
pub fn generateText(
|
||||
llama: LlamaLM,
|
||||
mod_prefill: zml.module.ExeWithWeights(LlamaLM.forward),
|
||||
mod: zml.module.ExeWithWeights(LlamaLM.forward),
|
||||
mod_prefill: zml.ModuleExe(LlamaLM.forward),
|
||||
mod: zml.ModuleExe(LlamaLM.forward),
|
||||
tokenizer: zml.tokenizer.Tokenizer,
|
||||
allocator: std.mem.Allocator,
|
||||
seed: u128,
|
||||
@ -221,9 +221,9 @@ pub fn asyncMain() !void {
|
||||
defer zml.aio.unloadBuffers(&llama_weights);
|
||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
var llama_module_prefill = try (try fut_mod_prefill.awaitt()).prepare(allocator, llama_weights);
|
||||
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
||||
defer llama_module_prefill.deinit();
|
||||
var llama_module = try (try fut_mod.awaitt()).prepare(allocator, llama_weights);
|
||||
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
||||
defer llama_module.deinit();
|
||||
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ pub fn asyncMain() !void {
|
||||
const compiled_mnist = try compilation.awaitt();
|
||||
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||
|
||||
var mnist = try compiled_mnist.prepare(allocator, model_weights);
|
||||
const mnist = compiled_mnist.prepare(model_weights);
|
||||
defer mnist.deinit();
|
||||
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ pub fn asyncMain() !void {
|
||||
const compiled = try compilation.awaitt();
|
||||
|
||||
// pass the model weights to the compiled module to create an executable module
|
||||
var executable = try compiled.prepare(arena, model_weights);
|
||||
var executable = compiled.prepare(model_weights);
|
||||
defer executable.deinit();
|
||||
|
||||
// prepare an input buffer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user