summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-14 09:01:12 +0100
committerGitHub <noreply@github.com>2024-08-14 10:01:12 +0200
commit68aa9c73208c8847bc7623a7eea4cbcbda0b31d6 (patch)
tree45e3b22d718a894f446d14ee4a967c7a859d5efe /candle-transformers
parent35e5f313977b6b1006ae98ee4443e0a27d14528d (diff)
downloadcandle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.tar.gz
candle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.tar.bz2
candle-68aa9c73208c8847bc7623a7eea4cbcbda0b31d6.zip
Fix the device for the bert attention mask. (#2414)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/bert.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 42486a2d..2262aa1a 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
- (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
+ (attention_mask.ones_like()? - &attention_mask)?
+ .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}