summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-05 13:06:33 +0100
committerGitHub <noreply@github.com>2023-07-05 13:06:33 +0100
commit93896f6596e44285f6250f4966ada8c08fa85f09 (patch)
treefee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core/src/tensor.rs
parentd8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff)
parentbce28ab7938b27931fd51e59c8bcad37038e0337 (diff)
downloadcandle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz
candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2
candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs38
1 files changed, 37 insertions, 1 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a468d879..95f663f0 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -326,7 +326,7 @@ impl Tensor {
}
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
- Ok::<_, Error>(data[0])
+ Ok::<_, Error>(data[self.layout().start_offset()])
};
match self.storage.as_ref() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
@@ -432,6 +432,42 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
+ pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
+ let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
+ let (b_size, c_in, l_in) = match *self.dims() {
+ [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
+ [c_in, l_in] => (None, c_in, l_in),
+ _ => todo!("proper error message"),
+ };
+ if c_in != c_in_k {
+ todo!("proper error message")
+ }
+ let params = crate::conv::ParamsConv1D {
+ b_size,
+ l_in,
+ c_out,
+ c_in,
+ k_size,
+ padding,
+ stride,
+ };
+ let storage =
+ self.storage
+ .conv1d(self.layout(), &kernel.storage, kernel.layout(), &params)?;
+ let op = if self.track_op() || kernel.track_op() {
+ Some(Op::Conv1D {
+ arg: self.clone(),
+ kernel: kernel.clone(),
+ padding,
+ stride,
+ })
+ } else {
+ None
+ };
+ let out_dims = params.out_dims();
+ Ok(from_storage(storage, out_dims, op, false))
+ }
+
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();