Replace log with select for generating the attention mask to avoid NaNs on zero values.
This commit is contained in:
parent
24a7c98476
commit
639f5cd994
10
zml/nn.zig
10
zml/nn.zig
@ -699,12 +699,16 @@ pub fn causalAttnMask(
|
||||
mask = mask.logical(.AND, window_mask);
|
||||
}
|
||||
}
|
||||
mask = mask.convert(dtype);
|
||||
|
||||
if (dtype.isFloat()) {
|
||||
// use log to convert "true" (ie 1) to 0, and "false" (ie 0) to -inf
|
||||
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
|
||||
mask = mask.log();
|
||||
const zeros = Tensor.constant(mask.shape(), dtype.zero());
|
||||
const minus_inf = Tensor.constant(mask.shape(), dtype.minValue());
|
||||
mask = Tensor.select(mask, zeros, minus_inf);
|
||||
} else {
|
||||
mask = mask.convert(dtype);
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user