Add new Zig example programs (benchmark, llama, loader, mnist, simple_layer) and include a test for the llama example.
This commit is contained in:
parent
9b7eea8ac2
commit
7985716562
File diff suppressed because it is too large
Load Diff
@ -17,9 +17,7 @@ const Benchmark = struct {
|
||||
};
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
@ -48,35 +46,7 @@ pub fn asyncMain() !void {
|
||||
const platform = context.autoPlatform().withCompilationOptions(.{
|
||||
.sharding_enabled = true,
|
||||
});
|
||||
{
|
||||
// List available targets
|
||||
std.debug.print("Available Platforms:\n", .{});
|
||||
const selected_prefix = "✅";
|
||||
const not_selected_prefix = "• ";
|
||||
const selected_postfix = "(AUTO-SELECTED)\n";
|
||||
const not_selected_postfix = "\n";
|
||||
for (zml.platform.available_targets) |target| {
|
||||
std.debug.print(" {s} {s} {s}", .{
|
||||
if (target == platform.target) selected_prefix else not_selected_prefix,
|
||||
@tagName(target),
|
||||
if (target == platform.target) selected_postfix else not_selected_postfix,
|
||||
});
|
||||
|
||||
// now the platform's devices
|
||||
if (context.platforms.get(target)) |pfm| {
|
||||
for (pfm.getDevices(), 0..) |device, index| {
|
||||
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
|
||||
std.debug.print(" ◦ #{d}: {s}\n", .{
|
||||
index,
|
||||
deviceKind,
|
||||
});
|
||||
// we only list 1 CPU device
|
||||
if (target == .cpu and platform.sharding().num_partitions == 1) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
var args = std.process.args();
|
||||
const cli_args = flags.parse(&args, CliArgs);
|
||||
|
||||
@ -13,6 +13,7 @@ zig_cc_binary(
|
||||
deps = [
|
||||
"//third_party/tigerbeetle:flags",
|
||||
"@zml//async",
|
||||
"@zml//stdx",
|
||||
"@zml//zml",
|
||||
],
|
||||
)
|
||||
@ -126,6 +127,7 @@ zig_cc_binary(
|
||||
deps = [
|
||||
"//third_party/tigerbeetle:flags",
|
||||
"@zml//async",
|
||||
"@zml//metax",
|
||||
"@zml//zml",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
|
||||
const zml = @import("zml");
|
||||
const meta = zml.meta;
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const log = std.log.scoped(.llama);
|
||||
const gguf = zml.io.gguf;
|
||||
const testing = std.testing;
|
||||
const Buffer = zml.Buffer;
|
||||
const Tensor = zml.Tensor;
|
||||
const ShapeOf = zml.ShapeOf;
|
||||
|
||||
const gguf = zml.io.gguf;
|
||||
const expectClose = zml.testing.expectClose;
|
||||
const log = std.log.scoped(.llama);
|
||||
|
||||
pub const LlamaOptions = struct {
|
||||
gen_opts: zml.nn.SamplingStrategy,
|
||||
@ -72,7 +72,7 @@ pub const LlamaLM = struct {
|
||||
kv_cache: ?KvCache,
|
||||
rng: Tensor.Rng,
|
||||
) struct { Tensor, Tensor, KvCache, Tensor.Rng } {
|
||||
meta.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||
stdx.debug.assert(tokens_.dtype() == .i32 and tokens_.rank() >= 1 and token_index.dtype() == .i32 and token_index.rank() == 0, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
||||
|
||||
var tokens = tokens_.withPartialTags(.{.s});
|
||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, if (kv_cache == null) null else token_index, kv_cache });
|
||||
@ -219,7 +219,7 @@ pub const TransformerLayer = struct {
|
||||
) struct { Tensor, KvCache } {
|
||||
// Self Attention
|
||||
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
||||
meta.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0});
|
||||
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0});
|
||||
|
||||
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
|
||||
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
|
||||
@ -313,7 +313,7 @@ pub const SelfAttn = struct {
|
||||
|
||||
const new_kv_cache = kv_cache.update(k, v, token_index orelse Tensor.scalar(0, .i32));
|
||||
if (token_index) |_| {
|
||||
meta.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
|
||||
stdx.debug.assert(q.dim(.q) == 1, "Expected dimension .q to be 1, got {}", .{q.dim(.q)});
|
||||
k = new_kv_cache.keys();
|
||||
v = new_kv_cache.values();
|
||||
}
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
const std = @import("std");
|
||||
|
||||
const zml = @import("zml");
|
||||
const meta = zml.meta;
|
||||
const asynk = @import("async");
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const llama_mod = @import("llama.zig");
|
||||
|
||||
const LlamaLM = llama_mod.LlamaLM;
|
||||
@ -93,7 +93,7 @@ pub fn generateText(
|
||||
std.debug.print("{s}\n", .{output.items[n..]});
|
||||
const end = std.time.microTimestamp();
|
||||
|
||||
const duration = zml.meta.divFloat(f64, end - start, std.time.us_per_s);
|
||||
const duration = stdx.math.divFloor(f64, end - start, std.time.us_per_s);
|
||||
const speed = @as(f64, @floatFromInt(max_seq_len)) / duration;
|
||||
log.info("✅ Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ max_seq_len, duration, speed });
|
||||
|
||||
@ -106,7 +106,7 @@ pub fn generateText(
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
@ -131,9 +131,7 @@ pub fn asyncMain() !void {
|
||||
|
||||
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
|
||||
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{ .thread_safe = true }){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
const allocator = std.heap.c_allocator;
|
||||
|
||||
const tmp = try std.fs.openDirAbsolute("/tmp", .{});
|
||||
try tmp.makePath("zml/llama/cache");
|
||||
@ -147,35 +145,7 @@ pub fn asyncMain() !void {
|
||||
};
|
||||
|
||||
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
|
||||
{
|
||||
// List available targets
|
||||
std.debug.print("\nSupported Platforms:\n", .{});
|
||||
const selected_prefix = "✅";
|
||||
const not_selected_prefix = "• ";
|
||||
const selected_postfix = "(AUTO-SELECTED)\n";
|
||||
const not_selected_postfix = "\n";
|
||||
for (zml.platform.available_targets) |target| {
|
||||
std.debug.print(" {s} {s} {s}", .{
|
||||
if (target == platform.target) selected_prefix else not_selected_prefix,
|
||||
@tagName(target),
|
||||
if (target == platform.target) selected_postfix else not_selected_postfix,
|
||||
});
|
||||
|
||||
// now the platform's devices
|
||||
if (context.platforms.get(target)) |pfm| {
|
||||
for (pfm.getDevices(), 0..) |device, index| {
|
||||
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
|
||||
std.debug.print(" ◦ #{d}: {s}\n", .{
|
||||
index,
|
||||
deviceKind,
|
||||
});
|
||||
// we only list 1 CPU device
|
||||
if (target == .cpu) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
var args = std.process.args();
|
||||
const cli_args = flags.parse(&args, CliArgs);
|
||||
@ -213,7 +183,7 @@ pub fn asyncMain() !void {
|
||||
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
||||
},
|
||||
};
|
||||
log.info("✅ Parsed llama config: {}", .{llama_options});
|
||||
log.info("✅\tParsed llama config: {}", .{llama_options});
|
||||
llama.init(llama_options);
|
||||
|
||||
if (cli_args.tokenizer == null and !std.mem.endsWith(u8, cli_args.model, ".gguf")) {
|
||||
@ -221,9 +191,9 @@ pub fn asyncMain() !void {
|
||||
@panic("No tokenizer provided");
|
||||
}
|
||||
const tokenizer_path = cli_args.tokenizer orelse cli_args.model;
|
||||
log.info(" Loading tokenizer from {s}", .{tokenizer_path});
|
||||
log.info("\tLoading tokenizer from {s}", .{tokenizer_path});
|
||||
var tokenizer = try zml.aio.detectFormatAndLoadTokenizer(allocator, tokenizer_path);
|
||||
log.info("✅ Loaded tokenizer from {s}", .{tokenizer_path});
|
||||
log.info("✅\tLoaded tokenizer from {s}", .{tokenizer_path});
|
||||
defer tokenizer.deinit();
|
||||
|
||||
const dims = llama.model.shape();
|
||||
@ -238,25 +208,23 @@ pub fn asyncMain() !void {
|
||||
const kv_cache_shape: ?ShapeOf(KvCache) = KvCache.initShape(kv_shape);
|
||||
const rng_shape = Tensor.Rng.shape();
|
||||
|
||||
const compile_start = std.time.milliTimestamp();
|
||||
var fut_mod_prefill = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
|
||||
var fut_mod = try asynk.asyncGeneric(zml.compile, .{ allocator, LlamaLM, .{llama_options}, .forward, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
|
||||
var start = try std.time.Timer.start();
|
||||
var fut_mod_prefill = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, null, rng_shape }, ts, platform });
|
||||
var fut_mod = try asynk.asyncc(zml.compile, .{ allocator, LlamaLM.forward, .{llama_options}, .{ tokens_shape, token_idx_shape, kv_cache_shape, rng_shape }, ts, platform });
|
||||
|
||||
log.info("Starting loading weights", .{});
|
||||
log.info("\tLoading Llama weights from {s}...", .{cli_args.model});
|
||||
var llama_weights = try zml.aio.loadBuffers(LlamaLM, .{llama_options}, ts, model_arena, platform);
|
||||
defer zml.aio.unloadBuffers(&llama_weights);
|
||||
log.info("✅ Done loading weights", .{});
|
||||
log.info("✅ Llama model loaded from {s}", .{cli_args.model});
|
||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
var llama_module_prefill = try (try fut_mod_prefill.await_()).prepare(allocator, llama_weights);
|
||||
defer llama_module_prefill.deinit();
|
||||
var llama_module = try (try fut_mod.await_()).prepare(allocator, llama_weights);
|
||||
defer llama_module.deinit();
|
||||
const compile_end = std.time.milliTimestamp();
|
||||
log.info("✅ Compiled model in {d} milliseconds! \n", .{compile_end - compile_start});
|
||||
log.info("✅\tCompiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
||||
|
||||
const prompt = cli_args.prompt orelse "Once upon a time, there was a little girl named Lily.";
|
||||
log.info("✅ Prompt: {s}\n", .{prompt});
|
||||
log.info("✅\tPrompt: {s}", .{prompt});
|
||||
|
||||
const seed = cli_args.seed orelse @as(u128, @bitCast(std.time.nanoTimestamp()));
|
||||
const story = try generateText(llama, llama_module_prefill, llama_module, tokenizer, allocator, seed, prompt);
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
const std = @import("std");
|
||||
|
||||
const zml = @import("zml");
|
||||
const asynk = @import("async");
|
||||
const flags = @import("tigerbeetle/flags");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const llama_mod = @import("./llama.zig");
|
||||
const LlamaLM = llama_mod.LlamaLM;
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
const asynk = @import("async");
|
||||
|
||||
const asyncc = asynk.asyncc;
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
@ -36,11 +35,7 @@ pub fn asyncMain() !void {
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
const devices = platform.getDevices();
|
||||
|
||||
for (devices) |device| {
|
||||
std.debug.print("Device visible: {s}\n", .{device.getDescription(platform.pjrt_api).debugString(platform.pjrt_api)});
|
||||
}
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());
|
||||
defer {
|
||||
@ -65,8 +60,8 @@ pub fn asyncMain() !void {
|
||||
}
|
||||
|
||||
const stop = timer.read();
|
||||
const time_in_s = zml.meta.divFloat(f64, stop, std.time.ns_per_s);
|
||||
const mbs = zml.meta.divFloat(f64, total_bytes, 1024 * 1024);
|
||||
const time_in_s = stdx.math.divFloor(f64, stop, std.time.ns_per_s);
|
||||
const mbs = stdx.math.divFloor(f64, total_bytes, 1024 * 1024);
|
||||
|
||||
std.debug.print("\nLoading speed: {d:.2} MB/s\n\n", .{mbs / time_in_s});
|
||||
}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
const asynk = @import("async");
|
||||
const std = @import("std");
|
||||
const zml = @import("zml");
|
||||
const asynk = @import("async");
|
||||
|
||||
const show_mlir = true;
|
||||
|
||||
const log = std.log.scoped(.mnist);
|
||||
|
||||
/// Model definition
|
||||
const Mnist = struct {
|
||||
fc1: Layer,
|
||||
@ -31,56 +33,21 @@ const Mnist = struct {
|
||||
};
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
|
||||
try asynk.AsyncThread.main(allocator, asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
// Short lived allocations
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
const allocator = std.heap.c_allocator;
|
||||
|
||||
// Create ZML context
|
||||
// // Create ZML context
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
std.debug.print("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
|
||||
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
|
||||
|
||||
// Auto-select platform
|
||||
// // Auto-select platform
|
||||
const platform = context.autoPlatform();
|
||||
{
|
||||
// List available targets
|
||||
std.debug.print("Available Platforms:\n", .{});
|
||||
const selected_prefix = "✅";
|
||||
const not_selected_prefix = "• ";
|
||||
const selected_postfix = "(AUTO-SELECTED)\n";
|
||||
const not_selected_postfix = "\n";
|
||||
for (zml.platform.available_targets) |target| {
|
||||
std.debug.print(" {s} {s} {s}", .{
|
||||
if (target == platform.target) selected_prefix else not_selected_prefix,
|
||||
@tagName(target),
|
||||
if (target == platform.target) selected_postfix else not_selected_postfix,
|
||||
});
|
||||
|
||||
// now the platform's devices
|
||||
if (context.platforms.get(target)) |pfm| {
|
||||
for (pfm.getDevices(), 0..) |device, index| {
|
||||
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
|
||||
std.debug.print(" ◦ #{d}: {s}\n", .{
|
||||
index,
|
||||
deviceKind,
|
||||
});
|
||||
// we only list 1 CPU device
|
||||
if (target == .cpu) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Parse program args
|
||||
const process_args = try std.process.argsAlloc(allocator);
|
||||
@ -99,32 +66,26 @@ pub fn asyncMain() !void {
|
||||
defer buffer_store.deinit();
|
||||
|
||||
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
||||
std.debug.print("✅ Read model shapes from PyTorch file {s}\n", .{pt_model});
|
||||
log.info("Reading model shapes from PyTorch file {s}...", .{pt_model});
|
||||
|
||||
// Start loading weights
|
||||
// Start compiling
|
||||
log.info("Compiling model to MLIR....", .{});
|
||||
var start_time = try std.time.Timer.start();
|
||||
var compilation = try asynk.asyncc(zml.compile, .{ allocator, Mnist.forward, .{}, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
|
||||
|
||||
// While compiling, start loading weights on the platform
|
||||
var model_weights = try zml.aio.loadModelBuffers(Mnist, mnist_model, buffer_store, arena, platform);
|
||||
defer zml.aio.unloadBuffers(&model_weights);
|
||||
|
||||
// Start compiling
|
||||
const comp_start_time = std.time.milliTimestamp();
|
||||
if (show_mlir) {
|
||||
std.debug.print("\nCompiling model to MLIR....\n", .{});
|
||||
std.debug.print("-" ** 160 ++ "\n", .{});
|
||||
} else {
|
||||
std.debug.print("Compiling model to MLIR....\r", .{});
|
||||
}
|
||||
var compilation = try asynk.asyncGeneric(zml.compile, .{ allocator, Mnist, .{}, .forward, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
|
||||
|
||||
// Wait for end of compilation and end of weights loading.
|
||||
const compiled_mnist = try compilation.await_();
|
||||
const comp_end_time = std.time.milliTimestamp();
|
||||
if (show_mlir) std.debug.print("-" ** 160 ++ "\n", .{});
|
||||
std.debug.print("✅ Compiled MNIST model in {d} milliseconds! \n", .{comp_end_time - comp_start_time});
|
||||
const compiled_mnist = try compilation.wait();
|
||||
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||
|
||||
// send weights to accelerator / GPU
|
||||
var mnist = try compiled_mnist.prepare(allocator, model_weights);
|
||||
defer mnist.deinit();
|
||||
std.debug.print("✅ Weights transferred, starting inference...\n\n", .{});
|
||||
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
||||
|
||||
log.info("Starting inference...", .{});
|
||||
|
||||
// Load a random digit image from the dataset.
|
||||
const dataset = try asynk.File.open(t10kfilename, .{ .mode = .read_only });
|
||||
@ -143,17 +104,18 @@ pub fn asyncMain() !void {
|
||||
var result: zml.Buffer = mnist.call(.{input});
|
||||
defer result.deinit();
|
||||
|
||||
std.debug.print("\n✅ RECOGNIZED DIGIT:\n", .{});
|
||||
std.debug.print(" +-------------+\n", .{});
|
||||
std.debug.print("{s}\n", .{digits[try result.getValue(u8)]});
|
||||
std.debug.print(" +-------------+\n\n", .{});
|
||||
log.info(
|
||||
\\✅ RECOGNIZED DIGIT:
|
||||
\\ +-------------+
|
||||
\\{s}
|
||||
\\ +-------------+
|
||||
\\
|
||||
, .{digits[try result.getValue(u8)]});
|
||||
}
|
||||
}
|
||||
|
||||
fn printDigit(digit: [28 * 28]u8) void {
|
||||
var buffer: [28][30][2]u8 = undefined;
|
||||
std.debug.print(" R E C O G N I Z I N G I N P U T I M A G E :\n", .{});
|
||||
std.debug.print("+---------------------------------------------------------+\n", .{});
|
||||
for (0..28) |y| {
|
||||
buffer[y][0] = .{ '|', ' ' };
|
||||
buffer[y][29] = .{ '|', '\n' };
|
||||
@ -168,8 +130,14 @@ fn printDigit(digit: [28 * 28]u8) void {
|
||||
};
|
||||
}
|
||||
}
|
||||
std.fmt.format(asynk.StdOut().writer(), "{s}", .{std.mem.asBytes(&buffer)}) catch unreachable;
|
||||
std.debug.print("+---------------------------------------------------------+\n", .{});
|
||||
|
||||
log.info(
|
||||
\\
|
||||
\\ R E C O G N I Z I N G I N P U T I M A G E :
|
||||
\\+---------------------------------------------------------+
|
||||
\\{s}+---------------------------------------------------------+
|
||||
\\
|
||||
, .{std.mem.asBytes(&buffer)});
|
||||
}
|
||||
|
||||
const digits = [_][]const u8{
|
||||
@ -256,10 +224,6 @@ const digits = [_][]const u8{
|
||||
};
|
||||
|
||||
pub const std_options = .{
|
||||
// Set the global log level to err
|
||||
.log_level = .err,
|
||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||
.{ .scope = .pjrt, .level = .err },
|
||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err },
|
||||
},
|
||||
.logFn = asynk.logFn,
|
||||
.log_level = .info,
|
||||
};
|
||||
|
||||
@ -17,9 +17,7 @@ const Layer = struct {
|
||||
};
|
||||
|
||||
pub fn main() !void {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
try asynk.AsyncThread.main(gpa.allocator(), asyncMain, .{});
|
||||
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||||
}
|
||||
|
||||
pub fn asyncMain() !void {
|
||||
@ -37,6 +35,7 @@ pub fn asyncMain() !void {
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Our weights and bias to use
|
||||
var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 };
|
||||
|
||||
Loading…
Reference in New Issue
Block a user