diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-05 13:06:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-05 13:06:33 +0100 |
commit | 93896f6596e44285f6250f4966ada8c08fa85f09 (patch) | |
tree | fee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core/src/tensor.rs | |
parent | d8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff) | |
parent | bce28ab7938b27931fd51e59c8bcad37038e0337 (diff) | |
download | candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2 candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip |
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 38 |
1 files changed, 37 insertions, 1 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a468d879..95f663f0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -326,7 +326,7 @@ impl Tensor { } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok::<_, Error>(data[0]) + Ok::<_, Error>(data[self.layout().start_offset()]) }; match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), @@ -432,6 +432,42 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { + let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (b_size, c_in, l_in) = match *self.dims() { + [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), + [c_in, l_in] => (None, c_in, l_in), + _ => todo!("proper error message"), + }; + if c_in != c_in_k { + todo!("proper error message") + } + let params = crate::conv::ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + let storage = + self.storage + .conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?; + let op = if self.track_op() || kernel.track_op() { + Some(Op::Conv1D { + arg: self.clone(), + kernel: kernel.clone(), + padding, + stride, + }) + } else { + None + }; + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result<Self> { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); |