diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-23 17:07:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-23 17:07:21 +0200 |
commit | 45e235a7473d473df5c1e50f55504a97e28be822 (patch) | |
tree | 6e8518249f1bbbc634431327b96d24f3e270ebc5 /candle-transformers | |
parent | 31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 (diff) | |
download | candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.gz candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.bz2 candle-45e235a7473d473df5c1e50f55504a97e28be822.zip |
Simplify the KvCache api. (#2207)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/quantized_phi3.rs | 8 |
1 files changed, 1 insertions, 7 deletions
diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index a1161722..f9b55d9d 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -203,7 +203,6 @@ fn precomput_freqs_cis( impl ModelWeights { pub fn from_gguf<R: std::io::Seek + std::io::Read>( - batch_size: usize, use_flash_attn: bool, ct: gguf_file::Content, reader: &mut R, @@ -252,12 +251,7 @@ impl ModelWeights { )?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let kv_cache = KvCache::new( - 2, - (batch_size, head_count_kv, max_seq_len, head_dim), - DType::F32, - device, - )?; + let kv_cache = KvCache::new(2, max_seq_len); layers.push(LayerWeights { attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, |