Add Llama example showcasing the new func.call emission and function caching behavior.
This commit is contained in:
parent
7d36913b31
commit
37de7b9613
@ -180,10 +180,8 @@ pub const Llama = struct {
|
||||
var updated_kv_cache = kv_cache0;
|
||||
for (self.layers, 0..) |layer, i| {
|
||||
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
||||
hidden = hidden.withPartialTags(.{ .s, .d });
|
||||
}
|
||||
// TODO: tags seem to be lost by `callFunc`.
|
||||
const output = zml.call(self.norm, .forward, .{hidden.withPartialTags(.{ .s, .d })});
|
||||
const output = zml.call(self.norm, .forward, .{hidden});
|
||||
|
||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
|
||||
}
|
||||
|
||||
@ -21,12 +21,10 @@ const log = std.log.scoped(.llama);
|
||||
const show_mlir = true;
|
||||
|
||||
pub const std_options = .{
|
||||
.log_level = .err,
|
||||
.log_level = .warn,
|
||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||
.{ .scope = .pjrt, .level = if (show_mlir) .debug else .err },
|
||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err },
|
||||
.{ .scope = .zml, .level = if (show_mlir) .debug else .err },
|
||||
.{ .scope = .llama, .level = if (show_mlir) .debug else .info },
|
||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
|
||||
.{ .scope = .llama, .level = .info },
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user