summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/mixformer.rs38
1 files changed, 33 insertions, 5 deletions
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index 61eaea54..6a3b5515 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -75,6 +75,20 @@ impl Module for Embedding {
}
}
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
#[derive(Debug)]
struct RotaryEmbedding {
sin: Tensor,
@@ -198,6 +212,7 @@ struct MHA {
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize,
+ n_head: usize,
softmax_scale: f64,
span: tracing::Span,
}
@@ -214,6 +229,7 @@ impl MHA {
wqkv,
out_proj,
head_dim,
+ n_head: cfg.n_head,
kv_cache: None,
rotary_emb,
softmax_scale,
@@ -221,7 +237,7 @@ impl MHA {
})
}
- fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self
@@ -249,9 +265,16 @@ impl MHA {
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
- // TODO: Add the causal mask.
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
// scores = scores + causal_mask.to(dtype=scores.dtype)
+ let attn_weights = match mask {
+ None => attn_weights,
+ Some(mask) => masked_fill(
+ &attn_weights,
+ &mask.broadcast_left(b_size * self.n_head)?,
+ f32::NEG_INFINITY,
+ )?,
+ };
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
@@ -287,11 +310,11 @@ impl ParallelBlock {
})
}
- fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = xs.apply(&self.ln)?;
- let attn_outputs = self.mixer.forward(&xs)?;
+ let attn_outputs = self.mixer.forward(&xs, mask)?;
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
@@ -327,8 +350,13 @@ impl MixFormerSequentialForCausalLM {
let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?;
+ let mask = if seq_len <= 1 {
+ None
+ } else {
+ Some(get_mask(seq_len, xs.device())?)
+ };
for block in self.blocks.iter_mut() {
- xs = block.forward(&xs)?
+ xs = block.forward(&xs, mask.as_ref())?
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}