Update Llama example to use renamed zml.aio.Metadata (formerly Value) and reflect torch loader changes.
This commit is contained in:
parent
e25f70d923
commit
aea23c720e
@ -192,8 +192,8 @@ pub fn asyncMain() !void {
|
||||
defer ts.deinit();
|
||||
|
||||
var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts);
|
||||
const num_heads: i64 = cli_args.num_heads orelse ts.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model");
|
||||
const num_kv_heads: i64 = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int64) orelse num_heads;
|
||||
const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
|
||||
const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads;
|
||||
|
||||
const rope_impl = if (ts.metadata("rope_impl", .string)) |val|
|
||||
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
|
||||
@ -208,10 +208,10 @@ pub fn asyncMain() !void {
|
||||
.topk = cli_args.topk,
|
||||
.temperature = @floatFromInt(cli_args.temperature),
|
||||
},
|
||||
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float64) orelse 1e-5),
|
||||
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5),
|
||||
.rope_opts = .{
|
||||
.impl = rope_impl,
|
||||
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
||||
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
|
||||
},
|
||||
};
|
||||
log.info("✅ Parsed llama config: {}", .{llama_options});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user