summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models')
-rw-r--r--candle-transformers/src/models/mamba.rs6
-rw-r--r--candle-transformers/src/models/rwkv_v5.rs12
2 files changed, 9 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs
index da254bd1..81828ad5 100644
--- a/candle-transformers/src/models/mamba.rs
+++ b/candle-transformers/src/models/mamba.rs
@@ -32,9 +32,9 @@ impl Config {
}
pub struct State {
- hs: Vec<Tensor>,
- prev_xs: Vec<[Tensor; D_CONV]>,
- pos: usize,
+ pub hs: Vec<Tensor>,
+ pub prev_xs: Vec<[Tensor; D_CONV]>,
+ pub pos: usize,
}
impl State {
diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs
index d11cdedd..38b1e450 100644
--- a/candle-transformers/src/models/rwkv_v5.rs
+++ b/candle-transformers/src/models/rwkv_v5.rs
@@ -22,15 +22,15 @@ pub struct Config {
pub rescale_every: usize,
}
-struct StatePerLayer {
- extract_key_value: Tensor,
- linear_attention: Tensor,
- feed_forward: Tensor,
+pub struct StatePerLayer {
+ pub extract_key_value: Tensor,
+ pub linear_attention: Tensor,
+ pub feed_forward: Tensor,
}
pub struct State {
- per_layer: Vec<StatePerLayer>,
- pos: usize,
+ pub per_layer: Vec<StatePerLayer>,
+ pub pos: usize,
}
impl State {