summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning/ddpg.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/ddpg.rs')
-rw-r--r--candle-examples/examples/reinforcement-learning/ddpg.rs105
1 files changed, 105 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs
index c6d72fed..1ce4889e 100644
--- a/candle-examples/examples/reinforcement-learning/ddpg.rs
+++ b/candle-examples/examples/reinforcement-learning/ddpg.rs
@@ -8,6 +8,8 @@ use candle_nn::{
};
use rand::{distributions::Uniform, thread_rng, Rng};
+use super::gym_env::GymEnv;
+
pub struct OuNoise {
mu: f64,
theta: f64,
@@ -449,3 +451,106 @@ impl DDPG<'_> {
Ok(())
}
}
+
+// The impact of the q value of the next state on the current state's q value.
+const GAMMA: f64 = 0.99;
+// The weight for updating the target networks.
+const TAU: f64 = 0.005;
+// The capacity of the replay buffer used for sampling training data.
+const REPLAY_BUFFER_CAPACITY: usize = 100_000;
+// The training batch size for each training iteration.
+const TRAINING_BATCH_SIZE: usize = 100;
+// The total number of episodes.
+const MAX_EPISODES: usize = 100;
+// The maximum length of an episode.
+const EPISODE_LENGTH: usize = 200;
+// The number of training iterations after one episode finishes.
+const TRAINING_ITERATIONS: usize = 200;
+
+// Ornstein-Uhlenbeck process parameters.
+const MU: f64 = 0.0;
+const THETA: f64 = 0.15;
+const SIGMA: f64 = 0.1;
+
+const ACTOR_LEARNING_RATE: f64 = 1e-4;
+const CRITIC_LEARNING_RATE: f64 = 1e-3;
+
+pub fn run() -> Result<()> {
+ let env = GymEnv::new("Pendulum-v1")?;
+ println!("action space: {}", env.action_space());
+ println!("observation space: {:?}", env.observation_space());
+
+ let size_state = env.observation_space().iter().product::<usize>();
+ let size_action = env.action_space();
+
+ let mut agent = DDPG::new(
+ &Device::Cpu,
+ size_state,
+ size_action,
+ true,
+ ACTOR_LEARNING_RATE,
+ CRITIC_LEARNING_RATE,
+ GAMMA,
+ TAU,
+ REPLAY_BUFFER_CAPACITY,
+ OuNoise::new(MU, THETA, SIGMA, size_action)?,
+ )?;
+
+ let mut rng = rand::thread_rng();
+
+ for episode in 0..MAX_EPISODES {
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
+
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
+
+ let step = env.step(vec![action])?;
+ total_reward += step.reward;
+
+ agent.remember(
+ &state,
+ &Tensor::new(vec![action], &Device::Cpu)?,
+ &Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
+ &step.state,
+ step.terminated,
+ step.truncated,
+ );
+
+ if step.terminated || step.truncated {
+ break;
+ }
+ state = step.state;
+ }
+
+ println!("episode {episode} with total reward of {total_reward}");
+
+ for _ in 0..TRAINING_ITERATIONS {
+ agent.train(TRAINING_BATCH_SIZE)?;
+ }
+ }
+
+ println!("Testing...");
+ agent.train = false;
+ for episode in 0..10 {
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
+
+ let step = env.step(vec![action])?;
+ total_reward += step.reward;
+
+ if step.terminated || step.truncated {
+ break;
+ }
+ state = step.state;
+ }
+ println!("episode {episode} with total reward of {total_reward}");
+ }
+ Ok(())
+}