summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs16
1 files changed, 16 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 1f3ef582..f92c05b2 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -25,6 +25,20 @@ impl ParamsConv1D {
}
}
+#[allow(unused)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum CudnnFwdAlgo {
+ ImplicitGemm,
+ ImplicitPrecompGemm,
+ Gemm,
+ Direct,
+ Fft,
+ FftTiling,
+ Winograd,
+ WinogradNonFused,
+ Count,
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
@@ -37,6 +51,7 @@ pub struct ParamsConv2D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
+ pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv2D {
@@ -188,6 +203,7 @@ impl Tensor {
padding,
stride,
dilation,
+ cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)