diff options
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/ddpg.rs')
-rw-r--r-- | candle-examples/examples/reinforcement-learning/ddpg.rs | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index c6d72fed..1ce4889e 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -8,6 +8,8 @@ use candle_nn::{ }; use rand::{distributions::Uniform, thread_rng, Rng}; +use super::gym_env::GymEnv; + pub struct OuNoise { mu: f64, theta: f64, @@ -449,3 +451,106 @@ impl DDPG<'_> { Ok(()) } } + +// The impact of the q value of the next state on the current state's q value. +const GAMMA: f64 = 0.99; +// The weight for updating the target networks. +const TAU: f64 = 0.005; +// The capacity of the replay buffer used for sampling training data. +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +// The training batch size for each training iteration. +const TRAINING_BATCH_SIZE: usize = 100; +// The total number of episodes. +const MAX_EPISODES: usize = 100; +// The maximum length of an episode. +const EPISODE_LENGTH: usize = 200; +// The number of training iterations after one episode finishes. +const TRAINING_ITERATIONS: usize = 200; + +// Ornstein-Uhlenbeck process parameters. +const MU: f64 = 0.0; +const THETA: f64 = 0.15; +const SIGMA: f64 = 0.1; + +const ACTOR_LEARNING_RATE: f64 = 1e-4; +const CRITIC_LEARNING_RATE: f64 = 1e-3; + +pub fn run() -> Result<()> { + let env = GymEnv::new("Pendulum-v1")?; + println!("action space: {}", env.action_space()); + println!("observation space: {:?}", env.observation_space()); + + let size_state = env.observation_space().iter().product::<usize>(); + let size_action = env.action_space(); + + let mut agent = DDPG::new( + &Device::Cpu, + size_state, + size_action, + true, + ACTOR_LEARNING_RATE, + CRITIC_LEARNING_RATE, + GAMMA, + TAU, + REPLAY_BUFFER_CAPACITY, + OuNoise::new(MU, THETA, SIGMA, size_action)?, + )?; + + let mut rng = rand::thread_rng(); + + for episode in 0..MAX_EPISODES { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::<u64>())?; + + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + agent.remember( + &state, + &Tensor::new(vec![action], &Device::Cpu)?, + &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, + &step.state, + step.terminated, + step.truncated, + ); + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + + println!("episode {episode} with total reward of {total_reward}"); + + for _ in 0..TRAINING_ITERATIONS { + agent.train(TRAINING_BATCH_SIZE)?; + } + } + + println!("Testing..."); + agent.train = false; + for episode in 0..10 { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::<u64>())?; + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + println!("episode {episode} with total reward of {total_reward}"); + } + Ok(()) +} |