diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-31 09:43:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 08:43:10 +0100 |
commit | db598160876a841fcf372ae923e4c5328a47f653 (patch) | |
tree | 8d927dad45d515859d8e5b8a5f4c5822fafcc636 /candle-nn/tests | |
parent | d210c71d77a6044c2a42c2e75487b6180e957158 (diff) | |
download | candle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2 candle-db598160876a841fcf372ae923e4c5328a47f653.zip |
Add a GRU layer. (#688)
* Add a GRU layer.
* Fix the n gate computation.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r-- | candle-nn/tests/rnn.rs | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs index eda1a381..498c9188 100644 --- a/candle-nn/tests/rnn.rs +++ b/candle-nn/tests/rnn.rs @@ -55,3 +55,47 @@ fn lstm() -> Result<()> { assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]); Ok(()) } + +/* The following test can be verified against PyTorch using the following snippet. +import torch +from torch import nn +gru = nn.GRU(2, 3, 1) +gru.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 18.).reshape(9, 2).cos()) +gru.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 27.).reshape(9, 3).sin()) +gru.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1])) +gru.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]).cos()) +state = torch.zeros((1, 3)) +for inp in [3., 1., 4., 1., 5., 9., 2.]: + inp = torch.tensor([[inp, inp * 0.5]]) + _out, state = gru(inp, state) +print(state) +# tensor([[ 0.0579, 0.8836, -0.9991]], grad_fn=<SqueezeBackward1>) +*/ +#[test] +fn gru() -> Result<()> { + let cpu = &Device::Cpu; + let w_ih = Tensor::arange(0f32, 18f32, cpu)?.reshape((9, 2))?; + let w_ih = w_ih.cos()?; + let w_hh = Tensor::arange(0f32, 27f32, cpu)?.reshape((9, 3))?; + let w_hh = w_hh.sin()?; + let b_ih = Tensor::new(&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1.], cpu)?; + let b_hh = b_ih.cos()?; + let tensors: std::collections::HashMap<_, _> = [ + ("weight_ih_l0".to_string(), w_ih), + ("weight_hh_l0".to_string(), w_hh), + ("bias_ih_l0".to_string(), b_ih), + ("bias_hh_l0".to_string(), b_hh), + ] + .into_iter() + .collect(); + let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu); + let gru = candle_nn::gru(2, 3, Default::default(), vb)?; + let mut state = gru.zero_state(1)?; + for inp in [3f32, 1., 4., 1., 5., 9., 2.] { + let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?; + state = gru.step(&inp, &state)? + } + let h = state.h(); + assert_eq!(to_vec2_round(h, 4)?, &[[0.0579, 0.8836, -0.9991]]); + Ok(()) +} |