diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-14 09:01:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-14 10:01:12 +0200 |
commit | 68aa9c73208c8847bc7623a7eea4cbcbda0b31d6 (patch) | |
tree | 45e3b22d718a894f446d14ee4a967c7a859d5efe /candle-transformers | |
parent | 35e5f313977b6b1006ae98ee4443e0a27d14528d (diff) | |
download | candle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.tar.gz candle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.tar.bz2 candle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.zip |
Fix the device for the bert attention mask. (#2414)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/bert.rs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 42486a2d..2262aa1a 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< }; let attention_mask = attention_mask.to_dtype(dtype)?; // torch.finfo(dtype).min - (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?) + (attention_mask.ones_like()? - &attention_mask)? + .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) } |