diff options
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/policy_gradient.rs')
-rw-r--r-- | candle-examples/examples/reinforcement-learning/policy_gradient.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs index 044cbfcd..6c355fe6 100644 --- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -74,7 +74,7 @@ pub fn run() -> Result<()> { loop { let action = { let action_probs: Vec<f32> = - softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)? + softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)? .squeeze(0)? .to_vec1()?; weighted_sample(action_probs, &mut rng)? as i64 @@ -109,7 +109,7 @@ pub fn run() -> Result<()> { let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? .to_dtype(DType::F32)? - .detach()?; + .detach(); let actions_mask = { let actions: Vec<i64> = steps.iter().map(|s| s.action).collect(); @@ -126,12 +126,12 @@ pub fn run() -> Result<()> { .unwrap() }) .collect(); - Tensor::stack(&actions_mask, 0)?.detach()? + Tensor::stack(&actions_mask, 0)?.detach() }; let states = { let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect(); - Tensor::stack(&states, 0)?.detach()? + Tensor::stack(&states, 0)?.detach() }; let log_probs = actions_mask |