diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-27 17:37:09 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-27 17:37:09 +0100 |
commit | c44e5346f40f3825c60dc8cab113867753916400 (patch) | |
tree | 1af042500193773b6c42348a47ef52ffb9d0cb7a /candle-core/examples | |
parent | efc39b71c577f4e4ab067cc4301aa08cdb7c1bb1 (diff) | |
download | candle-c44e5346f40f3825c60dc8cab113867753916400.tar.gz candle-c44e5346f40f3825c60dc8cab113867753916400.tar.bz2 candle-c44e5346f40f3825c60dc8cab113867753916400.zip |
Add some helper functions.
Diffstat (limited to 'candle-core/examples')
-rw-r--r-- | candle-core/examples/llama/main.rs | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index baf0cdb8..eb681f4b 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -306,12 +306,7 @@ impl CausalSelfAttention { let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; let rope = Tensor::cat(&[&re, &im], rank - 1)?; - // TODO: Add the flatten op. - let mut dims = rope.dims().to_vec(); - let v1 = dims.pop().unwrap(); - let v2 = dims.pop().unwrap(); - dims.push(v1 * v2); - let rope = rope.reshape(dims)?; + let rope = rope.flatten(Some(rope.rank() - 2), None)?; Ok(rope) } |