summaryrefslogtreecommitdiff
path: root/candle-examples/examples/falcon/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 11:04:40 +0100
committerGitHub <noreply@github.com>2023-07-13 11:04:40 +0100
commit50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb (patch)
treec48c4ecc686748e10b678d347af8d46cb0955a6c /candle-examples/examples/falcon/model.rs
parenta3663ce2f2b03263075099baed677340974b7f4c (diff)
downloadcandle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.gz
candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.bz2
candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.zip
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r--candle-examples/examples/falcon/model.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index 82c5d4b2..60821add 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -183,7 +183,7 @@ impl FalconRotaryEmbedding {
past_kv_len: usize,
) -> Result<(Tensor, Tensor)> {
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
- let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, &query.device(), query.dtype())?;
+ let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
let cos = cos.narrow(0, past_kv_len, seq_len)?;
let sin = sin.narrow(0, past_kv_len, seq_len)?;
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
@@ -194,7 +194,7 @@ impl FalconRotaryEmbedding {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
- let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@@ -471,7 +471,7 @@ impl Falcon {
Some((k, _)) => k.dim(1)?,
None => 0,
};
- let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
+ let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
for block in self.blocks.iter_mut() {
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
}