summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning
diff options
context:
space:
mode:
authorTravis Hammond <dashdeckers@gmail.com>2023-10-28 20:53:34 +0200
committerGitHub <noreply@github.com>2023-10-28 19:53:34 +0100
commit498c50348ce13456d683c987ad9aef319a45eb4a (patch)
treefbe33f6770e554767f334f38c948e583e26c9b71 /candle-examples/examples/reinforcement-learning
parent012ae0090e70da67987a0308ef18587e9e8a8e44 (diff)
downloadcandle-498c50348ce13456d683c987ad9aef319a45eb4a.tar.gz
candle-498c50348ce13456d683c987ad9aef319a45eb4a.tar.bz2
candle-498c50348ce13456d683c987ad9aef319a45eb4a.zip
Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper - It was returning things in the wrong order - Gym now differentiates between terminated and truncated * Add DDPG * Apply fixes * Remove Result annotations * Also remove Vec annotation * rustfmt * Various small improvements (avoid cloning, mutability, get clippy to pass, ...) --------- Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
-rw-r--r--candle-examples/examples/reinforcement-learning/ddpg.rs451
-rw-r--r--candle-examples/examples/reinforcement-learning/gym_env.rs38
-rw-r--r--candle-examples/examples/reinforcement-learning/main.rs85
3 files changed, 549 insertions, 25 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs
new file mode 100644
index 00000000..c6d72fed
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/ddpg.rs
@@ -0,0 +1,451 @@
+use std::collections::VecDeque;
+use std::fmt::Display;
+
+use candle::{DType, Device, Error, Module, Result, Tensor, Var};
+use candle_nn::{
+ func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
+ VarBuilder, VarMap,
+};
+use rand::{distributions::Uniform, thread_rng, Rng};
+
+pub struct OuNoise {
+ mu: f64,
+ theta: f64,
+ sigma: f64,
+ state: Tensor,
+}
+impl OuNoise {
+ pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
+ Ok(Self {
+ mu,
+ theta,
+ sigma,
+ state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
+ })
+ }
+
+ pub fn sample(&mut self) -> Result<Tensor> {
+ let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
+ let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
+ self.state = (&self.state + dx)?;
+ Ok(self.state.clone())
+ }
+}
+
+#[derive(Clone)]
+struct Transition {
+ state: Tensor,
+ action: Tensor,
+ reward: Tensor,
+ next_state: Tensor,
+ terminated: bool,
+ truncated: bool,
+}
+impl Transition {
+ fn new(
+ state: &Tensor,
+ action: &Tensor,
+ reward: &Tensor,
+ next_state: &Tensor,
+ terminated: bool,
+ truncated: bool,
+ ) -> Self {
+ Self {
+ state: state.clone(),
+ action: action.clone(),
+ reward: reward.clone(),
+ next_state: next_state.clone(),
+ terminated,
+ truncated,
+ }
+ }
+}
+
+pub struct ReplayBuffer {
+ buffer: VecDeque<Transition>,
+ capacity: usize,
+ size: usize,
+}
+impl ReplayBuffer {
+ pub fn new(capacity: usize) -> Self {
+ Self {
+ buffer: VecDeque::with_capacity(capacity),
+ capacity,
+ size: 0,
+ }
+ }
+
+ pub fn push(
+ &mut self,
+ state: &Tensor,
+ action: &Tensor,
+ reward: &Tensor,
+ next_state: &Tensor,
+ terminated: bool,
+ truncated: bool,
+ ) {
+ if self.size == self.capacity {
+ self.buffer.pop_front();
+ } else {
+ self.size += 1;
+ }
+ self.buffer.push_back(Transition::new(
+ state, action, reward, next_state, terminated, truncated,
+ ));
+ }
+
+ #[allow(clippy::type_complexity)]
+ pub fn random_batch(
+ &self,
+ batch_size: usize,
+ ) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
+ if self.size < batch_size {
+ Ok(None)
+ } else {
+ let transitions: Vec<&Transition> = thread_rng()
+ .sample_iter(Uniform::from(0..self.size))
+ .take(batch_size)
+ .map(|i| self.buffer.get(i).unwrap())
+ .collect();
+
+ let states: Vec<Tensor> = transitions
+ .iter()
+ .map(|t| t.state.unsqueeze(0))
+ .collect::<Result<_>>()?;
+ let actions: Vec<Tensor> = transitions
+ .iter()
+ .map(|t| t.action.unsqueeze(0))
+ .collect::<Result<_>>()?;
+ let rewards: Vec<Tensor> = transitions
+ .iter()
+ .map(|t| t.reward.unsqueeze(0))
+ .collect::<Result<_>>()?;
+ let next_states: Vec<Tensor> = transitions
+ .iter()
+ .map(|t| t.next_state.unsqueeze(0))
+ .collect::<Result<_>>()?;
+ let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
+ let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
+
+ Ok(Some((
+ Tensor::cat(&states, 0)?,
+ Tensor::cat(&actions, 0)?,
+ Tensor::cat(&rewards, 0)?,
+ Tensor::cat(&next_states, 0)?,
+ terminateds,
+ truncateds,
+ )))
+ }
+ }
+}
+
+fn track(
+ varmap: &mut VarMap,
+ vb: &VarBuilder,
+ target_prefix: &str,
+ network_prefix: &str,
+ dims: &[(usize, usize)],
+ tau: f64,
+) -> Result<()> {
+ for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
+ let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
+ let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
+ varmap.set_one(
+ format!("{target_prefix}-fc{i}.weight"),
+ ((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
+ )?;
+
+ let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
+ let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
+ varmap.set_one(
+ format!("{target_prefix}-fc{i}.bias"),
+ ((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
+ )?;
+ }
+ Ok(())
+}
+
+struct Actor<'a> {
+ varmap: VarMap,
+ vb: VarBuilder<'a>,
+ network: Sequential,
+ target_network: Sequential,
+ size_state: usize,
+ size_action: usize,
+ dims: Vec<(usize, usize)>,
+}
+
+impl Actor<'_> {
+ fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
+ let mut varmap = VarMap::new();
+ let vb = VarBuilder::from_varmap(&varmap, dtype, device);
+
+ let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
+
+ let make_network = |prefix: &str| {
+ let seq = seq()
+ .add(linear(
+ dims[0].0,
+ dims[0].1,
+ vb.pp(format!("{prefix}-fc0")),
+ )?)
+ .add(Activation::Relu)
+ .add(linear(
+ dims[1].0,
+ dims[1].1,
+ vb.pp(format!("{prefix}-fc1")),
+ )?)
+ .add(Activation::Relu)
+ .add(linear(
+ dims[2].0,
+ dims[2].1,
+ vb.pp(format!("{prefix}-fc2")),
+ )?)
+ .add(func(|xs| xs.tanh()));
+ Ok::<Sequential, Error>(seq)
+ };
+
+ let network = make_network("actor")?;
+ let target_network = make_network("target-actor")?;
+
+ // this sets the two networks to be equal to each other using tau = 1.0
+ track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
+
+ Ok(Self {
+ varmap,
+ vb,
+ network,
+ target_network,
+ size_state,
+ size_action,
+ dims,
+ })
+ }
+
+ fn forward(&self, state: &Tensor) -> Result<Tensor> {
+ self.network.forward(state)
+ }
+
+ fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
+ self.target_network.forward(state)
+ }
+
+ fn track(&mut self, tau: f64) -> Result<()> {
+ track(
+ &mut self.varmap,
+ &self.vb,
+ "target-actor",
+ "actor",
+ &self.dims,
+ tau,
+ )
+ }
+}
+
+struct Critic<'a> {
+ varmap: VarMap,
+ vb: VarBuilder<'a>,
+ network: Sequential,
+ target_network: Sequential,
+ size_state: usize,
+ size_action: usize,
+ dims: Vec<(usize, usize)>,
+}
+
+impl Critic<'_> {
+ fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
+ let mut varmap = VarMap::new();
+ let vb = VarBuilder::from_varmap(&varmap, dtype, device);
+
+ let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
+
+ let make_network = |prefix: &str| {
+ let seq = seq()
+ .add(linear(
+ dims[0].0,
+ dims[0].1,
+ vb.pp(format!("{prefix}-fc0")),
+ )?)
+ .add(Activation::Relu)
+ .add(linear(
+ dims[1].0,
+ dims[1].1,
+ vb.pp(format!("{prefix}-fc1")),
+ )?)
+ .add(Activation::Relu)
+ .add(linear(
+ dims[2].0,
+ dims[2].1,
+ vb.pp(format!("{prefix}-fc2")),
+ )?);
+ Ok::<Sequential, Error>(seq)
+ };
+
+ let network = make_network("critic")?;
+ let target_network = make_network("target-critic")?;
+
+ // this sets the two networks to be equal to each other using tau = 1.0
+ track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
+
+ Ok(Self {
+ varmap,
+ vb,
+ network,
+ target_network,
+ size_state,
+ size_action,
+ dims,
+ })
+ }
+
+ fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
+ let xs = Tensor::cat(&[action, state], 1)?;
+ self.network.forward(&xs)
+ }
+
+ fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
+ let xs = Tensor::cat(&[action, state], 1)?;
+ self.target_network.forward(&xs)
+ }
+
+ fn track(&mut self, tau: f64) -> Result<()> {
+ track(
+ &mut self.varmap,
+ &self.vb,
+ "target-critic",
+ "critic",
+ &self.dims,
+ tau,
+ )
+ }
+}
+
+#[allow(clippy::upper_case_acronyms)]
+pub struct DDPG<'a> {
+ actor: Actor<'a>,
+ actor_optim: AdamW,
+ critic: Critic<'a>,
+ critic_optim: AdamW,
+ gamma: f64,
+ tau: f64,
+ replay_buffer: ReplayBuffer,
+ ou_noise: OuNoise,
+
+ size_state: usize,
+ size_action: usize,
+ pub train: bool,
+}
+
+impl DDPG<'_> {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ device: &Device,
+ size_state: usize,
+ size_action: usize,
+ train: bool,
+ actor_lr: f64,
+ critic_lr: f64,
+ gamma: f64,
+ tau: f64,
+ buffer_capacity: usize,
+ ou_noise: OuNoise,
+ ) -> Result<Self> {
+ let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
+ varmap
+ .data()
+ .lock()
+ .unwrap()
+ .iter()
+ .filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
+ .collect::<Vec<Var>>()
+ };
+
+ let actor = Actor::new(device, DType::F32, size_state, size_action)?;
+ let actor_optim = AdamW::new(
+ filter_by_prefix(&actor.varmap, "actor"),
+ ParamsAdamW {
+ lr: actor_lr,
+ ..Default::default()
+ },
+ )?;
+
+ let critic = Critic::new(device, DType::F32, size_state, size_action)?;
+ let critic_optim = AdamW::new(
+ filter_by_prefix(&critic.varmap, "critic"),
+ ParamsAdamW {
+ lr: critic_lr,
+ ..Default::default()
+ },
+ )?;
+
+ Ok(Self {
+ actor,
+ actor_optim,
+ critic,
+ critic_optim,
+ gamma,
+ tau,
+ replay_buffer: ReplayBuffer::new(buffer_capacity),
+ ou_noise,
+ size_state,
+ size_action,
+ train,
+ })
+ }
+
+ pub fn remember(
+ &mut self,
+ state: &Tensor,
+ action: &Tensor,
+ reward: &Tensor,
+ next_state: &Tensor,
+ terminated: bool,
+ truncated: bool,
+ ) {
+ self.replay_buffer
+ .push(state, action, reward, next_state, terminated, truncated)
+ }
+
+ pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
+ let actions = self
+ .actor
+ .forward(&state.detach()?.unsqueeze(0)?)?
+ .squeeze(0)?;
+ let actions = if self.train {
+ (actions + self.ou_noise.sample()?)?
+ } else {
+ actions
+ };
+ actions.squeeze(0)?.to_scalar::<f32>()
+ }
+
+ pub fn train(&mut self, batch_size: usize) -> Result<()> {
+ let (states, actions, rewards, next_states, _, _) =
+ match self.replay_buffer.random_batch(batch_size)? {
+ Some(v) => v,
+ _ => return Ok(()),
+ };
+
+ let q_target = self
+ .critic
+ .target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
+ let q_target = (rewards + (self.gamma * q_target)?.detach())?;
+ let q = self.critic.forward(&states, &actions)?;
+ let diff = (q_target - q)?;
+
+ let critic_loss = diff.sqr()?.mean_all()?;
+ self.critic_optim.backward_step(&critic_loss)?;
+
+ let actor_loss = self
+ .critic
+ .forward(&states, &self.actor.forward(&states)?)?
+ .mean_all()?
+ .neg()?;
+ self.actor_optim.backward_step(&actor_loss)?;
+
+ self.critic.track(self.tau)?;
+ self.actor.track(self.tau)?;
+
+ Ok(())
+ }
+}
diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs
index b98be6bc..8868c188 100644
--- a/candle-examples/examples/reinforcement-learning/gym_env.rs
+++ b/candle-examples/examples/reinforcement-learning/gym_env.rs
@@ -7,20 +7,22 @@ use pyo3::types::PyDict;
/// The return value for a step.
#[derive(Debug)]
pub struct Step<A> {
- pub obs: Tensor,
+ pub state: Tensor,
pub action: A,
pub reward: f64,
- pub is_done: bool,
+ pub terminated: bool,
+ pub truncated: bool,
}
impl<A: Copy> Step<A> {
/// Returns a copy of this step changing the observation tensor.
- pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
+ pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
Step {
- obs: obs.clone(),
+ state: state.clone(),
action: self.action,
reward: self.reward,
- is_done: self.is_done,
+ terminated: self.terminated,
+ truncated: self.truncated,
}
}
}
@@ -63,14 +65,14 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
- let obs: Vec<f32> = Python::with_gil(|py| {
+ let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
- let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
- obs.as_ref(py).get_item(0)?.extract()
+ let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
+ state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
- Tensor::new(obs, &Device::Cpu)
+ Tensor::new(state, &Device::Cpu)
}
/// Applies an environment step using the specified action.
@@ -78,21 +80,23 @@ impl GymEnv {
&self,
action: A,
) -> Result<Step<A>> {
- let (obs, reward, is_done) = Python::with_gil(|py| {
+ let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
- let obs: Vec<f32> = step.get_item(0)?.extract()?;
+ let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
- let is_done: bool = step.get_item(2)?.extract()?;
- Ok((obs, reward, is_done))
+ let terminated: bool = step.get_item(2)?.extract()?;
+ let truncated: bool = step.get_item(3)?.extract()?;
+ Ok((state, reward, terminated, truncated))
})
.map_err(w)?;
- let obs = Tensor::new(obs, &Device::Cpu)?;
+ let state = Tensor::new(state, &Device::Cpu)?;
Ok(Step {
- obs,
- reward,
- is_done,
+ state,
action,
+ reward,
+ terminated,
+ truncated,
})
}
diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs
index f16f042e..96d7102d 100644
--- a/candle-examples/examples/reinforcement-learning/main.rs
+++ b/candle-examples/examples/reinforcement-learning/main.rs
@@ -9,14 +9,34 @@ extern crate accelerate_src;
mod gym_env;
mod vec_gym_env;
-use candle::Result;
+mod ddpg;
+
+use candle::{Device, Result, Tensor};
use clap::Parser;
use rand::Rng;
+// The impact of the q value of the next state on the current state's q value.
+const GAMMA: f64 = 0.99;
+// The weight for updating the target networks.
+const TAU: f64 = 0.005;
+// The capacity of the replay buffer used for sampling training data.
+const REPLAY_BUFFER_CAPACITY: usize = 100_000;
+// The training batch size for each training iteration.
+const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
+// The number of training iterations after one episode finishes.
+const TRAINING_ITERATIONS: usize = 200;
+
+// Ornstein-Uhlenbeck process parameters.
+const MU: f64 = 0.0;
+const THETA: f64 = 0.15;
+const SIGMA: f64 = 0.1;
+
+const ACTOR_LEARNING_RATE: f64 = 1e-4;
+const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
@@ -48,28 +68,77 @@ fn main() -> Result<()> {
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
- let _num_obs = env.observation_space().iter().product::<usize>();
- let _num_actions = env.action_space();
+ let size_state = env.observation_space().iter().product::<usize>();
+ let size_action = env.action_space();
+
+ let mut agent = ddpg::DDPG::new(
+ &Device::Cpu,
+ size_state,
+ size_action,
+ true,
+ ACTOR_LEARNING_RATE,
+ CRITIC_LEARNING_RATE,
+ GAMMA,
+ TAU,
+ REPLAY_BUFFER_CAPACITY,
+ ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
+ )?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
- let mut obs = env.reset(episode as u64)?;
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
- let actions = rng.gen_range(-2.0..2.0);
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
- let step = env.step(vec![actions])?;
+ let step = env.step(vec![action])?;
total_reward += step.reward;
- if step.is_done {
+ agent.remember(
+ &state,
+ &Tensor::new(vec![action], &Device::Cpu)?,
+ &Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
+ &step.state,
+ step.terminated,
+ step.truncated,
+ );
+
+ if step.terminated || step.truncated {
break;
}
- obs = step.obs;
+ state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
+
+ for _ in 0..TRAINING_ITERATIONS {
+ agent.train(TRAINING_BATCH_SIZE)?;
+ }
+ }
+
+ println!("Testing...");
+ agent.train = false;
+ for episode in 0..10 {
+ // let mut state = env.reset(episode as u64)?;
+ let mut state = env.reset(rng.gen::<u64>())?;
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let mut action = 2.0 * agent.actions(&state)?;
+ action = action.clamp(-2.0, 2.0);
+
+ let step = env.step(vec![action])?;
+ total_reward += step.reward;
+
+ if step.terminated || step.truncated {
+ break;
+ }
+ state = step.state;
+ }
+ println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}