diff options
Diffstat (limited to 'candle-transformers/src/models')
-rw-r--r-- | candle-transformers/src/models/mamba.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/rwkv_v5.rs | 12 |
2 files changed, 9 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index da254bd1..81828ad5 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -32,9 +32,9 @@ impl Config { } pub struct State { - hs: Vec<Tensor>, - prev_xs: Vec<[Tensor; D_CONV]>, - pos: usize, + pub hs: Vec<Tensor>, + pub prev_xs: Vec<[Tensor; D_CONV]>, + pub pos: usize, } impl State { diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index d11cdedd..38b1e450 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -22,15 +22,15 @@ pub struct Config { pub rescale_every: usize, } -struct StatePerLayer { - extract_key_value: Tensor, - linear_attention: Tensor, - feed_forward: Tensor, +pub struct StatePerLayer { + pub extract_key_value: Tensor, + pub linear_attention: Tensor, + pub feed_forward: Tensor, } pub struct State { - per_layer: Vec<StatePerLayer>, - pos: usize, + pub per_layer: Vec<StatePerLayer>, + pub pos: usize, } impl State { |