diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 28 |
1 files changed, 9 insertions, 19 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 060e8792..c326a5ac 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -842,45 +842,35 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the + /// Returns a tensor with the values from the `self` tensor at the index corresponding to the /// values hold in the `ids` tensor. /// /// # Arguments /// + /// * `self` - A tensor with dimensions `v, h`. /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive). - /// * `rhs` - A tensor with dimensions `v, h`. /// /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the /// vocabulary size, and `h` the hidden size. /// /// ```rust /// use candle::{Tensor, Device}; - /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?; - /// let emb = Tensor::embedding(&ids, &rhs)?; + /// let emb = values.embedding(&ids)?; /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]); /// # Ok::<(), candle::Error>(()) /// ``` - pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> { - if !rhs.is_contiguous() { - Err(Error::RequiresContiguous { op: "embedding" }.bt())? - } else if rhs.rank() != 2 || ids.rank() != 1 { + pub fn embedding(&self, ids: &Self) -> Result<Self> { + if self.rank() != 2 || ids.rank() != 1 { Err(Error::ShapeMismatchBinaryOp { - lhs: ids.shape().clone(), - rhs: rhs.shape().clone(), + lhs: self.shape().clone(), + rhs: ids.shape().clone(), op: "embedding", } .bt())? } - let ids_shape = ids.shape(); - let seq_len = ids_shape.dims1()?; - let (_, hidden_size) = rhs.dims2()?; - let storage = ids - .storage() - .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; - let shape: Shape = (seq_len, hidden_size).into(); - let op = BackpropOp::new2(ids, rhs, Op::Embedding); - Ok(from_storage(storage, shape, op, false)) + self.index_select(ids, 0) } pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { |