Skip to content

Commit

Permalink
Flattened space/point dtype mismatch (#2070)
Browse files Browse the repository at this point in the history
* add test showing mismatch in flattened space dtype and flattened point dtype

* fix mismatch in flattened space dtype and flattened point dtype

* fix typo

* enhance test to detect when flattened dtype is incorrect

* fix incorrect flattened dtype

* remove inaccurate comment

* change flatten to always use space.dtype

* added testing for unflattened dtypes

* fix unflatten dtypes

* swtich flatten_space to use space.dtype for hardcoded space dtypes

* fix failure in python 3.5
  • Loading branch information
wmmc88 committed Nov 6, 2020
1 parent 28c42b6 commit eee9b28
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
56 changes: 54 additions & 2 deletions gym/spaces/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections import OrderedDict

import numpy as np
import pytest

from gym.spaces import utils
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils


@pytest.mark.parametrize(["space", "flatdim"], [
Expand Down Expand Up @@ -118,3 +118,55 @@ def compare_nested(left, right):
return res
else:
return left == right

'''
Expecteded flattened types are based off:
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
internally by numpy when np.concatenate is called.
'''
@pytest.mark.parametrize(["original_space", "expected_flattened_dtype"], [
(Discrete(3), np.int64),
(Box(low=0., high=np.inf, shape=(2, 2)), np.float32),
(Box(low=0., high=np.inf, shape=(2, 2), dtype=np.float16), np.float16),
(Tuple([Discrete(5), Discrete(10)]), np.int64),
(Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), np.float64),
(Tuple((Discrete(5), Discrete(2), Discrete(2))), np.int64),
(MultiDiscrete([2, 2, 100]), np.int64),
(MultiBinary(10), np.int8),
(Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16)}), np.float64),
])
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)

original_sample = original_space.sample()
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)

assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \
"{}".format(expected_flattened_dtype)

assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \
"flattened_sample's dtype "

compare_sample_types(original_space, original_sample, unflattened_sample)


def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), "Expected unflattened_sample to be an int. unflattened_sample: " \
"{} original_sample: {}".format(unflattened_sample, original_sample)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):
compare_sample_types(original_space.spaces[index], original_sample[index], unflattened_sample[index])
elif isinstance(original_space, Dict):
for key, space in original_space.spaces.items():
compare_sample_types(space, original_sample[key], unflattened_sample[key])
else:
assert unflattened_sample.dtype == original_sample.dtype, "Expected unflattened_sample's dtype to equal " \
"original_sample's dtype. unflattened_sample: " \
"{} original_sample: {}".format(unflattened_sample,
original_sample)
29 changes: 18 additions & 11 deletions gym/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def flatten(space, x):
``gym.spaces``.
"""
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
elif isinstance(space, Discrete):
onehot = np.zeros(space.n, dtype=np.float32)
onehot[x] = 1.0
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1
return onehot
elif isinstance(space, Tuple):
return np.concatenate(
Expand All @@ -55,9 +55,9 @@ def flatten(space, x):
return np.concatenate(
[flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
elif isinstance(space, MultiDiscrete):
return np.asarray(x).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
else:
raise NotImplementedError

Expand All @@ -73,7 +73,7 @@ def unflatten(space, x):
defined in ``gym.spaces``.
"""
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
elif isinstance(space, Discrete):
return int(np.nonzero(x)[0][0])
elif isinstance(space, Tuple):
Expand All @@ -94,9 +94,9 @@ def unflatten(space, x):
]
return OrderedDict(list_unflattened)
elif isinstance(space, MultiBinary):
return np.asarray(x).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
elif isinstance(space, MultiDiscrete):
return np.asarray(x).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
else:
raise NotImplementedError

Expand Down Expand Up @@ -140,26 +140,33 @@ def flatten_space(space):
True
"""
if isinstance(space, Box):
return Box(space.low.flatten(), space.high.flatten())
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
if isinstance(space, Discrete):
return Box(low=0, high=1, shape=(space.n, ))
return Box(low=0, high=1, shape=(space.n, ), dtype=space.dtype)
if isinstance(space, Tuple):
space = [flatten_space(s) for s in space.spaces]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space])
)
if isinstance(space, Dict):
space = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space])
)
if isinstance(space, MultiBinary):
return Box(low=0, high=1, shape=(space.n, ))
return Box(low=0,
high=1,
shape=(space.n, ),
dtype=space.dtype
)
if isinstance(space, MultiDiscrete):
return Box(
low=np.zeros_like(space.nvec),
high=space.nvec,
dtype=space.dtype
)
raise NotImplementedError

0 comments on commit eee9b28

Please sign in to comment.