summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-29 19:07:52 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-29 19:07:52 +0100
commitb50bd880ce472d7c20d09d6e5c7f49fcdf95f8db (patch)
tree3dc788d3a1f1c7dda9dbb3a931a54f5a48df4cf6
parent3232df9458e41c7414d51459b23e493b75a3949c (diff)
downloadcandle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.tar.gz
candle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.tar.bz2
candle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.zip
Only narrow when needed + deactivate the kv cache.
-rw-r--r--candle-core/examples/llama/main.rs8
-rw-r--r--candle-core/src/error.rs8
-rw-r--r--candle-core/src/tensor.rs41
3 files changed, 41 insertions, 16 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index 9d70921c..5a8a15d3 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -24,7 +24,7 @@ mod var_store;
mod weights;
const CONTEXT_SIZE: usize = 512;
-const USE_KV_CACHE: bool = true;
+const USE_KV_CACHE: bool = false;
const START_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
@@ -268,7 +268,11 @@ impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
- let freqs_cis = freqs_cis.narrow(1, freqs_cis.dims()[1] - dims[1], dims[1])?;
+ let freqs_cis = if dims[1] < CONTEXT_SIZE {
+ freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
+ } else {
+ freqs_cis.clone()
+ };
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 83d3e66d..637fd8b7 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -10,6 +10,14 @@ pub enum Error {
got: DType,
},
+ #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
+ NarrowInvalidArgs {
+ shape: Shape,
+ dim: usize,
+ start: usize,
+ len: usize,
+ },
+
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 6586834c..2f05094b 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -349,21 +349,34 @@ impl Tensor {
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
- /// ranges from `start` to `start + length`.
- pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
- let op = if self.track_op() {
- Some(Op::Narrow(self.clone(), dim, start, length))
+ /// ranges from `start` to `start + len`.
+ pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
+ let dims = self.dims();
+ if dim >= dims.len() || start + len > dims[dim] {
+ Err(Error::NarrowInvalidArgs {
+ shape: self.shape().clone(),
+ dim,
+ start,
+ len,
+ })?
+ }
+ if start == 0 && dims[dim] == len {
+ Ok(self.clone())
} else {
- None
- };
- let tensor_ = Tensor_ {
- id: TensorId::new(),
- storage: self.storage.clone(),
- layout: self.layout().narrow(dim, start, length)?,
- op,
- is_variable: false,
- };
- Ok(Tensor(Arc::new(tensor_)))
+ let op = if self.track_op() {
+ Some(Op::Narrow(self.clone(), dim, start, len))
+ } else {
+ None
+ };
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ layout: self.layout().narrow(dim, start, len)?,
+ op,
+ is_variable: false,
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
}
pub fn softmax(&self, dim: usize) -> Result<Self> {