summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs28
1 files changed, 9 insertions, 19 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 060e8792..c326a5ac 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -842,45 +842,35 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
- /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
+ /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
/// values hold in the `ids` tensor.
///
/// # Arguments
///
+ /// * `self` - A tensor with dimensions `v, h`.
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
- /// * `rhs` - A tensor with dimensions `v, h`.
///
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
/// vocabulary size, and `h` the hidden size.
///
/// ```rust
/// use candle::{Tensor, Device};
- /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
- /// let emb = Tensor::embedding(&ids, &rhs)?;
+ /// let emb = values.embedding(&ids)?;
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
/// # Ok::<(), candle::Error>(())
/// ```
- pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
- if !rhs.is_contiguous() {
- Err(Error::RequiresContiguous { op: "embedding" }.bt())?
- } else if rhs.rank() != 2 || ids.rank() != 1 {
+ pub fn embedding(&self, ids: &Self) -> Result<Self> {
+ if self.rank() != 2 || ids.rank() != 1 {
Err(Error::ShapeMismatchBinaryOp {
- lhs: ids.shape().clone(),
- rhs: rhs.shape().clone(),
+ lhs: self.shape().clone(),
+ rhs: ids.shape().clone(),
op: "embedding",
}
.bt())?
}
- let ids_shape = ids.shape();
- let seq_len = ids_shape.dims1()?;
- let (_, hidden_size) = rhs.dims2()?;
- let storage = ids
- .storage()
- .embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
- let shape: Shape = (seq_len, hidden_size).into();
- let op = BackpropOp::new2(ids, rhs, Op::Embedding);
- Ok(from_storage(storage, shape, op, false))
+ self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {