-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Wrappers]: TimeAwareObservation (#1490)
* Create time_aware_observation.py * Update __init__.py * Create test_time_aware_observation.py * Update time_aware_observation.py * Update time_aware_observation.py * Update time_aware_observation.py Co-authored-by: pzhokhov <peterz@openai.com>
- Loading branch information
1 parent
57968ca
commit 28c42b6
Showing
3 changed files
with
66 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pytest | ||
|
||
import gym | ||
from gym.wrappers import TimeAwareObservation | ||
|
||
|
||
@pytest.mark.parametrize('env_id', ['CartPole-v1', 'Pendulum-v0']) | ||
def test_time_aware_observation(env_id): | ||
env = gym.make(env_id) | ||
wrapped_env = TimeAwareObservation(env) | ||
|
||
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 | ||
|
||
obs = env.reset() | ||
wrapped_obs = wrapped_env.reset() | ||
assert wrapped_env.t == 0.0 | ||
assert wrapped_obs[-1] == 0.0 | ||
assert wrapped_obs.shape[0] == obs.shape[0] + 1 | ||
|
||
wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) | ||
assert wrapped_env.t == 1.0 | ||
assert wrapped_obs[-1] == 1.0 | ||
assert wrapped_obs.shape[0] == obs.shape[0] + 1 | ||
|
||
wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) | ||
assert wrapped_env.t == 2.0 | ||
assert wrapped_obs[-1] == 2.0 | ||
assert wrapped_obs.shape[0] == obs.shape[0] + 1 | ||
|
||
wrapped_obs = wrapped_env.reset() | ||
assert wrapped_env.t == 0.0 | ||
assert wrapped_obs[-1] == 0.0 | ||
assert wrapped_obs.shape[0] == obs.shape[0] + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy as np | ||
|
||
from gym.spaces import Box | ||
from gym import ObservationWrapper | ||
|
||
|
||
class TimeAwareObservation(ObservationWrapper): | ||
r"""Augment the observation with current time step in the trajectory. | ||
.. note:: | ||
Currently it only works with one-dimensional observation space. It doesn't | ||
support pixel observation space yet. | ||
""" | ||
def __init__(self, env): | ||
super(TimeAwareObservation, self).__init__(env) | ||
assert isinstance(env.observation_space, Box) | ||
assert env.observation_space.dtype == np.float32 | ||
low = np.append(self.observation_space.low, 0.0) | ||
high = np.append(self.observation_space.high, np.inf) | ||
self.observation_space = Box(low, high, dtype=np.float32) | ||
|
||
def observation(self, observation): | ||
return np.append(observation, self.t) | ||
|
||
def step(self, action): | ||
self.t += 1 | ||
return super(TimeAwareObservation, self).step(action) | ||
|
||
def reset(self, **kwargs): | ||
self.t = 0 | ||
return super(TimeAwareObservation, self).reset(**kwargs) |