summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs26
-rw-r--r--candle-core/src/indexer.rs2
-rw-r--r--candle-core/src/op.rs6
-rw-r--r--candle-core/src/quantized/avx.rs4
-rw-r--r--candle-core/src/quantized/gguf_file.rs2
-rw-r--r--candle-core/src/tensor.rs27
6 files changed, 55 insertions, 12 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index fc0c79a2..c152f31f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -114,7 +114,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
- | Op::UpsampleNearest2D(node)
+ | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
@@ -350,9 +350,27 @@ impl Tensor {
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
- Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
- op: "upsample-nearest2d",
- })?,
+ Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ } => {
+ let (_n, c, h, w) = arg.dims4()?;
+ if target_h % h != 0 || target_w % w != 0 {
+ crate::bail!("backward not supported for non integer upscaling factors")
+ }
+ let scale_h = target_h / h;
+ let scale_w = target_w / w;
+
+ if scale_h != scale_w {
+ crate::bail!("backward not supported for non uniform upscaling factors")
+ };
+ let kernel =
+ Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
+ let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = conv_sum;
+ }
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index df106b73..e3ed41e5 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -64,7 +64,7 @@ impl Tensor {
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
- /// This selects the elemnts for which an index has some specific value.
+ /// This selects the elements for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index fbb20f6c..868673e7 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -132,7 +132,11 @@ pub enum Op {
},
UpsampleNearest1D(Tensor),
- UpsampleNearest2D(Tensor),
+ UpsampleNearest2D {
+ arg: Tensor,
+ target_h: usize,
+ target_w: usize,
+ },
Cat(Vec<Tensor>, usize),
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index 5c3ac822..664f7653 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
q3 = q3.add(32);
// Prepare low and high bits
- // We hardcode the shifts here to avoid loading them into a seperate register
+ // We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
@@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
- //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
+ //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index 620bc037..1e9dc517 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -463,7 +463,7 @@ impl Content {
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
- None => crate::bail!("cannot find tensor-infor for {name}"),
+ None => crate::bail!("cannot find tensor info for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e478869a..f15f8c1c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,4 +1,4 @@
-//! Tensors are N-dimenional matrixes of elements using a single data type.
+//! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
@@ -361,6 +361,16 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
+ /// Returns a new tensor with all the elements having the same specified value. Note that
+ /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
+ pub fn full<D: crate::WithDType, S: Into<Shape>>(
+ value: D,
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
+ }
+
/// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
@@ -669,7 +679,7 @@ impl Tensor {
}
/// Split a tensor into the specified number of chunks, this may return less chunks than
- /// specificed.
+ /// specified.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?;
@@ -994,7 +1004,11 @@ impl Tensor {
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
- let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
+ let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ });
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
@@ -2558,6 +2572,13 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
+
+ /// Returns log(sum(exp(tensor), dim)).
+ pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ let exp = self.exp()?;
+ let sum = exp.sum(sum_dims)?;
+ sum.log()
+ }
}
macro_rules! bin_trait {