summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-30 18:29:36 +0100
committerGitHub <noreply@github.com>2023-10-30 17:29:36 +0000
commit4c967b9184834cd1e166dfdd6d88450d16bad8f2 (patch)
tree1b758b5839804607aab2531dd8bf86a97637552b /candle-transformers
parentc05c0a8213eb5518901b2aa87503e8c6b65b9d0f (diff)
downloadcandle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.tar.gz
candle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.tar.bz2
candle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.zip
Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example. * Use the secondary decoder. * Add a readme. * More readme.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/marian.rs39
1 files changed, 29 insertions, 10 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index 71f17720..2bcfd2f7 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -135,7 +135,12 @@ impl Attention {
.contiguous()
}
- fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result<Tensor> {
+ fn forward(
+ &self,
+ xs: &Tensor,
+ kv_states: Option<&Tensor>,
+ attn_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
@@ -156,7 +161,10 @@ impl Attention {
let key_states = key_states.reshape(proj_shape)?;
let value_states = value_states.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
- // todo: attn_mask
+ let attn_weights = match attn_mask {
+ None => attn_weights,
+ Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
+ };
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_probs.matmul(&value_states)?;
attn_output
@@ -196,8 +204,8 @@ impl EncoderLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
- let xs =
- (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
+ let xs = (self.self_attn.forward(xs, None, None)? + residual)?
+ .apply(&self.self_attn_layer_norm)?;
let residual = &xs;
let xs = xs
.apply(&self.fc1)?
@@ -241,15 +249,20 @@ impl DecoderLayer {
})
}
- fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result<Tensor> {
+ fn forward(
+ &self,
+ xs: &Tensor,
+ encoder_xs: Option<&Tensor>,
+ attn_mask: &Tensor,
+ ) -> Result<Tensor> {
let residual = xs;
- let xs =
- (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
+ let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
+ .apply(&self.self_attn_layer_norm)?;
let xs = match encoder_xs {
None => xs,
Some(encoder_xs) => {
let residual = &xs;
- let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
+ let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
}
};
@@ -346,6 +359,7 @@ impl Decoder {
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
+ attn_mask: &Tensor,
) -> Result<Tensor> {
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
@@ -358,7 +372,7 @@ impl Decoder {
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter() {
- xs = layer.forward(&xs, encoder_xs)?;
+ xs = layer.forward(&xs, encoder_xs, attn_mask)?;
}
Ok(xs)
}
@@ -413,9 +427,14 @@ impl MTModel {
}
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
+ let seq_len = xs.dim(1)?;
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.model
.decoder
- .forward(xs, Some(encoder_xs), 0)?
+ .forward(xs, Some(encoder_xs), 0, &mask)?
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}