diff options
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe..20363aea 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -351,13 +351,16 @@ impl ModelWeights { let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(device)?; + let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings_q.dequantize(device)?; let norm = RmsNorm::from_qtensor( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, )?; - let output = ct.tensor(reader, "output.weight", device)?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(tensor) => tensor, + Err(_) => tok_embeddings_q, + }; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); |