Remove .print() calls from globalAttnMask() and localAttnMask() in ModernBERT example to resolve compilation sharding error.
This commit is contained in:
parent
09c43b8759
commit
7324a49da3
@ -104,7 +104,7 @@ pub const ModernBertModel = struct {
|
||||
// Broadcast to the desired output shape.
|
||||
const seq_len = ids.dim(.k);
|
||||
const pad_mask_shape = zml.Shape.init(.{ .b = pad_mask.dim(.b), .q = seq_len, .k = seq_len }, dt);
|
||||
return pad_mask.broad(pad_mask_shape).print();
|
||||
return pad_mask.broad(pad_mask_shape);
|
||||
}
|
||||
|
||||
/// Restrict global attn mask to a sliding window.
|
||||
@ -121,7 +121,7 @@ pub const ModernBertModel = struct {
|
||||
// Create sliding window mask (1 for positions within window, 0 outside)
|
||||
const window_mask = distance.cmp(.LE, Tensor.scalar(@divExact(window_size, 2), .i32));
|
||||
const minus_inf = Tensor.constant(mask_shape, mask_shape.dtype().minValue());
|
||||
return window_mask.select(global_mask, minus_inf).print();
|
||||
return window_mask.select(global_mask, minus_inf);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user