diff options
-rw-r--r-- | candle-core/src/cuda_backend/cudnn.rs | 11 | ||||
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 13 |
2 files changed, 14 insertions, 10 deletions
diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs index d604863d..f5b4db90 100644 --- a/candle-core/src/cuda_backend/cudnn.rs +++ b/candle-core/src/cuda_backend/cudnn.rs @@ -26,6 +26,7 @@ impl From<cudarc::driver::DriverError> for crate::Error { pub(crate) fn launch_conv2d< T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType, + Y: cudarc::cudnn::CudnnDataType, >( src: &CudaView<T>, src_l: &crate::Layout, @@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d< } c })?; - let conv = cudnn.create_conv2d::<T>( + let conv = cudnn.create_conv2d::<Y>( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], /* dilation */ [params.dilation as i32, params.dilation as i32], @@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d< ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { - cudnn.create_4d_tensor( + cudnn.create_4d_tensor::<T>( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, x_shape, )? } else { let s = src_l.stride(); - cudnn.create_4d_tensor_ex( + cudnn.create_4d_tensor_ex::<T>( x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32], )? }; - let w = cudnn.create_4d_filter( + let w = cudnn.create_4d_filter::<T>( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [ params.c_out as i32, @@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d< ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); - let y = cudnn.create_4d_tensor( + let y = cudnn.create_4d_tensor::<T>( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 07bb1785..f14e00d5 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1522,7 +1522,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?; - crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::U8(out) } @@ -1530,7 +1530,10 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?; - crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device) + // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" + // version. + // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 + crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::BF16(out) } @@ -1538,7 +1541,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?; - crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F16(out) } @@ -1546,7 +1549,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?; - crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F32(out) } @@ -1554,7 +1557,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?; - crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F64(out) } |