summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend/mod.rs4
-rw-r--r--candle-core/src/cuda_backend/mod.rs75
-rw-r--r--candle-kernels/src/conv.cu65
3 files changed, 140 insertions, 4 deletions
diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs
index 299b1e6e..18b73e9b 100644
--- a/candle-core/src/cpu_backend/mod.rs
+++ b/candle-core/src/cpu_backend/mod.rs
@@ -10,7 +10,7 @@ pub use utils::{
};
const USE_IM2COL_CONV1D: bool = true;
-const USE_IM2COL_CONV1D_TR: bool = true;
+const USE_COL2IM_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@@ -2249,7 +2249,7 @@ impl BackendStorage for CpuStorage {
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
- if USE_IM2COL_CONV1D_TR && can_use_col2im {
+ if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 1ea9beaf..88f325f4 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct Col2Im1D {
+ stride: usize,
+}
+
+impl Map1 for Col2Im1D {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ col: &CudaSlice<T>,
+ dev: &CudaDevice,
+ l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
+ let stride = self.stride;
+ let l_out = (l_in - 1) * stride + k_size;
+ let dst_el = b_size * c_out * l_out;
+ let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
+ let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(im)
+ }
+}
+
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
+ const USE_COL2IM_CONV1D_TR: bool = true;
+
let device = self.device().clone();
- let slice =
- ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ let can_use_col2im = kernel_l.is_contiguous()
+ && params.dilation == 1
+ && params.padding == 0
+ && params.output_padding == 0;
+ let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
+ let (b_size, c_in, l_in) = l.shape().dims3()?;
+ let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
+ if !kernel_l.is_contiguous() {
+ crate::bail!(
+ "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
+ )
+ }
+ if c_in != c_in2 {
+ crate::bail!(
+ "convtr1d: shape mismatch on c_in {:?} {:?}",
+ l.shape(),
+ kernel_l.shape()
+ )
+ }
+ let col = {
+ // This merges the last two dimensions of the kernel together.
+ let kernel_l_mm = Layout::new(
+ (b_size, c_in, k_size * c_out).into(),
+ vec![0, k_size * c_out, 1],
+ kernel_l.start_offset(),
+ );
+ self.matmul(
+ kernel,
+ (
+ b_size,
+ /* m */ l_in,
+ /* n */ c_out * k_size,
+ /* k */ c_in,
+ ),
+ &l.transpose(1, 2)?,
+ &kernel_l_mm,
+ )?
+ };
+ let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
+ Col2Im1D {
+ stride: params.stride,
+ }
+ .map(&col.slice, &device, &col_l)?
+ } else {
+ ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
+ };
Ok(Self { slice, device })
}
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index fed920f1..fa834faa 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -98,6 +98,50 @@ __device__ void im2col1d(
}
template <typename T>
+__device__ void col2im1d(
+ const size_t dst_el,
+ const size_t l_out,
+ const size_t l_in,
+ const size_t c_out,
+ const size_t k_size,
+ const size_t stride,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, l_in, c_out, l_k)
+ // dst: (b_size, c_out, l_out)
+ if (dst_i >= dst_el) {
+ return;
+ }
+
+ const size_t dst_s0 = c_out * l_out;
+ const size_t dst_s1 = l_out;
+ const size_t src_s0 = c_out * k_size * l_in;
+ const size_t src_s1 = c_out * k_size;
+ const size_t src_s2 = k_size;
+
+ size_t tmp_dst_i = dst_i;
+ const size_t b_idx = tmp_dst_i / dst_s0;
+ tmp_dst_i -= b_idx * dst_s0;
+ const size_t c_idx = tmp_dst_i / dst_s1;
+ tmp_dst_i -= c_idx * dst_s1;
+ const int l_out_idx = tmp_dst_i;
+
+ dst[dst_i] = static_cast<T>(0);
+
+ int l_in_idx = l_out_idx / stride;
+ int k0 = l_out_idx - l_in_idx * stride;
+ // l_out_idx = l_in_idx * stride + k0
+ for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
+ if (l_in_idx < l_in) {
+ const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
+ dst[dst_i] += src[src_i];
+ }
+ }
+}
+
+template <typename T>
__device__ void im2col(
const size_t dst_numel,
const size_t h_out,
@@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
+#define COL2IM1D_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t dst_el, \
+ const size_t l_out, \
+ const size_t l_in, \
+ const size_t c_out, \
+ const size_t k_size, \
+ const size_t stride, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
+} \
+
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
@@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
+COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
+COL2IM1D_OP(__half, col2im1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
+
+COL2IM1D_OP(float, col2im1d_f32)
+COL2IM1D_OP(double, col2im1d_f64)
+COL2IM1D_OP(uint8_t, col2im1d_u8)
+COL2IM1D_OP(uint32_t, col2im1d_u32)