diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-25 15:31:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-25 15:31:04 +0100 |
commit | d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (patch) | |
tree | 06a083abb9e35c708a4a5736740858f46bfd45ad /candle-transformers/src/models/falcon.rs | |
parent | cd254074f354c4066bc73e1c5cc5ecc84d25a2db (diff) | |
download | candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.gz candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.bz2 candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.zip |
Avoid the attention mask where possible. (#1933)
Diffstat (limited to 'candle-transformers/src/models/falcon.rs')
-rw-r--r-- | candle-transformers/src/models/falcon.rs | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 86cf8451..24fd3c46 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -247,7 +247,7 @@ impl FalconAttention { } } - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> { let fused_qkv = self.query_key_value.forward(x)?; let head_dim = self.head_dim; let (query, key, value) = self.split_heads(&fused_qkv)?; @@ -267,7 +267,6 @@ impl FalconAttention { (query, key) }; let (mut key, mut value) = (key, value); - let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; if self.use_cache { if let Some((cache_k, cache_v)) = &self.kv_cache { // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for @@ -293,13 +292,18 @@ impl FalconAttention { // Only handle the case where alibi is None here, and non-flash attention. let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = candle_nn::ops::softmax( - &attention_scores - .broadcast_add(&mask.squeeze(1)?)? - .to_dtype(DType::F32)?, - D::Minus1, - )? - .to_dtype(x.dtype())?; + let attention_scores = match mask { + None => attention_scores, + Some(mask) => { + let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)? + .to_dtype(query.dtype())?; + attention_scores.broadcast_add(&mask.squeeze(1)?)? + } + }; + + let attention_scores = + candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)? + .to_dtype(x.dtype())?; let attn_output = attention_scores .matmul(&value)? .reshape((b_sz, self.num_heads, seq_len, head_dim))? @@ -372,7 +376,7 @@ impl FalconDecoderLayer { }) } - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> { let residual = x.clone(); let ln_attn = self.inp_layernorm.forward(x)?; let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; @@ -457,9 +461,13 @@ impl Falcon { Some((k, _)) => k.dim(1)?, None => 0, }; - let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; + let causal_mask = if seq_len <= 1 { + None + } else { + Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?) + }; for block in self.blocks.iter_mut() { - hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; + hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?; } let hidden_state = self.ln_f.forward(&hidden_state)?; let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; |