summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-26 17:31:22 +0000
committerGitHub <noreply@github.com>2023-11-26 17:31:22 +0000
commit481c45d78da4e76951b214f43d4a5789d97c3620 (patch)
tree3d8decae4d19176723fbc94684b9db78d56bfc18 /candle-core/src/tensor.rs
parent14a2bdc06232066c4be06825b8894a22666ef1ca (diff)
downloadcandle-481c45d78da4e76951b214f43d4a5789d97c3620.tar.gz
candle-481c45d78da4e76951b214f43d4a5789d97c3620.tar.bz2
candle-481c45d78da4e76951b214f43d4a5789d97c3620.zip
Add a basic implementation for slice-assign. (#1377)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs58
1 files changed, 58 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ce5858fa..87323a84 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2503,6 +2503,64 @@ impl Tensor {
t.transpose(dim, last)
}
}
+
+ /// Returns a copy of `self` where the values within `ranges` have been replaced with the
+ /// content of `src`.
+ pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
+ &self,
+ ranges: &[D],
+ src: &Tensor,
+ ) -> Result<Self> {
+ let src_dims = src.dims();
+ let self_dims = self.dims();
+ if self_dims.len() != src_dims.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank {} <> {}",
+ self_dims.len(),
+ src_dims.len()
+ )
+ }
+ if self_dims.len() != ranges.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank as there are ranges {} <> {}",
+ self_dims.len(),
+ ranges.len()
+ )
+ }
+ let mut src = src.clone();
+ let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
+ for (i, range) in ranges.iter().enumerate() {
+ let start_included = match range.start_bound() {
+ std::ops::Bound::Unbounded => 0,
+ std::ops::Bound::Included(v) => *v,
+ std::ops::Bound::Excluded(v) => *v + 1,
+ };
+ let end_excluded = match range.end_bound() {
+ std::ops::Bound::Unbounded => self_dims[i],
+ std::ops::Bound::Included(v) => *v + 1,
+ std::ops::Bound::Excluded(v) => *v,
+ };
+ if end_excluded <= start_included {
+ crate::bail!(
+ "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
+ )
+ }
+ if self_dims[i] < end_excluded {
+ crate::bail!(
+ "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
+ self_dims[i]
+ )
+ }
+ if end_excluded - start_included != src_dims[i] {
+ crate::bail!(
+ "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
+ )
+ }
+ src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
+ mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
+ }
+ mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
+ }
}
macro_rules! bin_trait {