summaryrefslogtreecommitdiff
path: root/candle-nn/tests
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-31 09:43:10 +0200
committerGitHub <noreply@github.com>2023-08-31 08:43:10 +0100
commitdb598160876a841fcf372ae923e4c5328a47f653 (patch)
tree8d927dad45d515859d8e5b8a5f4c5822fafcc636 /candle-nn/tests
parentd210c71d77a6044c2a42c2e75487b6180e957158 (diff)
downloadcandle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz
candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2
candle-db598160876a841fcf372ae923e4c5328a47f653.zip
Add a GRU layer. (#688)
* Add a GRU layer. * Fix the n gate computation.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r--candle-nn/tests/rnn.rs44
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs
index eda1a381..498c9188 100644
--- a/candle-nn/tests/rnn.rs
+++ b/candle-nn/tests/rnn.rs
@@ -55,3 +55,47 @@ fn lstm() -> Result<()> {
assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]);
Ok(())
}
+
+/* The following test can be verified against PyTorch using the following snippet.
+import torch
+from torch import nn
+gru = nn.GRU(2, 3, 1)
+gru.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 18.).reshape(9, 2).cos())
+gru.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 27.).reshape(9, 3).sin())
+gru.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]))
+gru.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]).cos())
+state = torch.zeros((1, 3))
+for inp in [3., 1., 4., 1., 5., 9., 2.]:
+ inp = torch.tensor([[inp, inp * 0.5]])
+ _out, state = gru(inp, state)
+print(state)
+# tensor([[ 0.0579, 0.8836, -0.9991]], grad_fn=<SqueezeBackward1>)
+*/
+#[test]
+fn gru() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let w_ih = Tensor::arange(0f32, 18f32, cpu)?.reshape((9, 2))?;
+ let w_ih = w_ih.cos()?;
+ let w_hh = Tensor::arange(0f32, 27f32, cpu)?.reshape((9, 3))?;
+ let w_hh = w_hh.sin()?;
+ let b_ih = Tensor::new(&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1.], cpu)?;
+ let b_hh = b_ih.cos()?;
+ let tensors: std::collections::HashMap<_, _> = [
+ ("weight_ih_l0".to_string(), w_ih),
+ ("weight_hh_l0".to_string(), w_hh),
+ ("bias_ih_l0".to_string(), b_ih),
+ ("bias_hh_l0".to_string(), b_hh),
+ ]
+ .into_iter()
+ .collect();
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu);
+ let gru = candle_nn::gru(2, 3, Default::default(), vb)?;
+ let mut state = gru.zero_state(1)?;
+ for inp in [3f32, 1., 4., 1., 5., 9., 2.] {
+ let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
+ state = gru.step(&inp, &state)?
+ }
+ let h = state.h();
+ assert_eq!(to_vec2_round(h, 4)?, &[[0.0579, 0.8836, -0.9991]]);
+ Ok(())
+}