From f00538667ea33b3a299892f4eee86b28ec72da2c Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 16 Jun 2023 14:34:18 +0000 Subject: [PATCH] =?UTF-8?q?zml.nn:=20add=20dynamic=20sampling=20with=20sup?= =?UTF-8?q?port=20for=20top=E2=80=91k,=20top=E2=80=91p,=20and=20min?= =?UTF-8?q?=E2=80=91p=20settings.=20Implements=20token=20index=20computati?= =?UTF-8?q?on=20based=20on=20the=20selected=20sampling=20strategy,=20inclu?= =?UTF-8?q?ding=20options=20for=20top=5Fk,=20max=5Ftop=5Fk,=20top=5Fp,=20a?= =?UTF-8?q?nd=20min=5Fp.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/nn.zig | 168 +++++++++++++++++++++++++++++++++++++++++++++++++ zml/tensor.zig | 7 ++- 2 files changed, 172 insertions(+), 3 deletions(-) diff --git a/zml/nn.zig b/zml/nn.zig index 2793bf7..44a9065 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -910,6 +910,7 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn { .max_value = partial.max_value, }; } + test "sdpaMemEfficient without mask" { const platform = zml.testing.env(); const allocator = std.testing.allocator; @@ -1010,3 +1011,170 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); return .{ next_tokens, next_rng }; } + +test sampleTokens { + const platform = zml.testing.env(); + const allocator = std.testing.allocator; + + const inf = std.math.inf(f32); + var rng_buff = try zml.Tensor.Rng.init(platform, 0xdeadbeef); + defer rng_buff._state.deinit(); + + const mod = try zml.compileFn(allocator, sampleTokens, .{ Shape.init(.{ .voc = 4 }, .f32), .{ .topk = 4, .temperature = 2.0 }, zml.Tensor.Rng.shape() }, platform); + defer mod.deinit(); + + inline for (.{ + .{ [_]f32{ inf, 3.0, 2.0, 1.0 }, 0 }, + .{ [_]f32{ -inf, 3.0, -inf, -inf }, 1 }, + .{ [_]f32{ 3.0, 2, inf, inf }, 2 }, + }) |logits_expected| { + const logits, const expected: i32 = logits_expected; + var logits_buff = try zml.Buffer.fromArray(platform, logits); + defer logits_buff.deinit(); + var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff }); + defer sampled.deinit(); + try zml.testing.expectEqual(expected, try sampled.getValue(i32)); + } +} + +pub const DynamicSamplingStrategy = struct { + max_top_k: u32, + top_k: Tensor, + temperature: Tensor, + top_p: Tensor, + min_p: Tensor, + + pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) { + const scalar_float = Shape.init(.{}, dtype); + const scalar_i32 = Shape.init(.{}, .i32); + return .{ + .max_top_k = max_top_k, + .top_k = scalar_i32, + .temperature = scalar_float, + .top_p = scalar_float, + .min_p = scalar_float, + }; + } + + pub fn makeBuffers( + platform: zml.Platform, + dtype: zml.DataType, + args: struct { + top_k: u32, + temperature: f32 = 1.0, + top_p: f32 = 1.0, + min_p: f32 = 0.0, + }, + ) !zml.Bufferized(DynamicSamplingStrategy) { + return .{ + .max_top_k = 0, + .top_k = try zml.Buffer.scalar(platform, args.top_k, .i32), + .temperature = try zml.Buffer.scalar(platform, args.temperature, dtype), + .top_p = try zml.Buffer.scalar(platform, args.top_p, dtype), + .min_p = try zml.Buffer.scalar(platform, args.min_p, dtype), + }; + } +}; + +/// Given the output of the last layer of a LM with a `.voc` axis, +/// Compute indices for the next tokens, following the given sampling strategy. +/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax. +/// +/// Options are: +/// +/// * top_k: only sample among the k top scoring tokens, +/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens. +/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p +/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token +pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } { + var x, const topk_indices = fixupLogits(logits, opts); + + // the rest is similar to sampleTokens + const next_rng, const gumbel_noise = rng.gumbel(x.shape()); + x = x.add(gumbel_noise); + + const topk_idx = x.argMax(.topk, .i32).indices; + const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{}); + return .{ next_tokens, next_rng }; +} + +fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor { + const min_inf = Tensor.constant(.{}, logits.dtype().minValue()); + + // First reduce the vocab size to a reasonable sub set of candidate. + const full_topk = if (opts.max_top_k > 0) + logits.topK(opts.max_top_k, .voc, .{ .descending = true }) + else + logits.sort(.voc, .{ .descending = true }); + + // After the topk, we don't have .voc indices, anymore, only topk. + var x = full_topk.values.rename(.{ .voc = .topk }); + // mask values above the dynamic top_k + x = Tensor.iota(x.shape(), .topk).cmp(.GE, opts.top_k).select(min_inf, x); + x = x.mul(opts.temperature); + + // if there are high values in x, softmax can overflow and will create nans in full probs + // this propagate to probs_sum and probs_max. + const probs = x.softmax(.topk); + const probs_sum = probs.cumulativeSum(.topk); + const probs_max = probs.slice1d(.topk, .{ .start = 0, .end = 1 }); + + const top_p = opts.top_p.broad(x.shape()); + const min_p = probs_max.mul(opts.min_p).broad(x.shape()); + + // * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false + // * if first candidate score is even bigger, the probs become Nan because of the softmax, + // then cmp is is full false, and candidate is full false too. + const candidate = probs_sum.cmp(.LE, top_p).logical(.AND, probs.cmp(.GE, min_p)); + // * so we explicitly always accept first candidate. + const first_token = Tensor.iota(x.shape(), .topk).cmp(.EQ, Tensor.scalar(0, .i32)); + x = candidate.logical(.OR, first_token).select(x, min_inf); + + return .{ x, full_topk.indices }; +} + +test sampleTokensDynamic { + const platform = zml.testing.env(); + const allocator = std.testing.allocator; + + const ___ = -std.math.inf(f32); + const logits = [_]f32{ @log(2.0), @log(1.0), @log(4.0), @log(3.0) }; + const top_k_indices = [_]i32{ 2, 3, 0, 1 }; + const logits_buff = try zml.Buffer.fromArray(platform, logits); + const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform); + defer mod.deinit(); + + inline for (.{ + // top_k == logits.len -> just sort the input + .{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } }, + .{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } }, + .{ .{ .top_k = 2, .temperature = 0.1 }, [_]f32{ @log(4.0) * 0.1, @log(3.0) * 0.1, ___, ___ } }, + // top_k == logits.len and small top_p -> make sure at least one is returned + .{ .{ .top_k = 4, .top_p = 0.1 }, [_]f32{ @log(4.0), ___, ___, ___ } }, + .{ .{ .top_k = 4, .top_p = 0.701 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } }, + .{ .{ .top_k = 4, .top_p = 0.901 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), ___ } }, + // Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0) + .{ .{ .top_k = 3, .top_p = 0.701 }, [_]f32{ @log(4.0), ___, ___, ___ } }, + // Here top_p allows the first 3 results, but min_p only accepts the first two. + .{ .{ .top_k = 4, .top_p = 0.901, .min_p = 0.6 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } }, + }) |args_expected| { + const args, const expected = args_expected; + const new_logits, const indices = mod.call(.{ logits_buff, try DynamicSamplingStrategy.makeBuffers(platform, .f32, args) }); + try std.testing.expectEqual(top_k_indices, try indices.getValue(@TypeOf(top_k_indices))); + try zml.testing.expectEqual(expected, try new_logits.getValue(@TypeOf(expected))); + } + + { + // Similar but use bf16, and uses infinity to trigger nans after the softmax. + const bf16 = zml.floats.BFloat16; + + const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform); + defer mod_bf16.deinit(); + const boost = bf16.inf(); + const nerf = bf16.minusInf(); + const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf }); + const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) }); + try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32)); + try zml.testing.expectEqual([_]bf16{ boost, nerf, nerf, nerf }, try new_logits.getValue([4]bf16)); + } +} diff --git a/zml/tensor.zig b/zml/tensor.zig index ba79165..69aeddd 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -656,11 +656,12 @@ pub const Tensor = struct { /// Note: we only implement the μ=0, β=1 version. pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } { const rand, const u = self.uniform( - shape_, + // Always use .f32 to have a big enough mantissa. + shape_.withDtype(.f32), // We don't want 0 to be sampled otherwise `log` will return -inf. - .{ .min = std.math.floatEps(f64), .max = 1 }, + .{ .min = std.math.floatEps(f32), .max = 1 }, ); - return .{ rand, u.log().scale(-1).log().scale(-1) }; + return .{ rand, u.log().scale(-1).log().scale(-1).convert(shape_.dtype()) }; } test gumbel {