use crate::WithDType; use cudarc; use cudarc::cudnn::safe::{ConvForward, Cudnn}; use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits}; use std::cell::RefCell; use std::collections::HashMap; use std::sync::Arc; // The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither // send nor sync. thread_local! { static CUDNN: RefCell>> = HashMap::new().into(); } impl From for crate::Error { fn from(err: cudarc::cudnn::CudnnError) -> Self { crate::Error::wrap(err) } } impl From for crate::Error { fn from(err: cudarc::driver::DriverError) -> Self { crate::Error::wrap(err) } } pub(crate) fn launch_conv2d< T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType, Y: cudarc::cudnn::CudnnDataType, >( src: &CudaView, src_l: &crate::Layout, filter: &CudaView, dst: &mut CudaSlice, params: &crate::conv::ParamsConv2D, dev: &crate::cuda_backend::CudaDevice, ) -> crate::Result<()> { use crate::conv::CudnnFwdAlgo as CandleAlgo; use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A; let device_id = dev.id(); let cudnn = CUDNN.with(|cudnn| { if let Some(cudnn) = cudnn.borrow().get(&device_id) { return Ok(cudnn.clone()); } let c = Cudnn::new(dev.cuda_device()); if let Ok(c) = &c { cudnn.borrow_mut().insert(device_id, c.clone()); } c })?; let conv = cudnn.create_conv2d::( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], /* dilation */ [params.dilation as i32, params.dilation as i32], cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, )?; let x_shape = [ params.b_size as i32, params.c_in as i32, params.i_h as i32, params.i_w as i32, ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, x_shape, )? } else { let s = src_l.stride(); cudnn.create_4d_tensor_ex::( x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32], )? }; let w = cudnn.create_4d_filter::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [ params.c_out as i32, params.c_in as i32, params.k_h as i32, params.k_w as i32, ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); let y = cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; let conv2d = ConvForward { conv: &conv, x: &x, w: &w, y: &y, }; let alg = match params.cudnn_fwd_algo { None => conv2d.pick_algorithm()?, Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, Some(CandleAlgo::ImplicitPrecompGemm) => { A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM } Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM, Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT, Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT, }; let workspace_size = conv2d.get_workspace_size(alg)?; let mut workspace = dev.cuda_device().alloc_zeros::(workspace_size)?; unsafe { conv2d.launch::, _, _, _>( alg, Some(&mut workspace), (T::one(), T::zero()), src, filter, dst, )?; } Ok(()) }