summaryrefslogtreecommitdiff
path: root/candle-examples/examples/gemma/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/gemma/main.rs')
-rw-r--r--candle-examples/examples/gemma/main.rs5
1 files changed, 4 insertions, 1 deletions
diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs
index a5f7d591..31c55618 100644
--- a/candle-examples/examples/gemma/main.rs
+++ b/candle-examples/examples/gemma/main.rs
@@ -193,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
+
+ #[arg(long)]
+ use_flash_attn: bool,
}
fn main() -> Result<()> {
@@ -270,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
- let model = Model::new(&config, vb)?;
+ let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());