From 639f5cd9944a552d9923cad05dc4ecef2acdc8c2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 16 Feb 2023 10:36:23 +0000 Subject: [PATCH] Replace `log` with `select` for generating the attention mask to avoid NaNs on zero values. --- zml/nn.zig | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/zml/nn.zig b/zml/nn.zig index eff5803..d609f3c 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -699,12 +699,16 @@ pub fn causalAttnMask( mask = mask.logical(.AND, window_mask); } } - mask = mask.convert(dtype); + if (dtype.isFloat()) { - // use log to convert "true" (ie 1) to 0, and "false" (ie 0) to -inf meta.guard(dtype.isFloat(), @src()); // -inf only exists for floats - mask = mask.log(); + const zeros = Tensor.constant(mask.shape(), dtype.zero()); + const minus_inf = Tensor.constant(mask.shape(), dtype.minValue()); + mask = Tensor.select(mask, zeros, minus_inf); + } else { + mask = mask.convert(dtype); } + return mask; }