summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-28 15:43:03 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-28 15:43:03 +0100
commit3f0d9fbb257baf94acde184de76eb9667e0fa025 (patch)
tree9bd3217971362a991faac24968f9bf77bf663476 /candle-core/src/tensor.rs
parentcca699be6c8167f565067ceb3c940dd3c1d87503 (diff)
downloadcandle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.gz
candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.bz2
candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.zip
Adapt the cuda bits.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 93846160..f64bd6f2 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -481,10 +481,10 @@ impl Tensor {
}
let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
- let (vocab_size, hidden_size) = rhs.shape().r2()?;
+ let (_, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
- .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
+ .embedding(ids.layout(), &rhs.storage, rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))