diff options
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 7 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 6 |
2 files changed, 10 insertions, 3 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index a972279c..c00af3fe 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -196,6 +196,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// Use the slower dmmv cuda kernel. + #[arg(long)] + force_dmmv: bool, } fn main() -> Result<()> { @@ -203,6 +207,9 @@ fn main() -> Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); + let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 3cabc3a4..b03768ed 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -236,9 +236,9 @@ struct Args { #[arg(long)] gqa: Option<usize>, - /// Use the (experimental) fast cuda kernels. + /// Use the slower dmmv cuda kernel. #[arg(long)] - fast_cuda: bool, + force_dmmv: bool, } impl Args { @@ -347,7 +347,7 @@ fn main() -> anyhow::Result<()> { let args = Args::parse(); #[cfg(feature = "cuda")] - candle::quantized::cuda::set_force_dmmv(!args.fast_cuda); + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); let temperature = if args.temperature == 0. { None |