Skip to content

Commit

Permalink
Fix problem with variable space sizes/bounds in KellyCoinflipGenerali…
Browse files Browse the repository at this point in the history
…zedEnv (#2020)

* Clip/resample game parameters to be able to fix obs/action space sizes/bounds

* Make KellyCoinflip distribution clipping optional
  • Loading branch information
AlexKuhnle committed Sep 11, 2020
1 parent 811ad53 commit 8cf2685
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions gym/envs/toy_text/kellycoinflip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from scipy.stats import genpareto
from scipy.stats import genpareto, norm
import numpy as np

import gym
Expand Down Expand Up @@ -110,7 +110,9 @@ class KellyCoinflipGeneralizedEnv(gym.Env):

def __init__(self, initial_wealth=25.0, edge_prior_alpha=7, edge_prior_beta=3,
max_wealth_alpha=5.0, max_wealth_m=200.0, max_rounds_mean=300.0,
max_rounds_sd=25.0, reseed=True):
max_rounds_sd=25.0, reseed=True, clip_distributions=False):
# clip_distributions=True asserts that state and action space are not modified at reset()

# store the hyper-parameters for passing back into __init__() during resets so
# the same hyper-parameters govern the next game's parameters, as the user
# expects:
Expand All @@ -122,15 +124,31 @@ def __init__(self, initial_wealth=25.0, edge_prior_alpha=7, edge_prior_beta=3,
self.max_wealth_m = max_wealth_m
self.max_rounds_mean = max_rounds_mean
self.max_rounds_sd = max_rounds_sd
self.clip_distributions = clip_distributions

if reseed or not hasattr(self, 'np_random'):
self.seed()

# draw this game's set of parameters:
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m,
random_state=self.np_random))
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
if self.clip_distributions:
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
max_wealth = max_wealth_bound + 1.0
while max_wealth > max_wealth_bound:
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m,
random_state=self.np_random))
max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
max_rounds = max_rounds_bound + 1
while max_rounds > max_rounds_bound:
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))

else:
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m,
random_state=self.np_random))
max_wealth_bound = max_wealth
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
max_rounds_bound = max_rounds

# add an additional global variable which is the sufficient statistic for the
# Pareto distribution on wealth cap; alpha doesn't update, but x_m does, and
Expand All @@ -143,13 +161,13 @@ def __init__(self, initial_wealth=25.0, edge_prior_alpha=7, edge_prior_beta=3,
self.rounds_elapsed = 0

# the rest proceeds as before:
self.action_space = spaces.Discrete(int(max_wealth*100))
self.action_space = spaces.Discrete(int(max_wealth_bound*100))
self.observation_space = spaces.Tuple((
spaces.Box(0, max_wealth, shape=[1], dtype=np.float32), # current wealth
spaces.Discrete(max_rounds+1), # rounds elapsed
spaces.Discrete(max_rounds+1), # wins
spaces.Discrete(max_rounds+1), # losses
spaces.Box(0, max_wealth, [1], dtype=np.float32))) # maximum observed wealth
spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
spaces.Discrete(max_rounds_bound+1), # rounds elapsed
spaces.Discrete(max_rounds_bound+1), # wins
spaces.Discrete(max_rounds_bound+1), # losses
spaces.Box(0, max_wealth_bound, [1], dtype=np.float32))) # maximum observed wealth
self.reward_range = (0, max_wealth)
self.edge = edge
self.wealth = self.initial_wealth
Expand Down Expand Up @@ -195,7 +213,8 @@ def reset(self):
max_wealth_m=self.max_wealth_m,
max_rounds_mean=self.max_rounds_mean,
max_rounds_sd=self.max_rounds_sd,
reseed=False)
reseed=False,
clip_distributions=self.clip_distributions)
return self._get_obs()

def render(self, mode='human'):
Expand Down

0 comments on commit 8cf2685

Please sign in to comment.