Update example Zig code for llama3 rope scaling and modernbert usage.
This commit is contained in:
parent
d0cf5d3042
commit
9f61a8aacb
@ -1,14 +1,12 @@
|
||||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const testing = std.testing;
|
||||
const Buffer = zml.Buffer;
|
||||
const Tensor = zml.Tensor;
|
||||
const ShapeOf = zml.ShapeOf;
|
||||
|
||||
const gguf = zml.io.gguf;
|
||||
const expectClose = zml.testing.expectClose;
|
||||
const log = std.log.scoped(.llama);
|
||||
|
||||
/// Llama architecture, using huggingface transformers naming.
|
||||
@ -27,6 +25,7 @@ pub const LlamaLM = struct {
|
||||
max_position_embeddings: usize,
|
||||
rms_norm_eps: f32,
|
||||
hf_rope_impl: bool = true,
|
||||
rope_scaling: zml.nn.RopeOpts.Scaling = .{ .default = {} },
|
||||
};
|
||||
|
||||
pub const Options = struct {
|
||||
@ -48,8 +47,9 @@ pub const LlamaLM = struct {
|
||||
self.model.num_heads = @intCast(config.num_attention_heads);
|
||||
self.model.num_kv_heads = @intCast(config.num_key_value_heads);
|
||||
self.model.rope_opts = .{
|
||||
.impl = if (config.hf_rope_impl) .sequential else .interleaved,
|
||||
.layout = if (config.hf_rope_impl) .sequential else .interleaved,
|
||||
.freq_base = config.rope_theta,
|
||||
.scaling = config.rope_scaling,
|
||||
};
|
||||
for (self.model.layers) |*layer| {
|
||||
layer.self_attn.num_heads = self.model.num_heads;
|
||||
@ -131,7 +131,7 @@ pub const Llama = struct {
|
||||
num_heads: i64 = 32,
|
||||
num_kv_heads: i64 = 32,
|
||||
rope_opts: zml.nn.RopeOpts = .{
|
||||
.impl = .interleaved,
|
||||
.layout = .interleaved,
|
||||
.freq_base = 10_000,
|
||||
},
|
||||
|
||||
@ -221,12 +221,7 @@ const RmsNorm = struct {
|
||||
/// L2 normalization of input tensor along `.d` axis.
|
||||
pub fn forward(self: RmsNorm, input: Tensor) Tensor {
|
||||
const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d});
|
||||
// upcast to improve precision
|
||||
const xf32 = x.convert(.f32);
|
||||
const mean = xf32.mul(xf32).mean(.d);
|
||||
const rsqrt = Tensor.rsqrt(mean.addConstant(self.eps)).convert(x.dtype());
|
||||
const normalized = x.mul(rsqrt.broad(x.shape()));
|
||||
|
||||
const normalized = zml.nn.rmsNorm(x, .d, self.eps);
|
||||
return normalized.mul(self.weight.convert(x.dtype()).withTags(.{.d}).broad(x.shape()));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.modernbert);
|
||||
|
||||
const asynk = @import("async");
|
||||
const stdx = @import("stdx");
|
||||
const zml = @import("zml");
|
||||
|
||||
const Tensor = zml.Tensor;
|
||||
|
||||
const log = std.log.scoped(.modernbert);
|
||||
|
||||
pub const ModernBertOptions = struct {
|
||||
num_attention_heads: i64,
|
||||
pad_token: u32,
|
||||
@ -222,7 +222,7 @@ pub const ModernBertAttention = struct {
|
||||
// Layer 0, 3, 6, 9, 12 ... use global RoPE
|
||||
// Layer 1, 2, 4, 5, 7, 8, 10, 11 ... use local RoPE
|
||||
const rope_opts = zml.nn.RopeOpts{
|
||||
.impl = .sequential,
|
||||
.layout = .sequential,
|
||||
.freq_base = if (self.is_global_attention) 160_000 else 10_000,
|
||||
};
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user