summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs70
1 files changed, 2 insertions, 68 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a4b9795b..46f9c53f 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
-fn from_storage<S: Into<Shape>>(
+pub(crate) fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: BackpropOp,
@@ -787,72 +787,6 @@ impl Tensor {
self.cmp(rhs, CmpOp::Le)
}
- /// Applies a 1D convolution over the input tensor.
- pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = self.dims3()?;
- if c_in != c_in_k {
- Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "the number of in-channels on the input doesn't match the kernel size",
- }
- .bt())?
- }
- let params = crate::conv::ParamsConv1D {
- b_size,
- l_in,
- c_out,
- c_in,
- k_size,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
- pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (b_size, c_in, i_h, i_w) = self.dims4()?;
- let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
- if c_in != c_in_k {
- crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
- }
- let params = crate::conv::ParamsConv2D {
- b_size,
- i_h,
- i_w,
- k_h,
- k_w,
- c_out,
- c_in,
- padding,
- stride,
- };
- let storage =
- self.storage()
- .conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
- arg,
- kernel,
- padding,
- stride,
- });
- let out_dims = params.out_dims();
- Ok(from_storage(storage, out_dims, op, false))
- }
-
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
@@ -1920,7 +1854,7 @@ impl Tensor {
}
}
- fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
+ pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}