diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-25 10:31:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-25 10:31:53 +0100 |
commit | dc47224ab9d34c8f4ea0e6ce87d964a030eae89c (patch) | |
tree | 147fb6b3e9e61740604726791934079b74f03892 /candle-core/src | |
parent | 1ce7fe25436db6afc1076e5e2a678ec6e129d95f (diff) | |
download | candle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.tar.gz candle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.tar.bz2 candle-dc47224ab9d34c8f4ea0e6ce87d964a030eae89c.zip |
Override the default cudnn heuristics. (#957)
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/conv.rs | 16 | ||||
-rw-r--r-- | candle-core/src/cudnn.rs | 18 |
2 files changed, 33 insertions, 1 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 1f3ef582..f92c05b2 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,6 +25,20 @@ impl ParamsConv1D { } } +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CudnnFwdAlgo { + ImplicitGemm, + ImplicitPrecompGemm, + Gemm, + Direct, + Fft, + FftTiling, + Winograd, + WinogradNonFused, + Count, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv2D { pub(crate) b_size: usize, @@ -37,6 +51,7 @@ pub struct ParamsConv2D { pub(crate) padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, + pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>, } impl ParamsConv2D { @@ -188,6 +203,7 @@ impl Tensor { padding, stride, dilation, + cudnn_fwd_algo: None, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index dd466ba2..0c149cd0 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d< params: &crate::conv::ParamsConv2D, dev: &crate::cuda_backend::CudaDevice, ) -> crate::Result<()> { + use crate::conv::CudnnFwdAlgo as CandleAlgo; + use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A; + let device_id = dev.id(); let cudnn = CUDNN.with(|cudnn| { if let Some(cudnn) = cudnn.borrow().get(&device_id) { @@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d< w: &w, y: &y, }; - let alg = conv2d.pick_algorithm()?; + let alg = match params.cudnn_fwd_algo { + None => conv2d.pick_algorithm()?, + Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + Some(CandleAlgo::ImplicitPrecompGemm) => { + A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + } + Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT, + Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT, + }; let workspace_size = conv2d.get_workspace_size(alg)?; let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?; unsafe { |