summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/t5.rs51
1 files changed, 48 insertions, 3 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 84e072a2..8ba0c1c1 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -1,12 +1,38 @@
// T5 Text Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
-use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
+use crate::models::with_tracing::Embedding;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
+#[derive(Debug, Clone)]
+pub struct Linear {
+ weight: Tensor,
+ span: tracing::Span,
+}
+
+pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
+ let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { weight, span })
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let weight = self.weight.to_dtype(xs.dtype())?;
+ let w = match *xs.dims() {
+ [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => weight.broadcast_left(bsize)?.t()?,
+ _ => weight.t()?,
+ };
+ xs.matmul(&w)
+ }
+}
+
fn default_relative_attention_max_distance() -> usize {
128
}
@@ -185,7 +211,7 @@ impl Module for T5LayerNorm {
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
- let xs = xs.broadcast_mul(&self.weight)?;
+ let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?;
Ok(xs)
}
}
@@ -472,7 +498,8 @@ impl T5Attention {
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
- .unsqueeze(0)?;
+ .unsqueeze(0)?
+ .to_dtype(scores.dtype())?;
(scores.broadcast_add(&position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
@@ -679,8 +706,21 @@ impl T5Stack {
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
+ self.forward_dt(input_ids, encoder_hidden_states, None)
+ }
+
+ fn forward_dt(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ dtype: Option<DType>,
+ ) -> Result<Tensor> {
let _enter = self.span.enter();
let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let input_embeds = match dtype {
+ None => input_embeds,
+ Some(dtype) => input_embeds.to_dtype(dtype)?,
+ };
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter_mut() {
@@ -729,6 +769,11 @@ impl T5EncoderModel {
self.encoder.forward(input_ids, None)
}
+ pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward_dt(input_ids, None, dtype)
+ }
+
pub fn device(&self) -> &Device {
&self.device
}