summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-30 18:08:51 +0100
committerGitHub <noreply@github.com>2024-10-30 18:08:51 +0100
commit7ac0de15a9fafe59d9f97fb6d90662790488433e (patch)
treeadd9e1c47668a6992bb3cf11e41cc8d560d67288
parentd232e132f6af552c351bb046a38df4bce009c8aa (diff)
downloadcandle-7ac0de15a9fafe59d9f97fb6d90662790488433e.tar.gz
candle-7ac0de15a9fafe59d9f97fb6d90662790488433e.tar.bz2
candle-7ac0de15a9fafe59d9f97fb6d90662790488433e.zip
Lazy upcasting for t5. (#2589)
-rw-r--r--candle-examples/examples/stable-diffusion-3/clip.rs29
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs13
-rw-r--r--candle-transformers/src/models/t5.rs51
3 files changed, 59 insertions, 34 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs
index d198366a..4891a1ba 100644
--- a/candle-examples/examples/stable-diffusion-3/clip.rs
+++ b/candle-examples/examples/stable-diffusion-3/clip.rs
@@ -118,7 +118,7 @@ impl T5WithTokenizer {
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
- let embeddings = self.t5.forward(&input_token_ids)?;
+ let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
Ok(embeddings)
}
}
@@ -144,7 +144,7 @@ impl StableDiffusion3TripleClipWithTokenizer {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
- candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
+ candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
@@ -164,11 +164,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
max_position_embeddings,
)?;
- // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
- // This is a temporary workaround until the T5 implementation is updated to support fp16.
- // Also see:
- // https://github.com/huggingface/candle/issues/2480
- // https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
@@ -178,34 +173,26 @@ impl StableDiffusion3TripleClipWithTokenizer {
})
}
- pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
+ pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
- vb_fp16.pp("clip_l.transformer"),
+ vb.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
- vb_fp16.pp("clip_g.transformer"),
+ vb.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
- let text_projection = candle_nn::linear_no_bias(
- 1280,
- 1280,
- vb_fp16.pp("clip_g.transformer.text_projection"),
- )?;
+ let text_projection =
+ candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
- // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
- // This is a temporary workaround until the T5 implementation is updated to support fp16.
- // Also see:
- // https://github.com/huggingface/candle/issues/2480
- // https://github.com/huggingface/candle/pull/2481
- let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
+ let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 31d3fc42..9ad057e3 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -194,18 +194,11 @@ fn main() -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
- let vb_fp16 = unsafe {
+ let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
-
- let vb_fp32 = unsafe {
- candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
- };
- let triple = StableDiffusion3TripleClipWithTokenizer::new(
- vb_fp16.pp("text_encoders"),
- vb_fp32.pp("text_encoders"),
- )?;
- (MMDiTConfig::sd3_medium(), triple, vb_fp16)
+ let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?;
+ (MMDiTConfig::sd3_medium(), triple, vb)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 84e072a2..8ba0c1c1 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -1,12 +1,38 @@
// T5 Text Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
-use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
+use crate::models::with_tracing::Embedding;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
+#[derive(Debug, Clone)]
+pub struct Linear {
+ weight: Tensor,
+ span: tracing::Span,
+}
+
+pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
+ let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { weight, span })
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let weight = self.weight.to_dtype(xs.dtype())?;
+ let w = match *xs.dims() {
+ [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => weight.broadcast_left(bsize)?.t()?,
+ _ => weight.t()?,
+ };
+ xs.matmul(&w)
+ }
+}
+
fn default_relative_attention_max_distance() -> usize {
128
}
@@ -185,7 +211,7 @@ impl Module for T5LayerNorm {
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
- let xs = xs.broadcast_mul(&self.weight)?;
+ let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?;
Ok(xs)
}
}
@@ -472,7 +498,8 @@ impl T5Attention {
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
- .unsqueeze(0)?;
+ .unsqueeze(0)?
+ .to_dtype(scores.dtype())?;
(scores.broadcast_add(&position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
@@ -679,8 +706,21 @@ impl T5Stack {
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
+ self.forward_dt(input_ids, encoder_hidden_states, None)
+ }
+
+ fn forward_dt(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ dtype: Option<DType>,
+ ) -> Result<Tensor> {
let _enter = self.span.enter();
let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let input_embeds = match dtype {
+ None => input_embeds,
+ Some(dtype) => input_embeds.to_dtype(dtype)?,
+ };
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter_mut() {
@@ -729,6 +769,11 @@ impl T5EncoderModel {
self.encoder.forward(input_ids, None)
}
+ pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward_dt(input_ids, None, dtype)
+ }
+
pub fn device(&self) -> &Device {
&self.device
}