summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-07-26 15:32:26 -0400
committerGitHub <noreply@github.com>2024-07-26 21:32:26 +0200
commit0f5cbb08b36a2d962470ec590a2d2bd9770bd12d (patch)
treea5d6911051646e96fc833664b44c530c76fe4416 /candle-nn
parentddafc61055601002622778b7762c15bd60057c1f (diff)
downloadcandle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.gz
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.bz2
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.zip
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/activation.rs6
-rw-r--r--candle-nn/src/var_builder.rs1
2 files changed, 4 insertions, 3 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index b9745375..fc1819f5 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -93,9 +93,9 @@ impl candle::Module for PReLU {
/// # Arguments
///
/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and
-/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
-/// function, the input tensor shape `s` should either be one dimension with this number of
-/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
+/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
+/// function, the input tensor shape `s` should either be one dimension with this number of
+/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> {
let init_ws = crate::init::Init::Const(0.25);
// When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1.
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index d6f6214f..f6e6160b 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -264,6 +264,7 @@ impl SimpleBackend for VarMap {
}
}
+#[allow(dead_code)]
pub struct SafeTensorWithRouting<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,