summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-14 16:46:43 +0100
committerGitHub <noreply@github.com>2023-10-14 16:46:43 +0100
commit29c7f2565d9a62b3451bec45ae3d031c19fd9d7a (patch)
treef2e5c5f7340701b314fad81840e7d5ea3a93542b /candle-examples/examples/reinforcement-learning
parent9309cfc47d3a73605cc6dea8669bbea5b0a5784c (diff)
downloadcandle-29c7f2565d9a62b3451bec45ae3d031c19fd9d7a.tar.gz
candle-29c7f2565d9a62b3451bec45ae3d031c19fd9d7a.tar.bz2
candle-29c7f2565d9a62b3451bec45ae3d031c19fd9d7a.zip
Add some reinforcement learning example. (#1090)
* Add some reinforcement learning example. * Python initialization. * Get the example to run. * Vectorized gym envs for the atari wrappers. * Get some simulation loop to run.
Diffstat (limited to 'candle-examples/examples/reinforcement-learning')
-rw-r--r--candle-examples/examples/reinforcement-learning/README.md16
-rw-r--r--candle-examples/examples/reinforcement-learning/atari_wrappers.py308
-rw-r--r--candle-examples/examples/reinforcement-learning/gym_env.rs108
-rw-r--r--candle-examples/examples/reinforcement-learning/main.rs75
-rw-r--r--candle-examples/examples/reinforcement-learning/vec_gym_env.rs91
5 files changed, 598 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/README.md b/candle-examples/examples/reinforcement-learning/README.md
new file mode 100644
index 00000000..2d3d14b0
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/README.md
@@ -0,0 +1,16 @@
+# candle-reinforcement-learning
+
+Reinforcement Learning examples for candle.
+
+This has been tested with `gymnasium` version `0.29.1`. You can install the
+Python package with:
+```bash
+pip install "gymnasium[accept-rom-license]"
+```
+
+In order to run the example, use the following command. Note the additional
+`--package` flag to ensure that there is no conflict with the `candle-pyo3`
+crate.
+```bash
+cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
+```
diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
new file mode 100644
index 00000000..b5c4665d
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
@@ -0,0 +1,308 @@
+import gymnasium as gym
+import numpy as np
+from collections import deque
+from PIL import Image
+from multiprocessing import Process, Pipe
+
+# atari_wrappers.py
+class NoopResetEnv(gym.Wrapper):
+ def __init__(self, env, noop_max=30):
+ """Sample initial states by taking random number of no-ops on reset.
+ No-op is assumed to be action 0.
+ """
+ gym.Wrapper.__init__(self, env)
+ self.noop_max = noop_max
+ self.override_num_noops = None
+ assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
+
+ def reset(self):
+ """ Do no-op action for a number of steps in [1, noop_max]."""
+ self.env.reset()
+ if self.override_num_noops is not None:
+ noops = self.override_num_noops
+ else:
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
+ assert noops > 0
+ obs = None
+ for _ in range(noops):
+ obs, _, done, _ = self.env.step(0)
+ if done:
+ obs = self.env.reset()
+ return obs
+
+class FireResetEnv(gym.Wrapper):
+ def __init__(self, env):
+ """Take action on reset for environments that are fixed until firing."""
+ gym.Wrapper.__init__(self, env)
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
+ assert len(env.unwrapped.get_action_meanings()) >= 3
+
+ def reset(self):
+ self.env.reset()
+ obs, _, done, _ = self.env.step(1)
+ if done:
+ self.env.reset()
+ obs, _, done, _ = self.env.step(2)
+ if done:
+ self.env.reset()
+ return obs
+
+class ImageSaver(gym.Wrapper):
+ def __init__(self, env, img_path, rank):
+ gym.Wrapper.__init__(self, env)
+ self._cnt = 0
+ self._img_path = img_path
+ self._rank = rank
+
+ def step(self, action):
+ step_result = self.env.step(action)
+ obs, _, _, _ = step_result
+ img = Image.fromarray(obs, 'RGB')
+ img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
+ self._cnt += 1
+ return step_result
+
+class EpisodicLifeEnv(gym.Wrapper):
+ def __init__(self, env):
+ """Make end-of-life == end-of-episode, but only reset on true game over.
+ Done by DeepMind for the DQN and co. since it helps value estimation.
+ """
+ gym.Wrapper.__init__(self, env)
+ self.lives = 0
+ self.was_real_done = True
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ self.was_real_done = done
+ # check current lives, make loss of life terminal,
+ # then update lives to handle bonus lives
+ lives = self.env.unwrapped.ale.lives()
+ if lives < self.lives and lives > 0:
+ # for Qbert somtimes we stay in lives == 0 condtion for a few frames
+ # so its important to keep lives > 0, so that we only reset once
+ # the environment advertises done.
+ done = True
+ self.lives = lives
+ return obs, reward, done, info
+
+ def reset(self):
+ """Reset only when lives are exhausted.
+ This way all states are still reachable even though lives are episodic,
+ and the learner need not know about any of this behind-the-scenes.
+ """
+ if self.was_real_done:
+ obs = self.env.reset()
+ else:
+ # no-op step to advance from terminal/lost life state
+ obs, _, _, _ = self.env.step(0)
+ self.lives = self.env.unwrapped.ale.lives()
+ return obs
+
+class MaxAndSkipEnv(gym.Wrapper):
+ def __init__(self, env, skip=4):
+ """Return only every `skip`-th frame"""
+ gym.Wrapper.__init__(self, env)
+ # most recent raw observations (for max pooling across time steps)
+ self._obs_buffer = deque(maxlen=2)
+ self._skip = skip
+
+ def step(self, action):
+ """Repeat action, sum reward, and max over last observations."""
+ total_reward = 0.0
+ done = None
+ for _ in range(self._skip):
+ obs, reward, done, info = self.env.step(action)
+ self._obs_buffer.append(obs)
+ total_reward += reward
+ if done:
+ break
+ max_frame = np.max(np.stack(self._obs_buffer), axis=0)
+
+ return max_frame, total_reward, done, info
+
+ def reset(self):
+ """Clear past frame buffer and init. to first obs. from inner env."""
+ self._obs_buffer.clear()
+ obs = self.env.reset()
+ self._obs_buffer.append(obs)
+ return obs
+
+class ClipRewardEnv(gym.RewardWrapper):
+ def reward(self, reward):
+ """Bin reward to {+1, 0, -1} by its sign."""
+ return np.sign(reward)
+
+class WarpFrame(gym.ObservationWrapper):
+ def __init__(self, env):
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
+ gym.ObservationWrapper.__init__(self, env)
+ self.res = 84
+ self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
+
+ def observation(self, obs):
+ frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
+ frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
+ resample=Image.BILINEAR), dtype=np.uint8)
+ return frame.reshape((self.res, self.res, 1))
+
+class FrameStack(gym.Wrapper):
+ def __init__(self, env, k):
+ """Buffer observations and stack across channels (last axis)."""
+ gym.Wrapper.__init__(self, env)
+ self.k = k
+ self.frames = deque([], maxlen=k)
+ shp = env.observation_space.shape
+ assert shp[2] == 1 # can only stack 1-channel frames
+ self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
+
+ def reset(self):
+ """Clear buffer and re-fill by duplicating the first observation."""
+ ob = self.env.reset()
+ for _ in range(self.k): self.frames.append(ob)
+ return self.observation()
+
+ def step(self, action):
+ ob, reward, done, info = self.env.step(action)
+ self.frames.append(ob)
+ return self.observation(), reward, done, info
+
+ def observation(self):
+ assert len(self.frames) == self.k
+ return np.concatenate(self.frames, axis=2)
+
+def wrap_deepmind(env, episode_life=True, clip_rewards=True):
+ """Configure environment for DeepMind-style Atari.
+
+ Note: this does not include frame stacking!"""
+ assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
+ if episode_life:
+ env = EpisodicLifeEnv(env)
+ env = NoopResetEnv(env, noop_max=30)
+ env = MaxAndSkipEnv(env, skip=4)
+ if 'FIRE' in env.unwrapped.get_action_meanings():
+ env = FireResetEnv(env)
+ env = WarpFrame(env)
+ if clip_rewards:
+ env = ClipRewardEnv(env)
+ return env
+
+# envs.py
+def make_env(env_id, img_dir, seed, rank):
+ def _thunk():
+ env = gym.make(env_id)
+ env.reset(seed=(seed + rank))
+ if img_dir is not None:
+ env = ImageSaver(env, img_dir, rank)
+ env = wrap_deepmind(env)
+ env = WrapPyTorch(env)
+ return env
+
+ return _thunk
+
+class WrapPyTorch(gym.ObservationWrapper):
+ def __init__(self, env=None):
+ super(WrapPyTorch, self).__init__(env)
+ self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
+
+ def observation(self, observation):
+ return observation.transpose(2, 0, 1)
+
+# vecenv.py
+class VecEnv(object):
+ """
+ Vectorized environment base class
+ """
+ def step(self, vac):
+ """
+ Apply sequence of actions to sequence of environments
+ actions -> (observations, rewards, news)
+
+ where 'news' is a boolean vector indicating whether each element is new.
+ """
+ raise NotImplementedError
+ def reset(self):
+ """
+ Reset all environments
+ """
+ raise NotImplementedError
+ def close(self):
+ pass
+
+# subproc_vec_env.py
+def worker(remote, env_fn_wrapper):
+ env = env_fn_wrapper.x()
+ while True:
+ cmd, data = remote.recv()
+ if cmd == 'step':
+ ob, reward, done, info = env.step(data)
+ if done:
+ ob = env.reset()
+ remote.send((ob, reward, done, info))
+ elif cmd == 'reset':
+ ob = env.reset()
+ remote.send(ob)
+ elif cmd == 'close':
+ remote.close()
+ break
+ elif cmd == 'get_spaces':
+ remote.send((env.action_space, env.observation_space))
+ else:
+ raise NotImplementedError
+
+class CloudpickleWrapper(object):
+ """
+ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
+ """
+ def __init__(self, x):
+ self.x = x
+ def __getstate__(self):
+ import cloudpickle
+ return cloudpickle.dumps(self.x)
+ def __setstate__(self, ob):
+ import pickle
+ self.x = pickle.loads(ob)
+
+class SubprocVecEnv(VecEnv):
+ def __init__(self, env_fns):
+ """
+ envs: list of gym environments to run in subprocesses
+ """
+ nenvs = len(env_fns)
+ self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
+ self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
+ for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
+ for p in self.ps:
+ p.start()
+
+ self.remotes[0].send(('get_spaces', None))
+ self.action_space, self.observation_space = self.remotes[0].recv()
+
+
+ def step(self, actions):
+ for remote, action in zip(self.remotes, actions):
+ remote.send(('step', action))
+ results = [remote.recv() for remote in self.remotes]
+ obs, rews, dones, infos = zip(*results)
+ return np.stack(obs), np.stack(rews), np.stack(dones), infos
+
+ def reset(self):
+ for remote in self.remotes:
+ remote.send(('reset', None))
+ return np.stack([remote.recv() for remote in self.remotes])
+
+ def close(self):
+ for remote in self.remotes:
+ remote.send(('close', None))
+ for p in self.ps:
+ p.join()
+
+ @property
+ def num_envs(self):
+ return len(self.remotes)
+
+# Create the environment.
+def make(env_name, img_dir, num_processes):
+ envs = SubprocVecEnv([
+ make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
+ ])
+ return envs
diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs
new file mode 100644
index 00000000..b98be6bc
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/gym_env.rs
@@ -0,0 +1,108 @@
+#![allow(unused)]
+//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
+use candle::{Device, Result, Tensor};
+use pyo3::prelude::*;
+use pyo3::types::PyDict;
+
+/// The return value for a step.
+#[derive(Debug)]
+pub struct Step<A> {
+ pub obs: Tensor,
+ pub action: A,
+ pub reward: f64,
+ pub is_done: bool,
+}
+
+impl<A: Copy> Step<A> {
+ /// Returns a copy of this step changing the observation tensor.
+ pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
+ Step {
+ obs: obs.clone(),
+ action: self.action,
+ reward: self.reward,
+ is_done: self.is_done,
+ }
+ }
+}
+
+/// An OpenAI Gym session.
+pub struct GymEnv {
+ env: PyObject,
+ action_space: usize,
+ observation_space: Vec<usize>,
+}
+
+fn w(res: PyErr) -> candle::Error {
+ candle::Error::wrap(res)
+}
+
+impl GymEnv {
+ /// Creates a new session of the specified OpenAI Gym environment.
+ pub fn new(name: &str) -> Result<GymEnv> {
+ Python::with_gil(|py| {
+ let gym = py.import("gymnasium")?;
+ let make = gym.getattr("make")?;
+ let env = make.call1((name,))?;
+ let action_space = env.getattr("action_space")?;
+ let action_space = if let Ok(val) = action_space.getattr("n") {
+ val.extract()?
+ } else {
+ let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
+ action_space[0]
+ };
+ let observation_space = env.getattr("observation_space")?;
+ let observation_space = observation_space.getattr("shape")?.extract()?;
+ Ok(GymEnv {
+ env: env.into(),
+ action_space,
+ observation_space,
+ })
+ })
+ .map_err(w)
+ }
+
+ /// Resets the environment, returning the observation tensor.
+ pub fn reset(&self, seed: u64) -> Result<Tensor> {
+ let obs: Vec<f32> = Python::with_gil(|py| {
+ let kwargs = PyDict::new(py);
+ kwargs.set_item("seed", seed)?;
+ let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
+ obs.as_ref(py).get_item(0)?.extract()
+ })
+ .map_err(w)?;
+ Tensor::new(obs, &Device::Cpu)
+ }
+
+ /// Applies an environment step using the specified action.
+ pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
+ &self,
+ action: A,
+ ) -> Result<Step<A>> {
+ let (obs, reward, is_done) = Python::with_gil(|py| {
+ let step = self.env.call_method(py, "step", (action.clone(),), None)?;
+ let step = step.as_ref(py);
+ let obs: Vec<f32> = step.get_item(0)?.extract()?;
+ let reward: f64 = step.get_item(1)?.extract()?;
+ let is_done: bool = step.get_item(2)?.extract()?;
+ Ok((obs, reward, is_done))
+ })
+ .map_err(w)?;
+ let obs = Tensor::new(obs, &Device::Cpu)?;
+ Ok(Step {
+ obs,
+ reward,
+ is_done,
+ action,
+ })
+ }
+
+ /// Returns the number of allowed actions for this environment.
+ pub fn action_space(&self) -> usize {
+ self.action_space
+ }
+
+ /// Returns the shape of the observation tensors.
+ pub fn observation_space(&self) -> &[usize] {
+ &self.observation_space
+ }
+}
diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs
new file mode 100644
index 00000000..f16f042e
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/main.rs
@@ -0,0 +1,75 @@
+#![allow(unused)]
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+mod gym_env;
+mod vec_gym_env;
+
+use candle::Result;
+use clap::Parser;
+use rand::Rng;
+
+// The total number of episodes.
+const MAX_EPISODES: usize = 100;
+// The maximum length of an episode.
+const EPISODE_LENGTH: usize = 200;
+
+#[derive(Parser, Debug, Clone)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ let env = gym_env::GymEnv::new("Pendulum-v1")?;
+ println!("action space: {}", env.action_space());
+ println!("observation space: {:?}", env.observation_space());
+
+ let _num_obs = env.observation_space().iter().product::<usize>();
+ let _num_actions = env.action_space();
+
+ let mut rng = rand::thread_rng();
+
+ for episode in 0..MAX_EPISODES {
+ let mut obs = env.reset(episode as u64)?;
+
+ let mut total_reward = 0.0;
+ for _ in 0..EPISODE_LENGTH {
+ let actions = rng.gen_range(-2.0..2.0);
+
+ let step = env.step(vec![actions])?;
+ total_reward += step.reward;
+
+ if step.is_done {
+ break;
+ }
+ obs = step.obs;
+ }
+
+ println!("episode {episode} with total reward of {total_reward}");
+ }
+ Ok(())
+}
diff --git a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs
new file mode 100644
index 00000000..8f8f30bd
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs
@@ -0,0 +1,91 @@
+#![allow(unused)]
+//! Vectorized version of the gym environment.
+use candle::{DType, Device, Result, Tensor};
+use pyo3::prelude::*;
+use pyo3::types::PyDict;
+
+#[derive(Debug)]
+pub struct Step {
+ pub obs: Tensor,
+ pub reward: Tensor,
+ pub is_done: Tensor,
+}
+
+pub struct VecGymEnv {
+ env: PyObject,
+ action_space: usize,
+ observation_space: Vec<usize>,
+}
+
+fn w(res: PyErr) -> candle::Error {
+ candle::Error::wrap(res)
+}
+
+impl VecGymEnv {
+ pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
+ Python::with_gil(|py| {
+ let sys = py.import("sys")?;
+ let path = sys.getattr("path")?;
+ let _ = path.call_method1(
+ "append",
+ ("candle-examples/examples/reinforcement-learning",),
+ )?;
+ let gym = py.import("atari_wrappers")?;
+ let make = gym.getattr("make")?;
+ let env = make.call1((name, img_dir, nprocesses))?;
+ let action_space = env.getattr("action_space")?;
+ let action_space = action_space.getattr("n")?.extract()?;
+ let observation_space = env.getattr("observation_space")?;
+ let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
+ let observation_space =
+ [vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
+ Ok(VecGymEnv {
+ env: env.into(),
+ action_space,
+ observation_space,
+ })
+ })
+ .map_err(w)
+ }
+
+ pub fn reset(&self) -> Result<Tensor> {
+ let obs = Python::with_gil(|py| {
+ let obs = self.env.call_method0(py, "reset")?;
+ let obs = obs.call_method0(py, "flatten")?;
+ obs.extract::<Vec<f32>>(py)
+ })
+ .map_err(w)?;
+ Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
+ }
+
+ pub fn step(&self, action: Vec<usize>) -> Result<Step> {
+ let (obs, reward, is_done) = Python::with_gil(|py| {
+ let step = self.env.call_method(py, "step", (action,), None)?;
+ let step = step.as_ref(py);
+ let obs = step.get_item(0)?.call_method("flatten", (), None)?;
+ let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
+ let obs: Vec<u8> = obs_buffer.to_vec(py)?;
+ let reward: Vec<f32> = step.get_item(1)?.extract()?;
+ let is_done: Vec<f32> = step.get_item(2)?.extract()?;
+ Ok((obs, reward, is_done))
+ })
+ .map_err(w)?;
+ let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
+ .to_dtype(DType::F32)?;
+ let reward = Tensor::new(reward, &Device::Cpu)?;
+ let is_done = Tensor::new(is_done, &Device::Cpu)?;
+ Ok(Step {
+ obs,
+ reward,
+ is_done,
+ })
+ }
+
+ pub fn action_space(&self) -> usize {
+ self.action_space
+ }
+
+ pub fn observation_space(&self) -> &[usize] {
+ &self.observation_space
+ }
+}