summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/cuda.rs2
-rw-r--r--candle-examples/examples/mistral/main.rs7
-rw-r--r--candle-examples/examples/quantized/main.rs6
3 files changed, 11 insertions, 4 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index a8f0d622..64404beb 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -10,7 +10,7 @@ pub struct QCudaStorage {
device: CudaDevice,
}
-static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true);
+static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index a972279c..c00af3fe 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -196,6 +196,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
+
+ /// Use the slower dmmv cuda kernel.
+ #[arg(long)]
+ force_dmmv: bool,
}
fn main() -> Result<()> {
@@ -203,6 +207,9 @@ fn main() -> Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
+ #[cfg(feature = "cuda")]
+ candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
+
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 3cabc3a4..b03768ed 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -236,9 +236,9 @@ struct Args {
#[arg(long)]
gqa: Option<usize>,
- /// Use the (experimental) fast cuda kernels.
+ /// Use the slower dmmv cuda kernel.
#[arg(long)]
- fast_cuda: bool,
+ force_dmmv: bool,
}
impl Args {
@@ -347,7 +347,7 @@ fn main() -> anyhow::Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
- candle::quantized::cuda::set_force_dmmv(!args.fast_cuda);
+ candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
let temperature = if args.temperature == 0. {
None