diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-29 19:07:52 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-29 19:07:52 +0100 |
commit | b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db (patch) | |
tree | 3dc788d3a1f1c7dda9dbb3a931a54f5a48df4cf6 | |
parent | 3232df9458e41c7414d51459b23e493b75a3949c (diff) | |
download | candle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.tar.gz candle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.tar.bz2 candle-b50bd880ce472d7c20d09d6e5c7f49fcdf95f8db.zip |
Only narrow when needed + deactivate the kv cache.
-rw-r--r-- | candle-core/examples/llama/main.rs | 8 | ||||
-rw-r--r-- | candle-core/src/error.rs | 8 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 41 |
3 files changed, 41 insertions, 16 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 9d70921c..5a8a15d3 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -24,7 +24,7 @@ mod var_store; mod weights; const CONTEXT_SIZE: usize = 512; -const USE_KV_CACHE: bool = true; +const USE_KV_CACHE: bool = false; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -268,7 +268,11 @@ impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { let mut dims = x.dims().to_vec(); - let freqs_cis = freqs_cis.narrow(1, freqs_cis.dims()[1] - dims[1], dims[1])?; + let freqs_cis = if dims[1] < CONTEXT_SIZE { + freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])? + } else { + freqs_cis.clone() + }; let v = dims.pop().unwrap(); dims.push(v / 2); dims.push(2); diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 83d3e66d..637fd8b7 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -10,6 +10,14 @@ pub enum Error { got: DType, }, + #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")] + NarrowInvalidArgs { + shape: Shape, + dim: usize, + start: usize, + len: usize, + }, + #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6586834c..2f05094b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -349,21 +349,34 @@ impl Tensor { } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` - /// ranges from `start` to `start + length`. - pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { - let op = if self.track_op() { - Some(Op::Narrow(self.clone(), dim, start, length)) + /// ranges from `start` to `start + len`. + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { + let dims = self.dims(); + if dim >= dims.len() || start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape().clone(), + dim, + start, + len, + })? + } + if start == 0 && dims[dim] == len { + Ok(self.clone()) } else { - None - }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage: self.storage.clone(), - layout: self.layout().narrow(dim, start, length)?, - op, - is_variable: false, - }; - Ok(Tensor(Arc::new(tensor_))) + let op = if self.track_op() { + Some(Op::Narrow(self.clone(), dim, start, len)) + } else { + None + }; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, len)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } } pub fn softmax(&self, dim: usize) -> Result<Self> { |