summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-25 20:50:08 +0100
committerGitHub <noreply@github.com>2024-02-25 20:50:08 +0100
commit1a6043af5123bf9e189063d3baf110b39cf47617 (patch)
tree3400ac112e92d7d83a0b98a1c66ae046fbbf82df /candle-transformers
parent2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190 (diff)
downloadcandle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.gz
candle-1a6043af5123bf9e189063d3baf110b39cf47617.tar.bz2
candle-1a6043af5123bf9e189063d3baf110b39cf47617.zip
Tweak the VarMap set type. (#1758)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mamba.rs6
-rw-r--r--candle-transformers/src/models/rwkv_v5.rs12
2 files changed, 9 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs
index da254bd1..81828ad5 100644
--- a/candle-transformers/src/models/mamba.rs
+++ b/candle-transformers/src/models/mamba.rs
@@ -32,9 +32,9 @@ impl Config {
}
pub struct State {
- hs: Vec<Tensor>,
- prev_xs: Vec<[Tensor; D_CONV]>,
- pos: usize,
+ pub hs: Vec<Tensor>,
+ pub prev_xs: Vec<[Tensor; D_CONV]>,
+ pub pos: usize,
}
impl State {
diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs
index d11cdedd..38b1e450 100644
--- a/candle-transformers/src/models/rwkv_v5.rs
+++ b/candle-transformers/src/models/rwkv_v5.rs
@@ -22,15 +22,15 @@ pub struct Config {
pub rescale_every: usize,
}
-struct StatePerLayer {
- extract_key_value: Tensor,
- linear_attention: Tensor,
- feed_forward: Tensor,
+pub struct StatePerLayer {
+ pub extract_key_value: Tensor,
+ pub linear_attention: Tensor,
+ pub feed_forward: Tensor,
}
pub struct State {
- per_layer: Vec<StatePerLayer>,
- pos: usize,
+ pub per_layer: Vec<StatePerLayer>,
+ pub pos: usize,
}
impl State {