summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_llama.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r--candle-transformers/src/models/quantized_llama.rs41
1 files changed, 25 insertions, 16 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 1fb2d9e2..8aa06088 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -356,6 +356,7 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
+ device: &Device,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
@@ -383,21 +384,28 @@ impl ModelWeights {
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
- let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
+ let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
- let output = ct.tensor(reader, "output.weight")?;
+ let norm = RmsNorm::new(
+ ct.tensor(reader, "output_norm.weight", device)?,
+ rms_norm_eps,
+ )?;
+ let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
- let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
- let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
- let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
- let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
+ let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
+ let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
+ let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
+ let attention_wo =
+ ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 {
- let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
- let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
- let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ let feed_forward_w1 =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
+ let feed_forward_w2 =
+ ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
+ let feed_forward_w3 =
+ ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@@ -405,15 +413,15 @@ impl ModelWeights {
})
} else {
let feed_forward_gate_inp =
- ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let feed_forward_w1 =
- ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
let feed_forward_w2 =
- ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
let feed_forward_w3 =
- ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@@ -426,8 +434,9 @@ impl ModelWeights {
experts,
}
};
- let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
- let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
+ let attention_norm =
+ ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
+ let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");