diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-05 13:06:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-05 13:06:33 +0100 |
commit | 93896f6596e44285f6250f4966ada8c08fa85f09 (patch) | |
tree | fee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core/src | |
parent | d8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff) | |
parent | bce28ab7938b27931fd51e59c8bcad37038e0337 (diff) | |
download | candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2 candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip |
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 8 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 27 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 68 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 10 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 10 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/op.rs | 8 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 26 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 38 |
9 files changed, 194 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 45448505..a44f732f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -33,7 +33,12 @@ impl Tensor { track_grad |= tg; nodes } - Op::Add(lhs, rhs) + Op::Conv1D { + arg: lhs, + kernel: rhs, + .. + } + | Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) @@ -147,6 +152,7 @@ impl Tensor { let f_grad = pred.where_cond(&zeros, &grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?; } + Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }), Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) } diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs new file mode 100644 index 00000000..041bb6fb --- /dev/null +++ b/candle-core/src/conv.rs @@ -0,0 +1,27 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ParamsConv1D { + pub(crate) b_size: Option<usize>, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self) -> usize { + let dilation = 1; + (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec<usize> { + let l_out = self.l_out(); + match self.b_size { + None => vec![self.c_out, l_out], + Some(n) => vec![n, self.c_out, l_out], + } + } +} diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0871175f..b2345756 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -202,6 +202,64 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>( } } +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); + +impl<'a> Map2 for Conv1D<'a> { + const OP: &'static str = "conv1d"; + fn f<T: 'static + num_traits::NumAssign + Copy>( + &self, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, + ) -> Result<Vec<T>> { + // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { + (inp_stride[0], &inp_stride[1..]) + } else { + (0, inp_stride) // This value never gets used anyway + }; + let k_stride = k_l.stride(); + let k_over_2 = p.k_size / 2; + let l_out = p.l_out(); + let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let mut dst = vec![T::zero(); dst_elems]; + // The output shape is [b_size, c_out, l_out] + for b_idx in 0..p.b_size.unwrap_or(1) { + let inp_idx = b_idx * inp_stride0; + let dst_idx = b_idx * p.c_out * l_out; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; + let mut d = T::zero(); + for offset in 0..p.k_size { + let src_l_plus = p.stride * dst_l + offset; + // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] + if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { + let src_l = src_l_plus - k_over_2; + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] + } + } + } + dst[dst_idx] = d + } + } + } + Ok(dst) + } +} + struct MatMul((usize, usize, usize, usize)); impl Map2 for MatMul { @@ -627,6 +685,16 @@ impl CpuStorage { WCond(pred, layout).map(t, t_l, f, f_l) } + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result<Self> { + Conv1D(params).map(self, l, kernel, kernel_l) + } + pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = self.as_slice::<u32>()?; let (vocab_size, hidden_size) = rhs_l.shape().r2()?; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0c87004b..917655fc 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -801,6 +801,16 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result<Self> { + todo!() + } + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let device = self.device().clone(); let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index b025eeab..0dbd8d54 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -100,6 +100,16 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0d4c2a8d..2365a34d 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -1,4 +1,5 @@ mod backprop; +mod conv; mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 860be0b3..ee57b325 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -12,6 +12,14 @@ pub(crate) enum Op { Embedding(Tensor, Tensor), WhereCond(Tensor, Tensor, Tensor), + #[allow(dead_code)] + Conv1D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + Cat(Vec<Tensor>, usize), #[allow(dead_code)] // add is currently unused. diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 4e630a58..53ea1544 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -144,6 +144,32 @@ impl Storage { } } + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result<Self> { + self.same_device(kernel, "conv1d")?; + self.same_dtype(kernel, "conv1d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv1d", + }), + } + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a468d879..95f663f0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -326,7 +326,7 @@ impl Tensor { } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok::<_, Error>(data[0]) + Ok::<_, Error>(data[self.layout().start_offset()]) }; match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), @@ -432,6 +432,42 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { + let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (b_size, c_in, l_in) = match *self.dims() { + [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), + [c_in, l_in] => (None, c_in, l_in), + _ => todo!("proper error message"), + }; + if c_in != c_in_k { + todo!("proper error message") + } + let params = crate::conv::ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + let storage = + self.storage + .conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?; + let op = if self.track_op() || kernel.track_op() { + Some(Op::Conv1D { + arg: self.clone(), + kernel: kernel.clone(), + padding, + stride, + }) + } else { + None + }; + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result<Self> { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); |