summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning
diff options
context:
space:
mode:
authors-casci <simone.cascino546@gmail.com>2023-12-30 09:01:29 +0100
committerGitHub <noreply@github.com>2023-12-30 09:01:29 +0100
commit51e577a682ab9497d6022b4080f3b54bbbd75f1b (patch)
treecbf8f7f7a4f79665d5ff5ff372f4bd4fe27cbe5d /candle-examples/examples/reinforcement-learning
parent0a245e6fa46c16f332555a58271dbd49a8058a9c (diff)
downloadcandle-51e577a682ab9497d6022b4080f3b54bbbd75f1b.tar.gz
candle-51e577a682ab9497d6022b4080f3b54bbbd75f1b.tar.bz2
candle-51e577a682ab9497d6022b4080f3b54bbbd75f1b.zip
Add Policy Gradient to Reinforcement Learning examples (#1500)
* added policy_gradient, modified main, ddpg and README * fixed typo in README * removed unnecessary imports * small refactor * Use clap for picking up the subcommand to run. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
-rw-r--r--candle-examples/examples/reinforcement-learning/README.md11
-rw-r--r--candle-examples/examples/reinforcement-learning/ddpg.rs105
-rw-r--r--candle-examples/examples/reinforcement-learning/main.rs137
-rw-r--r--candle-examples/examples/reinforcement-learning/policy_gradient.rs146
4 files changed, 275 insertions, 124 deletions
diff --git a/candle-examples/examples/reinforcement-learning/README.md b/candle-examples/examples/reinforcement-learning/README.md
index 2d3d14b0..28819067 100644
--- a/candle-examples/examples/reinforcement-learning/README.md
+++ b/candle-examples/examples/reinforcement-learning/README.md
@@ -8,9 +8,16 @@ Python package with:
pip install "gymnasium[accept-rom-license]"
```
-In order to run the example, use the following command. Note the additional
+In order to run the examples, use the following commands. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate.
+
+For the Policy Gradient example:
+```bash
+cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
+```
+
+For the Deep Deterministic Policy Gradient example:
```bash
-cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
+cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
```
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs
index c6d72fed..1ce4889e 100644
--- a/candle-examples/examples/reinforcement-learning/ddpg.rs
+++ b/candle-examples/examples/reinforcement-learning/ddpg.rs
@@ -8,6 +8,8 @@ use candle_nn::{
};
use rand::{distributions::Uniform, thread_rng, Rng};
+use super::gym_env::GymEnv;
+
pub struct OuNoise {
mu: f64,
theta: f64,
@@ -449,3 +451,106 @@ impl DDPG<'_> {
Ok(())
}
}
+
+// The impact of the q value of the next state on the current state's q value.
+const GAMMA: f64 = 0.99;
+// The weight for updating the target networks.
+const TAU: f64 = 0.005;
+// The capacity of the replay buffer used for sampling training data.
+const REPLAY_BUFFER_CAPACITY: usize = 100_000;
+// The training batch size for each training iteration.
+const TRAINING_BATCH_SIZE: usize = 100;
+// The total number of episodes.
+const MAX_EPISODES: usize = 100;
+// The maximum length of an episode.
+const EPISODE_LENGTH: usize = 200;
+// The number of training iterations after one episode finishes.
+const TRAINING_ITERATIONS: usize = 200;
+
+// Ornstein-Uhlenbeck process parameters.
+const MU: f64 = 0.0;
+const THETA: f64 = 0.15;
+const SIGMA: f64 = 0.1;
+
+const ACTOR_LEARNING_RATE: f64 = 1e-4;
+const CRITIC_LEARNING_RATE: f64 = 1e-3;
+
+pub fn run() -> Result<()> {
+ let env = GymEnv::new("Pendulum-v1")?;
+ println!("action space: {}", env.action_space());
+ println!("observation space: {:?}", env.observation_space());
+
+ let size_state = env.observation_space().iter().product::<usize>();
+ let size_action = env.action_space();
+
+ let mut agent = DDPG::new(
+ &Device::Cpu,
+ size_state,
+ size_action,
+ true,
+ ACTOR_LEARNING_RATE,
+ CRITIC_LEARNING_RATE,
+ GAMMA,
+ TAU,
+ REPLAY_BUFFER_CAPACITY,
+ OuNoise::new(MU, THETA, SIGMA, size_action)?,
+ )?;
+
+ let mut rng = rand::thread_rng();
+
+ for episode in 0..MAX_EPISODES {
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
+
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
+
+ let step = env.step(vec![action])?;
+ total_reward += step.reward;
+
+ agent.remember(
+ &state,
+ &Tensor::new(vec![action], &Device::Cpu)?,
+ &Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
+ &step.state,
+ step.terminated,
+ step.truncated,
+ );
+
+ if step.terminated || step.truncated {
+ break;
+ }
+ state = step.state;
+ }
+
+ println!("episode {episode} with total reward of {total_reward}");
+
+ for _ in 0..TRAINING_ITERATIONS {
+ agent.train(TRAINING_BATCH_SIZE)?;
+ }
+ }
+
+ println!("Testing...");
+ agent.train = false;
+ for episode in 0..10 {
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
+
+ let step = env.step(vec![action])?;
+ total_reward += step.reward;
+
+ if step.terminated || step.truncated {
+ break;
+ }
+ state = step.state;
+ }
+ println!("episode {episode} with total reward of {total_reward}");
+ }
+ Ok(())
+}
diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs
index 96d7102d..e87afae2 100644
--- a/candle-examples/examples/reinforcement-learning/main.rs
+++ b/candle-examples/examples/reinforcement-learning/main.rs
@@ -6,139 +6,32 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::Result;
+use clap::{Parser, Subcommand};
+
mod gym_env;
mod vec_gym_env;
mod ddpg;
+mod policy_gradient;
-use candle::{Device, Result, Tensor};
-use clap::Parser;
-use rand::Rng;
-
-// The impact of the q value of the next state on the current state's q value.
-const GAMMA: f64 = 0.99;
-// The weight for updating the target networks.
-const TAU: f64 = 0.005;
-// The capacity of the replay buffer used for sampling training data.
-const REPLAY_BUFFER_CAPACITY: usize = 100_000;
-// The training batch size for each training iteration.
-const TRAINING_BATCH_SIZE: usize = 100;
-// The total number of episodes.
-const MAX_EPISODES: usize = 100;
-// The maximum length of an episode.
-const EPISODE_LENGTH: usize = 200;
-// The number of training iterations after one episode finishes.
-const TRAINING_ITERATIONS: usize = 200;
-
-// Ornstein-Uhlenbeck process parameters.
-const MU: f64 = 0.0;
-const THETA: f64 = 0.15;
-const SIGMA: f64 = 0.1;
-
-const ACTOR_LEARNING_RATE: f64 = 1e-4;
-const CRITIC_LEARNING_RATE: f64 = 1e-3;
-
-#[derive(Parser, Debug, Clone)]
-#[command(author, version, about, long_about = None)]
+#[derive(Parser)]
struct Args {
- /// Run on CPU rather than on GPU.
- #[arg(long)]
- cpu: bool,
+ #[command(subcommand)]
+ command: Command,
+}
- /// Enable tracing (generates a trace-timestamp.json file).
- #[arg(long)]
- tracing: bool,
+#[derive(Subcommand)]
+enum Command {
+ Pg,
+ Ddpg,
}
fn main() -> Result<()> {
- use tracing_chrome::ChromeLayerBuilder;
- use tracing_subscriber::prelude::*;
-
let args = Args::parse();
-
- let _guard = if args.tracing {
- let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
- tracing_subscriber::registry().with(chrome_layer).init();
- Some(guard)
- } else {
- None
- };
-
- let env = gym_env::GymEnv::new("Pendulum-v1")?;
- println!("action space: {}", env.action_space());
- println!("observation space: {:?}", env.observation_space());
-
- let size_state = env.observation_space().iter().product::<usize>();
- let size_action = env.action_space();
-
- let mut agent = ddpg::DDPG::new(
- &Device::Cpu,
- size_state,
- size_action,
- true,
- ACTOR_LEARNING_RATE,
- CRITIC_LEARNING_RATE,
- GAMMA,
- TAU,
- REPLAY_BUFFER_CAPACITY,
- ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
- )?;
-
- let mut rng = rand::thread_rng();
-
- for episode in 0..MAX_EPISODES {
- // let mut state = env.reset(episode as u64)?;
- let mut state = env.reset(rng.gen::<u64>())?;
-
- let mut total_reward = 0.0;
- for _ in 0..EPISODE_LENGTH {
- let mut action = 2.0 * agent.actions(&state)?;
- action = action.clamp(-2.0, 2.0);
-
- let step = env.step(vec![action])?;
- total_reward += step.reward;
-
- agent.remember(
- &state,
- &Tensor::new(vec![action], &Device::Cpu)?,
- &Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
- &step.state,
- step.terminated,
- step.truncated,
- );
-
- if step.terminated || step.truncated {
- break;
- }
- state = step.state;
- }
-
- println!("episode {episode} with total reward of {total_reward}");
-
- for _ in 0..TRAINING_ITERATIONS {
- agent.train(TRAINING_BATCH_SIZE)?;
- }
- }
-
- println!("Testing...");
- agent.train = false;
- for episode in 0..10 {
- // let mut state = env.reset(episode as u64)?;
- let mut state = env.reset(rng.gen::<u64>())?;
- let mut total_reward = 0.0;
- for _ in 0..EPISODE_LENGTH {
- let mut action = 2.0 * agent.actions(&state)?;
- action = action.clamp(-2.0, 2.0);
-
- let step = env.step(vec![action])?;
- total_reward += step.reward;
-
- if step.terminated || step.truncated {
- break;
- }
- state = step.state;
- }
- println!("episode {episode} with total reward of {total_reward}");
+ match args.command {
+ Command::Pg => policy_gradient::run()?,
+ Command::Ddpg => ddpg::run()?,
}
Ok(())
}
diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
new file mode 100644
index 00000000..044cbfcd
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
@@ -0,0 +1,146 @@
+use super::gym_env::{GymEnv, Step};
+use candle::{DType, Device, Error, Module, Result, Tensor};
+use candle_nn::{
+ linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
+ ParamsAdamW, VarBuilder, VarMap,
+};
+use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
+
+fn new_model(
+ input_shape: &[usize],
+ num_actions: usize,
+ dtype: DType,
+ device: &Device,
+) -> Result<(impl Module, VarMap)> {
+ let input_size = input_shape.iter().product();
+
+ let mut varmap = VarMap::new();
+ let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
+
+ let model = seq()
+ .add(linear(input_size, 32, var_builder.pp("lin1"))?)
+ .add(Activation::Relu)
+ .add(linear(32, num_actions, var_builder.pp("lin2"))?);
+
+ Ok((model, varmap))
+}
+
+fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
+ let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
+ let mut acc_reward = 0f64;
+ for (i, reward) in rewards.iter_mut().enumerate().rev() {
+ if steps[i].terminated {
+ acc_reward = 0.0;
+ }
+ acc_reward += *reward;
+ *reward = acc_reward;
+ }
+ rewards
+}
+
+fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
+ let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
+ let mut rng = rng;
+ Ok(distribution.sample(&mut rng))
+}
+
+pub fn run() -> Result<()> {
+ let env = GymEnv::new("CartPole-v1")?;
+
+ println!("action space: {:?}", env.action_space());
+ println!("observation space: {:?}", env.observation_space());
+
+ let (model, varmap) = new_model(
+ env.observation_space(),
+ env.action_space(),
+ DType::F32,
+ &Device::Cpu,
+ )?;
+
+ let optimizer_params = ParamsAdamW {
+ lr: 0.01,
+ weight_decay: 0.01,
+ ..Default::default()
+ };
+
+ let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
+
+ let mut rng = rand::thread_rng();
+
+ for epoch_idx in 0..100 {
+ let mut state = env.reset(rng.gen::<u64>())?;
+ let mut steps: Vec<Step<i64>> = vec![];
+
+ loop {
+ let action = {
+ let action_probs: Vec<f32> =
+ softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
+ .squeeze(0)?
+ .to_vec1()?;
+ weighted_sample(action_probs, &mut rng)? as i64
+ };
+
+ let step = env.step(action)?;
+ steps.push(step.copy_with_obs(&state));
+
+ if step.terminated || step.truncated {
+ state = env.reset(rng.gen::<u64>())?;
+ if steps.len() > 5000 {
+ break;
+ }
+ } else {
+ state = step.state;
+ }
+ }
+
+ let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
+ let episodes: i64 = steps
+ .iter()
+ .map(|s| (s.terminated || s.truncated) as i64)
+ .sum();
+ println!(
+ "epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
+ epoch_idx,
+ episodes,
+ total_reward / episodes as f64
+ );
+
+ let batch_size = steps.len();
+
+ let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
+ .to_dtype(DType::F32)?
+ .detach()?;
+
+ let actions_mask = {
+ let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
+ let actions_mask: Vec<Tensor> = actions
+ .iter()
+ .map(|&action| {
+ // One-hot encoding
+ let mut action_mask = vec![0.0; env.action_space()];
+ action_mask[action as usize] = 1.0;
+
+ Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
+ .unwrap()
+ .to_dtype(DType::F32)
+ .unwrap()
+ })
+ .collect();
+ Tensor::stack(&actions_mask, 0)?.detach()?
+ };
+
+ let states = {
+ let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
+ Tensor::stack(&states, 0)?.detach()?
+ };
+
+ let log_probs = actions_mask
+ .mul(&log_softmax(&model.forward(&states)?, 1)?)?
+ .sum(1)?;
+
+ let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
+ optimizer.backward_step(&loss)?;
+ }
+
+ Ok(())
+}