diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 93846160..f64bd6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,10 +481,10 @@ impl Tensor { } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; - let (vocab_size, hidden_size) = rhs.shape().r2()?; + let (_, hidden_size) = rhs.shape().r2()?; let storage = ids .storage - .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; + .embedding(ids.layout(), &rhs.storage, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) |