diff options
Diffstat (limited to 'candle-examples/examples/musicgen/t5_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 33b11b95..607b5c93 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -96,10 +96,9 @@ impl T5LayerNorm { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; - let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?; - let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; - let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs) @@ -167,6 +166,9 @@ struct T5Attention { n_heads: usize, d_kv: usize, relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, } impl T5Attention { @@ -194,6 +196,9 @@ impl T5Attention { n_heads: cfg.num_heads, d_kv: cfg.d_kv, relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, }) } @@ -206,17 +211,53 @@ impl T5Attention { let v = self.v.forward(xs)?; let q = q .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let v = v .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let scores = q.matmul(&k.t()?)?; - // TODO: position_bias_masked + + let scores = match &self.relative_attention_bias { + None => scores, + Some(relative_attention_bias) => { + let query_length = seq_len; + let key_length = seq_len; + // This only handles the bidirectional case. + let num_buckets = self.relative_attention_num_buckets / 2; + let relative_position = (0..query_length as u32) + .map(|i| { + (0..key_length as u32) + .map(|j| { + if i < j { + j - i + num_buckets as u32 + } else { + i - j + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores + position_bias)? + // TODO: position_bias_masked? + } + }; + let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; Ok(attn_output) } @@ -324,7 +365,7 @@ impl T5Stack { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = input_embeds.dims2()?; + let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); let mut hidden_states = input_embeds; for block in self.block.iter() { |