diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-10 18:10:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-10 18:10:01 +0200 |
commit | b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9 (patch) | |
tree | c1f68535524010bd32bbd2a3f267f46aa7126c64 /candle-transformers/src/models/falcon.rs | |
parent | a4d5a414e3ae79642ecfd6b7bb410c26a8a62a06 (diff) | |
download | candle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.tar.gz candle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.tar.bz2 candle-b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9.zip |
Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba.
* Add a dtype flag.
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r-- | candle-transformers/src/models/falcon.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 5fea27b9..e9d4af7e 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -179,7 +179,9 @@ impl FalconRotaryEmbedding { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let on_true = Tensor::new(on_true, on_false.device())? + .to_dtype(on_false.dtype())? + .broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true, on_false)?; Ok(m) } |