summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs8
-rw-r--r--candle-core/src/conv.rs29
-rw-r--r--candle-core/src/cpu_backend.rs25
-rw-r--r--candle-core/src/cuda_backend.rs10
-rw-r--r--candle-core/src/dummy_cuda_backend.rs10
-rw-r--r--candle-core/src/storage.rs27
-rw-r--r--candle-core/src/tensor.rs30
7 files changed, 137 insertions, 2 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index c3f8aa3c..a8e5ac52 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,6 +37,14 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self>;
+
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 4cf9d0ad..30799459 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -25,3 +25,32 @@ impl ParamsConv1D {
}
}
}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConv2D {
+ pub(crate) b_size: usize,
+ pub(crate) i_h: usize,
+ pub(crate) i_w: usize,
+ pub(crate) k_h: usize,
+ pub(crate) k_w: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv2D {
+ pub(crate) fn out_h(&self) -> usize {
+ let dilation = 1;
+ (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_w(&self) -> usize {
+ let dilation = 1;
+ (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec<usize> {
+ vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
+ }
+}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index a04ed9a0..c997d767 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1033,6 +1033,21 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
+
+impl<'a> Map2 for Conv2D<'a> {
+ const OP: &'static str = "conv2d";
+ fn f<T: 'static + num_traits::NumAssign + Copy>(
+ &self,
+ _inp: &[T],
+ _inp_l: &Layout,
+ _k: &[T],
+ _k_l: &Layout,
+ ) -> Result<Vec<T>> {
+ todo!()
+ }
+}
+
struct MatMul((usize, usize, usize, usize));
impl MatMul {
@@ -1804,6 +1819,16 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
+ fn conv2d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ Conv2D(params).map(self, l, kernel, kernel_l)
+ }
+
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 3c37373a..727ea073 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1381,6 +1381,16 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 99cb7c4e..ae4dd09f 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn conv2d(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index b4fa02e8..3ed38e6a 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -266,6 +266,33 @@ impl Storage {
}
}
+ pub(crate) fn conv2d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ self.same_device(kernel, "conv2d")?;
+ self.same_dtype(kernel, "conv2d")?;
+ match (self, &kernel) {
+ (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
+ let s = inp.conv2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cpu(s))
+ }
+ (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
+ let s = inp.conv2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cuda(s))
+ }
+ (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
+ op: "conv2d",
+ }
+ .bt()),
+ }
+ }
+
pub(crate) fn avg_pool2d(
&self,
layout: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ffa4bf8c..adba7376 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -817,8 +817,34 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
- pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
- todo!()
+ pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ let params = crate::conv::ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {