summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/falcon.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-10 18:10:01 +0200
committerGitHub <noreply@github.com>2024-04-10 18:10:01 +0200
commitb81ecf712d1854598d6c9f9cfa06fbf0093f3bc9 (patch)
treec1f68535524010bd32bbd2a3f267f46aa7126c64 /candle-transformers/src/models/falcon.rs
parenta4d5a414e3ae79642ecfd6b7bb410c26a8a62a06 (diff)
downloadcandle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.tar.gz
candle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.tar.bz2
candle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.zip
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba. * Add a dtype flag.
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r--candle-transformers/src/models/falcon.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
index 5fea27b9..e9d4af7e 100644
--- a/candle-transformers/src/models/falcon.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -179,7 +179,9 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let on_true = Tensor::new(on_true, on_false.device())?
+ .to_dtype(on_false.dtype())?
+ .broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}