Fix Llama3 rope scaling implementation in the neural network module (zml/nn.zig)
This commit is contained in:
parent
9f61a8aacb
commit
aacbf2ee04
175
zml/nn.zig
175
zml/nn.zig
@ -1,21 +1,20 @@
|
|||||||
//! Common layer definition and functions for Neural Networks (NN)
|
//! Common layer definition and functions for Neural Networks (NN)
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
|
||||||
|
|
||||||
const cuda = @import("nn/cuda.zig");
|
|
||||||
const helpers = @import("helpers.zig");
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const ops = @import("ops.zig");
|
|
||||||
const zml = @import("zml.zig");
|
|
||||||
|
|
||||||
const DataType = @import("dtype.zig").DataType;
|
|
||||||
const Shape = @import("shape.zig").Shape;
|
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const log = std.log.scoped(.@"zml/tensor");
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const DataType = @import("dtype.zig").DataType;
|
||||||
|
const helpers = @import("helpers.zig");
|
||||||
|
const meta = @import("meta.zig");
|
||||||
|
const cuda = @import("nn/cuda.zig");
|
||||||
|
const ops = @import("ops.zig");
|
||||||
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/tensor");
|
||||||
test {
|
test {
|
||||||
_ = cuda;
|
_ = cuda;
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -108,6 +107,15 @@ pub const LayerNorm = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
|
||||||
|
const ax = x.axis(axis);
|
||||||
|
// upcast to improve precision
|
||||||
|
const xf32 = x.convert(.f32);
|
||||||
|
const mean = xf32.mul(xf32).mean(ax);
|
||||||
|
const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype());
|
||||||
|
return x.mul(rsqrt.broad(x.shape()));
|
||||||
|
}
|
||||||
|
|
||||||
/// Center and scale by the variance.
|
/// Center and scale by the variance.
|
||||||
/// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps)
|
/// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps)
|
||||||
/// Work on the last axis.
|
/// Work on the last axis.
|
||||||
@ -142,13 +150,46 @@ test normalizeL2 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const RopeOpts = struct {
|
pub const RopeOpts = struct {
|
||||||
/// There are two implementations corresponding to how to split `x` in real/imag parts.
|
layout: Layout = .sequential,
|
||||||
|
freq_base: f32 = 10_000,
|
||||||
|
scaling: Scaling = .default,
|
||||||
|
|
||||||
|
/// There are two layouts corresponding to how to split `x` in real/imag parts.
|
||||||
/// * Interleaved means that the real/imag of each scalar is contiguous.
|
/// * Interleaved means that the real/imag of each scalar is contiguous.
|
||||||
/// * Sequential means that you first read all real values then all imag values.
|
/// * Sequential means that you first read all real values then all imag values.
|
||||||
pub const Implementation = enum { interleaved, sequential };
|
/// Typically HF models use sequential, while GGUF use interleaved.
|
||||||
|
pub const Layout = enum { interleaved, sequential };
|
||||||
|
|
||||||
impl: Implementation,
|
/// There are several ways to init the scaling aka "inv_freq"
|
||||||
freq_base: f32 = 10_000,
|
pub const Scaling = union(enum) {
|
||||||
|
default: void,
|
||||||
|
custom: []const f32,
|
||||||
|
llama3: Llama3,
|
||||||
|
|
||||||
|
pub const Llama3 = struct {
|
||||||
|
factor: f32,
|
||||||
|
high_freq_factor: f32,
|
||||||
|
low_freq_factor: f32,
|
||||||
|
original_max_position_embeddings: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Read a Rope scaling config from HF config.json format.
|
||||||
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
|
||||||
|
const content = try std.json.Value.jsonParse(allocator, source, options);
|
||||||
|
if (content != .object) return error.InvalidEnumTag;
|
||||||
|
|
||||||
|
const obj = content.object;
|
||||||
|
const impl = obj.get("rope_type") orelse return error.MissingField;
|
||||||
|
if (impl != .string) return error.InvalidEnumTag;
|
||||||
|
if (std.mem.eql(u8, impl.string, "llama3")) {
|
||||||
|
// Note: leaky is fine here cause Llama3 struct don't need to allocate memory.
|
||||||
|
return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, undefined, content, .{ .ignore_unknown_fields = true }) };
|
||||||
|
} else {
|
||||||
|
log.warn("Unsupported Rope implementation: {s}, will use the default one which will produce altered results", .{impl.string});
|
||||||
|
return .{ .default = {} };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
|
/// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention.
|
||||||
@ -170,10 +211,10 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
|||||||
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x});
|
stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x});
|
||||||
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s});
|
||||||
};
|
};
|
||||||
const x_real, const x_imag = splitRealImg(x, opts.impl);
|
const x_real, const x_imag = splitRealImg(x, opts.layout);
|
||||||
|
|
||||||
// compute sin and cos in f32 before downcasting to x type.
|
// compute sin and cos in f32 before downcasting to x type.
|
||||||
const inv_freq = invFreq(x.dim(.hd), opts.freq_base, .f32).withTags(.{.hd});
|
const inv_freq = invFreq(x.dim(.hd), opts).withTags(.{.hd});
|
||||||
const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq);
|
const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq);
|
||||||
const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape());
|
const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape());
|
||||||
const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape());
|
const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape());
|
||||||
@ -183,13 +224,13 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor {
|
|||||||
const y_imag = x_real.mul(sin).add(x_imag.mul(cos));
|
const y_imag = x_real.mul(sin).add(x_imag.mul(cos));
|
||||||
|
|
||||||
// flatten last dimensions
|
// flatten last dimensions
|
||||||
return mergeRealImg(y_real, y_imag, opts.impl);
|
return mergeRealImg(y_real, y_imag, opts.layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
pub fn splitRealImg(x: Tensor, layout: RopeOpts.Layout) [2]Tensor {
|
||||||
const n = x.dim(-1);
|
const n = x.dim(-1);
|
||||||
|
|
||||||
return switch (impl) {
|
return switch (layout) {
|
||||||
.sequential => .{
|
.sequential => .{
|
||||||
x.slice1d(-1, .{ .end = @divExact(n, 2) }),
|
x.slice1d(-1, .{ .end = @divExact(n, 2) }),
|
||||||
x.slice1d(-1, .{ .start = @divExact(n, 2), .end = n }),
|
x.slice1d(-1, .{ .start = @divExact(n, 2), .end = n }),
|
||||||
@ -201,8 +242,8 @@ pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementation) Tensor {
|
pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, layout: RopeOpts.Layout) Tensor {
|
||||||
return switch (impl) {
|
return switch (layout) {
|
||||||
.sequential => Tensor.concatenate(&.{ x_real, x_imag }, -1),
|
.sequential => Tensor.concatenate(&.{ x_real, x_imag }, -1),
|
||||||
.interleaved => Tensor.concatenate(&.{
|
.interleaved => Tensor.concatenate(&.{
|
||||||
x_real.appendAxes(.{.interleaved_real_img}),
|
x_real.appendAxes(.{.interleaved_real_img}),
|
||||||
@ -212,21 +253,83 @@ pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementatio
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// {exp( - n * ln(10_000) / N ) | n in [0..N] }
|
/// {exp( - n * ln(10_000) / N ) | n in [0..N] }
|
||||||
pub fn invFreq(N: i64, theta: f32, dtype: DataType) Tensor {
|
pub fn invFreq(N: i64, opts: RopeOpts) Tensor {
|
||||||
const freq = -@log(theta) / @as(f32, @floatFromInt(N));
|
const allocator = zml.module.CompilationContext.current().allocator();
|
||||||
const range = Tensor.arange(.{ .start = 0, .end = N, .step = 2 }, dtype).scale(freq);
|
const N_half: usize = @intCast(@divExact(N, 2));
|
||||||
return range.exp();
|
const inv_freq = allocator.alloc(f32, N_half) catch @panic("OOM");
|
||||||
|
_invFreq(opts, inv_freq);
|
||||||
|
return zml.Tensor.constantTensor(.fromSlice(.{@divExact(N, 2)}, inv_freq));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _invFreq(opts: RopeOpts, inv_freq: []f32) void {
|
||||||
|
const N = inv_freq.len;
|
||||||
|
// Default frequencies
|
||||||
|
for (0.., inv_freq) |n, *f| {
|
||||||
|
f.* = @exp(-@log(opts.freq_base) * stdx.math.divFloat(f32, n, N));
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (opts.scaling) {
|
||||||
|
.default => {},
|
||||||
|
.custom => {
|
||||||
|
stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len });
|
||||||
|
@memcpy(inv_freq, opts.scaling.custom);
|
||||||
|
},
|
||||||
|
.llama3 => |s| {
|
||||||
|
// https://arxiv.org/pdf/2309.16039
|
||||||
|
// After Llama2 they observed that the rope frequencies where too sharp and hurting long distance attention.
|
||||||
|
// In Llama3 they used a higher base freq and also downscaled low frequencies.
|
||||||
|
std.debug.assert(s.low_freq_factor < s.high_freq_factor);
|
||||||
|
const M: f64 = @floatFromInt(s.original_max_position_embeddings);
|
||||||
|
const f_high = s.high_freq_factor * (2 * std.math.pi) / M;
|
||||||
|
const f_low = s.low_freq_factor * (2 * std.math.pi) / M;
|
||||||
|
const downscaling = 1.0 / s.factor;
|
||||||
|
|
||||||
|
for (0..N, inv_freq) |n, f| {
|
||||||
|
if (f > f_high) {
|
||||||
|
// High freq match default implem
|
||||||
|
} else if (f < f_low) {
|
||||||
|
// Downscaling for low freq
|
||||||
|
inv_freq[n] *= downscaling;
|
||||||
|
} else {
|
||||||
|
// Linear interpolation for middle freq
|
||||||
|
const lerp: f64 = (inv_freq[n] - f_low) / (f_high - f_low);
|
||||||
|
inv_freq[n] *= @floatCast(lerp + (1 - lerp) * downscaling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test invFreq {
|
||||||
|
// Llama 3.2-1B config
|
||||||
|
const llama_conf: RopeOpts = .{
|
||||||
|
.freq_base = 500_000,
|
||||||
|
.scaling = .{ .llama3 = .{
|
||||||
|
.factor = 32,
|
||||||
|
.high_freq_factor = 4,
|
||||||
|
.low_freq_factor = 1,
|
||||||
|
.original_max_position_embeddings = 8192,
|
||||||
|
} },
|
||||||
|
};
|
||||||
|
const llama_freq = [_]f32{ 1.000000000000e+00, 6.636012792587e-01, 4.403666257858e-01, 2.922278344631e-01, 1.939227581024e-01, 1.286873817444e-01, 8.539710193872e-02, 5.666961893439e-02, 3.760603070259e-02, 2.495540864766e-02, 1.656044088304e-02, 1.098952908069e-02, 7.292665075511e-03, 4.839421249926e-03, 3.211446106434e-03, 1.290548010729e-03, 4.295567050576e-04, 9.708286233945e-05, 1.946163865796e-05, 1.291476746701e-05, 8.570255886298e-06, 5.687232260243e-06, 3.774054448513e-06, 2.504467147446e-06, 1.661967417022e-06, 1.102883629756e-06, 7.318749339902e-07, 4.856731266045e-07, 3.222932889457e-07, 2.138742303259e-07, 1.419272024350e-07, 9.418306490261e-08 };
|
||||||
|
|
||||||
|
var inv_freq: @TypeOf(llama_freq) = undefined;
|
||||||
|
_invFreq(llama_conf, &inv_freq);
|
||||||
|
for (llama_freq, inv_freq, 0..) |expected, actual, i| {
|
||||||
|
errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq });
|
||||||
|
try std.testing.expectApproxEqRel(expected, actual, 1e-5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test "real/img" {
|
test "real/img" {
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const Fns = struct {
|
const Fns = struct {
|
||||||
fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
|
fn testSplitMergeIsId(layout: RopeOpts.Layout) Tensor {
|
||||||
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
||||||
const real, const imag = splitRealImg(x, impl);
|
const real, const imag = splitRealImg(x, layout);
|
||||||
const y = mergeRealImg(real, imag, impl);
|
const y = mergeRealImg(real, imag, layout);
|
||||||
const real2, const imag2 = splitRealImg(y, impl);
|
const real2, const imag2 = splitRealImg(y, layout);
|
||||||
return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add(
|
return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add(
|
||||||
imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1),
|
imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1),
|
||||||
);
|
);
|
||||||
@ -312,13 +415,13 @@ test rope {
|
|||||||
{
|
{
|
||||||
// Convert input to the requested format
|
// Convert input to the requested format
|
||||||
const real, const imag = splitRealImg(input, .sequential);
|
const real, const imag = splitRealImg(input, .sequential);
|
||||||
input = mergeRealImg(real, imag, opts.impl);
|
input = mergeRealImg(real, imag, opts.layout);
|
||||||
}
|
}
|
||||||
var res = rope(input, null, opts).squeeze(0);
|
var res = rope(input, null, opts).squeeze(0);
|
||||||
|
|
||||||
{
|
{
|
||||||
// Convert back to sequential
|
// Convert back to sequential
|
||||||
const real, const imag = splitRealImg(res, opts.impl);
|
const real, const imag = splitRealImg(res, opts.layout);
|
||||||
res = mergeRealImg(real, imag, .sequential);
|
res = mergeRealImg(real, imag, .sequential);
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
@ -328,8 +431,8 @@ test rope {
|
|||||||
// x is made such as the interleaved and sequential reps are the same.
|
// x is made such as the interleaved and sequential reps are the same.
|
||||||
// So the two implementations should give the same results.
|
// So the two implementations should give the same results.
|
||||||
const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
||||||
const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .interleaved } });
|
const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .interleaved } });
|
||||||
const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .sequential } });
|
const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .sequential } });
|
||||||
try zml.testing.expectClose(res1, res2, 1e-4);
|
try zml.testing.expectClose(res1, res2, 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user