summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/llava/config.rs
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-07-26 15:32:26 -0400
committerGitHub <noreply@github.com>2024-07-26 21:32:26 +0200
commit0f5cbb08b36a2d962470ec590a2d2bd9770bd12d (patch)
treea5d6911051646e96fc833664b44c530c76fe4416 /candle-transformers/src/models/llava/config.rs
parentddafc61055601002622778b7762c15bd60057c1f (diff)
downloadcandle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.gz
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.bz2
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.zip
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
Diffstat (limited to 'candle-transformers/src/models/llava/config.rs')
-rw-r--r--candle-transformers/src/models/llava/config.rs6
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs
index d2d47003..5dca6870 100644
--- a/candle-transformers/src/models/llava/config.rs
+++ b/candle-transformers/src/models/llava/config.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
- llama::Config,
+ llama::{Config, LlamaEosToks},
};
use serde::{Deserialize, Serialize};
@@ -73,8 +73,10 @@ impl LLaVAConfig {
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
- eos_token_id: Some(self.eos_token_id as u32),
+ eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
use_flash_attn: false,
+ rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
+ max_position_embeddings: self.max_position_embeddings,
}
}
}