summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs58
-rw-r--r--candle-core/tests/indexing_tests.rs29
2 files changed, 87 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ce5858fa..87323a84 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2503,6 +2503,64 @@ impl Tensor {
t.transpose(dim, last)
}
}
+
+ /// Returns a copy of `self` where the values within `ranges` have been replaced with the
+ /// content of `src`.
+ pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
+ &self,
+ ranges: &[D],
+ src: &Tensor,
+ ) -> Result<Self> {
+ let src_dims = src.dims();
+ let self_dims = self.dims();
+ if self_dims.len() != src_dims.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank {} <> {}",
+ self_dims.len(),
+ src_dims.len()
+ )
+ }
+ if self_dims.len() != ranges.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank as there are ranges {} <> {}",
+ self_dims.len(),
+ ranges.len()
+ )
+ }
+ let mut src = src.clone();
+ let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
+ for (i, range) in ranges.iter().enumerate() {
+ let start_included = match range.start_bound() {
+ std::ops::Bound::Unbounded => 0,
+ std::ops::Bound::Included(v) => *v,
+ std::ops::Bound::Excluded(v) => *v + 1,
+ };
+ let end_excluded = match range.end_bound() {
+ std::ops::Bound::Unbounded => self_dims[i],
+ std::ops::Bound::Included(v) => *v + 1,
+ std::ops::Bound::Excluded(v) => *v,
+ };
+ if end_excluded <= start_included {
+ crate::bail!(
+ "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
+ )
+ }
+ if self_dims[i] < end_excluded {
+ crate::bail!(
+ "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
+ self_dims[i]
+ )
+ }
+ if end_excluded - start_included != src_dims[i] {
+ crate::bail!(
+ "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
+ )
+ }
+ src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
+ mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
+ }
+ mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs
index 9c88f319..047205a3 100644
--- a/candle-core/tests/indexing_tests.rs
+++ b/candle-core/tests/indexing_tests.rs
@@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(())
}
+
+#[test]
+fn slice_assign() -> Result<()> {
+ let dev = Device::Cpu;
+
+ let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
+ let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
+ let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
+ assert_eq!(
+ out.to_vec2::<u32>()?,
+ &[
+ [0, 1, 2, 3, 4],
+ [5, 6, 7, 0, 1],
+ [10, 11, 12, 2, 3],
+ [15, 16, 17, 4, 5]
+ ]
+ );
+ let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
+ assert_eq!(
+ out.to_vec2::<u32>()?,
+ &[
+ [0, 1, 2, 3, 4],
+ [2, 3, 7, 8, 9],
+ [4, 5, 12, 13, 14],
+ [15, 16, 17, 18, 19]
+ ]
+ );
+ Ok(())
+}