summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 19:18:49 +0100
committerGitHub <noreply@github.com>2023-07-13 19:18:49 +0100
commit3c02ea56b0dcc39c30ad0d41d942384cc28f65c2 (patch)
treeafb9203915f9ab71af7fecb1917a66a8445eea66
parentded93a116983da7c84ea224a6191bcbc3a7fdef1 (diff)
downloadcandle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.tar.gz
candle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.tar.bz2
candle-3c02ea56b0dcc39c30ad0d41d942384cc28f65c2.zip
Add a cli argument to easily switch the dtype. (#161)
-rw-r--r--candle-examples/examples/falcon/main.rs16
-rw-r--r--candle-examples/examples/llama/main.rs13
2 files changed, 17 insertions, 12 deletions
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index 4757d2b1..5cc7b065 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -14,11 +14,6 @@ use tokenizers::Tokenizer;
mod model;
use model::{Config, Falcon};
-#[cfg(feature = "mkl")]
-const DTYPE: DType = DType::F32;
-#[cfg(not(feature = "mkl"))]
-const DTYPE: DType = DType::BF16;
-
struct TextGeneration {
model: Falcon,
device: Device,
@@ -99,6 +94,10 @@ struct Args {
#[arg(long)]
prompt: String,
+ /// Use f32 computations rather than bf16.
+ #[arg(long)]
+ use_f32: bool,
+
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
@@ -151,7 +150,12 @@ fn main() -> Result<()> {
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
+ let dtype = if args.use_f32 {
+ DType::F32
+ } else {
+ DType::BF16
+ };
+ let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(vb, config)?;
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 301b870a..7ba87c70 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -24,10 +24,6 @@ mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
-#[cfg(feature = "mkl")]
-const DTYPE: DType = DType::F32;
-#[cfg(not(feature = "mkl"))]
-const DTYPE: DType = DType::F16;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
@@ -127,6 +123,10 @@ struct Args {
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
+
+ /// Use f32 computations rather than f16.
+ #[arg(long)]
+ use_f32: bool,
}
fn main() -> Result<()> {
@@ -140,9 +140,10 @@ fn main() -> Result<()> {
};
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
+ let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let (llama, tokenizer_filename) = match args.npy {
Some(filename) => {
- let vb = VarBuilder::from_npz(filename, DTYPE, &device)?;
+ let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer)
}
@@ -170,7 +171,7 @@ fn main() -> Result<()> {
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device);
+ let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
}
};