summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-23 17:07:21 +0200
committerGitHub <noreply@github.com>2024-05-23 17:07:21 +0200
commit45e235a7473d473df5c1e50f55504a97e28be822 (patch)
tree6e8518249f1bbbc634431327b96d24f3e270ebc5 /candle-transformers
parent31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 (diff)
downloadcandle-45e235a7473d473df5c1e50f55504a97e28be822.tar.gz
candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.bz2
candle-45e235a7473d473df5c1e50f55504a97e28be822.zip
Simplify the KvCache api. (#2207)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/quantized_phi3.rs8
1 files changed, 1 insertions, 7 deletions
diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs
index a1161722..f9b55d9d 100644
--- a/candle-transformers/src/models/quantized_phi3.rs
+++ b/candle-transformers/src/models/quantized_phi3.rs
@@ -203,7 +203,6 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
- batch_size: usize,
use_flash_attn: bool,
ct: gguf_file::Content,
reader: &mut R,
@@ -252,12 +251,7 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
- let kv_cache = KvCache::new(
- 2,
- (batch_size, head_count_kv, max_seq_len, head_dim),
- DType::F32,
- device,
- )?;
+ let kv_cache = KvCache::new(2, max_seq_len);
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,