summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-25 10:31:53 +0100
committerGitHub <noreply@github.com>2023-09-25 10:31:53 +0100
commitdc47224ab9d34c8f4ea0e6ce87d964a030eae89c (patch)
tree147fb6b3e9e61740604726791934079b74f03892 /candle-core/src
parent1ce7fe25436db6afc1076e5e2a678ec6e129d95f (diff)
downloadcandle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.tar.gz
candle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.tar.bz2
candle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.zip
Override the default cudnn heuristics. (#957)
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/conv.rs16
-rw-r--r--candle-core/src/cudnn.rs18
2 files changed, 33 insertions, 1 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 1f3ef582..f92c05b2 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -25,6 +25,20 @@ impl ParamsConv1D {
}
}
+#[allow(unused)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum CudnnFwdAlgo {
+ ImplicitGemm,
+ ImplicitPrecompGemm,
+ Gemm,
+ Direct,
+ Fft,
+ FftTiling,
+ Winograd,
+ WinogradNonFused,
+ Count,
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
@@ -37,6 +51,7 @@ pub struct ParamsConv2D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
+ pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv2D {
@@ -188,6 +203,7 @@ impl Tensor {
padding,
stride,
dilation,
+ cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index dd466ba2..0c149cd0 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
+ use crate::conv::CudnnFwdAlgo as CandleAlgo;
+ use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
+
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
@@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
w: &w,
y: &y,
};
- let alg = conv2d.pick_algorithm()?;
+ let alg = match params.cudnn_fwd_algo {
+ None => conv2d.pick_algorithm()?,
+ Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
+ Some(CandleAlgo::ImplicitPrecompGemm) => {
+ A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
+ }
+ Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
+ Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
+ Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+ Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
+ Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+ Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
+ Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
+ };
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe {