summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/main.rs')
-rw-r--r--candle-examples/examples/reinforcement-learning/main.rs137
1 files changed, 15 insertions, 122 deletions
diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs
index 96d7102d..e87afae2 100644
--- a/candle-examples/examples/reinforcement-learning/main.rs
+++ b/candle-examples/examples/reinforcement-learning/main.rs
@@ -6,139 +6,32 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::Result;
+use clap::{Parser, Subcommand};
+
mod gym_env;
mod vec_gym_env;
mod ddpg;
+mod policy_gradient;
-use candle::{Device, Result, Tensor};
-use clap::Parser;
-use rand::Rng;
-
-// The impact of the q value of the next state on the current state's q value.
-const GAMMA: f64 = 0.99;
-// The weight for updating the target networks.
-const TAU: f64 = 0.005;
-// The capacity of the replay buffer used for sampling training data.
-const REPLAY_BUFFER_CAPACITY: usize = 100_000;
-// The training batch size for each training iteration.
-const TRAINING_BATCH_SIZE: usize = 100;
-// The total number of episodes.
-const MAX_EPISODES: usize = 100;
-// The maximum length of an episode.
-const EPISODE_LENGTH: usize = 200;
-// The number of training iterations after one episode finishes.
-const TRAINING_ITERATIONS: usize = 200;
-
-// Ornstein-Uhlenbeck process parameters.
-const MU: f64 = 0.0;
-const THETA: f64 = 0.15;
-const SIGMA: f64 = 0.1;
-
-const ACTOR_LEARNING_RATE: f64 = 1e-4;
-const CRITIC_LEARNING_RATE: f64 = 1e-3;
-
-#[derive(Parser, Debug, Clone)]
-#[command(author, version, about, long_about = None)]
+#[derive(Parser)]
struct Args {
- /// Run on CPU rather than on GPU.
- #[arg(long)]
- cpu: bool,
+ #[command(subcommand)]
+ command: Command,
+}
- /// Enable tracing (generates a trace-timestamp.json file).
- #[arg(long)]
- tracing: bool,
+#[derive(Subcommand)]
+enum Command {
+ Pg,
+ Ddpg,
}
fn main() -> Result<()> {
- use tracing_chrome::ChromeLayerBuilder;
- use tracing_subscriber::prelude::*;
-
let args = Args::parse();
-
- let _guard = if args.tracing {
- let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
- tracing_subscriber::registry().with(chrome_layer).init();
- Some(guard)
- } else {
- None
- };
-
- let env = gym_env::GymEnv::new("Pendulum-v1")?;
- println!("action space: {}", env.action_space());
- println!("observation space: {:?}", env.observation_space());
-
- let size_state = env.observation_space().iter().product::<usize>();
- let size_action = env.action_space();
-
- let mut agent = ddpg::DDPG::new(
- &Device::Cpu,
- size_state,
- size_action,
- true,
- ACTOR_LEARNING_RATE,
- CRITIC_LEARNING_RATE,
- GAMMA,
- TAU,
- REPLAY_BUFFER_CAPACITY,
- ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
- )?;
-
- let mut rng = rand::thread_rng();
-
- for episode in 0..MAX_EPISODES {
- // let mut state = env.reset(episode as u64)?;
- let mut state = env.reset(rng.gen::<u64>())?;
-
- let mut total_reward = 0.0;
- for _ in 0..EPISODE_LENGTH {
- let mut action = 2.0 * agent.actions(&state)?;
- action = action.clamp(-2.0, 2.0);
-
- let step = env.step(vec![action])?;
- total_reward += step.reward;
-
- agent.remember(
- &state,
- &Tensor::new(vec![action], &Device::Cpu)?,
- &Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
- &step.state,
- step.terminated,
- step.truncated,
- );
-
- if step.terminated || step.truncated {
- break;
- }
- state = step.state;
- }
-
- println!("episode {episode} with total reward of {total_reward}");
-
- for _ in 0..TRAINING_ITERATIONS {
- agent.train(TRAINING_BATCH_SIZE)?;
- }
- }
-
- println!("Testing...");
- agent.train = false;
- for episode in 0..10 {
- // let mut state = env.reset(episode as u64)?;
- let mut state = env.reset(rng.gen::<u64>())?;
- let mut total_reward = 0.0;
- for _ in 0..EPISODE_LENGTH {
- let mut action = 2.0 * agent.actions(&state)?;
- action = action.clamp(-2.0, 2.0);
-
- let step = env.step(vec![action])?;
- total_reward += step.reward;
-
- if step.terminated || step.truncated {
- break;
- }
- state = step.state;
- }
- println!("episode {episode} with total reward of {total_reward}");
+ match args.command {
+ Command::Pg => policy_gradient::run()?,
+ Command::Ddpg => ddpg::run()?,
}
Ok(())
}