diff --git a/zml/nn.zig b/zml/nn.zig index c3120d2..104d793 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -1,21 +1,20 @@ //! Common layer definition and functions for Neural Networks (NN) const std = @import("std"); -const stdx = @import("stdx"); - -const cuda = @import("nn/cuda.zig"); -const helpers = @import("helpers.zig"); -const meta = @import("meta.zig"); -const ops = @import("ops.zig"); -const zml = @import("zml.zig"); - -const DataType = @import("dtype.zig").DataType; -const Shape = @import("shape.zig").Shape; -const Tensor = @import("tensor.zig").Tensor; - const assert = std.debug.assert; -const log = std.log.scoped(.@"zml/tensor"); const testing = std.testing; +const stdx = @import("stdx"); + +const DataType = @import("dtype.zig").DataType; +const helpers = @import("helpers.zig"); +const meta = @import("meta.zig"); +const cuda = @import("nn/cuda.zig"); +const ops = @import("ops.zig"); +const Shape = @import("shape.zig").Shape; +const Tensor = @import("tensor.zig").Tensor; +const zml = @import("zml.zig"); + +const log = std.log.scoped(.@"zml/tensor"); test { _ = cuda; std.testing.refAllDecls(@This()); @@ -108,6 +107,15 @@ pub const LayerNorm = struct { } }; +pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor { + const ax = x.axis(axis); + // upcast to improve precision + const xf32 = x.convert(.f32); + const mean = xf32.mul(xf32).mean(ax); + const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype()); + return x.mul(rsqrt.broad(x.shape())); +} + /// Center and scale by the variance. /// normalize(x, eps) = (x - mean(x)) / sqrt(var(x) + eps) /// Work on the last axis. @@ -142,13 +150,46 @@ test normalizeL2 { } pub const RopeOpts = struct { - /// There are two implementations corresponding to how to split `x` in real/imag parts. + layout: Layout = .sequential, + freq_base: f32 = 10_000, + scaling: Scaling = .default, + + /// There are two layouts corresponding to how to split `x` in real/imag parts. /// * Interleaved means that the real/imag of each scalar is contiguous. /// * Sequential means that you first read all real values then all imag values. - pub const Implementation = enum { interleaved, sequential }; + /// Typically HF models use sequential, while GGUF use interleaved. + pub const Layout = enum { interleaved, sequential }; - impl: Implementation, - freq_base: f32 = 10_000, + /// There are several ways to init the scaling aka "inv_freq" + pub const Scaling = union(enum) { + default: void, + custom: []const f32, + llama3: Llama3, + + pub const Llama3 = struct { + factor: f32, + high_freq_factor: f32, + low_freq_factor: f32, + original_max_position_embeddings: u32, + }; + + /// Read a Rope scaling config from HF config.json format. + pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling { + const content = try std.json.Value.jsonParse(allocator, source, options); + if (content != .object) return error.InvalidEnumTag; + + const obj = content.object; + const impl = obj.get("rope_type") orelse return error.MissingField; + if (impl != .string) return error.InvalidEnumTag; + if (std.mem.eql(u8, impl.string, "llama3")) { + // Note: leaky is fine here cause Llama3 struct don't need to allocate memory. + return .{ .llama3 = try std.json.parseFromValueLeaky(Llama3, undefined, content, .{ .ignore_unknown_fields = true }) }; + } else { + log.warn("Unsupported Rope implementation: {s}, will use the default one which will produce altered results", .{impl.string}); + return .{ .default = {} }; + } + } + }; }; /// Rotary position embedding modify queries and keys tensor before compute Q * K in self attention. @@ -170,10 +211,10 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { stdx.debug.assert(x.shape().hasTags(.{ .s, .hd }), "rope expects x argument to have both .s and .hd axes got: rope(x={})", .{x}); break :blk Tensor.arange(.{ .end = x.dim(.s) }, .f32).withTags(.{.s}); }; - const x_real, const x_imag = splitRealImg(x, opts.impl); + const x_real, const x_imag = splitRealImg(x, opts.layout); // compute sin and cos in f32 before downcasting to x type. - const inv_freq = invFreq(x.dim(.hd), opts.freq_base, .f32).withTags(.{.hd}); + const inv_freq = invFreq(x.dim(.hd), opts).withTags(.{.hd}); const inv_freq_pos = Tensor.outer(idx.convert(.f32), inv_freq); const cos = inv_freq_pos.cos().convert(x.dtype()).broad(x_real.shape()); const sin = inv_freq_pos.sin().convert(x.dtype()).broad(x_real.shape()); @@ -183,13 +224,13 @@ pub fn rope(x: Tensor, pos_idx: ?Tensor, opts: RopeOpts) Tensor { const y_imag = x_real.mul(sin).add(x_imag.mul(cos)); // flatten last dimensions - return mergeRealImg(y_real, y_imag, opts.impl); + return mergeRealImg(y_real, y_imag, opts.layout); } -pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor { +pub fn splitRealImg(x: Tensor, layout: RopeOpts.Layout) [2]Tensor { const n = x.dim(-1); - return switch (impl) { + return switch (layout) { .sequential => .{ x.slice1d(-1, .{ .end = @divExact(n, 2) }), x.slice1d(-1, .{ .start = @divExact(n, 2), .end = n }), @@ -201,8 +242,8 @@ pub fn splitRealImg(x: Tensor, impl: RopeOpts.Implementation) [2]Tensor { }; } -pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementation) Tensor { - return switch (impl) { +pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, layout: RopeOpts.Layout) Tensor { + return switch (layout) { .sequential => Tensor.concatenate(&.{ x_real, x_imag }, -1), .interleaved => Tensor.concatenate(&.{ x_real.appendAxes(.{.interleaved_real_img}), @@ -212,21 +253,83 @@ pub fn mergeRealImg(x_real: Tensor, x_imag: Tensor, impl: RopeOpts.Implementatio } /// {exp( - n * ln(10_000) / N ) | n in [0..N] } -pub fn invFreq(N: i64, theta: f32, dtype: DataType) Tensor { - const freq = -@log(theta) / @as(f32, @floatFromInt(N)); - const range = Tensor.arange(.{ .start = 0, .end = N, .step = 2 }, dtype).scale(freq); - return range.exp(); +pub fn invFreq(N: i64, opts: RopeOpts) Tensor { + const allocator = zml.module.CompilationContext.current().allocator(); + const N_half: usize = @intCast(@divExact(N, 2)); + const inv_freq = allocator.alloc(f32, N_half) catch @panic("OOM"); + _invFreq(opts, inv_freq); + return zml.Tensor.constantTensor(.fromSlice(.{@divExact(N, 2)}, inv_freq)); +} + +fn _invFreq(opts: RopeOpts, inv_freq: []f32) void { + const N = inv_freq.len; + // Default frequencies + for (0.., inv_freq) |n, *f| { + f.* = @exp(-@log(opts.freq_base) * stdx.math.divFloat(f32, n, N)); + } + + switch (opts.scaling) { + .default => {}, + .custom => { + stdx.debug.assert(opts.scaling.custom.len == N, "rope expected custom inv_freq to match half head dimension {}, got {}", .{ N, opts.scaling.custom.len }); + @memcpy(inv_freq, opts.scaling.custom); + }, + .llama3 => |s| { + // https://arxiv.org/pdf/2309.16039 + // After Llama2 they observed that the rope frequencies where too sharp and hurting long distance attention. + // In Llama3 they used a higher base freq and also downscaled low frequencies. + std.debug.assert(s.low_freq_factor < s.high_freq_factor); + const M: f64 = @floatFromInt(s.original_max_position_embeddings); + const f_high = s.high_freq_factor * (2 * std.math.pi) / M; + const f_low = s.low_freq_factor * (2 * std.math.pi) / M; + const downscaling = 1.0 / s.factor; + + for (0..N, inv_freq) |n, f| { + if (f > f_high) { + // High freq match default implem + } else if (f < f_low) { + // Downscaling for low freq + inv_freq[n] *= downscaling; + } else { + // Linear interpolation for middle freq + const lerp: f64 = (inv_freq[n] - f_low) / (f_high - f_low); + inv_freq[n] *= @floatCast(lerp + (1 - lerp) * downscaling); + } + } + }, + } +} + +test invFreq { + // Llama 3.2-1B config + const llama_conf: RopeOpts = .{ + .freq_base = 500_000, + .scaling = .{ .llama3 = .{ + .factor = 32, + .high_freq_factor = 4, + .low_freq_factor = 1, + .original_max_position_embeddings = 8192, + } }, + }; + const llama_freq = [_]f32{ 1.000000000000e+00, 6.636012792587e-01, 4.403666257858e-01, 2.922278344631e-01, 1.939227581024e-01, 1.286873817444e-01, 8.539710193872e-02, 5.666961893439e-02, 3.760603070259e-02, 2.495540864766e-02, 1.656044088304e-02, 1.098952908069e-02, 7.292665075511e-03, 4.839421249926e-03, 3.211446106434e-03, 1.290548010729e-03, 4.295567050576e-04, 9.708286233945e-05, 1.946163865796e-05, 1.291476746701e-05, 8.570255886298e-06, 5.687232260243e-06, 3.774054448513e-06, 2.504467147446e-06, 1.661967417022e-06, 1.102883629756e-06, 7.318749339902e-07, 4.856731266045e-07, 3.222932889457e-07, 2.138742303259e-07, 1.419272024350e-07, 9.418306490261e-08 }; + + var inv_freq: @TypeOf(llama_freq) = undefined; + _invFreq(llama_conf, &inv_freq); + for (llama_freq, inv_freq, 0..) |expected, actual, i| { + errdefer log.err("Mismatch at position {d}.\nExpected: {d}\nActual: {d}", .{ i, llama_freq, inv_freq }); + try std.testing.expectApproxEqRel(expected, actual, 1e-5); + } } test "real/img" { const platform = zml.testing.env(); const Fns = struct { - fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor { + fn testSplitMergeIsId(layout: RopeOpts.Layout) Tensor { const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 }); - const real, const imag = splitRealImg(x, impl); - const y = mergeRealImg(real, imag, impl); - const real2, const imag2 = splitRealImg(y, impl); + const real, const imag = splitRealImg(x, layout); + const y = mergeRealImg(real, imag, layout); + const real2, const imag2 = splitRealImg(y, layout); return real.cmp(.EQ, real2).flatten(0).convert(.i32).sum(-1).add( imag.cmp(.EQ, imag2).flatten(0).convert(.i32).sum(-1), ); @@ -312,13 +415,13 @@ test rope { { // Convert input to the requested format const real, const imag = splitRealImg(input, .sequential); - input = mergeRealImg(real, imag, opts.impl); + input = mergeRealImg(real, imag, opts.layout); } var res = rope(input, null, opts).squeeze(0); { // Convert back to sequential - const real, const imag = splitRealImg(res, opts.impl); + const real, const imag = splitRealImg(res, opts.layout); res = mergeRealImg(real, imag, .sequential); } return res; @@ -328,8 +431,8 @@ test rope { // x is made such as the interleaved and sequential reps are the same. // So the two implementations should give the same results. const x = try zml.Buffer.fromSlice(platform, .{ .b = 1, .s = 5, .hd = 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); - const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .interleaved } }); - const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .impl = .sequential } }); + const res1 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .interleaved } }); + const res2 = try zml.testing.compileAndCall(platform, Local._fwd, .{ x, RopeOpts{ .layout = .sequential } }); try zml.testing.expectClose(res1, res2, 1e-4); }