Update Llama example to use renamed zml.aio.Metadata (formerly Value) and reflect torch loader changes.

This commit is contained in:
Foke Singh 2023-04-05 14:09:59 +00:00
parent e25f70d923
commit aea23c720e

View File

@ -192,8 +192,8 @@ pub fn asyncMain() !void {
defer ts.deinit();
var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts);
const num_heads: i64 = cli_args.num_heads orelse ts.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model");
const num_kv_heads: i64 = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int64) orelse num_heads;
const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model");
const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads;
const rope_impl = if (ts.metadata("rope_impl", .string)) |val|
std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).?
@ -208,10 +208,10 @@ pub fn asyncMain() !void {
.topk = cli_args.topk,
.temperature = @floatFromInt(cli_args.temperature),
},
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float64) orelse 1e-5),
.rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5),
.rope_opts = .{
.impl = rope_impl,
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
.freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))),
},
};
log.info("✅ Parsed llama config: {}", .{llama_options});