diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-25 20:50:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 20:50:08 +0100 |
commit | 1a6043af5123bf9e189063d3baf110b39cf47617 (patch) | |
tree | 3400ac112e92d7d83a0b98a1c66ae046fbbf82df /candle-transformers | |
parent | 2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190 (diff) | |
download | candle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.gz candle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.bz2 candle-1a6043af5123bf9e189063d3baf110b39cf47617.zip |
Tweak the VarMap set type. (#1758)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mamba.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/rwkv_v5.rs | 12 |
2 files changed, 9 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index da254bd1..81828ad5 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -32,9 +32,9 @@ impl Config { } pub struct State { - hs: Vec<Tensor>, - prev_xs: Vec<[Tensor; D_CONV]>, - pos: usize, + pub hs: Vec<Tensor>, + pub prev_xs: Vec<[Tensor; D_CONV]>, + pub pos: usize, } impl State { diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index d11cdedd..38b1e450 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -22,15 +22,15 @@ pub struct Config { pub rescale_every: usize, } -struct StatePerLayer { - extract_key_value: Tensor, - linear_attention: Tensor, - feed_forward: Tensor, +pub struct StatePerLayer { + pub extract_key_value: Tensor, + pub linear_attention: Tensor, + pub feed_forward: Tensor, } pub struct State { - per_layer: Vec<StatePerLayer>, - pos: usize, + pub per_layer: Vec<StatePerLayer>, + pub pos: usize, } impl State { |