zml.nn: add dynamic sampling with support for top‑k, top‑p, and min‑p settings. Implements token index computation based on the selected sampling strategy, including options for top_k, max_top_k, top_p, and min_p.
This commit is contained in:
parent
b244a18621
commit
f00538667e
168
zml/nn.zig
168
zml/nn.zig
@ -910,6 +910,7 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
|
|||||||
.max_value = partial.max_value,
|
.max_value = partial.max_value,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
test "sdpaMemEfficient without mask" {
|
test "sdpaMemEfficient without mask" {
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
@ -1010,3 +1011,170 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
|||||||
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
|
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
|
||||||
return .{ next_tokens, next_rng };
|
return .{ next_tokens, next_rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test sampleTokens {
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
const allocator = std.testing.allocator;
|
||||||
|
|
||||||
|
const inf = std.math.inf(f32);
|
||||||
|
var rng_buff = try zml.Tensor.Rng.init(platform, 0xdeadbeef);
|
||||||
|
defer rng_buff._state.deinit();
|
||||||
|
|
||||||
|
const mod = try zml.compileFn(allocator, sampleTokens, .{ Shape.init(.{ .voc = 4 }, .f32), .{ .topk = 4, .temperature = 2.0 }, zml.Tensor.Rng.shape() }, platform);
|
||||||
|
defer mod.deinit();
|
||||||
|
|
||||||
|
inline for (.{
|
||||||
|
.{ [_]f32{ inf, 3.0, 2.0, 1.0 }, 0 },
|
||||||
|
.{ [_]f32{ -inf, 3.0, -inf, -inf }, 1 },
|
||||||
|
.{ [_]f32{ 3.0, 2, inf, inf }, 2 },
|
||||||
|
}) |logits_expected| {
|
||||||
|
const logits, const expected: i32 = logits_expected;
|
||||||
|
var logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||||
|
defer logits_buff.deinit();
|
||||||
|
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
|
||||||
|
defer sampled.deinit();
|
||||||
|
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const DynamicSamplingStrategy = struct {
|
||||||
|
max_top_k: u32,
|
||||||
|
top_k: Tensor,
|
||||||
|
temperature: Tensor,
|
||||||
|
top_p: Tensor,
|
||||||
|
min_p: Tensor,
|
||||||
|
|
||||||
|
pub fn shapes(dtype: DataType, max_top_k: u32) zml.ShapeOf(DynamicSamplingStrategy) {
|
||||||
|
const scalar_float = Shape.init(.{}, dtype);
|
||||||
|
const scalar_i32 = Shape.init(.{}, .i32);
|
||||||
|
return .{
|
||||||
|
.max_top_k = max_top_k,
|
||||||
|
.top_k = scalar_i32,
|
||||||
|
.temperature = scalar_float,
|
||||||
|
.top_p = scalar_float,
|
||||||
|
.min_p = scalar_float,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn makeBuffers(
|
||||||
|
platform: zml.Platform,
|
||||||
|
dtype: zml.DataType,
|
||||||
|
args: struct {
|
||||||
|
top_k: u32,
|
||||||
|
temperature: f32 = 1.0,
|
||||||
|
top_p: f32 = 1.0,
|
||||||
|
min_p: f32 = 0.0,
|
||||||
|
},
|
||||||
|
) !zml.Bufferized(DynamicSamplingStrategy) {
|
||||||
|
return .{
|
||||||
|
.max_top_k = 0,
|
||||||
|
.top_k = try zml.Buffer.scalar(platform, args.top_k, .i32),
|
||||||
|
.temperature = try zml.Buffer.scalar(platform, args.temperature, dtype),
|
||||||
|
.top_p = try zml.Buffer.scalar(platform, args.top_p, dtype),
|
||||||
|
.min_p = try zml.Buffer.scalar(platform, args.min_p, dtype),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Given the output of the last layer of a LM with a `.voc` axis,
|
||||||
|
/// Compute indices for the next tokens, following the given sampling strategy.
|
||||||
|
/// The dynamic sampling strategy is more expressive but top_p requires computing the softmax.
|
||||||
|
///
|
||||||
|
/// Options are:
|
||||||
|
///
|
||||||
|
/// * top_k: only sample among the k top scoring tokens,
|
||||||
|
/// * max_top_k: limit a compilation time what is the max possible runtime value for top_k, saving memory and compute by not having to fully sort the tokens.
|
||||||
|
/// * top_p: only sample among top scoring tokens whose probabilities sum up to top_p
|
||||||
|
/// * min_p: drop tokens whose probabilities are lower than a ratio of the most likely token
|
||||||
|
pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
|
||||||
|
var x, const topk_indices = fixupLogits(logits, opts);
|
||||||
|
|
||||||
|
// the rest is similar to sampleTokens
|
||||||
|
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||||
|
x = x.add(gumbel_noise);
|
||||||
|
|
||||||
|
const topk_idx = x.argMax(.topk, .i32).indices;
|
||||||
|
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
||||||
|
return .{ next_tokens, next_rng };
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fixupLogits(logits: Tensor, opts: DynamicSamplingStrategy) [2]Tensor {
|
||||||
|
const min_inf = Tensor.constant(.{}, logits.dtype().minValue());
|
||||||
|
|
||||||
|
// First reduce the vocab size to a reasonable sub set of candidate.
|
||||||
|
const full_topk = if (opts.max_top_k > 0)
|
||||||
|
logits.topK(opts.max_top_k, .voc, .{ .descending = true })
|
||||||
|
else
|
||||||
|
logits.sort(.voc, .{ .descending = true });
|
||||||
|
|
||||||
|
// After the topk, we don't have .voc indices, anymore, only topk.
|
||||||
|
var x = full_topk.values.rename(.{ .voc = .topk });
|
||||||
|
// mask values above the dynamic top_k
|
||||||
|
x = Tensor.iota(x.shape(), .topk).cmp(.GE, opts.top_k).select(min_inf, x);
|
||||||
|
x = x.mul(opts.temperature);
|
||||||
|
|
||||||
|
// if there are high values in x, softmax can overflow and will create nans in full probs
|
||||||
|
// this propagate to probs_sum and probs_max.
|
||||||
|
const probs = x.softmax(.topk);
|
||||||
|
const probs_sum = probs.cumulativeSum(.topk);
|
||||||
|
const probs_max = probs.slice1d(.topk, .{ .start = 0, .end = 1 });
|
||||||
|
|
||||||
|
const top_p = opts.top_p.broad(x.shape());
|
||||||
|
const min_p = probs_max.mul(opts.min_p).broad(x.shape());
|
||||||
|
|
||||||
|
// * if first candidate has very high prob, then probs_sum is always greater than top_p and candidate is full false
|
||||||
|
// * if first candidate score is even bigger, the probs become Nan because of the softmax,
|
||||||
|
// then cmp is is full false, and candidate is full false too.
|
||||||
|
const candidate = probs_sum.cmp(.LE, top_p).logical(.AND, probs.cmp(.GE, min_p));
|
||||||
|
// * so we explicitly always accept first candidate.
|
||||||
|
const first_token = Tensor.iota(x.shape(), .topk).cmp(.EQ, Tensor.scalar(0, .i32));
|
||||||
|
x = candidate.logical(.OR, first_token).select(x, min_inf);
|
||||||
|
|
||||||
|
return .{ x, full_topk.indices };
|
||||||
|
}
|
||||||
|
|
||||||
|
test sampleTokensDynamic {
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
const allocator = std.testing.allocator;
|
||||||
|
|
||||||
|
const ___ = -std.math.inf(f32);
|
||||||
|
const logits = [_]f32{ @log(2.0), @log(1.0), @log(4.0), @log(3.0) };
|
||||||
|
const top_k_indices = [_]i32{ 2, 3, 0, 1 };
|
||||||
|
const logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||||
|
const mod = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .f32), DynamicSamplingStrategy.shapes(.f32, 0) }, platform);
|
||||||
|
defer mod.deinit();
|
||||||
|
|
||||||
|
inline for (.{
|
||||||
|
// top_k == logits.len -> just sort the input
|
||||||
|
.{ .{ .top_k = 4 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), @log(1.0) } },
|
||||||
|
.{ .{ .top_k = 2 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||||
|
.{ .{ .top_k = 2, .temperature = 0.1 }, [_]f32{ @log(4.0) * 0.1, @log(3.0) * 0.1, ___, ___ } },
|
||||||
|
// top_k == logits.len and small top_p -> make sure at least one is returned
|
||||||
|
.{ .{ .top_k = 4, .top_p = 0.1 }, [_]f32{ @log(4.0), ___, ___, ___ } },
|
||||||
|
.{ .{ .top_k = 4, .top_p = 0.701 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||||
|
.{ .{ .top_k = 4, .top_p = 0.901 }, [_]f32{ @log(4.0), @log(3.0), @log(2.0), ___ } },
|
||||||
|
// Here top_p is computed on the top 3 items, so 0.701 isn't enougth anymore to allow @log(3.0)
|
||||||
|
.{ .{ .top_k = 3, .top_p = 0.701 }, [_]f32{ @log(4.0), ___, ___, ___ } },
|
||||||
|
// Here top_p allows the first 3 results, but min_p only accepts the first two.
|
||||||
|
.{ .{ .top_k = 4, .top_p = 0.901, .min_p = 0.6 }, [_]f32{ @log(4.0), @log(3.0), ___, ___ } },
|
||||||
|
}) |args_expected| {
|
||||||
|
const args, const expected = args_expected;
|
||||||
|
const new_logits, const indices = mod.call(.{ logits_buff, try DynamicSamplingStrategy.makeBuffers(platform, .f32, args) });
|
||||||
|
try std.testing.expectEqual(top_k_indices, try indices.getValue(@TypeOf(top_k_indices)));
|
||||||
|
try zml.testing.expectEqual(expected, try new_logits.getValue(@TypeOf(expected)));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Similar but use bf16, and uses infinity to trigger nans after the softmax.
|
||||||
|
const bf16 = zml.floats.BFloat16;
|
||||||
|
|
||||||
|
const mod_bf16 = try zml.compileFn(allocator, fixupLogits, .{ Shape.init(.{ .voc = logits.len }, .bf16), DynamicSamplingStrategy.shapes(.bf16, 0) }, platform);
|
||||||
|
defer mod_bf16.deinit();
|
||||||
|
const boost = bf16.inf();
|
||||||
|
const nerf = bf16.minusInf();
|
||||||
|
const logits_buff_2 = try zml.Buffer.fromArray(platform, [4]bf16{ boost, boost, bf16.fromF32(2), nerf });
|
||||||
|
const new_logits, const indices = mod_bf16.call(.{ logits_buff_2, try DynamicSamplingStrategy.makeBuffers(platform, .bf16, .{ .top_k = 4, .top_p = 0.9, .min_p = 0.1 }) });
|
||||||
|
try std.testing.expectEqual([_]i32{ 0, 1, 2, 3 }, try indices.getValue([4]i32));
|
||||||
|
try zml.testing.expectEqual([_]bf16{ boost, nerf, nerf, nerf }, try new_logits.getValue([4]bf16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -656,11 +656,12 @@ pub const Tensor = struct {
|
|||||||
/// Note: we only implement the μ=0, β=1 version.
|
/// Note: we only implement the μ=0, β=1 version.
|
||||||
pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } {
|
pub fn gumbel(self: Rng, shape_: Shape) struct { Rng, Tensor } {
|
||||||
const rand, const u = self.uniform(
|
const rand, const u = self.uniform(
|
||||||
shape_,
|
// Always use .f32 to have a big enough mantissa.
|
||||||
|
shape_.withDtype(.f32),
|
||||||
// We don't want 0 to be sampled otherwise `log` will return -inf.
|
// We don't want 0 to be sampled otherwise `log` will return -inf.
|
||||||
.{ .min = std.math.floatEps(f64), .max = 1 },
|
.{ .min = std.math.floatEps(f32), .max = 1 },
|
||||||
);
|
);
|
||||||
return .{ rand, u.log().scale(-1).log().scale(-1) };
|
return .{ rand, u.log().scale(-1).log().scale(-1).convert(shape_.dtype()) };
|
||||||
}
|
}
|
||||||
|
|
||||||
test gumbel {
|
test gumbel {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user