diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 26 | ||||
-rw-r--r-- | candle-core/src/indexer.rs | 2 | ||||
-rw-r--r-- | candle-core/src/op.rs | 6 | ||||
-rw-r--r-- | candle-core/src/quantized/avx.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 27 |
6 files changed, 55 insertions, 12 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fc0c79a2..c152f31f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -114,7 +114,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D(node) - | Op::UpsampleNearest2D(node) + | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) @@ -350,9 +350,27 @@ impl Tensor { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest1d", })?, - Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest2d", - })?, + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::SliceScatter0(lhs, rhs, start_rhs) => { let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index df106b73..e3ed41e5 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -64,7 +64,7 @@ impl Tensor { #[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { - /// This selects the elemnts for which an index has some specific value. + /// This selects the elements for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbb20f6c..868673e7 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -132,7 +132,11 @@ pub enum Op { }, UpsampleNearest1D(Tensor), - UpsampleNearest2D(Tensor), + UpsampleNearest2D { + arg: Tensor, + target_h: usize, + target_w: usize, + }, Cat(Vec<Tensor>, usize), diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 5c3ac822..664f7653 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res q3 = q3.add(32); // Prepare low and high bits - // We hardcode the shifts here to avoid loading them into a seperate register + // We hardcode the shifts here to avoid loading them into a separate register let q3l_0 = _mm256_and_si256(q3bits, m3); let q3h_0 = if j == 0 { _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) @@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bits = _mm256_loadu_si256(q5 as *const __m256i); q5 = q5.add(32); - //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register + //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register let q5l_0 = _mm256_and_si256(q5bits, m4); let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); let q5l_0_right_shift = match j { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 620bc037..1e9dc517 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -463,7 +463,7 @@ impl Content { ) -> Result<QTensor> { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, - None => crate::bail!("cannot find tensor-infor for {name}"), + None => crate::bail!("cannot find tensor info for {name}"), }; tensor_info.read(reader, self.tensor_data_offset) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e478869a..f15f8c1c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -//! Tensors are N-dimenional matrixes of elements using a single data type. +//! Tensors are N-dimensional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ @@ -361,6 +361,16 @@ impl Tensor { Self::new_impl(array, shape, device, false) } + /// Returns a new tensor with all the elements having the same specified value. Note that + /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + pub fn full<D: crate::WithDType, S: Into<Shape>>( + value: D, + shape: S, + device: &Device, + ) -> Result<Self> { + Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape) + } + /// Creates a new 1D tensor from an iterator. pub fn from_iter<D: crate::WithDType>( iter: impl IntoIterator<Item = D>, @@ -669,7 +679,7 @@ impl Tensor { } /// Split a tensor into the specified number of chunks, this may return less chunks than - /// specificed. + /// specified. pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> { let dim = dim.to_index(self.shape(), "chunk")?; let size = self.dim(dim)?; @@ -994,7 +1004,11 @@ impl Tensor { /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; - let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D { + arg, + target_h, + target_w, + }); let storage = self .storage() .upsample_nearest2d(self.layout(), target_h, target_w)?; @@ -2558,6 +2572,13 @@ impl Tensor { } mask.where_cond(/* on_true= */ &src, /* on_false= */ self) } + + /// Returns log(sum(exp(tensor), dim)). + pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + let exp = self.exp()?; + let sum = exp.sum(sum_dims)?; + sum.log() + } } macro_rules! bin_trait { |