summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/examples/basics.rs2
-rw-r--r--candle-core/examples/cpu_benchmarks.rs4
-rw-r--r--candle-core/examples/cuda_basics.rs2
-rw-r--r--candle-core/src/conv.rs112
-rw-r--r--candle-core/src/tensor.rs70
-rw-r--r--candle-core/tests/conv_tests.rs14
7 files changed, 126 insertions, 80 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index b190c55e..3b3e4eb7 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.1.2", optional = true }
+candle-kernels = { path = "../candle-kernels", version = "0.1.3", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index efce913a..9d4734de 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
- let res = inp.conv2d(&w, 0, 1);
+ let res = inp.conv2d(&w, 0, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
index d7f60f81..1ebd9b75 100644
--- a/candle-core/examples/cpu_benchmarks.rs
+++ b/candle-core/examples/cpu_benchmarks.rs
@@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv1d(&d.1, 0, 1)
+ d.0.conv1d(&d.1, 0, 1, 1)
}
const ITERS: usize = 5;
@@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv2d(&d.1, 0, 1)
+ d.0.conv2d(&d.1, 0, 1, 1)
}
const ITERS: usize = 1;
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index 12febb60..ac435488 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
- let res = t.conv2d(&w, 1, 1)?;
+ let res = t.conv2d(&w, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index e3fea861..d4b7a76d 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,3 +1,5 @@
+use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: usize,
@@ -51,3 +53,113 @@ impl ParamsConv2D {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
+
+impl Tensor {
+ fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 1D convolution over the input tensor.
+ pub fn conv1d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (c_out, c_in_k, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
+ if c_in != c_in_k * groups {
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ }
+ .bt())?
+ }
+
+ let params = ParamsConv1D {
+ b_size,
+ l_in,
+ c_out,
+ c_in,
+ k_size,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv1d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv1d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 2D convolution over the input tensor.
+ pub fn conv2d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k * groups {
+ crate::bail!(
+ "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
+ )
+ }
+ let params = ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv2d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv2d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a4b9795b..46f9c53f 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
-fn from_storage<S: Into<Shape>>(
+pub(crate) fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: BackpropOp,
@@ -787,72 +787,6 @@ impl Tensor {
self.cmp(rhs, CmpOp::Le)
}
- /// Applies a 1D convolution over the input tensor.
- pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = self.dims3()?;
- if c_in != c_in_k {
- Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "the number of in-channels on the input doesn't match the kernel size",
- }
- .bt())?
- }
- let params = crate::conv::ParamsConv1D {
- b_size,
- l_in,
- c_out,
- c_in,
- k_size,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
- pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (b_size, c_in, i_h, i_w) = self.dims4()?;
- let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
- if c_in != c_in_k {
- crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
- }
- let params = crate::conv::ParamsConv2D {
- b_size,
- i_h,
- i_w,
- k_h,
- k_w,
- c_out,
- c_in,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
@@ -1920,7 +1854,7 @@ impl Tensor {
}
}
- fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
+ pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index c777fec7..d09fa344 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -33,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
- let res = t.conv1d(&w, 0, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
@@ -52,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> {
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
- let res = t.conv1d(&w, 0, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -109,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -143,7 +143,7 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -162,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
- let res = t.conv2d(&w, 0, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,