diff options
author | Travis Hammond <dashdeckers@gmail.com> | 2023-10-28 20:53:34 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-28 19:53:34 +0100 |
commit | 498c50348ce13456d683c987ad9aef319a45eb4a (patch) | |
tree | fbe33f6770e554767f334f38c948e583e26c9b71 /candle-examples/examples/reinforcement-learning | |
parent | 012ae0090e70da67987a0308ef18587e9e8a8e44 (diff) | |
download | candle-498c50348ce13456d683c987ad9aef319a45eb4a.tar.gz candle-498c50348ce13456d683c987ad9aef319a45eb4a.tar.bz2 candle-498c50348ce13456d683c987ad9aef319a45eb4a.zip |
Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated
* Add DDPG
* Apply fixes
* Remove Result annotations
* Also remove Vec annotation
* rustfmt
* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)
---------
Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
3 files changed, 549 insertions, 25 deletions
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs new file mode 100644 index 00000000..c6d72fed --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -0,0 +1,451 @@ +use std::collections::VecDeque; +use std::fmt::Display; + +use candle::{DType, Device, Error, Module, Result, Tensor, Var}; +use candle_nn::{ + func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential, + VarBuilder, VarMap, +}; +use rand::{distributions::Uniform, thread_rng, Rng}; + +pub struct OuNoise { + mu: f64, + theta: f64, + sigma: f64, + state: Tensor, +} +impl OuNoise { + pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> { + Ok(Self { + mu, + theta, + sigma, + state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?, + }) + } + + pub fn sample(&mut self) -> Result<Tensor> { + let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?; + let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?; + self.state = (&self.state + dx)?; + Ok(self.state.clone()) + } +} + +#[derive(Clone)] +struct Transition { + state: Tensor, + action: Tensor, + reward: Tensor, + next_state: Tensor, + terminated: bool, + truncated: bool, +} +impl Transition { + fn new( + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) -> Self { + Self { + state: state.clone(), + action: action.clone(), + reward: reward.clone(), + next_state: next_state.clone(), + terminated, + truncated, + } + } +} + +pub struct ReplayBuffer { + buffer: VecDeque<Transition>, + capacity: usize, + size: usize, +} +impl ReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: VecDeque::with_capacity(capacity), + capacity, + size: 0, + } + } + + pub fn push( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + if self.size == self.capacity { + self.buffer.pop_front(); + } else { + self.size += 1; + } + self.buffer.push_back(Transition::new( + state, action, reward, next_state, terminated, truncated, + )); + } + + #[allow(clippy::type_complexity)] + pub fn random_batch( + &self, + batch_size: usize, + ) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> { + if self.size < batch_size { + Ok(None) + } else { + let transitions: Vec<&Transition> = thread_rng() + .sample_iter(Uniform::from(0..self.size)) + .take(batch_size) + .map(|i| self.buffer.get(i).unwrap()) + .collect(); + + let states: Vec<Tensor> = transitions + .iter() + .map(|t| t.state.unsqueeze(0)) + .collect::<Result<_>>()?; + let actions: Vec<Tensor> = transitions + .iter() + .map(|t| t.action.unsqueeze(0)) + .collect::<Result<_>>()?; + let rewards: Vec<Tensor> = transitions + .iter() + .map(|t| t.reward.unsqueeze(0)) + .collect::<Result<_>>()?; + let next_states: Vec<Tensor> = transitions + .iter() + .map(|t| t.next_state.unsqueeze(0)) + .collect::<Result<_>>()?; + let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect(); + let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect(); + + Ok(Some(( + Tensor::cat(&states, 0)?, + Tensor::cat(&actions, 0)?, + Tensor::cat(&rewards, 0)?, + Tensor::cat(&next_states, 0)?, + terminateds, + truncateds, + ))) + } + } +} + +fn track( + varmap: &mut VarMap, + vb: &VarBuilder, + target_prefix: &str, + network_prefix: &str, + dims: &[(usize, usize)], + tau: f64, +) -> Result<()> { + for (i, &(in_dim, out_dim)) in dims.iter().enumerate() { + let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?; + let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.weight"), + ((tau * network_w)? + ((1.0 - tau) * target_w)?)?, + )?; + + let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?; + let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.bias"), + ((tau * network_b)? + ((1.0 - tau) * target_b)?)?, + )?; + } + Ok(()) +} + +struct Actor<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Actor<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims = vec![(size_state, 400), (400, 300), (300, size_action)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?) + .add(func(|xs| xs.tanh())); + Ok::<Sequential, Error>(seq) + }; + + let network = make_network("actor")?; + let target_network = make_network("target-actor")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor) -> Result<Tensor> { + self.network.forward(state) + } + + fn target_forward(&self, state: &Tensor) -> Result<Tensor> { + self.target_network.forward(state) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-actor", + "actor", + &self.dims, + tau, + ) + } +} + +struct Critic<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Critic<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?); + Ok::<Sequential, Error>(seq) + }; + + let network = make_network("critic")?; + let target_network = make_network("target-critic")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> { + let xs = Tensor::cat(&[action, state], 1)?; + self.network.forward(&xs) + } + + fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> { + let xs = Tensor::cat(&[action, state], 1)?; + self.target_network.forward(&xs) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-critic", + "critic", + &self.dims, + tau, + ) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct DDPG<'a> { + actor: Actor<'a>, + actor_optim: AdamW, + critic: Critic<'a>, + critic_optim: AdamW, + gamma: f64, + tau: f64, + replay_buffer: ReplayBuffer, + ou_noise: OuNoise, + + size_state: usize, + size_action: usize, + pub train: bool, +} + +impl DDPG<'_> { + #[allow(clippy::too_many_arguments)] + pub fn new( + device: &Device, + size_state: usize, + size_action: usize, + train: bool, + actor_lr: f64, + critic_lr: f64, + gamma: f64, + tau: f64, + buffer_capacity: usize, + ou_noise: OuNoise, + ) -> Result<Self> { + let filter_by_prefix = |varmap: &VarMap, prefix: &str| { + varmap + .data() + .lock() + .unwrap() + .iter() + .filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone())) + .collect::<Vec<Var>>() + }; + + let actor = Actor::new(device, DType::F32, size_state, size_action)?; + let actor_optim = AdamW::new( + filter_by_prefix(&actor.varmap, "actor"), + ParamsAdamW { + lr: actor_lr, + ..Default::default() + }, + )?; + + let critic = Critic::new(device, DType::F32, size_state, size_action)?; + let critic_optim = AdamW::new( + filter_by_prefix(&critic.varmap, "critic"), + ParamsAdamW { + lr: critic_lr, + ..Default::default() + }, + )?; + + Ok(Self { + actor, + actor_optim, + critic, + critic_optim, + gamma, + tau, + replay_buffer: ReplayBuffer::new(buffer_capacity), + ou_noise, + size_state, + size_action, + train, + }) + } + + pub fn remember( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + self.replay_buffer + .push(state, action, reward, next_state, terminated, truncated) + } + + pub fn actions(&mut self, state: &Tensor) -> Result<f32> { + let actions = self + .actor + .forward(&state.detach()?.unsqueeze(0)?)? + .squeeze(0)?; + let actions = if self.train { + (actions + self.ou_noise.sample()?)? + } else { + actions + }; + actions.squeeze(0)?.to_scalar::<f32>() + } + + pub fn train(&mut self, batch_size: usize) -> Result<()> { + let (states, actions, rewards, next_states, _, _) = + match self.replay_buffer.random_batch(batch_size)? { + Some(v) => v, + _ => return Ok(()), + }; + + let q_target = self + .critic + .target_forward(&next_states, &self.actor.target_forward(&next_states)?)?; + let q_target = (rewards + (self.gamma * q_target)?.detach())?; + let q = self.critic.forward(&states, &actions)?; + let diff = (q_target - q)?; + + let critic_loss = diff.sqr()?.mean_all()?; + self.critic_optim.backward_step(&critic_loss)?; + + let actor_loss = self + .critic + .forward(&states, &self.actor.forward(&states)?)? + .mean_all()? + .neg()?; + self.actor_optim.backward_step(&actor_loss)?; + + self.critic.track(self.tau)?; + self.actor.track(self.tau)?; + + Ok(()) + } +} diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index b98be6bc..8868c188 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -7,20 +7,22 @@ use pyo3::types::PyDict; /// The return value for a step. #[derive(Debug)] pub struct Step<A> { - pub obs: Tensor, + pub state: Tensor, pub action: A, pub reward: f64, - pub is_done: bool, + pub terminated: bool, + pub truncated: bool, } impl<A: Copy> Step<A> { /// Returns a copy of this step changing the observation tensor. - pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> { + pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> { Step { - obs: obs.clone(), + state: state.clone(), action: self.action, reward: self.reward, - is_done: self.is_done, + terminated: self.terminated, + truncated: self.truncated, } } } @@ -63,14 +65,14 @@ impl GymEnv { /// Resets the environment, returning the observation tensor. pub fn reset(&self, seed: u64) -> Result<Tensor> { - let obs: Vec<f32> = Python::with_gil(|py| { + let state: Vec<f32> = Python::with_gil(|py| { let kwargs = PyDict::new(py); kwargs.set_item("seed", seed)?; - let obs = self.env.call_method(py, "reset", (), Some(kwargs))?; - obs.as_ref(py).get_item(0)?.extract() + let state = self.env.call_method(py, "reset", (), Some(kwargs))?; + state.as_ref(py).get_item(0)?.extract() }) .map_err(w)?; - Tensor::new(obs, &Device::Cpu) + Tensor::new(state, &Device::Cpu) } /// Applies an environment step using the specified action. @@ -78,21 +80,23 @@ impl GymEnv { &self, action: A, ) -> Result<Step<A>> { - let (obs, reward, is_done) = Python::with_gil(|py| { + let (state, reward, terminated, truncated) = Python::with_gil(|py| { let step = self.env.call_method(py, "step", (action.clone(),), None)?; let step = step.as_ref(py); - let obs: Vec<f32> = step.get_item(0)?.extract()?; + let state: Vec<f32> = step.get_item(0)?.extract()?; let reward: f64 = step.get_item(1)?.extract()?; - let is_done: bool = step.get_item(2)?.extract()?; - Ok((obs, reward, is_done)) + let terminated: bool = step.get_item(2)?.extract()?; + let truncated: bool = step.get_item(3)?.extract()?; + Ok((state, reward, terminated, truncated)) }) .map_err(w)?; - let obs = Tensor::new(obs, &Device::Cpu)?; + let state = Tensor::new(state, &Device::Cpu)?; Ok(Step { - obs, - reward, - is_done, + state, action, + reward, + terminated, + truncated, }) } diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index f16f042e..96d7102d 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -9,14 +9,34 @@ extern crate accelerate_src; mod gym_env; mod vec_gym_env; -use candle::Result; +mod ddpg; + +use candle::{Device, Result, Tensor}; use clap::Parser; use rand::Rng; +// The impact of the q value of the next state on the current state's q value. +const GAMMA: f64 = 0.99; +// The weight for updating the target networks. +const TAU: f64 = 0.005; +// The capacity of the replay buffer used for sampling training data. +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +// The training batch size for each training iteration. +const TRAINING_BATCH_SIZE: usize = 100; // The total number of episodes. const MAX_EPISODES: usize = 100; // The maximum length of an episode. const EPISODE_LENGTH: usize = 200; +// The number of training iterations after one episode finishes. +const TRAINING_ITERATIONS: usize = 200; + +// Ornstein-Uhlenbeck process parameters. +const MU: f64 = 0.0; +const THETA: f64 = 0.15; +const SIGMA: f64 = 0.1; + +const ACTOR_LEARNING_RATE: f64 = 1e-4; +const CRITIC_LEARNING_RATE: f64 = 1e-3; #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] @@ -48,28 +68,77 @@ fn main() -> Result<()> { println!("action space: {}", env.action_space()); println!("observation space: {:?}", env.observation_space()); - let _num_obs = env.observation_space().iter().product::<usize>(); - let _num_actions = env.action_space(); + let size_state = env.observation_space().iter().product::<usize>(); + let size_action = env.action_space(); + + let mut agent = ddpg::DDPG::new( + &Device::Cpu, + size_state, + size_action, + true, + ACTOR_LEARNING_RATE, + CRITIC_LEARNING_RATE, + GAMMA, + TAU, + REPLAY_BUFFER_CAPACITY, + ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?, + )?; let mut rng = rand::thread_rng(); for episode in 0..MAX_EPISODES { - let mut obs = env.reset(episode as u64)?; + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::<u64>())?; let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { - let actions = rng.gen_range(-2.0..2.0); + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); - let step = env.step(vec![actions])?; + let step = env.step(vec![action])?; total_reward += step.reward; - if step.is_done { + agent.remember( + &state, + &Tensor::new(vec![action], &Device::Cpu)?, + &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, + &step.state, + step.terminated, + step.truncated, + ); + + if step.terminated || step.truncated { break; } - obs = step.obs; + state = step.state; } println!("episode {episode} with total reward of {total_reward}"); + + for _ in 0..TRAINING_ITERATIONS { + agent.train(TRAINING_BATCH_SIZE)?; + } + } + + println!("Testing..."); + agent.train = false; + for episode in 0..10 { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::<u64>())?; + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + println!("episode {episode} with total reward of {total_reward}"); } Ok(()) } |