diff --git a/design_bench/oracles/exact/ant_morphology_oracle.py b/design_bench/oracles/exact/ant_morphology_oracle.py index 5de6dee..8af6458 100644 --- a/design_bench/oracles/exact/ant_morphology_oracle.py +++ b/design_bench/oracles/exact/ant_morphology_oracle.py @@ -102,7 +102,7 @@ def is_simulated(cls): return True - def protected_predict(self, x): + def protected_predict(self, x, render=False, **render_kwargs): """Score function to be implemented by oracle subclasses, where x is either a batch of designs if self.is_batched is True or is a single design when self._is_batched is False @@ -147,6 +147,8 @@ def mlp_policy(h): sum_of_rewards = np.zeros([1], dtype=np.float32) for t in range(self.rollout_horizon): obs, rew, done, info = env.step(mlp_policy(obs)) + if render: + env.render(**render_kwargs) sum_of_rewards += rew.astype(np.float32) if done: break diff --git a/design_bench/oracles/exact/dkitty_morphology_oracle.py b/design_bench/oracles/exact/dkitty_morphology_oracle.py index 73efe28..d42b8ad 100644 --- a/design_bench/oracles/exact/dkitty_morphology_oracle.py +++ b/design_bench/oracles/exact/dkitty_morphology_oracle.py @@ -102,7 +102,7 @@ def is_simulated(cls): return True - def protected_predict(self, x): + def protected_predict(self, x, render=False, **render_kwargs): """Score function to be implemented by oracle subclasses, where x is either a batch of designs if self.is_batched is True or is a single design when self._is_batched is False @@ -147,6 +147,8 @@ def mlp_policy(h): sum_of_rewards = np.zeros([1], dtype=np.float32) for t in range(self.rollout_horizon): obs, rew, done, info = env.step(mlp_policy(obs)) + if render: + env.render(**render_kwargs) sum_of_rewards += rew.astype(np.float32) if done: break diff --git a/design_bench/oracles/exact/hopper_controller_oracle.py b/design_bench/oracles/exact/hopper_controller_oracle.py index de5dcd7..c34aed8 100644 --- a/design_bench/oracles/exact/hopper_controller_oracle.py +++ b/design_bench/oracles/exact/hopper_controller_oracle.py @@ -97,7 +97,7 @@ def is_simulated(cls): return True - def protected_predict(self, x): + def protected_predict(self, x, render=False, **render_kwargs): """Score function to be implemented by oracle subclasses, where x is either a batch of designs if self.is_batched is True or is a single design when self._is_batched is False @@ -148,6 +148,8 @@ def mlp_policy(h): path_returns = np.zeros([1], dtype=np.float32) while not done: obs, rew, done, info = env.step(mlp_policy(obs)) + if render: + env.render(**render_kwargs) path_returns += rew.astype(np.float32) # return the sum of rewards for a single trajectory diff --git a/design_bench/task.py b/design_bench/task.py index 373fa98..d2561d2 100644 --- a/design_bench/task.py +++ b/design_bench/task.py @@ -795,7 +795,7 @@ def to_logits(self, x): raise ValueError("only supported on discrete datasets") return self.dataset.to_logits(x) - def predict(self, x_batch): + def predict(self, x_batch, **kwargs): """a function that accepts a batch of design values 'x' as input and for each design computes a prediction value 'y' which corresponds to the score in a model-based optimization problem @@ -816,7 +816,7 @@ def predict(self, x_batch): """ - return self.oracle.predict(x_batch) + return self.oracle.predict(x_batch, **kwargs) def oracle_to_dataset_x(self, x_batch): """Helper function for converting from designs in the format of the