summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 93846160..f64bd6f2 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -481,10 +481,10 @@ impl Tensor {
}
let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
- let (vocab_size, hidden_size) = rhs.shape().r2()?;
+ let (_, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
- .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
+ .embedding(ids.layout(), &rhs.storage, rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))