summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-13 14:26:32 +0100
committerGitHub <noreply@github.com>2024-02-13 14:26:32 +0100
commitad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (patch)
tree5b5ea591d0fda870f4499869e3a8feb9718cfebf /candle-examples/examples/reinforcement-learning
parent13c67226e68de216b731707067f7e68af0438821 (diff)
downloadcandle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.gz
candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.bz2
candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.zip
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
-rw-r--r--candle-examples/examples/reinforcement-learning/ddpg.rs2
-rw-r--r--candle-examples/examples/reinforcement-learning/policy_gradient.rs8
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs
index 1ce4889e..5309eaf6 100644
--- a/candle-examples/examples/reinforcement-learning/ddpg.rs
+++ b/candle-examples/examples/reinforcement-learning/ddpg.rs
@@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
- .forward(&state.detach()?.unsqueeze(0)?)?
+ .forward(&state.detach().unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
index 044cbfcd..6c355fe6 100644
--- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs
+++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
@@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop {
let action = {
let action_probs: Vec<f32> =
- softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
+ softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
@@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
- .detach()?;
+ .detach();
let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap()
})
.collect();
- Tensor::stack(&actions_mask, 0)?.detach()?
+ Tensor::stack(&actions_mask, 0)?.detach()
};
let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
- Tensor::stack(&states, 0)?.detach()?
+ Tensor::stack(&states, 0)?.detach()
};
let log_probs = actions_mask