diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-13 14:26:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 14:26:32 +0100 |
commit | ad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (patch) | |
tree | 5b5ea591d0fda870f4499869e3a8feb9718cfebf /candle-examples/examples/reinforcement-learning | |
parent | 13c67226e68de216b731707067f7e68af0438821 (diff) | |
download | candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.gz candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.bz2 candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.zip |
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.
* Fix pyo3 bindings.
* Black tweak.
* Formatting.
* Also update the pyo3-onnx formatting.
* Apply black.
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
-rw-r--r-- | candle-examples/examples/reinforcement-learning/ddpg.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/reinforcement-learning/policy_gradient.rs | 8 |
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index 1ce4889e..5309eaf6 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -411,7 +411,7 @@ impl DDPG<'_> { pub fn actions(&mut self, state: &Tensor) -> Result<f32> { let actions = self .actor - .forward(&state.detach()?.unsqueeze(0)?)? + .forward(&state.detach().unsqueeze(0)?)? .squeeze(0)?; let actions = if self.train { (actions + self.ou_noise.sample()?)? diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs index 044cbfcd..6c355fe6 100644 --- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -74,7 +74,7 @@ pub fn run() -> Result<()> { loop { let action = { let action_probs: Vec<f32> = - softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)? + softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)? .squeeze(0)? .to_vec1()?; weighted_sample(action_probs, &mut rng)? as i64 @@ -109,7 +109,7 @@ pub fn run() -> Result<()> { let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? .to_dtype(DType::F32)? - .detach()?; + .detach(); let actions_mask = { let actions: Vec<i64> = steps.iter().map(|s| s.action).collect(); @@ -126,12 +126,12 @@ pub fn run() -> Result<()> { .unwrap() }) .collect(); - Tensor::stack(&actions_mask, 0)?.detach()? + Tensor::stack(&actions_mask, 0)?.detach() }; let states = { let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect(); - Tensor::stack(&states, 0)?.detach()? + Tensor::stack(&states, 0)?.detach() }; let log_probs = actions_mask |