summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-02-13 16:28:56 +0100
committerGitHub <noreply@github.com>2024-02-13 16:28:56 +0100
commitc1b418586c9477a85150ce6c15dcfe4c93d3a27d (patch)
treedaabda27c6689f7ea9db68ceb8f6bd74ebf8f4b1
parentad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (diff)
downloadcandle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.tar.gz
candle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.tar.bz2
candle-c1b418586c9477a85150ce6c15dcfe4c93d3a27d.zip
Fixing quantized llama demo on metal. (#1703)
-rw-r--r--candle-core/src/quantized/ggml_file.rs3
-rw-r--r--candle-core/src/quantized/metal.rs4
-rw-r--r--candle-core/src/quantized/mod.rs12
-rw-r--r--candle-transformers/src/models/quantized_llama.rs28
4 files changed, 34 insertions, 13 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs
index 38238580..e6f5791c 100644
--- a/candle-core/src/quantized/ggml_file.rs
+++ b/candle-core/src/quantized/ggml_file.rs
@@ -233,6 +233,7 @@ pub struct Content {
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
+ pub device: Device,
}
impl Content {
@@ -252,11 +253,13 @@ impl Content {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
+ let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
+ device,
})
}
diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs
index fe57ce14..94105327 100644
--- a/candle-core/src/quantized/metal.rs
+++ b/candle-core/src/quantized/metal.rs
@@ -14,6 +14,10 @@ impl QMetalStorage {
self.dtype
}
+ pub fn device(&self) -> &MetalDevice {
+ &self.device
+ }
+
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 1dc5fe8f..366552d9 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -76,6 +76,14 @@ impl QStorage {
}
}
+ fn device(&self) -> Device {
+ match self {
+ QStorage::Cpu(_storage) => Device::Cpu,
+ #[cfg(feature = "metal")]
+ QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
+ }
+ }
+
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
@@ -336,6 +344,10 @@ impl QTensor {
self.storage.dtype()
}
+ pub fn device(&self) -> Device {
+ self.storage.device()
+ }
+
pub fn rank(&self) -> usize {
self.shape.rank()
}
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 8aa06088..eb4136f6 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -16,7 +16,7 @@ struct RmsNorm {
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let scale = scale.dequantize(&Device::Cpu)?;
+ let scale = scale.dequantize(&scale.device())?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
@@ -275,13 +275,17 @@ pub struct ModelWeights {
span_output: tracing::Span,
}
-fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
+fn precomput_freqs_cis(
+ head_dim: usize,
+ freq_base: f32,
+ device: &Device,
+) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
- let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
- let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
+ let theta = Tensor::new(theta.as_slice(), device)?;
+ let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
@@ -292,11 +296,10 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
- let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
- let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
+ let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
- let tok_embeddings = tok_embeddings.dequantize(cpu)?;
+ let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
@@ -358,7 +361,6 @@ impl ModelWeights {
reader: &mut R,
device: &Device,
) -> Result<Self> {
- let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
@@ -382,10 +384,10 @@ impl ModelWeights {
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
- let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
+ let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
- let tok_embeddings = tok_embeddings.dequantize(cpu)?;
+ let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
@@ -472,14 +474,14 @@ impl ModelWeights {
})
}
- fn mask(&mut self, t: usize) -> Result<Tensor> {
+ fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
- let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
+ let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
@@ -487,7 +489,7 @@ impl ModelWeights {
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
- let mask = self.mask(seq_len)?;
+ let mask = self.mask(seq_len, x.device())?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {