summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-04 23:30:10 +0200
committerGitHub <noreply@github.com>2024-04-04 23:30:10 +0200
commitc87381fc9643ca15648c2e8379e44a596ba1854b (patch)
tree681727d7f4b89a647ff50e740aacc21cdb618b5b
parentc5626b827147e5029c6bd3e37352ec8ac501cfc3 (diff)
downloadcandle-c87381fc9643ca15648c2e8379e44a596ba1854b.tar.gz
candle-c87381fc9643ca15648c2e8379e44a596ba1854b.tar.bz2
candle-c87381fc9643ca15648c2e8379e44a596ba1854b.zip
Use F16 for moondream on cuda. (#2013)
-rw-r--r--candle-examples/examples/moondream/main.rs12
-rw-r--r--candle-transformers/src/models/mixformer.rs13
2 files changed, 17 insertions, 8 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs
index bcc21337..dfd83037 100644
--- a/candle-examples/examples/moondream/main.rs
+++ b/candle-examples/examples/moondream/main.rs
@@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
+ let dtype = if device.is_cuda() && !args.quantized {
+ DType::F16
+ } else {
+ DType::F32
+ };
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
@@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> {
let model = quantized_moondream::Model::new(&config, vb)?;
Model::Quantized(model)
} else {
- let vb =
- unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let model = moondream::Model::new(&config, vb)?;
Model::Moondream(model)
};
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
- let image = load_image(args.image)?.to_device(&device)?;
+ let image = load_image(args.image)?
+ .to_device(&device)?
+ .to_dtype(dtype)?;
let image_embeds = image.unsqueeze(0)?;
let image_embeds = match model {
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index 65a1665a..de15c3a5 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -135,7 +135,9 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let on_true = Tensor::new(on_true, on_false.device())?
+ .to_dtype(on_false.dtype())?
+ .broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@@ -147,7 +149,7 @@ struct RotaryEmbedding {
}
impl RotaryEmbedding {
- fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
+ fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
@@ -159,8 +161,8 @@ impl RotaryEmbedding {
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
- sin: freqs.sin()?,
- cos: freqs.cos()?,
+ sin: freqs.sin()?.to_dtype(dtype)?,
+ cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@@ -274,7 +276,8 @@ impl MHA {
let op_size = cfg.n_embd;
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
- let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
+ let rotary_emb =
+ RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
wqkv,