Fix Llama3 rope scaling implementation in the neural network module (zml/nn.zig)
This commit is contained in:
parent
9f61a8aacb
commit
aacbf2ee04
175
zml/nn.zig
175
zml/nn.zig
@ -1,21 +1,20 @@
|
||||
//! Common layer definition and functions for Neural Networks (NN)
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const cuda = @import("nn/cuda.zig");
|
||||
const helpers = @import("helpers.zig");
|
||||
const meta = @import("meta.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const zml = @import("zml.zig");
|
||||
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
const assert = std.debug.assert;
|
||||
const log = std.log.scoped(.@"zml/tensor");
|
||||
const testing = std.testing;
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const helpers = @import("helpers.zig");
|
||||
const meta = @import("meta.zig");
|
||||
const cuda = @import("nn/cuda.zig");
|
||||
const ops = @import("ops.zig");
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
const zml = @import("zml.zig");
|
||||
|
||||
const log = std.log.scoped(.@"zml/tensor");
|
||||
test {
|
||||
_ = cuda;
|
||||
std.testing.refAllDecls(@This());
|
||||
@ -108,6 +107,15 @@ pub const LayerNorm = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
|
||||
const ax = x.axis(axis);
|
||||
// upcast to improve precision
|
||||
const xf32 = x.convert(.f32);
|
||||
const mean = xf32.mul(xf32).mean(ax);
|
||||
const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype());
|
||||
return x.mul(rsqrt.broad(x.shape()));
|
||||
}
|
||||
|
||||
/// Center and scale by the variance.
|
||||
/// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps)
|
||||
/// Work on the last axis.
|
||||
@ -142,13 +150,46 @@ test normalizeL2 {
|
||||
}
|
||||
|
||||
pub const RopeOpts = struct {
|
||||
/// There are two implementations corresponding to how to split `x` in real/imag parts.
|
||||
layout: Layout = .sequential,
|
||||
freq_base: f32 = 10_000,
|
||||
scaling: Scaling = .default,
|
||||
|
||||
/// There are two layouts corresponding to how to split `x` in real/imag parts.
|
||||
/// * Interleaved means that the real/imag of each scalar is contiguous.
|
||||
/// * Sequential means that you first read all real values then all imag values.
|
||||
pub const Implementation = enum { interleaved, sequential };
|
||||
/// Typically HF models use sequential, while GGUF use interleaved.
|
||||
pub const Layout = enum { interleaved, sequential };
|
||||
|
||||
impl: Implementation,
|
||||
freq_base: f32 = 10_000,
|
||||
/// There are several ways to init the scaling aka "inv_freq"
|
||||
pub const Scaling = union(enum) {
|
||||
default: void,
|
||||
custom: []const f32,
|
||||
llama3: Llama3,
|
||||
|
||||
pub const Llama3 = struct {
|
||||
factor: f32,
|
||||
high_freq_factor: f32,
|
||||
low_freq_factor: f32,
|
||||
original_max_position_embeddings: u32,
|
||||
};
|
||||
|
||||
/// Read a Rope scaling config from HF config.json format.
|
||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
|
||||
const content = try std.json.Value.jsonParse(allocator, source, options);
|
||||
if (content != .object) return error.InvalidEnumTag;
|
||||
|
||||
const obj = content.object;
|
||||
const impl = obj.get("rope_type") orelse return error.MissingField;
|
||||
if (impl != .string) return error.InvalidEnumTag;
|
||||
if (std.mem.eql(u8, impl.string, "llama3")) {
|
||||
// Note: leaky is fine here cause Llama3 struct don't need to allocate memory.
|
||||
return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, undefined, content, .{ .ignore_unknown_fields = true }) };
|
||||
} else {
|
||||
log.warn("Unsupported Rope implementation: {s}, will use the default one which will produce altered results", .{impl.string});
|
||||
return .{ .default = {} };
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
|
||||
@ -170,10 +211,10 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
||||
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x});
|
||||
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
||||
};
|
||||
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
||||
const x_real, const x_imag = splitRealImg(x, opts.layout);
|
||||
|
||||
// compute sin and cos in f32 before downcasting to x type.
|
||||
const inv_freq = invFreq(x.dim(.hd), opts.freq_base, .f32).withTags(.{.hd});
|
||||
const inv_freq = invFreq(x.dim(.hd), opts).withTags(.{.hd});
|
||||
const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq);
|
||||
const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape());
|
||||
const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape());
|
||||
@ -183,13 +224,13 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
||||
const y_imag = x_real.mul(sin).add(x_imag.mul(cos));
|
||||
|
||||
// flatten last dimensions
|
||||
return mergeRealImg(y_real, y_imag, opts.impl);
|
||||
return mergeRealImg(y_real, y_imag, opts.layout);
|
||||
}
|
||||
|
||||
pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
||||
pub fn splitRealImg(x: Tensor, layout: RopeOpts.Layout) [2]Tensor {
|
||||
const n = x.dim(-1);
|
||||
|
||||
return switch (impl) {
|
||||
return switch (layout) {
|
||||
.sequential => .{
|
||||
x.slice1d(-1, .{ .end = @divExact(n, 2) }),
|
||||
x.slice1d(-1, .{ .start = @divExact(n, 2), .end = n }),
|
||||
@ -201,8 +242,8 @@ pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementation) Tensor {
|
||||
return switch (impl) {
|
||||
pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, layout: RopeOpts.Layout) Tensor {
|
||||
return switch (layout) {
|
||||
.sequential => Tensor.concatenate(&.{ x_real, x_imag }, -1),
|
||||
.interleaved => Tensor.concatenate(&.{
|
||||
x_real.appendAxes(.{.interleaved_real_img}),
|
||||
@ -212,21 +253,83 @@ pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementatio
|
||||
}
|
||||
|
||||
/// {exp( - n * ln(10_000) / N ) | n in [0..N] }
|
||||
pub fn invFreq(N: i64, theta: f32, dtype: DataType) Tensor {
|
||||
const freq = -@log(theta) / @as(f32, @floatFromInt(N));
|
||||
const range = Tensor.arange(.{ .start = 0, .end = N, .step = 2 }, dtype).scale(freq);
|
||||
return range.exp();
|
||||
pub fn invFreq(N: i64, opts: RopeOpts) Tensor {
|
||||
const allocator = zml.module.CompilationContext.current().allocator();
|
||||
const N_half: usize = @intCast(@divExact(N, 2));
|
||||
const inv_freq = allocator.alloc(f32, N_half) catch @panic("OOM");
|
||||
_invFreq(opts, inv_freq);
|
||||
return zml.Tensor.constantTensor(.fromSlice(.{@divExact(N, 2)}, inv_freq));
|
||||
}
|
||||
|
||||
fn _invFreq(opts: RopeOpts, inv_freq: []f32) void {
|
||||
const N = inv_freq.len;
|
||||
// Default frequencies
|
||||
for (0.., inv_freq) |n, *f| {
|
||||
f.* = @exp(-@log(opts.freq_base) * stdx.math.divFloat(f32, n, N));
|
||||
}
|
||||
|
||||
switch (opts.scaling) {
|
||||
.default => {},
|
||||
.custom => {
|
||||
stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len });
|
||||
@memcpy(inv_freq, opts.scaling.custom);
|
||||
},
|
||||
.llama3 => |s| {
|
||||
// https://arxiv.org/pdf/2309.16039
|
||||
// After Llama2 they observed that the rope frequencies where too sharp and hurting long distance attention.
|
||||
// In Llama3 they used a higher base freq and also downscaled low frequencies.
|
||||
std.debug.assert(s.low_freq_factor < s.high_freq_factor);
|
||||
const M: f64 = @floatFromInt(s.original_max_position_embeddings);
|
||||
const f_high = s.high_freq_factor * (2 * std.math.pi) / M;
|
||||
const f_low = s.low_freq_factor * (2 * std.math.pi) / M;
|
||||
const downscaling = 1.0 / s.factor;
|
||||
|
||||
for (0..N, inv_freq) |n, f| {
|
||||
if (f > f_high) {
|
||||
// High freq match default implem
|
||||
} else if (f < f_low) {
|
||||
// Downscaling for low freq
|
||||
inv_freq[n] *= downscaling;
|
||||
} else {
|
||||
// Linear interpolation for middle freq
|
||||
const lerp: f64 = (inv_freq[n] - f_low) / (f_high - f_low);
|
||||
inv_freq[n] *= @floatCast(lerp + (1 - lerp) * downscaling);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
test invFreq {
|
||||
// Llama 3.2-1B config
|
||||
const llama_conf: RopeOpts = .{
|
||||
.freq_base = 500_000,
|
||||
.scaling = .{ .llama3 = .{
|
||||
.factor = 32,
|
||||
.high_freq_factor = 4,
|
||||
.low_freq_factor = 1,
|
||||
.original_max_position_embeddings = 8192,
|
||||
} },
|
||||
};
|
||||
const llama_freq = [_]f32{ 1.000000000000e+00, 6.636012792587e-01, 4.403666257858e-01, 2.922278344631e-01, 1.939227581024e-01, 1.286873817444e-01, 8.539710193872e-02, 5.666961893439e-02, 3.760603070259e-02, 2.495540864766e-02, 1.656044088304e-02, 1.098952908069e-02, 7.292665075511e-03, 4.839421249926e-03, 3.211446106434e-03, 1.290548010729e-03, 4.295567050576e-04, 9.708286233945e-05, 1.946163865796e-05, 1.291476746701e-05, 8.570255886298e-06, 5.687232260243e-06, 3.774054448513e-06, 2.504467147446e-06, 1.661967417022e-06, 1.102883629756e-06, 7.318749339902e-07, 4.856731266045e-07, 3.222932889457e-07, 2.138742303259e-07, 1.419272024350e-07, 9.418306490261e-08 };
|
||||
|
||||
var inv_freq: @TypeOf(llama_freq) = undefined;
|
||||
_invFreq(llama_conf, &inv_freq);
|
||||
for (llama_freq, inv_freq, 0..) |expected, actual, i| {
|
||||
errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq });
|
||||
try std.testing.expectApproxEqRel(expected, actual, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
test "real/img" {
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Fns = struct {
|
||||
fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
|
||||
fn testSplitMergeIsId(layout: RopeOpts.Layout) Tensor {
|
||||
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
||||
const real, const imag = splitRealImg(x, impl);
|
||||
const y = mergeRealImg(real, imag, impl);
|
||||
const real2, const imag2 = splitRealImg(y, impl);
|
||||
const real, const imag = splitRealImg(x, layout);
|
||||
const y = mergeRealImg(real, imag, layout);
|
||||
const real2, const imag2 = splitRealImg(y, layout);
|
||||
return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add(
|
||||
imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1),
|
||||
);
|
||||
@ -312,13 +415,13 @@ test rope {
|
||||
{
|
||||
// Convert input to the requested format
|
||||
const real, const imag = splitRealImg(input, .sequential);
|
||||
input = mergeRealImg(real, imag, opts.impl);
|
||||
input = mergeRealImg(real, imag, opts.layout);
|
||||
}
|
||||
var res = rope(input, null, opts).squeeze(0);
|
||||
|
||||
{
|
||||
// Convert back to sequential
|
||||
const real, const imag = splitRealImg(res, opts.impl);
|
||||
const real, const imag = splitRealImg(res, opts.layout);
|
||||
res = mergeRealImg(real, imag, .sequential);
|
||||
}
|
||||
return res;
|
||||
@ -328,8 +431,8 @@ test rope {
|
||||
// x is made such as the interleaved and sequential reps are the same.
|
||||
// So the two implementations should give the same results.
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
||||
const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .interleaved } });
|
||||
const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .sequential } });
|
||||
const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .interleaved } });
|
||||
const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .sequential } });
|
||||
try zml.testing.expectClose(res1, res2, 1e-4);
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user