Add Llama example showcasing the new func.call emission and function caching behavior.

This commit is contained in:
Foke Singh 2023-10-17 11:00:37 +00:00
parent 7d36913b31
commit 37de7b9613
2 changed files with 4 additions and 8 deletions

View File

@ -180,10 +180,8 @@ pub const Llama = struct {
var updated_kv_cache = kv_cache0; var updated_kv_cache = kv_cache0;
for (self.layers, 0..) |layer, i| { for (self.layers, 0..) |layer, i| {
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) }); hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
hidden = hidden.withPartialTags(.{ .s, .d });
} }
// TODO: tags seem to be lost by `callFunc`. const output = zml.call(self.norm, .forward, .{hidden});
const output = zml.call(self.norm, .forward, .{hidden.withPartialTags(.{ .s, .d })});
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) }; return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
} }

View File

@ -21,12 +21,10 @@ const log = std.log.scoped(.llama);
const show_mlir = true; const show_mlir = true;
pub const std_options = .{ pub const std_options = .{
.log_level = .err, .log_level = .warn,
.log_scope_levels = &[_]std.log.ScopeLevel{ .log_scope_levels = &[_]std.log.ScopeLevel{
.{ .scope = .pjrt, .level = if (show_mlir) .debug else .err }, .{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err }, .{ .scope = .llama, .level = .info },
.{ .scope = .zml, .level = if (show_mlir) .debug else .err },
.{ .scope = .llama, .level = if (show_mlir) .debug else .info },
}, },
}; };