summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/falcon.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-25 15:31:04 +0100
committerGitHub <noreply@github.com>2024-03-25 15:31:04 +0100
commitd3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (patch)
tree06a083abb9e35c708a4a5736740858f46bfd45ad /candle-transformers/src/models/falcon.rs
parentcd254074f354c4066bc73e1c5cc5ecc84d25a2db (diff)
downloadcandle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.gz
candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.bz2
candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.zip
Avoid the attention mask where possible. (#1933)
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r--candle-transformers/src/models/falcon.rs32
1 files changed, 20 insertions, 12 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
index 86cf8451..24fd3c46 100644
--- a/candle-transformers/src/models/falcon.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -247,7 +247,7 @@ impl FalconAttention {
}
}
- fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let fused_qkv = self.query_key_value.forward(x)?;
let head_dim = self.head_dim;
let (query, key, value) = self.split_heads(&fused_qkv)?;
@@ -267,7 +267,6 @@ impl FalconAttention {
(query, key)
};
let (mut key, mut value) = (key, value);
- let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
if self.use_cache {
if let Some((cache_k, cache_v)) = &self.kv_cache {
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
@@ -293,13 +292,18 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
- let attention_scores = candle_nn::ops::softmax(
- &attention_scores
- .broadcast_add(&mask.squeeze(1)?)?
- .to_dtype(DType::F32)?,
- D::Minus1,
- )?
- .to_dtype(x.dtype())?;
+ let attention_scores = match mask {
+ None => attention_scores,
+ Some(mask) => {
+ let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
+ .to_dtype(query.dtype())?;
+ attention_scores.broadcast_add(&mask.squeeze(1)?)?
+ }
+ };
+
+ let attention_scores =
+ candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
+ .to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
@@ -372,7 +376,7 @@ impl FalconDecoderLayer {
})
}
- fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let residual = x.clone();
let ln_attn = self.inp_layernorm.forward(x)?;
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
@@ -457,9 +461,13 @@ impl Falcon {
Some((k, _)) => k.dim(1)?,
None => 0,
};
- let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
+ let causal_mask = if seq_len <= 1 {
+ None
+ } else {
+ Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
+ };
for block in self.blocks.iter_mut() {
- hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
+ hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
}
let hidden_state = self.ln_f.forward(&hidden_state)?;
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;