Update example programs (benchmark, llama, mnist, simple_layer) to use the new Exe API and reflect BaseExe allocation changes.
This commit is contained in:
parent
3bc6ad98be
commit
35395c13f8
@ -5,16 +5,12 @@ const flags = @import("tigerbeetle/flags");
|
|||||||
|
|
||||||
// set log level to debug to print the generated IR
|
// set log level to debug to print the generated IR
|
||||||
pub const std_options = .{
|
pub const std_options = .{
|
||||||
.log_level = .debug,
|
.log_level = .warn,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Model definition
|
pub fn benchmark(a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
||||||
const Benchmark = struct {
|
|
||||||
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
|
||||||
_ = self;
|
|
||||||
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
|
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
pub fn main() !void {
|
pub fn main() !void {
|
||||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||||
@ -34,11 +30,6 @@ pub fn asyncMain() !void {
|
|||||||
defer _ = gpa.deinit();
|
defer _ = gpa.deinit();
|
||||||
const allocator = gpa.allocator();
|
const allocator = gpa.allocator();
|
||||||
|
|
||||||
// Arena allocator for BufferStore etc.
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|
||||||
defer arena_state.deinit();
|
|
||||||
const arena = arena_state.allocator();
|
|
||||||
|
|
||||||
var context = try zml.Context.init();
|
var context = try zml.Context.init();
|
||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
|
|
||||||
@ -60,18 +51,15 @@ pub fn asyncMain() !void {
|
|||||||
// Start compiling.
|
// Start compiling.
|
||||||
// The shape of the input tensor, we have to pass in manually.
|
// The shape of the input tensor, we have to pass in manually.
|
||||||
timer.reset();
|
timer.reset();
|
||||||
var compilation = try asynk.asyncc(zml.module.compileModel, .{ allocator, Benchmark.forward, Benchmark{}, .{ a_shape, b_shape }, platform });
|
var compilation = try asynk.asyncc(zml.compileFn, .{ allocator, benchmark, .{ a_shape, b_shape }, platform });
|
||||||
|
|
||||||
// Wait for compilation to finish
|
// Wait for compilation to finish
|
||||||
const compiled = try compilation.awaitt();
|
const executable = try compilation.awaitt();
|
||||||
|
defer executable.deinit();
|
||||||
const compilation_elapsed = timer.lap() / std.time.ns_per_ms;
|
const compilation_elapsed = timer.lap() / std.time.ns_per_ms;
|
||||||
std.debug.print("-" ** 160 ++ "\n\n", .{});
|
std.debug.print("-" ** 160 ++ "\n\n", .{});
|
||||||
std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed});
|
std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed});
|
||||||
|
|
||||||
// pass the model weights to the compiled module to create an executable module
|
|
||||||
var executable = try compiled.prepare(arena, .{});
|
|
||||||
defer executable.deinit();
|
|
||||||
|
|
||||||
var rng = std.Random.DefaultPrng.init(0);
|
var rng = std.Random.DefaultPrng.init(0);
|
||||||
const random = rng.random();
|
const random = rng.random();
|
||||||
|
|
||||||
|
|||||||
@ -32,8 +32,8 @@ pub const std_options = .{
|
|||||||
|
|
||||||
pub fn generateText(
|
pub fn generateText(
|
||||||
llama: LlamaLM,
|
llama: LlamaLM,
|
||||||
mod_prefill: zml.module.ExeWithWeights(LlamaLM.forward),
|
mod_prefill: zml.ModuleExe(LlamaLM.forward),
|
||||||
mod: zml.module.ExeWithWeights(LlamaLM.forward),
|
mod: zml.ModuleExe(LlamaLM.forward),
|
||||||
tokenizer: zml.tokenizer.Tokenizer,
|
tokenizer: zml.tokenizer.Tokenizer,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
seed: u128,
|
seed: u128,
|
||||||
@ -221,9 +221,9 @@ pub fn asyncMain() !void {
|
|||||||
defer zml.aio.unloadBuffers(&llama_weights);
|
defer zml.aio.unloadBuffers(&llama_weights);
|
||||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||||
|
|
||||||
var llama_module_prefill = try (try fut_mod_prefill.awaitt()).prepare(allocator, llama_weights);
|
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module_prefill.deinit();
|
defer llama_module_prefill.deinit();
|
||||||
var llama_module = try (try fut_mod.awaitt()).prepare(allocator, llama_weights);
|
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module.deinit();
|
defer llama_module.deinit();
|
||||||
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||||
|
|
||||||
|
|||||||
@ -86,7 +86,7 @@ pub fn asyncMain() !void {
|
|||||||
const compiled_mnist = try compilation.awaitt();
|
const compiled_mnist = try compilation.awaitt();
|
||||||
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||||
|
|
||||||
var mnist = try compiled_mnist.prepare(allocator, model_weights);
|
const mnist = compiled_mnist.prepare(model_weights);
|
||||||
defer mnist.deinit();
|
defer mnist.deinit();
|
||||||
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||||
|
|
||||||
|
|||||||
@ -74,7 +74,7 @@ pub fn asyncMain() !void {
|
|||||||
const compiled = try compilation.awaitt();
|
const compiled = try compilation.awaitt();
|
||||||
|
|
||||||
// pass the model weights to the compiled module to create an executable module
|
// pass the model weights to the compiled module to create an executable module
|
||||||
var executable = try compiled.prepare(arena, model_weights);
|
var executable = compiled.prepare(model_weights);
|
||||||
defer executable.deinit();
|
defer executable.deinit();
|
||||||
|
|
||||||
// prepare an input buffer
|
// prepare an input buffer
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user