diff options
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/atari_wrappers.py')
-rw-r--r-- | candle-examples/examples/reinforcement-learning/atari_wrappers.py | 308 |
1 files changed, 308 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py new file mode 100644 index 00000000..b5c4665d --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py @@ -0,0 +1,308 @@ +import gymnasium as gym +import numpy as np +from collections import deque +from PIL import Image +from multiprocessing import Process, Pipe + +# atari_wrappers.py +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + + def reset(self): + """ Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset() + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101 + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, _ = self.env.step(0) + if done: + obs = self.env.reset() + return obs + +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self): + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs + +class ImageSaver(gym.Wrapper): + def __init__(self, env, img_path, rank): + gym.Wrapper.__init__(self, env) + self._cnt = 0 + self._img_path = img_path + self._rank = rank + + def step(self, action): + step_result = self.env.step(action) + obs, _, _, _ = step_result + img = Image.fromarray(obs, 'RGB') + img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt)) + self._cnt += 1 + return step_result + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. since it helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + self.was_real_done = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, + # then update lives to handle bonus lives + lives = self.env.unwrapped.ale.lives() + if lives < self.lives and lives > 0: + # for Qbert somtimes we stay in lives == 0 condtion for a few frames + # so its important to keep lives > 0, so that we only reset once + # the environment advertises done. + done = True + self.lives = lives + return obs, reward, done, info + + def reset(self): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset() + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _ = self.env.step(0) + self.lives = self.env.unwrapped.ale.lives() + return obs + +class MaxAndSkipEnv(gym.Wrapper): + def __init__(self, env, skip=4): + """Return only every `skip`-th frame""" + gym.Wrapper.__init__(self, env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = deque(maxlen=2) + self._skip = skip + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + self._obs_buffer.append(obs) + total_reward += reward + if done: + break + max_frame = np.max(np.stack(self._obs_buffer), axis=0) + + return max_frame, total_reward, done, info + + def reset(self): + """Clear past frame buffer and init. to first obs. from inner env.""" + self._obs_buffer.clear() + obs = self.env.reset() + self._obs_buffer.append(obs) + return obs + +class ClipRewardEnv(gym.RewardWrapper): + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return np.sign(reward) + +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env): + """Warp frames to 84x84 as done in the Nature paper and later work.""" + gym.ObservationWrapper.__init__(self, env) + self.res = 84 + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8') + + def observation(self, obs): + frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) + frame = np.array(Image.fromarray(frame).resize((self.res, self.res), + resample=Image.BILINEAR), dtype=np.uint8) + return frame.reshape((self.res, self.res, 1)) + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Buffer observations and stack across channels (last axis).""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + assert shp[2] == 1 # can only stack 1-channel frames + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8') + + def reset(self): + """Clear buffer and re-fill by duplicating the first observation.""" + ob = self.env.reset() + for _ in range(self.k): self.frames.append(ob) + return self.observation() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self.observation(), reward, done, info + + def observation(self): + assert len(self.frames) == self.k + return np.concatenate(self.frames, axis=2) + +def wrap_deepmind(env, episode_life=True, clip_rewards=True): + """Configure environment for DeepMind-style Atari. + + Note: this does not include frame stacking!""" + assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip + if episode_life: + env = EpisodicLifeEnv(env) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = WarpFrame(env) + if clip_rewards: + env = ClipRewardEnv(env) + return env + +# envs.py +def make_env(env_id, img_dir, seed, rank): + def _thunk(): + env = gym.make(env_id) + env.reset(seed=(seed + rank)) + if img_dir is not None: + env = ImageSaver(env, img_dir, rank) + env = wrap_deepmind(env) + env = WrapPyTorch(env) + return env + + return _thunk + +class WrapPyTorch(gym.ObservationWrapper): + def __init__(self, env=None): + super(WrapPyTorch, self).__init__(env) + self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32') + + def observation(self, observation): + return observation.transpose(2, 0, 1) + +# vecenv.py +class VecEnv(object): + """ + Vectorized environment base class + """ + def step(self, vac): + """ + Apply sequence of actions to sequence of environments + actions -> (observations, rewards, news) + + where 'news' is a boolean vector indicating whether each element is new. + """ + raise NotImplementedError + def reset(self): + """ + Reset all environments + """ + raise NotImplementedError + def close(self): + pass + +# subproc_vec_env.py +def worker(remote, env_fn_wrapper): + env = env_fn_wrapper.x() + while True: + cmd, data = remote.recv() + if cmd == 'step': + ob, reward, done, info = env.step(data) + if done: + ob = env.reset() + remote.send((ob, reward, done, info)) + elif cmd == 'reset': + ob = env.reset() + remote.send(ob) + elif cmd == 'close': + remote.close() + break + elif cmd == 'get_spaces': + remote.send((env.action_space, env.observation_space)) + else: + raise NotImplementedError + +class CloudpickleWrapper(object): + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + """ + def __init__(self, x): + self.x = x + def __getstate__(self): + import cloudpickle + return cloudpickle.dumps(self.x) + def __setstate__(self, ob): + import pickle + self.x = pickle.loads(ob) + +class SubprocVecEnv(VecEnv): + def __init__(self, env_fns): + """ + envs: list of gym environments to run in subprocesses + """ + nenvs = len(env_fns) + self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) + self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) + for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] + for p in self.ps: + p.start() + + self.remotes[0].send(('get_spaces', None)) + self.action_space, self.observation_space = self.remotes[0].recv() + + + def step(self, actions): + for remote, action in zip(self.remotes, actions): + remote.send(('step', action)) + results = [remote.recv() for remote in self.remotes] + obs, rews, dones, infos = zip(*results) + return np.stack(obs), np.stack(rews), np.stack(dones), infos + + def reset(self): + for remote in self.remotes: + remote.send(('reset', None)) + return np.stack([remote.recv() for remote in self.remotes]) + + def close(self): + for remote in self.remotes: + remote.send(('close', None)) + for p in self.ps: + p.join() + + @property + def num_envs(self): + return len(self.remotes) + +# Create the environment. +def make(env_name, img_dir, num_processes): + envs = SubprocVecEnv([ + make_env(env_name, img_dir, 1337, i) for i in range(num_processes) + ]) + return envs |