summaryrefslogtreecommitdiff
path: root/candle-core/examples
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-27 17:37:09 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-27 17:37:09 +0100
commitc44e5346f40f3825c60dc8cab113867753916400 (patch)
tree1af042500193773b6c42348a47ef52ffb9d0cb7a /candle-core/examples
parentefc39b71c577f4e4ab067cc4301aa08cdb7c1bb1 (diff)
downloadcandle-c44e5346f40f3825c60dc8cab113867753916400.tar.gz
candle-c44e5346f40f3825c60dc8cab113867753916400.tar.bz2
candle-c44e5346f40f3825c60dc8cab113867753916400.zip
Add some helper functions.
Diffstat (limited to 'candle-core/examples')
-rw-r--r--candle-core/examples/llama/main.rs7
1 files changed, 1 insertions, 6 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index baf0cdb8..eb681f4b 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -306,12 +306,7 @@ impl CausalSelfAttention {
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
- // TODO: Add the flatten op.
- let mut dims = rope.dims().to_vec();
- let v1 = dims.pop().unwrap();
- let v2 = dims.pop().unwrap();
- dims.push(v1 * v2);
- let rope = rope.reshape(dims)?;
+ let rope = rope.flatten(Some(rope.rank() - 2), None)?;
Ok(rope)
}