diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 19:18:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 19:18:49 +0100 |
commit | 3c02ea56b0dcc39c30ad0d41d942384cc28f65c2 (patch) | |
tree | afb9203915f9ab71af7fecb1917a66a8445eea66 | |
parent | ded93a116983da7c84ea224a6191bcbc3a7fdef1 (diff) | |
download | candle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.tar.gz candle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.tar.bz2 candle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.zip |
Add a cli argument to easily switch the dtype. (#161)
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 13 |
2 files changed, 17 insertions, 12 deletions
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 4757d2b1..5cc7b065 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -14,11 +14,6 @@ use tokenizers::Tokenizer; mod model; use model::{Config, Falcon}; -#[cfg(feature = "mkl")] -const DTYPE: DType = DType::F32; -#[cfg(not(feature = "mkl"))] -const DTYPE: DType = DType::BF16; - struct TextGeneration { model: Falcon, device: Device, @@ -99,6 +94,10 @@ struct Args { #[arg(long)] prompt: String, + /// Use f32 computations rather than bf16. + #[arg(long)] + use_f32: bool, + /// The temperature used to generate samples. #[arg(long)] temperature: Option<f64>, @@ -151,7 +150,12 @@ fn main() -> Result<()> { .map(|f| Ok(f.deserialize()?)) .collect::<Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &device); + let dtype = if args.use_f32 { + DType::F32 + } else { + DType::BF16 + }; + let vb = VarBuilder::from_safetensors(weights, dtype, &device); let config = Config::falcon7b(); config.validate()?; let model = Falcon::load(vb, config)?; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 301b870a..7ba87c70 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -24,10 +24,6 @@ mod model; use model::{Config, Llama}; const MAX_SEQ_LEN: usize = 4096; -#[cfg(feature = "mkl")] -const DTYPE: DType = DType::F32; -#[cfg(not(feature = "mkl"))] -const DTYPE: DType = DType::F16; const DEFAULT_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -127,6 +123,10 @@ struct Args { /// The initial prompt. #[arg(long)] prompt: Option<String>, + + /// Use f32 computations rather than f16. + #[arg(long)] + use_f32: bool, } fn main() -> Result<()> { @@ -140,9 +140,10 @@ fn main() -> Result<()> { }; let config = Config::config_7b(); let cache = model::Cache::new(!args.no_kv_cache, &config, &device); + let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let (llama, tokenizer_filename) = match args.npy { Some(filename) => { - let vb = VarBuilder::from_npz(filename, DTYPE, &device)?; + let vb = VarBuilder::from_npz(filename, dtype, &device)?; let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); (Llama::load(vb, &cache, &config)?, tokenizer) } @@ -170,7 +171,7 @@ fn main() -> Result<()> { .map(|h| Ok(h.deserialize()?)) .collect::<Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device); + let vb = VarBuilder::from_safetensors(tensors, dtype, &device); (Llama::load(vb, &cache, &config)?, tokenizer_filename) } }; |