Remove .print() calls from globalAttnMask() and localAttnMask() in ModernBERT example to resolve compilation sharding error.

This commit is contained in:
Foke Singh 2025-01-15 16:59:26 +00:00
parent 09c43b8759
commit 7324a49da3

View File

@ -104,7 +104,7 @@ pub const ModernBertModel = struct {
// Broadcast to the desired output shape. // Broadcast to the desired output shape.
const seq_len = ids.dim(.k); const seq_len = ids.dim(.k);
const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt); const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt);
return pad_mask.broad(pad_mask_shape).print(); return pad_mask.broad(pad_mask_shape);
} }
/// Restrict global attn mask to a sliding window. /// Restrict global attn mask to a sliding window.
@ -121,7 +121,7 @@ pub const ModernBertModel = struct {
// Create sliding window mask (1 for positions within window, 0 outside) // Create sliding window mask (1 for positions within window, 0 outside)
const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32)); const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32));
const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue()); const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue());
return window_mask.select(global_mask, minus_inf).print(); return window_mask.select(global_mask, minus_inf);
} }
}; };