diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 1fb2d9e2..8aa06088 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -356,6 +356,7 @@ impl ModelWeights { pub fn from_gguf<R: std::io::Seek + std::io::Read>( ct: gguf_file::Content, reader: &mut R, + device: &Device, ) -> Result<Self> { let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { @@ -383,21 +384,28 @@ impl ModelWeights { .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; + let norm = RmsNorm::new( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = ct.tensor(reader, "output.weight", device)?; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -405,15 +413,15 @@ impl ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -426,8 +434,9 @@ impl ModelWeights { experts, } }; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); |