diff options
author | Santiago Medina <santiagm08@gmail.com> | 2024-04-04 22:03:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 07:03:33 +0200 |
commit | ace282e5c2ef24ca2fb90683babb852936d4df17 (patch) | |
tree | e61e2223906b6f0cbf37c5b3af30312e4c174293 | |
parent | c87381fc9643ca15648c2e8379e44a596ba1854b (diff) | |
download | candle-ace282e5c2ef24ca2fb90683babb852936d4df17.tar.gz candle-ace282e5c2ef24ca2fb90683babb852936d4df17.tar.bz2 candle-ace282e5c2ef24ca2fb90683babb852936d4df17.zip |
Add flag to run Moondream in f16 precision (#2015)
* moondream implementation
* add moondream example
* change config default activation
* Add assets and integrate phi mixformer with example
* Make use of kv cache and fix seq_len bug; Clean up example code
* Add README link to example
* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig
* Delete image
* Use apply instead of forward
* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2
* Add flag to use f16
* Avoid breaking the quantized version on cuda.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-examples/examples/moondream/main.rs | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index dfd83037..c7500ed9 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -194,6 +194,10 @@ struct Args { #[arg(long)] quantized: bool, + /// Use f16 precision for all the computations rather than f32. + #[arg(long)] + f16: bool, + #[arg(long)] model_file: Option<String>, @@ -283,7 +287,12 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; let config = moondream::Config::v2(); - let dtype = if device.is_cuda() && !args.quantized { + let dtype = if args.quantized { + if args.f16 { + anyhow::bail!("Quantized model does not support f16"); + } + DType::F32 + } else if device.is_cuda() || args.f16 { DType::F16 } else { DType::F32 |