diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8aa06088..eb4136f6 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -16,7 +16,7 @@ struct RmsNorm { impl RmsNorm { fn new(scale: QTensor, eps: f32) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; + let scale = scale.dequantize(&scale.device())?; let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); Ok(Self { inner, span }) } @@ -275,13 +275,17 @@ pub struct ModelWeights { span_output: tracing::Span, } -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { let theta: Vec<_> = (0..head_dim) .step_by(2) .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? .to_dtype(DType::F32)? .reshape((MAX_SEQ_LEN, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; @@ -292,11 +296,10 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { - let cpu = &Device::Cpu; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); @@ -358,7 +361,6 @@ impl ModelWeights { reader: &mut R, device: &Device, ) -> Result<Self> { - let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { None => candle::bail!("cannot find {s} in metadata"), Some(v) => Ok(v), @@ -382,10 +384,10 @@ impl ModelWeights { let rope_freq_base = md_get("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; let norm = RmsNorm::new( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, @@ -472,14 +474,14 @@ impl ModelWeights { }) } - fn mask(&mut self, t: usize) -> Result<Tensor> { + fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> { if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + let mask = Tensor::from_slice(&mask, (t, t), device)?; self.masks.insert(t, mask.clone()); Ok(mask) } @@ -487,7 +489,7 @@ impl ModelWeights { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; + let mask = self.mask(seq_len, x.device())?; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() { |