diff options
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r-- | candle-transformers/src/models/t5.rs | 51 |
1 files changed, 48 insertions, 3 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 84e072a2..8ba0c1c1 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,12 +1,38 @@ // T5 Text Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; +use crate::models::with_tracing::Embedding; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use serde::Deserialize; use std::sync::Arc; +#[derive(Debug, Clone)] +pub struct Linear { + weight: Tensor, + span: tracing::Span, +} + +pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { weight, span }) +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let weight = self.weight.to_dtype(xs.dtype())?; + let w = match *xs.dims() { + [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => weight.broadcast_left(bsize)?.t()?, + _ => weight.t()?, + }; + xs.matmul(&w) + } +} + fn default_relative_attention_max_distance() -> usize { 128 } @@ -185,7 +211,7 @@ impl Module for T5LayerNorm { let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; - let xs = xs.broadcast_mul(&self.weight)?; + let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?; Ok(xs) } } @@ -472,7 +498,8 @@ impl T5Attention { let position_bias = relative_attention_bias .forward(&relative_buckets)? .permute((2, 0, 1))? - .unsqueeze(0)?; + .unsqueeze(0)? + .to_dtype(scores.dtype())?; (scores.broadcast_add(&position_bias)?, Some(position_bias)) // TODO: position_bias_masked? } @@ -679,8 +706,21 @@ impl T5Stack { input_ids: &Tensor, encoder_hidden_states: Option<&Tensor>, ) -> Result<Tensor> { + self.forward_dt(input_ids, encoder_hidden_states, None) + } + + fn forward_dt( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + dtype: Option<DType>, + ) -> Result<Tensor> { let _enter = self.span.enter(); let input_embeds = self.shared.as_ref().forward(input_ids)?; + let input_embeds = match dtype { + None => input_embeds, + Some(dtype) => input_embeds.to_dtype(dtype)?, + }; let mut hidden_states = input_embeds; let mut position_bias = None; for block in self.block.iter_mut() { @@ -729,6 +769,11 @@ impl T5EncoderModel { self.encoder.forward(input_ids, None) } + pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> { + let _enter = self.span.enter(); + self.encoder.forward_dt(input_ids, None, dtype) + } + pub fn device(&self) -> &Device { &self.device } |