diff --git a/examples/llama/main.zig b/examples/llama/main.zig index 2577f9b..43a3d8f 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -192,8 +192,8 @@ pub fn asyncMain() !void { defer ts.deinit(); var llama = try zml.aio.populateModel(LlamaLM, model_arena, ts); - const num_heads: i64 = cli_args.num_heads orelse ts.metadata("num_heads", .int64) orelse @panic("--num_heads is required for this model"); - const num_kv_heads: i64 = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int64) orelse num_heads; + const num_heads = cli_args.num_heads orelse ts.metadata("num_heads", .int) orelse @panic("--num_heads is required for this model"); + const num_kv_heads = cli_args.num_kv_heads orelse ts.metadata("num_kv_heads", .int) orelse num_heads; const rope_impl = if (ts.metadata("rope_impl", .string)) |val| std.meta.stringToEnum(zml.nn.RopeOpts.Implementation, val).? @@ -208,10 +208,10 @@ pub fn asyncMain() !void { .topk = cli_args.topk, .temperature = @floatFromInt(cli_args.temperature), }, - .rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float64) orelse 1e-5), + .rms_norm_eps = @floatCast(ts.metadata("rms_norm_eps", .float) orelse 1e-5), .rope_opts = .{ .impl = rope_impl, - .freq_base = @floatCast(ts.metadata("rope_freq_base", .float64) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))), + .freq_base = @floatCast(ts.metadata("rope_freq_base", .float) orelse @as(f32, @floatFromInt(cli_args.rope_freq_base orelse 10_000))), }, }; log.info("✅ Parsed llama config: {}", .{llama_options});