diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-14 10:58:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-14 10:58:32 +0100 |
commit | 2d5f2a728d9ade10ce4b7b618ee4dba8075064dd (patch) | |
tree | 304d99d8330c116bea92c2997474311c199e579a /candle-transformers/src/models/llama.rs | |
parent | 68f76558956f7f56cb5014bb5f7c7c5534436b72 (diff) | |
download | candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.gz candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.bz2 candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.zip |
Add the RWKV model (v5). (#1707)
* Start adding the RWKV model.
* More of the forward step.
* Handle rescaling.
* FeedForward.
* More work on RWKV.
* Better state tracking.
* Finish a first pass on forward.
* Fix the shape mismatches.
* Do not rescale in f32.
* Rename to rwkv-v5.
* Add the new models to the readme.
Diffstat (limited to 'candle-transformers/src/models/llama.rs')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7a920cb8..f8126394 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,13 +1,12 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; -use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, pub intermediate_size: usize, |