diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-26 17:31:22 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-26 17:31:22 +0000 |
commit | 481c45d78da4e76951b214f43d4a5789d97c3620 (patch) | |
tree | 3d8decae4d19176723fbc94684b9db78d56bfc18 /candle-core/src/tensor.rs | |
parent | 14a2bdc06232066c4be06825b8894a22666ef1ca (diff) | |
download | candle-481c45d78da4e76951b214f43d4a5789d97c3620.tar.gz candle-481c45d78da4e76951b214f43d4a5789d97c3620.tar.bz2 candle-481c45d78da4e76951b214f43d4a5789d97c3620.zip |
Add a basic implementation for slice-assign. (#1377)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce5858fa..87323a84 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2503,6 +2503,64 @@ impl Tensor { t.transpose(dim, last) } } + + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. + pub fn slice_assign<D: std::ops::RangeBounds<usize>>( + &self, + ranges: &[D], + src: &Tensor, + ) -> Result<Self> { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + crate::bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + crate::bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => *v, + std::ops::Bound::Excluded(v) => *v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => *v + 1, + std::ops::Bound::Excluded(v) => *v, + }; + if end_excluded <= start_included { + crate::bail!( + "slice-assign: empty range for dim {i}, {start_included} {end_excluded}" + ) + } + if self_dims[i] < end_excluded { + crate::bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + crate::bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } } macro_rules! bin_trait { |