Skip to content

Commit

Permalink
[Wrappers]: TimeAwareObservation (#1490)
Browse files Browse the repository at this point in the history
* Create time_aware_observation.py

* Update __init__.py

* Create test_time_aware_observation.py

* Update time_aware_observation.py

* Update time_aware_observation.py

* Update time_aware_observation.py

Co-authored-by: pzhokhov <peterz@openai.com>
  • Loading branch information
zuoxingdong and pzhokhov committed Nov 5, 2020
1 parent 57968ca commit 28c42b6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.filter_observation import FilterObservation
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.flatten_observation import FlattenObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
Expand Down
33 changes: 33 additions & 0 deletions gym/wrappers/test_time_aware_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

import gym
from gym.wrappers import TimeAwareObservation


@pytest.mark.parametrize('env_id', ['CartPole-v1', 'Pendulum-v0'])
def test_time_aware_observation(env_id):
env = gym.make(env_id)
wrapped_env = TimeAwareObservation(env)

assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1

obs = env.reset()
wrapped_obs = wrapped_env.reset()
assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1

wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample())
assert wrapped_env.t == 1.0
assert wrapped_obs[-1] == 1.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1

wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample())
assert wrapped_env.t == 2.0
assert wrapped_obs[-1] == 2.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1

wrapped_obs = wrapped_env.reset()
assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1
32 changes: 32 additions & 0 deletions gym/wrappers/time_aware_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np

from gym.spaces import Box
from gym import ObservationWrapper


class TimeAwareObservation(ObservationWrapper):
r"""Augment the observation with current time step in the trajectory.
.. note::
Currently it only works with one-dimensional observation space. It doesn't
support pixel observation space yet.
"""
def __init__(self, env):
super(TimeAwareObservation, self).__init__(env)
assert isinstance(env.observation_space, Box)
assert env.observation_space.dtype == np.float32
low = np.append(self.observation_space.low, 0.0)
high = np.append(self.observation_space.high, np.inf)
self.observation_space = Box(low, high, dtype=np.float32)

def observation(self, observation):
return np.append(observation, self.t)

def step(self, action):
self.t += 1
return super(TimeAwareObservation, self).step(action)

def reset(self, **kwargs):
self.t = 0
return super(TimeAwareObservation, self).reset(**kwargs)

0 comments on commit 28c42b6

Please sign in to comment.