diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-30 18:29:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-30 17:29:36 +0000 |
commit | 4c967b9184834cd1e166dfdd6d88450d16bad8f2 (patch) | |
tree | 1b758b5839804607aab2531dd8bf86a97637552b /candle-transformers | |
parent | c05c0a8213eb5518901b2aa87503e8c6b65b9d0f (diff) | |
download | candle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.tar.gz candle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.tar.bz2 candle-4c967b9184834cd1e166dfdd6d88450d16bad8f2.zip |
Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example.
* Use the secondary decoder.
* Add a readme.
* More readme.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/marian.rs | 39 |
1 files changed, 29 insertions, 10 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index 71f17720..2bcfd2f7 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -135,7 +135,12 @@ impl Attention { .contiguous() } - fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result<Tensor> { + fn forward( + &self, + xs: &Tensor, + kv_states: Option<&Tensor>, + attn_mask: Option<&Tensor>, + ) -> Result<Tensor> { let is_cross_attn = kv_states.is_some(); let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; @@ -156,7 +161,10 @@ impl Attention { let key_states = key_states.reshape(proj_shape)?; let value_states = value_states.reshape(proj_shape)?; let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; - // todo: attn_mask + let attn_weights = match attn_mask { + None => attn_weights, + Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?, + }; let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_probs.matmul(&value_states)?; attn_output @@ -196,8 +204,8 @@ impl EncoderLayer { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let residual = xs; - let xs = - (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = (self.self_attn.forward(xs, None, None)? + residual)? + .apply(&self.self_attn_layer_norm)?; let residual = &xs; let xs = xs .apply(&self.fc1)? @@ -241,15 +249,20 @@ impl DecoderLayer { }) } - fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result<Tensor> { + fn forward( + &self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + attn_mask: &Tensor, + ) -> Result<Tensor> { let residual = xs; - let xs = - (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)? + .apply(&self.self_attn_layer_norm)?; let xs = match encoder_xs { None => xs, Some(encoder_xs) => { let residual = &xs; - let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; + let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?; (residual + xs)?.apply(&self.encoder_attn_layer_norm)? } }; @@ -346,6 +359,7 @@ impl Decoder { xs: &Tensor, encoder_xs: Option<&Tensor>, past_kv_len: usize, + attn_mask: &Tensor, ) -> Result<Tensor> { let xs = xs.apply(&self.embed_tokens)?; let xs = match self.embed_scale { @@ -358,7 +372,7 @@ impl Decoder { .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; for layer in self.layers.iter() { - xs = layer.forward(&xs, encoder_xs)?; + xs = layer.forward(&xs, encoder_xs, attn_mask)?; } Ok(xs) } @@ -413,9 +427,14 @@ impl MTModel { } pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> { + let seq_len = xs.dim(1)?; + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; self.model .decoder - .forward(xs, Some(encoder_xs), 0)? + .forward(xs, Some(encoder_xs), 0, &mask)? .apply(&self.lm_head)? .broadcast_add(&self.final_logits_bias) } |