summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-05 13:06:33 +0100
committerGitHub <noreply@github.com>2023-07-05 13:06:33 +0100
commit93896f6596e44285f6250f4966ada8c08fa85f09 (patch)
treefee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core
parentd8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff)
parentbce28ab7938b27931fd51e59c8bcad37038e0337 (diff)
downloadcandle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz
candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2
candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backprop.rs8
-rw-r--r--candle-core/src/conv.rs27
-rw-r--r--candle-core/src/cpu_backend.rs68
-rw-r--r--candle-core/src/cuda_backend.rs10
-rw-r--r--candle-core/src/dummy_cuda_backend.rs10
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/op.rs8
-rw-r--r--candle-core/src/storage.rs26
-rw-r--r--candle-core/src/tensor.rs38
9 files changed, 194 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 45448505..a44f732f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -33,7 +33,12 @@ impl Tensor {
track_grad |= tg;
nodes
}
- Op::Add(lhs, rhs)
+ Op::Conv1D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
+ | Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
@@ -147,6 +152,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
+ Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
new file mode 100644
index 00000000..041bb6fb
--- /dev/null
+++ b/candle-core/src/conv.rs
@@ -0,0 +1,27 @@
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct ParamsConv1D {
+ pub(crate) b_size: Option<usize>,
+ // Maybe we should have a version without l_in as this bit depends on the input and not only on
+ // the weights.
+ pub(crate) l_in: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) k_size: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv1D {
+ pub(crate) fn l_out(&self) -> usize {
+ let dilation = 1;
+ (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec<usize> {
+ let l_out = self.l_out();
+ match self.b_size {
+ None => vec![self.c_out, l_out],
+ Some(n) => vec![n, self.c_out, l_out],
+ }
+ }
+}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0871175f..b2345756 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -202,6 +202,64 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
}
}
+struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
+
+impl<'a> Map2 for Conv1D<'a> {
+ const OP: &'static str = "conv1d";
+ fn f<T: 'static + num_traits::NumAssign + Copy>(
+ &self,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+ ) -> Result<Vec<T>> {
+ // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let k = &k[k_l.start_offset()..];
+ let inp_stride = inp_l.stride();
+ let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
+ (inp_stride[0], &inp_stride[1..])
+ } else {
+ (0, inp_stride) // This value never gets used anyway
+ };
+ let k_stride = k_l.stride();
+ let k_over_2 = p.k_size / 2;
+ let l_out = p.l_out();
+ let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let mut dst = vec![T::zero(); dst_elems];
+ // The output shape is [b_size, c_out, l_out]
+ for b_idx in 0..p.b_size.unwrap_or(1) {
+ let inp_idx = b_idx * inp_stride0;
+ let dst_idx = b_idx * p.c_out * l_out;
+ for dst_c_idx in 0..p.c_out {
+ let dst_idx = dst_idx + dst_c_idx * l_out;
+ for dst_l in 0..l_out {
+ let dst_idx = dst_idx + dst_l;
+ let mut d = T::zero();
+ for offset in 0..p.k_size {
+ let src_l_plus = p.stride * dst_l + offset;
+ // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
+ if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
+ let src_l = src_l_plus - k_over_2;
+ for src_c_idx in 0..p.c_in {
+ let inp_idx =
+ inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset * k_stride[2];
+ d += inp[inp_idx] * k[k_idx]
+ }
+ }
+ }
+ dst[dst_idx] = d
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct MatMul((usize, usize, usize, usize));
impl Map2 for MatMul {
@@ -627,6 +685,16 @@ impl CpuStorage {
WCond(pred, layout).map(t, t_l, f, f_l)
}
+ pub(crate) fn conv1d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
+ Conv1D(params).map(self, l, kernel, kernel_l)
+ }
+
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = self.as_slice::<u32>()?;
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 0c87004b..917655fc 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -801,6 +801,16 @@ impl CudaStorage {
Ok(Self { slice, device })
}
+ pub(crate) fn conv1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index b025eeab..0dbd8d54 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -100,6 +100,16 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ pub(crate) fn conv1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 0d4c2a8d..2365a34d 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -1,4 +1,5 @@
mod backprop;
+mod conv;
mod cpu_backend;
#[cfg(feature = "cuda")]
mod cuda_backend;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 860be0b3..ee57b325 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -12,6 +12,14 @@ pub(crate) enum Op {
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
+ #[allow(dead_code)]
+ Conv1D {
+ arg: Tensor,
+ kernel: Tensor,
+ padding: usize,
+ stride: usize,
+ },
+
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 4e630a58..53ea1544 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -144,6 +144,32 @@ impl Storage {
}
}
+ pub(crate) fn conv1d(
+ &self,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
+ self.same_device(kernel, "conv1d")?;
+ self.same_dtype(kernel, "conv1d")?;
+ match (self, &kernel) {
+ (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
+ let s = inp.conv1d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cpu(s))
+ }
+ (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
+ let s = inp.conv1d(l, kernel, kernel_l, params)?;
+ Ok(Self::Cuda(s))
+ }
+ (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
+ op: "conv1d",
+ }),
+ }
+ }
+
pub(crate) fn where_cond(
&self,
layout: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a468d879..95f663f0 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -326,7 +326,7 @@ impl Tensor {
}
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
- Ok::<_, Error>(data[0])
+ Ok::<_, Error>(data[self.layout().start_offset()])
};
match self.storage.as_ref() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
@@ -432,6 +432,42 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
+ pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
+ let (b_size, c_in, l_in) = match *self.dims() {
+ [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
+ [c_in, l_in] => (None, c_in, l_in),
+ _ => todo!("proper error message"),
+ };
+ if c_in != c_in_k {
+ todo!("proper error message")
+ }
+ let params = crate::conv::ParamsConv1D {
+ b_size,
+ l_in,
+ c_out,
+ c_in,
+ k_size,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage
+ .conv1d(self.layout(), &kernel.storage, kernel.layout(), &params)?;
+ let op = if self.track_op() || kernel.track_op() {
+ Some(Op::Conv1D {
+ arg: self.clone(),
+ kernel: kernel.clone(),
+ padding,
+ stride,
+ })
+ } else {
+ None
+ };
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
+ }
+
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();