summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend/cudnn.rs11
-rw-r--r--candle-core/src/cuda_backend/mod.rs13
2 files changed, 14 insertions, 10 deletions
diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs
index d604863d..f5b4db90 100644
--- a/candle-core/src/cuda_backend/cudnn.rs
+++ b/candle-core/src/cuda_backend/cudnn.rs
@@ -26,6 +26,7 @@ impl From<cudarc::driver::DriverError> for crate::Error {
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
+ Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
@@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d<
}
c
})?;
- let conv = cudnn.create_conv2d::<T>(
+ let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
@@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d<
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
- cudnn.create_4d_tensor(
+ cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
- cudnn.create_4d_tensor_ex(
+ cudnn.create_4d_tensor_ex::<T>(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
- let w = cudnn.create_4d_filter(
+ let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
@@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d<
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
- let y = cudnn.create_4d_tensor(
+ let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 07bb1785..f14e00d5 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -1522,7 +1522,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
- crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
+ crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
@@ -1530,7 +1530,10 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
- crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
+ // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
+ // version.
+ // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
+ crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
@@ -1538,7 +1541,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
- crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
+ crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
@@ -1546,7 +1549,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
- crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
+ crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
@@ -1554,7 +1557,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
- crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
+ crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}