summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning/atari_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/atari_wrappers.py')
-rw-r--r--candle-examples/examples/reinforcement-learning/atari_wrappers.py308
1 files changed, 308 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
new file mode 100644
index 00000000..b5c4665d
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py
@@ -0,0 +1,308 @@
+import gymnasium as gym
+import numpy as np
+from collections import deque
+from PIL import Image
+from multiprocessing import Process, Pipe
+
+# atari_wrappers.py
+class NoopResetEnv(gym.Wrapper):
+ def __init__(self, env, noop_max=30):
+ """Sample initial states by taking random number of no-ops on reset.
+ No-op is assumed to be action 0.
+ """
+ gym.Wrapper.__init__(self, env)
+ self.noop_max = noop_max
+ self.override_num_noops = None
+ assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
+
+ def reset(self):
+ """ Do no-op action for a number of steps in [1, noop_max]."""
+ self.env.reset()
+ if self.override_num_noops is not None:
+ noops = self.override_num_noops
+ else:
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
+ assert noops > 0
+ obs = None
+ for _ in range(noops):
+ obs, _, done, _ = self.env.step(0)
+ if done:
+ obs = self.env.reset()
+ return obs
+
+class FireResetEnv(gym.Wrapper):
+ def __init__(self, env):
+ """Take action on reset for environments that are fixed until firing."""
+ gym.Wrapper.__init__(self, env)
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
+ assert len(env.unwrapped.get_action_meanings()) >= 3
+
+ def reset(self):
+ self.env.reset()
+ obs, _, done, _ = self.env.step(1)
+ if done:
+ self.env.reset()
+ obs, _, done, _ = self.env.step(2)
+ if done:
+ self.env.reset()
+ return obs
+
+class ImageSaver(gym.Wrapper):
+ def __init__(self, env, img_path, rank):
+ gym.Wrapper.__init__(self, env)
+ self._cnt = 0
+ self._img_path = img_path
+ self._rank = rank
+
+ def step(self, action):
+ step_result = self.env.step(action)
+ obs, _, _, _ = step_result
+ img = Image.fromarray(obs, 'RGB')
+ img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
+ self._cnt += 1
+ return step_result
+
+class EpisodicLifeEnv(gym.Wrapper):
+ def __init__(self, env):
+ """Make end-of-life == end-of-episode, but only reset on true game over.
+ Done by DeepMind for the DQN and co. since it helps value estimation.
+ """
+ gym.Wrapper.__init__(self, env)
+ self.lives = 0
+ self.was_real_done = True
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ self.was_real_done = done
+ # check current lives, make loss of life terminal,
+ # then update lives to handle bonus lives
+ lives = self.env.unwrapped.ale.lives()
+ if lives < self.lives and lives > 0:
+ # for Qbert somtimes we stay in lives == 0 condtion for a few frames
+ # so its important to keep lives > 0, so that we only reset once
+ # the environment advertises done.
+ done = True
+ self.lives = lives
+ return obs, reward, done, info
+
+ def reset(self):
+ """Reset only when lives are exhausted.
+ This way all states are still reachable even though lives are episodic,
+ and the learner need not know about any of this behind-the-scenes.
+ """
+ if self.was_real_done:
+ obs = self.env.reset()
+ else:
+ # no-op step to advance from terminal/lost life state
+ obs, _, _, _ = self.env.step(0)
+ self.lives = self.env.unwrapped.ale.lives()
+ return obs
+
+class MaxAndSkipEnv(gym.Wrapper):
+ def __init__(self, env, skip=4):
+ """Return only every `skip`-th frame"""
+ gym.Wrapper.__init__(self, env)
+ # most recent raw observations (for max pooling across time steps)
+ self._obs_buffer = deque(maxlen=2)
+ self._skip = skip
+
+ def step(self, action):
+ """Repeat action, sum reward, and max over last observations."""
+ total_reward = 0.0
+ done = None
+ for _ in range(self._skip):
+ obs, reward, done, info = self.env.step(action)
+ self._obs_buffer.append(obs)
+ total_reward += reward
+ if done:
+ break
+ max_frame = np.max(np.stack(self._obs_buffer), axis=0)
+
+ return max_frame, total_reward, done, info
+
+ def reset(self):
+ """Clear past frame buffer and init. to first obs. from inner env."""
+ self._obs_buffer.clear()
+ obs = self.env.reset()
+ self._obs_buffer.append(obs)
+ return obs
+
+class ClipRewardEnv(gym.RewardWrapper):
+ def reward(self, reward):
+ """Bin reward to {+1, 0, -1} by its sign."""
+ return np.sign(reward)
+
+class WarpFrame(gym.ObservationWrapper):
+ def __init__(self, env):
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
+ gym.ObservationWrapper.__init__(self, env)
+ self.res = 84
+ self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
+
+ def observation(self, obs):
+ frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
+ frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
+ resample=Image.BILINEAR), dtype=np.uint8)
+ return frame.reshape((self.res, self.res, 1))
+
+class FrameStack(gym.Wrapper):
+ def __init__(self, env, k):
+ """Buffer observations and stack across channels (last axis)."""
+ gym.Wrapper.__init__(self, env)
+ self.k = k
+ self.frames = deque([], maxlen=k)
+ shp = env.observation_space.shape
+ assert shp[2] == 1 # can only stack 1-channel frames
+ self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
+
+ def reset(self):
+ """Clear buffer and re-fill by duplicating the first observation."""
+ ob = self.env.reset()
+ for _ in range(self.k): self.frames.append(ob)
+ return self.observation()
+
+ def step(self, action):
+ ob, reward, done, info = self.env.step(action)
+ self.frames.append(ob)
+ return self.observation(), reward, done, info
+
+ def observation(self):
+ assert len(self.frames) == self.k
+ return np.concatenate(self.frames, axis=2)
+
+def wrap_deepmind(env, episode_life=True, clip_rewards=True):
+ """Configure environment for DeepMind-style Atari.
+
+ Note: this does not include frame stacking!"""
+ assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
+ if episode_life:
+ env = EpisodicLifeEnv(env)
+ env = NoopResetEnv(env, noop_max=30)
+ env = MaxAndSkipEnv(env, skip=4)
+ if 'FIRE' in env.unwrapped.get_action_meanings():
+ env = FireResetEnv(env)
+ env = WarpFrame(env)
+ if clip_rewards:
+ env = ClipRewardEnv(env)
+ return env
+
+# envs.py
+def make_env(env_id, img_dir, seed, rank):
+ def _thunk():
+ env = gym.make(env_id)
+ env.reset(seed=(seed + rank))
+ if img_dir is not None:
+ env = ImageSaver(env, img_dir, rank)
+ env = wrap_deepmind(env)
+ env = WrapPyTorch(env)
+ return env
+
+ return _thunk
+
+class WrapPyTorch(gym.ObservationWrapper):
+ def __init__(self, env=None):
+ super(WrapPyTorch, self).__init__(env)
+ self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
+
+ def observation(self, observation):
+ return observation.transpose(2, 0, 1)
+
+# vecenv.py
+class VecEnv(object):
+ """
+ Vectorized environment base class
+ """
+ def step(self, vac):
+ """
+ Apply sequence of actions to sequence of environments
+ actions -> (observations, rewards, news)
+
+ where 'news' is a boolean vector indicating whether each element is new.
+ """
+ raise NotImplementedError
+ def reset(self):
+ """
+ Reset all environments
+ """
+ raise NotImplementedError
+ def close(self):
+ pass
+
+# subproc_vec_env.py
+def worker(remote, env_fn_wrapper):
+ env = env_fn_wrapper.x()
+ while True:
+ cmd, data = remote.recv()
+ if cmd == 'step':
+ ob, reward, done, info = env.step(data)
+ if done:
+ ob = env.reset()
+ remote.send((ob, reward, done, info))
+ elif cmd == 'reset':
+ ob = env.reset()
+ remote.send(ob)
+ elif cmd == 'close':
+ remote.close()
+ break
+ elif cmd == 'get_spaces':
+ remote.send((env.action_space, env.observation_space))
+ else:
+ raise NotImplementedError
+
+class CloudpickleWrapper(object):
+ """
+ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
+ """
+ def __init__(self, x):
+ self.x = x
+ def __getstate__(self):
+ import cloudpickle
+ return cloudpickle.dumps(self.x)
+ def __setstate__(self, ob):
+ import pickle
+ self.x = pickle.loads(ob)
+
+class SubprocVecEnv(VecEnv):
+ def __init__(self, env_fns):
+ """
+ envs: list of gym environments to run in subprocesses
+ """
+ nenvs = len(env_fns)
+ self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
+ self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
+ for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
+ for p in self.ps:
+ p.start()
+
+ self.remotes[0].send(('get_spaces', None))
+ self.action_space, self.observation_space = self.remotes[0].recv()
+
+
+ def step(self, actions):
+ for remote, action in zip(self.remotes, actions):
+ remote.send(('step', action))
+ results = [remote.recv() for remote in self.remotes]
+ obs, rews, dones, infos = zip(*results)
+ return np.stack(obs), np.stack(rews), np.stack(dones), infos
+
+ def reset(self):
+ for remote in self.remotes:
+ remote.send(('reset', None))
+ return np.stack([remote.recv() for remote in self.remotes])
+
+ def close(self):
+ for remote in self.remotes:
+ remote.send(('close', None))
+ for p in self.ps:
+ p.join()
+
+ @property
+ def num_envs(self):
+ return len(self.remotes)
+
+# Create the environment.
+def make(env_name, img_dir, num_processes):
+ envs = SubprocVecEnv([
+ make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
+ ])
+ return envs