summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/bert/README.md44
-rw-r--r--candle-examples/examples/bert/main.rs11
-rw-r--r--candle-examples/examples/mixtral/README.md25
-rw-r--r--candle-examples/examples/mixtral/main.rs263
-rw-r--r--candle-examples/examples/phi/README.md27
-rw-r--r--candle-examples/examples/phi/main.rs49
-rw-r--r--candle-examples/examples/quantized/README.md13
-rw-r--r--candle-examples/examples/quantized/main.rs103
-rw-r--r--candle-examples/examples/reinforcement-learning/atari_wrappers.py2
-rw-r--r--candle-examples/examples/stable-diffusion/README.md22
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs121
11 files changed, 601 insertions, 79 deletions
diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md
index 82ca5f40..5a75b516 100644
--- a/candle-examples/examples/bert/README.md
+++ b/candle-examples/examples/bert/README.md
@@ -2,10 +2,10 @@
Bert is a general large language model. In this example it can be used for two
different tasks:
+
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
-
## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights
@@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 384], f32]
```
+### Custom models
+
+You can specify different models, such as BGE, with the `--model-id` flag:
+
+```bash
+cargo run --example bert --release -- \
+--model-id BAAI/bge-large-zh-v1.5 \
+--prompt "Here is a test sentence"
+Loaded and encoded 435.70775ms
+[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
+ [-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
+ [ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
+ ...
+ [ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
+ [ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
+ [ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
+Tensor[[1, 9, 1024], f32]
+Took 176.744667ms
+```
+
+### Gelu approximation
+
+You can get a speedup by using an approximation of the gelu activation, with a
+small loss of precision, by passing the `--approximate-gelu` flag:
+
+```bash
+$ cargo run --example bert --release -- \
+--model-id BAAI/bge-large-zh-v1.5 \
+--prompt "Here is a test sentence" \
+--approximate-gelu
+Loaded and encoded 244.388042ms
+[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
+ [-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
+ [ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
+ ...
+ [ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
+ [ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
+ [ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
+Tensor[[1, 9, 1024], f32]
+Took 116.840791ms
+```
+
## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index fcd2eab9..88e29718 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -3,7 +3,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_transformers::models::bert::{BertModel, Config, DTYPE};
+use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use anyhow::{Error as E, Result};
use candle::Tensor;
@@ -45,6 +45,10 @@ struct Args {
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
+
+ /// Use tanh based approximation for Gelu instead of erf implementation.
+ #[arg(long, default_value = "false")]
+ approximate_gelu: bool,
}
impl Args {
@@ -73,7 +77,7 @@ impl Args {
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
- let config: Config = serde_json::from_str(&config)?;
+ let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
@@ -81,6 +85,9 @@ impl Args {
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
+ if self.approximate_gelu {
+ config.hidden_act = HiddenAct::GeluApproximate;
+ }
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
diff --git a/candle-examples/examples/mixtral/README.md b/candle-examples/examples/mixtral/README.md
new file mode 100644
index 00000000..aec5c148
--- /dev/null
+++ b/candle-examples/examples/mixtral/README.md
@@ -0,0 +1,25 @@
+# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
+
+Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
+
+- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
+- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
+
+## Running the example
+
+```bash
+$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
+def print_prime(n): # n is the number of prime numbers to be printed
+ i = 2
+ count = 0
+ while (count < n):
+ if (isPrime(i)):
+ print(i)
+ count += 1
+ i += 1
+
+def isPrime(n):
+ for x in range(2, int(n**0.5)+1):
+ if (n % x == 0):
+ ...
+```
diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs
new file mode 100644
index 00000000..fcde03c1
--- /dev/null
+++ b/candle-examples/examples/mixtral/main.rs
@@ -0,0 +1,263 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::{Error as E, Result};
+use clap::Parser;
+
+use candle_transformers::models::mixtral::{Config, Model};
+
+use candle::{DType, Device, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+struct TextGeneration {
+ model: Model,
+ device: Device,
+ tokenizer: TokenOutputStream,
+ logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
+impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ tokenizer: Tokenizer,
+ seed: u64,
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ device: &Device,
+ ) -> Self {
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ Self {
+ model,
+ tokenizer: TokenOutputStream::new(tokenizer),
+ logits_processor,
+ repeat_penalty,
+ repeat_last_n,
+ device: device.clone(),
+ }
+ }
+
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
+ use std::io::Write;
+ self.tokenizer.clear();
+ let mut tokens = self
+ .tokenizer
+ .tokenizer()
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ for &t in tokens.iter() {
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
+
+ let mut generated_tokens = 0usize;
+ let eos_token = match self.tokenizer.get_token("</s>") {
+ Some(token) => token,
+ None => anyhow::bail!("cannot find the </s> token"),
+ };
+ let start_gen = std::time::Instant::now();
+ for index in 0..sample_len {
+ let context_size = if index > 0 { 1 } else { tokens.len() };
+ let start_pos = tokens.len().saturating_sub(context_size);
+ let ctxt = &tokens[start_pos..];
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
+ let logits = self.model.forward(&input, start_pos)?;
+ let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
+
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == eos_token {
+ break;
+ }
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ }
+ let dt = start_gen.elapsed();
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ use_flash_attn: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long)]
+ temperature: Option<f64>,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, short = 'n', default_value_t = 100)]
+ sample_len: usize,
+
+ #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
+ model_id: String,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature.unwrap_or(0.),
+ args.repeat_penalty,
+ args.repeat_last_n
+ );
+
+ let start = std::time::Instant::now();
+ let api = Api::new()?;
+ let repo = api.repo(Repo::with_revision(
+ args.model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => {
+ vec![
+ repo.get("model-00001-of-00019.safetensors")?,
+ repo.get("model-00002-of-00019.safetensors")?,
+ repo.get("model-00003-of-00019.safetensors")?,
+ repo.get("model-00004-of-00019.safetensors")?,
+ repo.get("model-00005-of-00019.safetensors")?,
+ repo.get("model-00006-of-00019.safetensors")?,
+ repo.get("model-00007-of-00019.safetensors")?,
+ repo.get("model-00008-of-00019.safetensors")?,
+ repo.get("model-00009-of-00019.safetensors")?,
+ repo.get("model-00010-of-00019.safetensors")?,
+ repo.get("model-00011-of-00019.safetensors")?,
+ repo.get("model-00012-of-00019.safetensors")?,
+ repo.get("model-00013-of-00019.safetensors")?,
+ repo.get("model-00014-of-00019.safetensors")?,
+ repo.get("model-00015-of-00019.safetensors")?,
+ repo.get("model-00016-of-00019.safetensors")?,
+ repo.get("model-00017-of-00019.safetensors")?,
+ repo.get("model-00018-of-00019.safetensors")?,
+ repo.get("model-00019-of-00019.safetensors")?,
+ ]
+ }
+ };
+ println!("retrieved the files in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let start = std::time::Instant::now();
+ let config = Config::v0_1_8x7b(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
+ let dtype = if device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
+ let model = Model::new(&config, vb)?;
+ println!("loaded the model in {:?}", start.elapsed());
+
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ &device,
+ );
+ pipeline.run(&args.prompt, args.sample_len)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md
index 566411d1..70af6650 100644
--- a/candle-examples/examples/phi/README.md
+++ b/candle-examples/examples/phi/README.md
@@ -1,14 +1,33 @@
-# candle-phi: 1.3b LLM with state of the art performance for <10b models.
+# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
-[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
-only 1.3 billion parameters but with state of the art performance compared to
+[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
+[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
+only 1.3 and 2.7 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a
quantized variant.
-## Running some example
+## Running some examples
+For the v2 version.
+```bash
+$ cargo run --example phi --release -- --model 2 \
+ --prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
+
+A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?
+
+Solution:
+The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas:
+mgh = 1/2mv^2
+Solving for v, we get:
+v = sqrt(2gh)
+Substituting the given values, we get:
+v = sqrt(2*9.8*40) = 28 m/s
+Therefore, the skier speed at the bottom of the slope is 28 m/s.
+```
+
+For the v1.5 version.
```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 720a4441..52d453b5 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -123,6 +123,8 @@ enum WhichModel {
V1,
#[value(name = "1.5")]
V1_5,
+ #[value(name = "2")]
+ V2,
PuffinPhiV2,
PhiHermes,
}
@@ -158,7 +160,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
@@ -225,6 +227,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
+ WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@@ -241,7 +244,9 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
- WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
+ WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
+ "main".to_string()
+ }
}
}
}
@@ -250,27 +255,32 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
- WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
+ WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
- let filename = match args.weight_file {
- Some(weight_file) => std::path::PathBuf::from(weight_file),
+ let filenames = match args.weight_file {
+ Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => {
if args.quantized {
match args.model {
- WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
- WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
- WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
- WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
+ WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
+ WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
+ WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
+ WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
+ WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
}
} else {
match args.model {
- WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
- WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
- WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
+ WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
+ WhichModel::V2 => vec![
+ repo.get("model-00001-of-00002.safetensors")?,
+ repo.get("model-00002-of-00002.safetensors")?,
+ ],
+ WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
+ WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
}
}
@@ -282,17 +292,24 @@ fn main() -> Result<()> {
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
+ WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
- let model = QMixFormer::new(&config, vb)?;
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let model = match args.model {
+ WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
+ _ => QMixFormer::new(&config, vb)?,
+ };
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = MixFormer::new(&config, vb)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let model = match args.model {
+ WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
+ _ => MixFormer::new(&config, vb)?,
+ };
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/quantized/README.md b/candle-examples/examples/quantized/README.md
index bed09243..8144bffe 100644
--- a/candle-examples/examples/quantized/README.md
+++ b/candle-examples/examples/quantized/README.md
@@ -26,6 +26,19 @@ cargo run --example quantized --release -- --prompt "The best thing about coding
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
```
+Using the mixtral sparse mixture of expert model:
+```bash
+
+$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because "
+> avx: true, neon: false, simd128: false, f16c: true
+> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
+> loaded 995 tensors (26.44GB) in 0.03s
+Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition.
+
+The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral.
+ ```
+
+
## Command-line flags
Run with `--help` to see all options.
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index ab8a56ba..df758b4f 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -45,6 +45,10 @@ enum Which {
L13bCode,
#[value(name = "32b-code")]
L34bCode,
+ #[value(name = "7b-leo")]
+ Leo7b,
+ #[value(name = "13b-leo")]
+ Leo13b,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
@@ -55,6 +59,12 @@ enum Which {
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
+ #[value(name = "7b-starling-a")]
+ Starling7bAlpha,
+ #[value(name = "mixtral")]
+ Mixtral,
+ #[value(name = "mixtral-instruct")]
+ MixtralInstruct,
}
impl Which {
@@ -68,12 +78,17 @@ impl Which {
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
- | Self::L34bCode => false,
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
- // same way.
+ // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
+ | Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
+ | Self::Mixtral
+ | Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
@@ -90,15 +105,43 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
+ | Self::Mixtral
+ | Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
- | Self::OpenChat35 => false,
+ | Self::OpenChat35
+ | Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
fn is_open_chat(&self) -> bool {
match self {
+ Self::L7b
+ | Self::L13b
+ | Self::L70b
+ | Self::L7bChat
+ | Self::L13bChat
+ | Self::L70bChat
+ | Self::L7bCode
+ | Self::L13bCode
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
+ | Self::Mixtral
+ | Self::MixtralInstruct
+ | Self::Mistral7b
+ | Self::Mistral7bInstruct
+ | Self::Zephyr7bAlpha
+ | Self::Zephyr7bBeta => false,
+ Self::OpenChat35 | Self::Starling7bAlpha => true,
+ }
+ }
+
+ fn tokenizer_repo(&self) -> &'static str {
+ match self {
Which::L7b
| Which::L13b
| Which::L70b
@@ -107,12 +150,17 @@ impl Which {
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode
- | Which::Mistral7b
+ | Which::L34bCode => "hf-internal-testing/llama-tokenizer",
+ Which::Leo7b => "LeoLM/leo-hessianai-7b",
+ Which::Leo13b => "LeoLM/leo-hessianai-13b",
+ Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
+ Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
- | Which::Zephyr7bBeta => false,
- Which::OpenChat35 => true,
+ | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
+ Which::OpenChat35 => "openchat/openchat_3.5",
+ Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
}
}
}
@@ -181,13 +229,7 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
- let repo = if self.which.is_open_chat() {
- "openchat/openchat_3.5"
- } else if self.which.is_mistral() {
- "mistralai/Mistral-7B-v0.1"
- } else {
- "hf-internal-testing/llama-tokenizer"
- };
+ let repo = self.which.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
@@ -218,6 +260,22 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
+ Which::Leo7b => (
+ "TheBloke/leo-hessianai-7B-GGUF",
+ "leo-hessianai-7b.Q4_K_M.gguf",
+ ),
+ Which::Leo13b => (
+ "TheBloke/leo-hessianai-13B-GGUF",
+ "leo-hessianai-13b.Q4_K_M.gguf",
+ ),
+ Which::Mixtral => (
+ "TheBloke/Mixtral-8x7B-v0.1-GGUF",
+ "mixtral-8x7b-v0.1.Q4_K_M.gguf",
+ ),
+ Which::MixtralInstruct => (
+ "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
+ "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
+ ),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
@@ -234,6 +292,10 @@ impl Args {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
+ Which::Starling7bAlpha => (
+ "TheBloke/Starling-LM-7B-alpha-GGUF",
+ "starling-lm-7b-alpha.Q4_K_M.gguf",
+ ),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@@ -329,14 +391,19 @@ fn main() -> anyhow::Result<()> {
| Which::L13bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode => 1,
- Which::Mistral7b
+ | Which::L34bCode
+ | Which::Leo7b
+ | Which::Leo13b => 1,
+ Which::Mixtral
+ | Which::MixtralInstruct
+ | Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat
- | Which::OpenChat35 => 8,
+ | Which::OpenChat35
+ | Which::Starling7bAlpha => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@@ -369,7 +436,7 @@ fn main() -> anyhow::Result<()> {
}
}
if args.which.is_open_chat() {
- format!("User: {prompt}<|end_of_turn|>Assistant: ")
+ format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
index b5c4665d..b76fb85d 100644
--- a/candle-examples/examples/reinforcement-learning/atari_wrappers.py
+++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
@@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper):
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
- # for Qbert somtimes we stay in lives == 0 condtion for a few frames
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md
index b8736a2a..feb7ca56 100644
--- a/candle-examples/examples/stable-diffusion/README.md
+++ b/candle-examples/examples/stable-diffusion/README.md
@@ -8,7 +8,7 @@ XL using Rust and [candle](https://github.com/huggingface/candle).
The `stable-diffusion` example is a conversion of
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
-as well as Stable Diffusion XL 1.0.
+as well as Stable Diffusion XL 1.0, and Turbo.
## Getting the weights
@@ -23,16 +23,26 @@ cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
```
-The final image is named `sd_final.png` by default.
-The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
-original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
+The final image is named `sd_final.png` by default. The Turbo version is much
+faster than previous versions, to give it a try add a `--sd-version turbo` flag,
+e.g.:
+
+```bash
+cargo run --example stable-diffusion --release --features=cuda,cudnn \
+ -- --prompt "a cosmonaut on a horse (hd, realistic, high-def) --sd-version turbo"
+```
+
+The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
+Diffusion Implicit Model scheduler (DDIM). The original paper and some code can
+be found in the [associated repo](https://github.com/ermongroup/ddim).
+The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
### Command-line flags
- `--prompt`: the prompt to be used to generate the image.
- `--uncond-prompt`: the optional unconditional prompt.
-- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
- `xl`.
+- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`,
+ `xl`, or `turbo`.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 3e6de34d..8c3ca2ee 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
-const GUIDANCE_SCALE: f64 = 7.5;
-
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -63,8 +61,8 @@ struct Args {
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
- #[arg(long, default_value_t = 30)]
- n_steps: usize,
+ #[arg(long)]
+ n_steps: Option<usize>,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
@@ -87,6 +85,9 @@ struct Args {
#[arg(long)]
use_f16: bool,
+ #[arg(long)]
+ guidance_scale: Option<f64>,
+
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
@@ -102,6 +103,7 @@ enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
+ Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -120,12 +122,13 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
+ Self::Turbo => "stabilityai/sdxl-turbo",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -137,7 +140,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -149,7 +152,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -161,7 +164,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
@@ -189,7 +192,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
- StableDiffusionVersion::Xl => {
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
@@ -206,7 +209,11 @@ impl ModelFile {
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
- if version == StableDiffusionVersion::Xl && use_f16 {
+ if matches!(
+ version,
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
+ ) && use_f16
+ {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
@@ -261,6 +268,7 @@ fn text_embeddings(
use_f16: bool,
device: &Device,
dtype: DType,
+ use_guide_scale: bool,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
@@ -285,16 +293,6 @@ fn text_embeddings(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
-
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
@@ -310,8 +308,24 @@ fn text_embeddings(
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+
+ let text_embeddings = if use_guide_scale {
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+
+ Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ } else {
+ text_embeddings.to_dtype(dtype)?
+ };
Ok(text_embeddings)
}
@@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
unet_weights,
tracing,
use_f16,
+ guidance_scale,
use_flash_attn,
img2img,
img2img_strength,
@@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
None
};
+ let guidance_scale = match guidance_scale {
+ Some(guidance_scale) => guidance_scale,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 7.5,
+ StableDiffusionVersion::Turbo => 0.,
+ },
+ };
+ let n_steps = match n_steps {
+ Some(n_steps) => n_steps,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 30,
+ StableDiffusionVersion::Turbo => 1,
+ },
+ };
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
@@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
+ sliced_attention_size,
+ height,
+ width,
+ ),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
+ let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version {
- StableDiffusionVersion::Xl => vec![true, false],
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
@@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
use_f16,
&device,
dtype,
+ use_guide_scale,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
+
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
@@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
0
};
let bsize = 1;
+
+ let vae_scale = match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 0.18215,
+ StableDiffusionVersion::Turbo => 0.13025,
+ };
+
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
- let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
+ let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
@@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
continue;
}
let start_time = std::time::Instant::now();
- let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+ let latent_model_input = if use_guide_scale {
+ Tensor::cat(&[&latents, &latents], 0)?
+ } else {
+ latents.clone()
+ };
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
- let noise_pred = noise_pred.chunk(2, 0)?;
- let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
- let noise_pred =
- (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
+
+ let noise_pred = if use_guide_scale {
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
+
+ (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
+ } else {
+ noise_pred
+ };
+
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
@@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);