From 9f61a8aacb49e8d43ab78fcf6b867278f8b924f2 Mon Sep 17 00:00:00 2001 From: Foke Singh Date: Fri, 4 Oct 2024 17:49:07 +0000 Subject: [PATCH] Update example Zig code for llama3 rope scaling and modernbert usage. --- examples/llama/llama.zig | 19 +++++++------------ examples/modernbert/modernbert.zig | 6 +++--- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/examples/llama/llama.zig b/examples/llama/llama.zig index 9663035..8314a30 100644 --- a/examples/llama/llama.zig +++ b/examples/llama/llama.zig @@ -1,14 +1,12 @@ const std = @import("std"); +const testing = std.testing; + const stdx = @import("stdx"); const zml = @import("zml"); - -const testing = std.testing; const Buffer = zml.Buffer; const Tensor = zml.Tensor; const ShapeOf = zml.ShapeOf; -const gguf = zml.io.gguf; -const expectClose = zml.testing.expectClose; const log = std.log.scoped(.llama); /// Llama architecture, using huggingface transformers naming. @@ -27,6 +25,7 @@ pub const LlamaLM = struct { max_position_embeddings: usize, rms_norm_eps: f32, hf_rope_impl: bool = true, + rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} }, }; pub const Options = struct { @@ -48,8 +47,9 @@ pub const LlamaLM = struct { self.model.num_heads = @intCast(config.num_attention_heads); self.model.num_kv_heads = @intCast(config.num_key_value_heads); self.model.rope_opts = .{ - .impl = if (config.hf_rope_impl) .sequential else .interleaved, + .layout = if (config.hf_rope_impl) .sequential else .interleaved, .freq_base = config.rope_theta, + .scaling = config.rope_scaling, }; for (self.model.layers) |*layer| { layer.self_attn.num_heads = self.model.num_heads; @@ -131,7 +131,7 @@ pub const Llama = struct { num_heads: i64 = 32, num_kv_heads: i64 = 32, rope_opts: zml.nn.RopeOpts = .{ - .impl = .interleaved, + .layout = .interleaved, .freq_base = 10_000, }, @@ -221,12 +221,7 @@ const RmsNorm = struct { /// L2 normalization of input tensor along `.d` axis. pub fn forward(self: RmsNorm, input: Tensor) Tensor { const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d}); - // upcast to improve precision - const xf32 = x.convert(.f32); - const mean = xf32.mul(xf32).mean(.d); - const rsqrt = Tensor.rsqrt(mean.addConstant(self.eps)).convert(x.dtype()); - const normalized = x.mul(rsqrt.broad(x.shape())); - + const normalized = zml.nn.rmsNorm(x, .d, self.eps); return normalized.mul(self.weight.convert(x.dtype()).withTags(.{.d}).broad(x.shape())); } }; diff --git a/examples/modernbert/modernbert.zig b/examples/modernbert/modernbert.zig index 4a1182b..831610c 100644 --- a/examples/modernbert/modernbert.zig +++ b/examples/modernbert/modernbert.zig @@ -1,12 +1,12 @@ const std = @import("std"); -const log = std.log.scoped(.modernbert); const asynk = @import("async"); const stdx = @import("stdx"); const zml = @import("zml"); - const Tensor = zml.Tensor; +const log = std.log.scoped(.modernbert); + pub const ModernBertOptions = struct { num_attention_heads: i64, pad_token: u32, @@ -222,7 +222,7 @@ pub const ModernBertAttention = struct { // Layer 0, 3, 6, 9, 12 ... use global RoPE // Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE const rope_opts = zml.nn.RopeOpts{ - .impl = .sequential, + .layout = .sequential, .freq_base = if (self.is_global_attention) 160_000 else 10_000, };