diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2024-02-13 16:28:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 16:28:56 +0100 |
commit | c1b418586c9477a85150ce6c15dcfe4c93d3a27d (patch) | |
tree | daabda27c6689f7ea9db68ceb8f6bd74ebf8f4b1 | |
parent | ad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (diff) | |
download | candle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.tar.gz candle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.tar.bz2 candle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.zip |
Fixing quantized llama demo on metal. (#1703)
-rw-r--r-- | candle-core/src/quantized/ggml_file.rs | 3 | ||||
-rw-r--r-- | candle-core/src/quantized/metal.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 12 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 28 |
4 files changed, 34 insertions, 13 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 38238580..e6f5791c 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -233,6 +233,7 @@ pub struct Content { pub hparams: HParams, pub vocab: Vocab, pub tensors: HashMap<String, super::QTensor>, + pub device: Device, } impl Content { @@ -252,11 +253,13 @@ impl Content { let (name, tensor) = read_one_tensor(reader, magic, device)?; tensors.insert(name, tensor); } + let device = device.clone(); Ok(Self { magic, hparams, vocab, tensors, + device, }) } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index fe57ce14..94105327 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -14,6 +14,10 @@ impl QMetalStorage { self.dtype } + pub fn device(&self) -> &MetalDevice { + &self.device + } + pub fn buffer(&self) -> &Buffer { &self.buffer } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 1dc5fe8f..366552d9 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -76,6 +76,14 @@ impl QStorage { } } + fn device(&self) -> Device { + match self { + QStorage::Cpu(_storage) => Device::Cpu, + #[cfg(feature = "metal")] + QStorage::Metal(storage) => Device::Metal(storage.device().clone()), + } + } + fn size_in_bytes(&self) -> usize { match self { QStorage::Cpu(storage) => storage.storage_size_in_bytes(), @@ -336,6 +344,10 @@ impl QTensor { self.storage.dtype() } + pub fn device(&self) -> Device { + self.storage.device() + } + pub fn rank(&self) -> usize { self.shape.rank() } diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8aa06088..eb4136f6 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -16,7 +16,7 @@ struct RmsNorm { impl RmsNorm { fn new(scale: QTensor, eps: f32) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; + let scale = scale.dequantize(&scale.device())?; let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); Ok(Self { inner, span }) } @@ -275,13 +275,17 @@ pub struct ModelWeights { span_output: tracing::Span, } -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { let theta: Vec<_> = (0..head_dim) .step_by(2) .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? .to_dtype(DType::F32)? .reshape((MAX_SEQ_LEN, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; @@ -292,11 +296,10 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { - let cpu = &Device::Cpu; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); @@ -358,7 +361,6 @@ impl ModelWeights { reader: &mut R, device: &Device, ) -> Result<Self> { - let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { None => candle::bail!("cannot find {s} in metadata"), Some(v) => Ok(v), @@ -382,10 +384,10 @@ impl ModelWeights { let rope_freq_base = md_get("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; let norm = RmsNorm::new( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, @@ -472,14 +474,14 @@ impl ModelWeights { }) } - fn mask(&mut self, t: usize) -> Result<Tensor> { + fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> { if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + let mask = Tensor::from_slice(&mask, (t, t), device)?; self.masks.insert(t, mask.clone()); Ok(mask) } @@ -487,7 +489,7 @@ impl ModelWeights { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; + let mask = self.mask(seq_len, x.device())?; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() { |