Add Llama example showcasing the new func.call emission and function caching behavior.
This commit is contained in:
parent
7d36913b31
commit
37de7b9613
@ -180,10 +180,8 @@ pub const Llama = struct {
|
|||||||
var updated_kv_cache = kv_cache0;
|
var updated_kv_cache = kv_cache0;
|
||||||
for (self.layers, 0..) |layer, i| {
|
for (self.layers, 0..) |layer, i| {
|
||||||
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
hidden, updated_kv_cache = zml.call(layer, .forward, .{ hidden, token_index, updated_kv_cache.atLayer(i) });
|
||||||
hidden = hidden.withPartialTags(.{ .s, .d });
|
|
||||||
}
|
}
|
||||||
// TODO: tags seem to be lost by `callFunc`.
|
const output = zml.call(self.norm, .forward, .{hidden});
|
||||||
const output = zml.call(self.norm, .forward, .{hidden.withPartialTags(.{ .s, .d })});
|
|
||||||
|
|
||||||
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
|
return .{ output, updated_kv_cache.reuseBuffer(kv_cache0) };
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,12 +21,10 @@ const log = std.log.scoped(.llama);
|
|||||||
const show_mlir = true;
|
const show_mlir = true;
|
||||||
|
|
||||||
pub const std_options = .{
|
pub const std_options = .{
|
||||||
.log_level = .err,
|
.log_level = .warn,
|
||||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||||
.{ .scope = .pjrt, .level = if (show_mlir) .debug else .err },
|
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .warn },
|
||||||
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err },
|
.{ .scope = .llama, .level = .info },
|
||||||
.{ .scope = .zml, .level = if (show_mlir) .debug else .err },
|
|
||||||
.{ .scope = .llama, .level = if (show_mlir) .debug else .info },
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user