diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 11:04:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 11:04:40 +0100 |
commit | 50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb (patch) | |
tree | c48c4ecc686748e10b678d347af8d46cb0955a6c /candle-examples/examples/bert | |
parent | a3663ce2f2b03263075099baed677340974b7f4c (diff) | |
download | candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.gz candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.bz2 candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.zip |
Tensor mutability (#154)
* Working towards tensor mutability.
* Use a ref-cell to provide tensor mutability.
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index d8f6921e..1c3c429b 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -196,7 +196,7 @@ impl BertEmbeddings { if let Some(position_embeddings) = &self.position_embeddings { // TODO: Proper absolute positions? let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); - let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?; + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? } let embeddings = self.layer_norm.forward(&embeddings)?; |