Replace log with select for generating the attention mask to avoid NaNs on zero values.

This commit is contained in:
Tarry Singh 2023-02-16 10:36:23 +00:00
parent 24a7c98476
commit 639f5cd994

View File

@ -699,12 +699,16 @@ pub fn causalAttnMask(
mask = mask.logical(.AND, window_mask); mask = mask.logical(.AND, window_mask);
} }
} }
mask = mask.convert(dtype);
if (dtype.isFloat()) { if (dtype.isFloat()) {
// use log to convert "true" (ie 1) to 0, and "false" (ie 0) to -inf
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
mask = mask.log(); const zeros = Tensor.constant(mask.shape(), dtype.zero());
const minus_inf = Tensor.constant(mask.shape(), dtype.minValue());
mask = Tensor.select(mask, zeros, minus_inf);
} else {
mask = mask.convert(dtype);
} }
return mask; return mask;
} }