summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-11 19:57:06 +0200
committerGitHub <noreply@github.com>2023-08-11 18:57:06 +0100
commit1d0157bbc4807f993cecc0de7dbbe0f305a68cd4 (patch)
tree492d559f4b09b5332127a1af27be49be62ab5d93 /candle-examples/examples/stable-diffusion
parent91dbf907d3ee45dd4777efa82c1f431907ce8125 (diff)
downloadcandle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.gz
candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.bz2
candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.zip
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example. * Add to the readme.
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs91
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs14
2 files changed, 71 insertions, 34 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index ac31e855..5ec40f7d 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -45,21 +45,21 @@ struct Args {
#[arg(long)]
width: Option<usize>,
- /// The UNet weight file, in .ot or .safetensors format.
+ /// The UNet weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
unet_weights: Option<String>,
- /// The CLIP weight file, in .ot or .safetensors format.
+ /// The CLIP weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
- /// The VAE weight file, in .ot or .safetensors format.
+ /// The VAE weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for tokenization.
- tokenizer: String,
+ tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
@@ -91,34 +91,63 @@ enum StableDiffusionVersion {
V2_1,
}
-impl Args {
- fn clip_weights(&self) -> String {
- match &self.clip_weights {
- Some(w) => w.clone(),
- None => match self.sd_version {
- StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(),
- StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(),
- },
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum ModelFile {
+ Tokenizer,
+ Clip,
+ Unet,
+ Vae,
+}
+
+impl StableDiffusionVersion {
+ fn repo(&self) -> &'static str {
+ match self {
+ Self::V2_1 => "stabilityai/stable-diffusion-2-1",
+ Self::V1_5 => "runwayml/stable-diffusion-v1-5",
+ }
+ }
+
+ fn unet_file(&self) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
}
}
- fn vae_weights(&self) -> String {
- match &self.vae_weights {
- Some(w) => w.clone(),
- None => match self.sd_version {
- StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(),
- StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(),
- },
+ fn vae_file(&self) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
}
}
- fn unet_weights(&self) -> String {
- match &self.unet_weights {
- Some(w) => w.clone(),
- None => match self.sd_version {
- StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(),
- StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(),
- },
+ fn clip_file(&self) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
+ }
+ }
+}
+
+impl ModelFile {
+ const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
+ const TOKENIZER_PATH: &str = "tokenizer.json";
+
+ fn get(
+ &self,
+ filename: Option<String>,
+ version: StableDiffusionVersion,
+ ) -> Result<std::path::PathBuf> {
+ use hf_hub::api::sync::Api;
+ match filename {
+ Some(filename) => Ok(std::path::PathBuf::from(filename)),
+ None => {
+ let (repo, path) = match self {
+ Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
+ Self::Clip => (version.repo(), version.clip_file()),
+ Self::Unet => (version.repo(), version.unet_file()),
+ Self::Vae => (version.repo(), version.vae_file()),
+ };
+ let filename = Api::new()?.model(repo.to_string()).get(path)?;
+ Ok(filename)
+ }
}
}
}
@@ -151,9 +180,6 @@ fn output_filename(
}
fn run(args: Args) -> Result<()> {
- let clip_weights = args.clip_weights();
- let vae_weights = args.vae_weights();
- let unet_weights = args.unet_weights();
let Args {
prompt,
uncond_prompt,
@@ -166,6 +192,9 @@ fn run(args: Args) -> Result<()> {
sliced_attention_size,
num_samples,
sd_version,
+ clip_weights,
+ vae_weights,
+ unet_weights,
..
} = args;
let sd_config = match sd_version {
@@ -180,6 +209,7 @@ fn run(args: Args) -> Result<()> {
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
+ let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
@@ -207,14 +237,17 @@ fn run(args: Args) -> Result<()> {
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
+ let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
println!("Building the autoencoder.");
+ let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
+ let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
let bsize = 1;
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index c250ed56..023d8630 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -172,7 +172,11 @@ impl StableDiffusionConfig {
)
}
- pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
+ pub fn build_vae<P: AsRef<std::path::Path>>(
+ &self,
+ vae_weights: P,
+ device: &Device,
+ ) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
@@ -181,9 +185,9 @@ impl StableDiffusionConfig {
Ok(autoencoder)
}
- pub fn build_unet(
+ pub fn build_unet<P: AsRef<std::path::Path>>(
&self,
- unet_weights: &str,
+ unet_weights: P,
device: &Device,
in_channels: usize,
) -> Result<unet_2d::UNet2DConditionModel> {
@@ -198,9 +202,9 @@ impl StableDiffusionConfig {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
- pub fn build_clip_transformer(
+ pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
- clip_weights: &str,
+ clip_weights: P,
device: &Device,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };