Replace log with select for generating the attention mask to avoid NaNs on zero values.
This commit is contained in:
parent
24a7c98476
commit
639f5cd994
10
zml/nn.zig
10
zml/nn.zig
@ -699,12 +699,16 @@ pub fn causalAttnMask(
|
|||||||
mask = mask.logical(.AND, window_mask);
|
mask = mask.logical(.AND, window_mask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mask = mask.convert(dtype);
|
|
||||||
if (dtype.isFloat()) {
|
if (dtype.isFloat()) {
|
||||||
// use log to convert "true" (ie 1) to 0, and "false" (ie 0) to -inf
|
|
||||||
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
|
meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats
|
||||||
mask = mask.log();
|
const zeros = Tensor.constant(mask.shape(), dtype.zero());
|
||||||
|
const minus_inf = Tensor.constant(mask.shape(), dtype.minValue());
|
||||||
|
mask = Tensor.select(mask, zeros, minus_inf);
|
||||||
|
} else {
|
||||||
|
mask = mask.convert(dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user