zml.nn: add dynamic sampling with support for top‑k, top‑p, and min‑p settings. Implements token index computation based on the selected sampling strategy, including options for top_k, max_top_k, top_p, and min_p.
This commit is contained in:
parent
b244a18621
commit
f00538667e
168
zml/nn.zig
168
zml/nn.zig
@ -910,6 +910,7 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
|
||||
.max_value = partial.max_value,
|
||||
};
|
||||
}
|
||||
|
||||
test "sdpaMemEfficient without mask" {
|
||||
const platform = zml.testing.env();
|
||||
const allocator = std.testing.allocator;
|
||||
@ -1010,3 +1011,170 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
||||
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
|
||||
return .{ next_tokens, next_rng };
|
||||
}
|
||||
|
||||
test sampleTokens {
|
||||
const platform = zml.testing.env();
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
const inf = std.math.inf(f32);
|
||||
var rng_buff = try zml.Tensor.Rng.init(platform, 0xdeadbeef);
|
||||
defer rng_buff._state.deinit();
|
||||
|
||||
const mod = try zml.compileFn(allocator, sampleTokens, .{ Shape.init(.{ .voc = 4 }, .f32), .{ .topk = 4, .temperature = 2.0 }, zml.Tensor.Rng.shape() }, platform);
|
||||
defer mod.deinit();
|
||||
|
||||
inline for (.{
|
||||
.{ [_]f32{ inf, 3.0, 2.0, 1.0 }, 0 },
|
||||
.{ [_]f32{ -inf, 3.0, -inf, -inf }, 1 },
|
||||
.{ [_]f32{ 3.0, 2, inf, inf }, 2 },
|
||||
}) |logits_expected| {
|
||||
const logits, const expected: i32 = logits_expected;
|
||||
var logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||
defer logits_buff.deinit();
|
||||
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
|
||||
defer sampled.deinit();
|
||||
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
|
||||
}
|
||||
}
|
||||
|
||||
pub const DynamicSamplingStrategy = struct {
|
||||
max_top_k: u32,
|
||||
top_k: Tensor,
|
||||
temperature: Tensor,
|
||||
top_p: Tensor,
|
||||
min_p: Tensor,
|
||||
|
||||
pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) {
|
||||
const scalar_float = Shape.init(.{}, dtype);
|
||||
const scalar_i32 = Shape.init(.{}, .i32);
|
||||
return .{
|
||||
.max_top_k = max_top_k,
|
||||
.top_k = scalar_i32,
|
||||
.temperature = scalar_float,
|
||||
.top_p = scalar_float,
|
||||
.min_p = scalar_float,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn makeBuffers(
|
||||
platform: zml.Platform,
|
||||
dtype: zml.DataType,
|
||||
args: struct {
|
||||
top_k: u32,
|
||||
temperature: f32 = 1.0,
|
||||
top_p: f32 = 1.0,
|
||||
min_p: f32 = 0.0,
|
||||
},
|
||||
) !zml.Bufferized(DynamicSamplingStrategy) {
|
||||
return .{
|
||||
.max_top_k = 0,
|
||||
.top_k = try zml.Buffer.scalar(platform, args.top_k, .i32),
|
||||
.temperature = try zml.Buffer.scalar(platform, args.temperature, dtype),
|
||||
.top_p = try zml.Buffer.scalar(platform, args.top_p, dtype),
|
||||
.min_p = try zml.Buffer.scalar(platform, args.min_p, dtype),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Given the output of the last layer of a LM with a `.voc` axis,
|
||||
/// Compute indices for the next tokens, following the given sampling strategy.
|
||||
/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax.
|
||||
///
|
||||
/// Options are:
|
||||
///
|
||||
/// * top_k: only sample among the k top scoring tokens,
|
||||
/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens.
|
||||
/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p
|
||||
/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token
|
||||
pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
|
||||
var x, const topk_indices = fixupLogits(logits, opts);
|
||||
|
||||
// the rest is similar to sampleTokens
|
||||
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||
x = x.add(gumbel_noise);
|
||||
|
||||
const topk_idx = x.argMax(.topk, .i32).indices;
|
||||
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
||||
return .{ next_tokens, next_rng };
|
||||
}
|
||||
|
||||
fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor {
|
||||
const min_inf = Tensor.constant(.{}, logits.dtype().minValue());
|
||||
|
||||
// First reduce the vocab size to a reasonable sub set of candidate.
|
||||
const full_topk = if (opts.max_top_k > 0)
|
||||
logits.topK(opts.max_top_k, .voc, .{ .descending = true })
|
||||
else
|
||||
logits.sort(.voc, .{ .descending = true });
|
||||
|
||||
// After the topk, we don't have .voc indices, anymore, only topk.
|
||||
var x = full_topk.values.rename(.{ .voc = .topk });
|
||||
// mask values above the dynamic top_k
|
||||
x = Tensor.iota(x.shape(), .topk).cmp(.GE, opts.top_k).select(min_inf, x);
|
||||
x = x.mul(opts.temperature);
|
||||
|
||||
// if there are high values in x, softmax can overflow and will create nans in full probs
|
||||
// this propagate to probs_sum and probs_max.
|
||||
const probs = x.softmax(.topk);
|
||||
const probs_sum = probs.cumulativeSum(.topk);
|
||||
const probs_max = probs.slice1d(.topk, .{ .start = 0, .end = 1 });
|
||||
|
||||
const top_p = opts.top_p.broad(x.shape());
|
||||
const min_p = probs_max.mul(opts.min_p).broad(x.shape());
|
||||
|
||||
// * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false
|
||||
// * if first candidate score is even bigger, the probs become Nan because of the softmax,
|
||||
// then cmp is is full false, and candidate is full false too.
|
||||
const candidate = probs_sum.cmp(.LE, top_p).logical(.AND, probs.cmp(.GE, min_p));
|
||||
// * so we explicitly always accept first candidate.
|
||||
const first_token = Tensor.iota(x.shape(), .topk).cmp(.EQ, Tensor.scalar(0, .i32));
|
||||
x = candidate.logical(.OR, first_token).select(x, min_inf);
|
||||
|
||||
return .{ x, full_topk.indices };
|
||||
}
|
||||
|
||||
test sampleTokensDynamic {
|
||||
const platform = zml.testing.env();
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
const ___ = -std.math.inf(f32);
|
||||
const logits = [_]f32{ @log(2.0), @log(1.0), @log(4.0), @log(3.0) };
|
||||
const top_k_indices = [_]i32{ 2, 3, 0, 1 };
|
||||
const logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||
const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform);
|
||||
defer mod.deinit();
|
||||
|
||||
inline for (.{
|
||||
// top_k == logits.len -> just sort the input
|
||||
.{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } },
|
||||
.{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||
.{ .{ .top_k = 2, .temperature = 0.1 }, [_]f32{ @log(4.0) * 0.1, @log(3.0) * 0.1, ___, ___ } },
|
||||
// top_k == logits.len and small top_p -> make sure at least one is returned
|
||||
.{ .{ .top_k = 4, .top_p = 0.1 }, [_]f32{ @log(4.0), ___, ___, ___ } },
|
||||
.{ .{ .top_k = 4, .top_p = 0.701 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||
.{ .{ .top_k = 4, .top_p = 0.901 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), ___ } },
|
||||
// Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0)
|
||||
.{ .{ .top_k = 3, .top_p = 0.701 }, [_]f32{ @log(4.0), ___, ___, ___ } },
|
||||
// Here top_p allows the first 3 results, but min_p only accepts the first two.
|
||||
.{ .{ .top_k = 4, .top_p = 0.901, .min_p = 0.6 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||
}) |args_expected| {
|
||||
const args, const expected = args_expected;
|
||||
const new_logits, const indices = mod.call(.{ logits_buff, try DynamicSamplingStrategy.makeBuffers(platform, .f32, args) });
|
||||
try std.testing.expectEqual(top_k_indices, try indices.getValue(@TypeOf(top_k_indices)));
|
||||
try zml.testing.expectEqual(expected, try new_logits.getValue(@TypeOf(expected)));
|
||||
}
|
||||
|
||||
{
|
||||
// Similar but use bf16, and uses infinity to trigger nans after the softmax.
|
||||
const bf16 = zml.floats.BFloat16;
|
||||
|
||||
const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform);
|
||||
defer mod_bf16.deinit();
|
||||
const boost = bf16.inf();
|
||||
const nerf = bf16.minusInf();
|
||||
const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf });
|
||||
const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) });
|
||||
try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32));
|
||||
try zml.testing.expectEqual([_]bf16{ boost, nerf, nerf, nerf }, try new_logits.getValue([4]bf16));
|
||||
}
|
||||
}
|
||||
|
||||
@ -656,11 +656,12 @@ pub const Tensor = struct {
|
||||
/// Note: we only implement the μ=0, β=1 version.
|
||||
pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } {
|
||||
const rand, const u = self.uniform(
|
||||
shape_,
|
||||
// Always use .f32 to have a big enough mantissa.
|
||||
shape_.withDtype(.f32),
|
||||
// We don't want 0 to be sampled otherwise `log` will return -inf.
|
||||
.{ .min = std.math.floatEps(f64), .max = 1 },
|
||||
.{ .min = std.math.floatEps(f32), .max = 1 },
|
||||
);
|
||||
return .{ rand, u.log().scale(-1).log().scale(-1) };
|
||||
return .{ rand, u.log().scale(-1).log().scale(-1).convert(shape_.dtype()) };
|
||||
}
|
||||
|
||||
test gumbel {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user