diff options
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 1f3ef582..f92c05b2 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,6 +25,20 @@ impl ParamsConv1D { } } +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CudnnFwdAlgo { + ImplicitGemm, + ImplicitPrecompGemm, + Gemm, + Direct, + Fft, + FftTiling, + Winograd, + WinogradNonFused, + Count, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv2D { pub(crate) b_size: usize, @@ -37,6 +51,7 @@ pub struct ParamsConv2D { pub(crate) padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, + pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>, } impl ParamsConv2D { @@ -188,6 +203,7 @@ impl Tensor { padding, stride, dilation, + cudnn_fwd_algo: None, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) |