Add new Zig example programs (benchmark, llama, loader, mnist, simple_layer) and include a test for the llama example.

This commit is contained in:
Foke Singh 2023-06-27 14:23:22 +00:00
parent 9b7eea8ac2
commit 7985716562
9 changed files with 993 additions and 1095 deletions

File diff suppressed because it is too large Load Diff

View File

@ -17,9 +17,7 @@ const Benchmark = struct {
};
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
@ -48,35 +46,7 @@ pub fn asyncMain() !void {
const platform = context.autoPlatform().withCompilationOptions(.{
.sharding_enabled = true,
});
{
// List available targets
std.debug.print("Available Platforms:\n", .{});
const selected_prefix = "";
const not_selected_prefix = "";
const selected_postfix = "(AUTO-SELECTED)\n";
const not_selected_postfix = "\n";
for (zml.platform.available_targets) |target| {
std.debug.print(" {s} {s} {s}", .{
if (target == platform.target) selected_prefix else not_selected_prefix,
@tagName(target),
if (target == platform.target) selected_postfix else not_selected_postfix,
});
// now the platform's devices
if (context.platforms.get(target)) |pfm| {
for (pfm.getDevices(), 0..) |device, index| {
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
std.debug.print(" ◦ #{d}: {s}\n", .{
index,
deviceKind,
});
// we only list 1 CPU device
if (target == .cpu and platform.sharding().num_partitions == 1) break;
}
}
}
std.debug.print("\n", .{});
}
context.printAvailablePlatforms(platform);
var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs);

View File

