diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/main.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/main.rs | 44 |
1 files changed, 37 insertions, 7 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index d0bf4bb8..31d3fc42 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -19,13 +19,15 @@ enum Which { V3_5Large, #[value(name = "3.5-large-turbo")] V3_5LargeTurbo, + #[value(name = "3.5-medium")] + V3_5Medium, } impl Which { fn is_3_5(&self) -> bool { match self { Self::V3Medium => false, - Self::V3_5Large | Self::V3_5LargeTurbo => true, + Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true, } } } @@ -117,36 +119,59 @@ fn main() -> Result<()> { let default_inference_steps = match which { Which::V3_5Large => 28, Which::V3_5LargeTurbo => 4, + Which::V3_5Medium => 28, Which::V3Medium => 28, }; let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps); let default_cfg_scale = match which { Which::V3_5Large => 4.0, Which::V3_5LargeTurbo => 1.0, + Which::V3_5Medium => 4.0, Which::V3Medium => 4.0, }; let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale); let api = hf_hub::api::sync::Api::new()?; let (mmdit_config, mut triple, vb) = if which.is_3_5() { - let sai_repo = { + let sai_repo_for_text_encoders = { + let name = match which { + Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", + Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + + // Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually + // placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo. + // To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors + // under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions + // to get the monolithic text encoders. This is not a trivial task. + // Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium, + // which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors. + // so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository. + // TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders. + Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large", + Which::V3Medium => unreachable!(), + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let sai_repo_for_mmdit = { let name = match which { Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium", Which::V3Medium => unreachable!(), }; api.repo(hf_hub::Repo::model(name.to_string())) }; - let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?; - let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?; - let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?; + let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?; + let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?; + let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?; let model_file = { let model_file = match which { Which::V3_5Large => "sd3.5_large.safetensors", Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors", + Which::V3_5Medium => "sd3.5_medium.safetensors", Which::V3Medium => unreachable!(), }; - sai_repo.get(model_file)? + sai_repo_for_mmdit.get(model_file)? }; let triple = StableDiffusion3TripleClipWithTokenizer::new_split( &clip_g_file, @@ -157,7 +182,12 @@ fn main() -> Result<()> { let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)? }; - (MMDiTConfig::sd3_5_large(), triple, vb) + match which { + Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb), + Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb), + Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb), + Which::V3Medium => unreachable!(), + } } else { let sai_repo = { let name = "stabilityai/stable-diffusion-3-medium"; |