diff options
Diffstat (limited to 'candle-examples/examples/gemma/main.rs')
-rw-r--r-- | candle-examples/examples/gemma/main.rs | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index a5f7d591..31c55618 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -193,6 +193,9 @@ struct Args { /// The model to use. #[arg(long, default_value = "2b")] which: Which, + + #[arg(long)] + use_flash_attn: bool, } fn main() -> Result<()> { @@ -270,7 +273,7 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + let model = Model::new(args.use_flash_attn, &config, vb)?; println!("loaded the model in {:?}", start.elapsed()); |