summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 11:04:40 +0100
committerGitHub <noreply@github.com>2023-07-13 11:04:40 +0100
commit50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb (patch)
treec48c4ecc686748e10b678d347af8d46cb0955a6c /candle-examples/examples/bert
parenta3663ce2f2b03263075099baed677340974b7f4c (diff)
downloadcandle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.gz
candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.bz2
candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.zip
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r--candle-examples/examples/bert/main.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index d8f6921e..1c3c429b 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -196,7 +196,7 @@ impl BertEmbeddings {
if let Some(position_embeddings) = &self.position_embeddings {
// TODO: Proper absolute positions?
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
- let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
+ let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
}
let embeddings = self.layer_norm.forward(&embeddings)?;