zml.nn: add dynamic sampling with support for top‑k, top‑p, and min‑p settings. Implements token index computation based on the selected sampling strategy, including options for top_k, max_top_k, top_p, and min_p.

This commit is contained in:
Tarry Singh 2023-06-16 14:34:18 +00:00
parent b244a18621
commit f00538667e
2 changed files with 172 additions and 3 deletions

View File

@ -910,6 +910,7 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
.max_value = partial.max_value, .max_value = partial.max_value,
}; };
} }
test "sdpaMemEfficient without mask" { test "sdpaMemEfficient without mask" {
const platform = zml.testing.env(); const platform = zml.testing.env();
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
@ -1010,3 +1011,170 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return .{ next_tokens, next_rng }; return .{ next_tokens, next_rng };
} }
test sampleTokens {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
const inf = std.math.inf(f32);
var rng_buff = try zml.Tensor.Rng.init(platform, 0xdeadbeef);
defer rng_buff._state.deinit();
const mod = try zml.compileFn(allocator, sampleTokens, .{ Shape.init(.{ .voc = 4 }, .f32), .{ .topk = 4, .temperature = 2.0 }, zml.Tensor.Rng.shape() }, platform);
defer mod.deinit();
inline for (.{
.{ [_]f32{ inf, 3.0, 2.0, 1.0 }, 0 },
.{ [_]f32{ -inf, 3.0, -inf, -inf }, 1 },
.{ [_]f32{ 3.0, 2, inf, inf }, 2 },
}) |logits_expected| {
const logits, const expected: i32 = logits_expected;
var logits_buff = try zml.Buffer.fromArray(platform, logits);
defer logits_buff.deinit();
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
defer sampled.deinit();
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
}
}
pub const DynamicSamplingStrategy = struct {
max_top_k: u32,
top_k: Tensor,
temperature: Tensor,
top_p: Tensor,
min_p: Tensor,
pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) {
const scalar_float = Shape.init(.{}, dtype);
const scalar_i32 = Shape.init(.{}, .i32);
return .{
.max_top_k = max_top_k,
.top_k = scalar_i32,
.temperature = scalar_float,
.top_p = scalar_float,
.min_p = scalar_float,
};
}
pub fn makeBuffers(
platform: zml.Platform,
dtype: zml.DataType,
args: struct {
top_k: u32,
temperature: f32 = 1.0,
top_p: f32 = 1.0,
min_p: f32 = 0.0,
},
) !zml.Bufferized(DynamicSamplingStrategy) {
return .{
.max_top_k = 0,
.top_k = try zml.Buffer.scalar(platform, args.top_k, .i32),
.temperature = try zml.Buffer.scalar(platform, args.temperature, dtype),
.top_p = try zml.Buffer.scalar(platform, args.top_p, dtype),
.min_p = try zml.Buffer.scalar(platform, args.min_p, dtype),
};
}
};
/// Given the output of the last layer of a LM with a `.voc` axis,
/// Compute indices for the next tokens, following the given sampling strategy.
/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax.
///
/// Options are:
///
/// * top_k: only sample among the k top scoring tokens,
/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens.
/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p
/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token
pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
var x, const topk_indices = fixupLogits(logits, opts);
// the rest is similar to sampleTokens
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
x = x.add(gumbel_noise);
const topk_idx = x.argMax(.topk, .i32).indices;
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
return .{ next_tokens, next_rng };
}
fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor {
const min_inf = Tensor.constant(.{}, logits.dtype().minValue());
// First reduce the vocab size to a reasonable sub set of candidate.
const full_topk = if (opts.max_top_k > 0)
logits.topK(opts.max_top_k, .voc, .{ .descending = true })
else
logits.sort(.voc, .{ .descending = true });
// After the topk, we don't have .voc indices, anymore, only topk.
var x = full_topk.values.rename(.{ .voc = .topk });
// mask values above the dynamic top_k
x = Tensor.iota(x.shape(), .topk).cmp(.GE, opts.top_k).select(min_inf, x);
x = x.mul(opts.temperature);
// if there are high values in x, softmax can overflow and will create nans in full probs
// this propagate to probs_sum and probs_max.
const probs = x.softmax(.topk);
const probs_sum = probs.cumulativeSum(.topk);
const probs_max = probs.slice1d(.topk, .{ .start = 0, .end = 1 });
const top_p = opts.top_p.broad(x.shape());
const min_p = probs_max.mul(opts.min_p).broad(x.shape());
// * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false
// * if first candidate score is even bigger, the probs become Nan because of the softmax,
// then cmp is is full false, and candidate is full false too.
const candidate = probs_sum.cmp(.LE, top_p).logical(.AND, probs.cmp(.GE, min_p));
// * so we explicitly always accept first candidate.
const first_token = Tensor.iota(x.shape(), .topk).cmp(.EQ, Tensor.scalar(0, .i32));
x = candidate.logical(.OR, first_token).select(x, min_inf);
return .{ x, full_topk.indices };
}
test sampleTokensDynamic {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
const ___ = -std.math.inf(f32);
const logits = [_]f32{ @log(2.0), @log(1.0), @log(4.0), @log(3.0) };
const top_k_indices = [_]i32{ 2, 3, 0, 1 };
const logits_buff = try zml.Buffer.fromArray(platform, logits);
const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform);
defer mod.deinit();
inline for (.{
// top_k == logits.len -> just sort the input
.{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } },
.{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
.{ .{ .top_k = 2, .temperature = 0.1 }, [_]f32{ @log(4.0) * 0.1, @log(3.0) * 0.1, ___, ___ } },
// top_k == logits.len and small top_p -> make sure at least one is returned
.{ .{ .top_k = 4, .top_p = 0.1 }, [_]f32{ @log(4.0), ___, ___, ___ } },
.{ .{ .top_k = 4, .top_p = 0.701 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
.{ .{ .top_k = 4, .top_p = 0.901 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), ___ } },
// Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0)
.{ .{ .top_k = 3, .top_p = 0.701 }, [_]f32{ @log(4.0), ___, ___, ___ } },
// Here top_p allows the first 3 results, but min_p only accepts the first two.
.{ .{ .top_k = 4, .top_p = 0.901, .min_p = 0.6 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
}) |args_expected| {
const args, const expected = args_expected;
const new_logits, const indices = mod.call(.{ logits_buff, try DynamicSamplingStrategy.makeBuffers(platform, .f32, args) });
try std.testing.expectEqual(top_k_indices, try indices.getValue(@TypeOf(top_k_indices)));
try zml.testing.expectEqual(expected, try new_logits.getValue(@TypeOf(expected)));
}
{
// Similar but use bf16, and uses infinity to trigger nans after the softmax.
const bf16 = zml.floats.BFloat16;
const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform);
defer mod_bf16.deinit();
const boost = bf16.inf();
const nerf = bf16.minusInf();
const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf });
const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) });
try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32));
try zml.testing.expectEqual([_]bf16{ boost, nerf, nerf, nerf }, try new_logits.getValue([4]bf16));
}
}

View File

@ -656,11 +656,12 @@ pub const Tensor = struct {
/// Note: we only implement the μ=0, β=1 version. /// Note: we only implement the μ=0, β=1 version.
pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } { pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } {
const rand, const u = self.uniform( const rand, const u = self.uniform(
shape_, // Always use .f32 to have a big enough mantissa.
shape_.withDtype(.f32),
// We don't want 0 to be sampled otherwise `log` will return -inf. // We don't want 0 to be sampled otherwise `log` will return -inf.
.{ .min = std.math.floatEps(f64), .max = 1 }, .{ .min = std.math.floatEps(f32), .max = 1 },
); );
return .{ rand, u.log().scale(-1).log().scale(-1) }; return .{ rand, u.log().scale(-1).log().scale(-1).convert(shape_.dtype()) };
} }
test gumbel { test gumbel {