From 37de7b961344af946f37cb186ce4b11fb6668b64 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Tue, 17 Oct 2023 11:00:37 +0000 Subject: [PATCH] Add Llama example showcasing the new `func.call` emission and function caching behavior. --- examples/llama/llama.zig | 4 +--- examples/llama/main.zig | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index b28b86b..6fe74cc 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -180,10 +180,8 @@ pub const Llama = struct { var updated_kv_cache = kv_cache0; for (self.layers, 0..) |layer, i| { hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) }); - hidden = hidden.withPartialTags(.{ .s, .d }); } - // TODO: tags seem to be lost by `callFunc`. - const output = zml.call(self.norm, .forward, .{hidden.withPartialTags(.{ .s, .d })}); + const output = zml.call(self.norm, .forward, .{hidden}); return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) }; } diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 1442a7c..3909b46 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -21,12 +21,10 @@ const log = std.log.scoped(.llama); const show_mlir = true; pub const std_options = .{ - .log_level = .err, + .log_level = .warn, .log_scope_levels = &[_]std.log.ScopeLevel{ - .{ .scope = .pjrt, .level = if (show_mlir) .debug else .err }, - .{ .scope = .zml_module, .level = if (show_mlir) .debug else .err }, - .{ .scope = .zml, .level = if (show_mlir) .debug else .err }, - .{ .scope = .llama, .level = if (show_mlir) .debug else .info }, + .{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn }, + .{ .scope = .llama, .level = .info }, }, };