summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/llama.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-14 10:58:32 +0100
committerGitHub <noreply@github.com>2024-02-14 10:58:32 +0100
commit2d5f2a728d9ade10ce4b7b618ee4dba8075064dd (patch)
tree304d99d8330c116bea92c2997474311c199e579a /candle-transformers/src/models/llama.rs
parent68f76558956f7f56cb5014bb5f7c7c5534436b72 (diff)
downloadcandle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.gz
candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.bz2
candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.zip
Add the RWKV model (v5). (#1707)
* Start adding the RWKV model. * More of the forward step. * Handle rescaling. * FeedForward. * More work on RWKV. * Better state tracking. * Finish a first pass on forward. * Fix the shape mismatches. * Do not rescale in f32. * Rename to rwkv-v5. * Add the new models to the readme.
Diffstat (limited to 'candle-transformers/src/models/llama.rs')
-rw-r--r--candle-transformers/src/models/llama.rs3
1 files changed, 1 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 7a920cb8..f8126394 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,13 +1,12 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
-use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,