summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSantiago Medina <santiagm08@gmail.com>2024-04-04 22:03:33 -0700
committerGitHub <noreply@github.com>2024-04-05 07:03:33 +0200
commitace282e5c2ef24ca2fb90683babb852936d4df17 (patch)
treee61e2223906b6f0cbf37c5b3af30312e4c174293
parentc87381fc9643ca15648c2e8379e44a596ba1854b (diff)
downloadcandle-ace282e5c2ef24ca2fb90683babb852936d4df17.tar.gz
candle-ace282e5c2ef24ca2fb90683babb852936d4df17.tar.bz2
candle-ace282e5c2ef24ca2fb90683babb852936d4df17.zip
Add flag to run Moondream in f16 precision (#2015)
* moondream implementation * add moondream example * change config default activation * Add assets and integrate phi mixformer with example * Make use of kv cache and fix seq_len bug; Clean up example code * Add README link to example * Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig * Delete image * Use apply instead of forward * Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2 * Add flag to use f16 * Avoid breaking the quantized version on cuda. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/moondream/main.rs11
1 files changed, 10 insertions, 1 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs
index dfd83037..c7500ed9 100644
--- a/candle-examples/examples/moondream/main.rs
+++ b/candle-examples/examples/moondream/main.rs
@@ -194,6 +194,10 @@ struct Args {
#[arg(long)]
quantized: bool,
+ /// Use f16 precision for all the computations rather than f32.
+ #[arg(long)]
+ f16: bool,
+
#[arg(long)]
model_file: Option<String>,
@@ -283,7 +287,12 @@ async fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
- let dtype = if device.is_cuda() && !args.quantized {
+ let dtype = if args.quantized {
+ if args.f16 {
+ anyhow::bail!("Quantized model does not support f16");
+ }
+ DType::F32
+ } else if device.is_cuda() || args.f16 {
DType::F16
} else {
DType::F32