@ -13,6 +13,7 @@ zig_cc_binary(
deps = [
"//third_party/tigerbeetle:flags",
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
)
@ -126,6 +127,7 @@ zig_cc_binary(
deps = [
"//third_party/tigerbeetle:flags",
"@zml//async",
"@zml//metax",
"@zml//zml",
],
)

View File

@ -1,16 +1,16 @@
const std = @import("std");
const testing = std.testing;
const zml = @import("zml");
const meta = zml.meta;
const flags = @import("tigerbeetle/flags");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const log = std.log.scoped(.llama);
const gguf = zml.io.gguf;
const testing = std.testing;
const Buffer = zml.Buffer;
const Tensor = zml.Tensor;
const ShapeOf = zml.ShapeOf;
const gguf = zml.io.gguf;
const expectClose = zml.testing.expectClose;
const log = std.log.scoped(.llama);
pub const LlamaOptions = struct {
gen_opts: zml.nn.SamplingStrategy,
@ -72,7 +72,7 @@ pub const LlamaLM = struct {
kv_cache: ?KvCache,
rng: Tensor.Rng,
) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
meta.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
stdx.debug.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
var tokens = tokens_.withPartialTags(.{.s});
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
@ -219,7 +219,7 @@ pub const TransformerLayer = struct {
) struct { Tensor, KvCache } {
// Self Attention
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
meta.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0});
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0});
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
@ -313,7 +313,7 @@ pub const SelfAttn = struct {
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
if (token_index) |_| {
meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
stdx.debug.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
k = new_kv_cache.keys();
v = new_kv_cache.values();
}

View File

@ -1,9 +1,9 @@
const std = @import("std");
const zml = @import("zml");
const meta = zml.meta;
const asynk = @import("async");
const flags = @import("tigerbeetle/flags");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const llama_mod = @import("llama.zig");
const LlamaLM = llama_mod.LlamaLM;
@ -93,7 +93,7 @@ pub fn generateText(
std.debug.print("{s}\n", .{output.items[n..]});
const end = std.time.microTimestamp();
const duration = zml.meta.divFloat(f64, end - start, std.time.us_per_s);
const duration = stdx.math.divFloor(f64, end - start, std.time.us_per_s);
const speed = @as(f64, @floatFromInt(max_seq_len)) / duration;
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed });
@ -106,7 +106,7 @@ pub fn generateText(
}
pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{});
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
@ -131,9 +131,7 @@ pub fn asyncMain() !void {
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
var gpa = std.heap.GeneralPurposeAllocator(.{ .thread_safe = true }){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const allocator = std.heap.c_allocator;
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
try tmp.makePath("zml/llama/cache");
@ -147,35 +145,7 @@ pub fn asyncMain() !void {
};
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
{
// List available targets
std.debug.print("\nSupported Platforms:\n", .{});
const selected_prefix = "";
const not_selected_prefix = "";
const selected_postfix = "(AUTO-SELECTED)\n";
const not_selected_postfix = "\n";
for (zml.platform.available_targets) |target| {
std.debug.print(" {s} {s} {s}", .{
if (target == platform.target) selected_prefix else not_selected_prefix,
@tagName(target),
if (target == platform.target) selected_postfix else not_selected_postfix,
});
// now the platform's devices
if (context.platforms.get(target)) |pfm| {
for (pfm.getDevices(), 0..) |device, index| {
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
std.debug.print(" ◦ #{d}: {s}\n", .{
index,
deviceKind,
});
// we only list 1 CPU device
if (target == .cpu) break;
}
}
}
std.debug.print("\n", .{});
}
context.printAvailablePlatforms(platform);
var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs);
@ -213,7 +183,7 @@ pub fn asyncMain() !void {
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
},
};
log.info(" Parsed llama config: {}", .{llama_options});
log.info("\tParsed llama config: {}", .{llama_options});
llama.init(llama_options);
if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) {
@ -221,9 +191,9 @@ pub fn asyncMain() !void {
@panic("No tokenizer provided");
}
const tokenizer_path = cli_args.tokenizer orelse cli_args.model;
log.info(" Loading tokenizer from {s}", .{tokenizer_path});
log.info("\tLoading tokenizer from {s}", .{tokenizer_path});
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path);
log.info(" Loaded tokenizer from {s}", .{tokenizer_path});
log.info("\tLoaded tokenizer from {s}", .{tokenizer_path});
defer tokenizer.deinit();
const dims = llama.model.shape();
@ -238,25 +208,23 @@ pub fn asyncMain() !void {
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
const rng_shape = Tensor.Rng.shape();
const compile_start = std.time.milliTimestamp();
var fut_mod_prefill = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
var fut_mod = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
var start = try std.time.Timer.start();
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
log.info("Starting loading weights", .{});
log.info("\tLoading Llama weights from {s}...", .{cli_args.model});
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform);
defer zml.aio.unloadBuffers(&llama_weights);
log.info("✅ Done loading weights", .{});
log.info("✅ Llama model loaded from {s}", .{cli_args.model});
log.info("\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
var llama_module_prefill = try (try fut_mod_prefill.await_()).prepare(allocator, llama_weights);
defer llama_module_prefill.deinit();
var llama_module = try (try fut_mod.await_()).prepare(allocator, llama_weights);
defer llama_module.deinit();
const compile_end = std.time.milliTimestamp();
log.info("✅ Compiled model in {d} milliseconds! \n", .{compile_end - compile_start});
log.info("\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily.";
log.info(" Prompt: {s}\n", .{prompt});
log.info("\tPrompt: {s}", .{prompt});
const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt);

View File

@ -1,8 +1,8 @@
const std = @import("std");
const zml = @import("zml");
const asynk = @import("async");
const flags = @import("tigerbeetle/flags");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const llama_mod = @import("./llama.zig");
const LlamaLM = llama_mod.LlamaLM;

View File

@ -1,13 +1,12 @@
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const asynk = @import("async");
const asyncc = asynk.asyncc;
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
@ -36,11 +35,7 @@ pub fn asyncMain() !void {
defer context.deinit();
const platform = context.autoPlatform();
const devices = platform.getDevices();
for (devices) |device| {
std.debug.print("Device visible: {s}\n", .{device.getDescription(platform.pjrt_api).debugString(platform.pjrt_api)});
}
context.printAvailablePlatforms(platform);
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());
defer {
@ -65,8 +60,8 @@ pub fn asyncMain() !void {
}
const stop = timer.read();
const time_in_s = zml.meta.divFloat(f64, stop, std.time.ns_per_s);
const mbs = zml.meta.divFloat(f64, total_bytes, 1024 * 1024);
const time_in_s = stdx.math.divFloor(f64, stop, std.time.ns_per_s);
const mbs = stdx.math.divFloor(f64, total_bytes, 1024 * 1024);
std.debug.print("\nLoading speed: {d:.2} MB/s\n\n", .{mbs / time_in_s});
}

View File

@ -1,9 +1,11 @@
const asynk = @import("async");
const std = @import("std");
const zml = @import("zml");
const asynk = @import("async");
const show_mlir = true;
const log = std.log.scoped(.mnist);
/// Model definition
const Mnist = struct {
fc1: Layer,
@ -31,56 +33,21 @@ const Mnist = struct {
};
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
try asynk.AsyncThread.main(allocator, asyncMain, .{});
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
// Short lived allocations
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
const allocator = std.heap.c_allocator;
// Create ZML context
// // Create ZML context
var context = try zml.Context.init();
defer context.deinit();
std.debug.print("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
// Auto-select platform
// // Auto-select platform
const platform = context.autoPlatform();
{
// List available targets
std.debug.print("Available Platforms:\n", .{});
const selected_prefix = "";
const not_selected_prefix = "";
const selected_postfix = "(AUTO-SELECTED)\n";
const not_selected_postfix = "\n";
for (zml.platform.available_targets) |target| {
std.debug.print(" {s} {s} {s}", .{
if (target == platform.target) selected_prefix else not_selected_prefix,
@tagName(target),
if (target == platform.target) selected_postfix else not_selected_postfix,
});
// now the platform's devices
if (context.platforms.get(target)) |pfm| {
for (pfm.getDevices(), 0..) |device, index| {
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
std.debug.print(" ◦ #{d}: {s}\n", .{
index,
deviceKind,
});
// we only list 1 CPU device
if (target == .cpu) break;
}
}
}
std.debug.print("\n", .{});
}
context.printAvailablePlatforms(platform);
// Parse program args
const process_args = try std.process.argsAlloc(allocator);
@ -99,32 +66,26 @@ pub fn asyncMain() !void {
defer buffer_store.deinit();
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
std.debug.print("✅ Read model shapes from PyTorch file {s}\n", .{pt_model});
log.info("Reading model shapes from PyTorch file {s}...", .{pt_model});
// Start loading weights
// Start compiling
log.info("Compiling model to MLIR....", .{});
var start_time = try std.time.Timer.start();
var compilation = try asynk.asyncc(zml.compile, .{ allocator, Mnist.forward, .{}, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
// While compiling, start loading weights on the platform
var model_weights = try zml.aio.loadModelBuffers(Mnist, mnist_model, buffer_store, arena, platform);
defer zml.aio.unloadBuffers(&model_weights);
// Start compiling
const comp_start_time = std.time.milliTimestamp();
if (show_mlir) {
std.debug.print("\nCompiling model to MLIR....\n", .{});
std.debug.print("-" ** 160 ++ "\n", .{});
} else {
std.debug.print("Compiling model to MLIR....\r", .{});
}
var compilation = try asynk.asyncGeneric(zml.compile, .{ allocator, Mnist, .{}, .forward, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
// Wait for end of compilation and end of weights loading.
const compiled_mnist = try compilation.await_();
const comp_end_time = std.time.milliTimestamp();
if (show_mlir) std.debug.print("-" ** 160 ++ "\n", .{});
std.debug.print("✅ Compiled MNIST model in {d} milliseconds! \n", .{comp_end_time - comp_start_time});
const compiled_mnist = try compilation.wait();
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
// send weights to accelerator / GPU
var mnist = try compiled_mnist.prepare(allocator, model_weights);
defer mnist.deinit();
std.debug.print("✅ Weights transferred, starting inference...\n\n", .{});
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
log.info("Starting inference...", .{});
// Load a random digit image from the dataset.
const dataset = try asynk.File.open(t10kfilename, .{ .mode = .read_only });
@ -143,17 +104,18 @@ pub fn asyncMain() !void {
var result: zml.Buffer = mnist.call(.{input});
defer result.deinit();
std.debug.print("\n✅ RECOGNIZED DIGIT:\n", .{});
std.debug.print(" +-------------+\n", .{});
std.debug.print("{s}\n", .{digits[try result.getValue(u8)]});
std.debug.print(" +-------------+\n\n", .{});
log.info(
\\✅ RECOGNIZED DIGIT:
\\ +-------------+
\\{s}
\\ +-------------+
\\
, .{digits[try result.getValue(u8)]});
}
}
fn printDigit(digit: [28 * 28]u8) void {
var buffer: [28][30][2]u8 = undefined;
std.debug.print(" R E C O G N I Z I N G I N P U T I M A G E :\n", .{});
std.debug.print("+---------------------------------------------------------+\n", .{});
for (0..28) |y| {
buffer[y][0] = .{ '|', ' ' };
buffer[y][29] = .{ '|', '\n' };
@ -168,8 +130,14 @@ fn printDigit(digit: [28 * 28]u8) void {
};
}
}
std.fmt.format(asynk.StdOut().writer(), "{s}", .{std.mem.asBytes(&buffer)}) catch unreachable;
std.debug.print("+---------------------------------------------------------+\n", .{});
log.info(
\\
\\ R E C O G N I Z I N G I N P U T I M A G E :
\\+---------------------------------------------------------+
\\{s}+---------------------------------------------------------+
\\
, .{std.mem.asBytes(&buffer)});
}
const digits = [_][]const u8{
@ -256,10 +224,6 @@ const digits = [_][]const u8{
};
pub const std_options = .{
// Set the global log level to err
.log_level = .err,
.log_scope_levels = &[_]std.log.ScopeLevel{
.{ .scope = .pjrt, .level = .err },
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err },
},
.logFn = asynk.logFn,
.log_level = .info,
};

View File

@ -17,9 +17,7 @@ const Layer = struct {
};
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
@ -37,6 +35,7 @@ pub fn asyncMain() !void {
defer context.deinit();
const platform = context.autoPlatform();
context.printAvailablePlatforms(platform);
// Our weights and bias to use
var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 };