summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c/model.rs')
-rw-r--r--candle-examples/examples/llama2-c/model.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 9b982ddd..07a6e2f2 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -36,9 +36,9 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
- kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
- cos: Tensor,
- sin: Tensor,
+ pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ pub cos: Tensor,
+ pub sin: Tensor,
device: Device,
}
@@ -75,7 +75,7 @@ impl Cache {
})
}
- fn mask(&self, t: usize) -> Result<Tensor> {
+ pub fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())