Skip to content

Commit

Permalink
adding render argument
Browse files Browse the repository at this point in the history
  • Loading branch information
brandontrabucco committed Jun 7, 2021
1 parent cfcab39 commit a383866
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
4 changes: 3 additions & 1 deletion design_bench/oracles/exact/ant_morphology_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_simulated(cls):

return True

def protected_predict(self, x):
def protected_predict(self, x, render=False, **render_kwargs):
"""Score function to be implemented by oracle subclasses, where x is
either a batch of designs if self.is_batched is True or is a
single design when self._is_batched is False
Expand Down Expand Up @@ -147,6 +147,8 @@ def mlp_policy(h):
sum_of_rewards = np.zeros([1], dtype=np.float32)
for t in range(self.rollout_horizon):
obs, rew, done, info = env.step(mlp_policy(obs))
if render:
env.render(**render_kwargs)
sum_of_rewards += rew.astype(np.float32)
if done:
break
Expand Down
4 changes: 3 additions & 1 deletion design_bench/oracles/exact/dkitty_morphology_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_simulated(cls):

return True

def protected_predict(self, x):
def protected_predict(self, x, render=False, **render_kwargs):
"""Score function to be implemented by oracle subclasses, where x is
either a batch of designs if self.is_batched is True or is a
single design when self._is_batched is False
Expand Down Expand Up @@ -147,6 +147,8 @@ def mlp_policy(h):
sum_of_rewards = np.zeros([1], dtype=np.float32)
for t in range(self.rollout_horizon):
obs, rew, done, info = env.step(mlp_policy(obs))
if render:
env.render(**render_kwargs)
sum_of_rewards += rew.astype(np.float32)
if done:
break
Expand Down
4 changes: 3 additions & 1 deletion design_bench/oracles/exact/hopper_controller_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_simulated(cls):

return True

def protected_predict(self, x):
def protected_predict(self, x, render=False, **render_kwargs):
"""Score function to be implemented by oracle subclasses, where x is
either a batch of designs if self.is_batched is True or is a
single design when self._is_batched is False
Expand Down Expand Up @@ -148,6 +148,8 @@ def mlp_policy(h):
path_returns = np.zeros([1], dtype=np.float32)
while not done:
obs, rew, done, info = env.step(mlp_policy(obs))
if render:
env.render(**render_kwargs)
path_returns += rew.astype(np.float32)

# return the sum of rewards for a single trajectory
Expand Down
4 changes: 2 additions & 2 deletions design_bench/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def to_logits(self, x):
raise ValueError("only supported on discrete datasets")
return self.dataset.to_logits(x)

def predict(self, x_batch):
def predict(self, x_batch, **kwargs):
"""a function that accepts a batch of design values 'x' as input and
for each design computes a prediction value 'y' which corresponds
to the score in a model-based optimization problem
Expand All @@ -816,7 +816,7 @@ def predict(self, x_batch):
"""

return self.oracle.predict(x_batch)
return self.oracle.predict(x_batch, **kwargs)

def oracle_to_dataset_x(self, x_batch):
"""Helper function for converting from designs in the format of the
Expand Down

0 comments on commit a383866

Please sign in to comment.