diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
commit | 3f0d9fbb257baf94acde184de76eb9667e0fa025 (patch) | |
tree | 9bd3217971362a991faac24968f9bf77bf663476 /candle-core/src/tensor.rs | |
parent | cca699be6c8167f565067ceb3c940dd3c1d87503 (diff) | |
download | candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.gz candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.bz2 candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.zip |
Adapt the cuda bits.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 93846160..f64bd6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,10 +481,10 @@ impl Tensor { } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; - let (vocab_size, hidden_size) = rhs.shape().r2()?; + let (_, hidden_size) = rhs.shape().r2()?; let storage = ids .storage - .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; + .embedding(ids.layout(), &rhs.storage, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) |