From 28c42b63c81e34afda516e450719a775c9e00568 Mon Sep 17 00:00:00 2001 From: Xingdong Zuo Date: Thu, 5 Nov 2020 22:44:37 +0100 Subject: [PATCH] [Wrappers]: TimeAwareObservation (#1490) * Create time_aware_observation.py * Update __init__.py * Create test_time_aware_observation.py * Update time_aware_observation.py * Update time_aware_observation.py * Update time_aware_observation.py Co-authored-by: pzhokhov --- gym/wrappers/__init__.py | 1 + gym/wrappers/test_time_aware_observation.py | 33 +++++++++++++++++++++ gym/wrappers/time_aware_observation.py | 32 ++++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 gym/wrappers/test_time_aware_observation.py create mode 100644 gym/wrappers/time_aware_observation.py diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index 2681f4daed8..6c6c4341a93 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -3,6 +3,7 @@ from gym.wrappers.time_limit import TimeLimit from gym.wrappers.filter_observation import FilterObservation from gym.wrappers.atari_preprocessing import AtariPreprocessing +from gym.wrappers.time_aware_observation import TimeAwareObservation from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.flatten_observation import FlattenObservation from gym.wrappers.gray_scale_observation import GrayScaleObservation diff --git a/gym/wrappers/test_time_aware_observation.py b/gym/wrappers/test_time_aware_observation.py new file mode 100644 index 00000000000..923e747c3ee --- /dev/null +++ b/gym/wrappers/test_time_aware_observation.py @@ -0,0 +1,33 @@ +import pytest + +import gym +from gym.wrappers import TimeAwareObservation + + +@pytest.mark.parametrize('env_id', ['CartPole-v1', 'Pendulum-v0']) +def test_time_aware_observation(env_id): + env = gym.make(env_id) + wrapped_env = TimeAwareObservation(env) + + assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 + + obs = env.reset() + wrapped_obs = wrapped_env.reset() + assert wrapped_env.t == 0.0 + assert wrapped_obs[-1] == 0.0 + assert wrapped_obs.shape[0] == obs.shape[0] + 1 + + wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + assert wrapped_env.t == 1.0 + assert wrapped_obs[-1] == 1.0 + assert wrapped_obs.shape[0] == obs.shape[0] + 1 + + wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + assert wrapped_env.t == 2.0 + assert wrapped_obs[-1] == 2.0 + assert wrapped_obs.shape[0] == obs.shape[0] + 1 + + wrapped_obs = wrapped_env.reset() + assert wrapped_env.t == 0.0 + assert wrapped_obs[-1] == 0.0 + assert wrapped_obs.shape[0] == obs.shape[0] + 1 diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py new file mode 100644 index 00000000000..d9dbdb4462b --- /dev/null +++ b/gym/wrappers/time_aware_observation.py @@ -0,0 +1,32 @@ +import numpy as np + +from gym.spaces import Box +from gym import ObservationWrapper + + +class TimeAwareObservation(ObservationWrapper): + r"""Augment the observation with current time step in the trajectory. + + .. note:: + Currently it only works with one-dimensional observation space. It doesn't + support pixel observation space yet. + + """ + def __init__(self, env): + super(TimeAwareObservation, self).__init__(env) + assert isinstance(env.observation_space, Box) + assert env.observation_space.dtype == np.float32 + low = np.append(self.observation_space.low, 0.0) + high = np.append(self.observation_space.high, np.inf) + self.observation_space = Box(low, high, dtype=np.float32) + + def observation(self, observation): + return np.append(observation, self.t) + + def step(self, action): + self.t += 1 + return super(TimeAwareObservation, self).step(action) + + def reset(self, **kwargs): + self.t = 0 + return super(TimeAwareObservation, self).reset(**kwargs)