diff --git a/design_bench/oracles/exact/ant_morphology_oracle.py b/design_bench/oracles/exact/ant_morphology_oracle.py index fb08156..5de6dee 100644 --- a/design_bench/oracles/exact/ant_morphology_oracle.py +++ b/design_bench/oracles/exact/ant_morphology_oracle.py @@ -125,10 +125,15 @@ def protected_predict(self, x): # create a policy forward pass in numpy def mlp_policy(h): - h = np.maximum(0.0, self.policy[0] @ h + self.policy[1]) - h = np.maximum(0.0, self.policy[2] @ h + self.policy[3]) - return np.tanh(np.split( - self.policy[4] @ h + self.policy[5], 2)[0]) + h = np.maximum(0.0, self.policy[ + "base_network.network.0.weight"] @ h + self.policy[ + "base_network.network.0.bias"]) + h = np.maximum(0.0, self.policy[ + "base_network.network.2.weight"] @ h + self.policy[ + "base_network.network.2.bias"]) + return np.tanh(np.split(self.policy[ + "base_network.network.4.weight"] @ h + self.policy[ + "base_network.network.4.bias"], 2)[0]) # convert vectors to morphologies env = MorphingAntEnv(expose_design=True, diff --git a/design_bench/oracles/exact/dkitty_morphology_oracle.py b/design_bench/oracles/exact/dkitty_morphology_oracle.py index c6f2515..73efe28 100644 --- a/design_bench/oracles/exact/dkitty_morphology_oracle.py +++ b/design_bench/oracles/exact/dkitty_morphology_oracle.py @@ -125,10 +125,15 @@ def protected_predict(self, x): # create a policy forward pass in numpy def mlp_policy(h): - h = np.maximum(0.0, self.policy[0] @ h + self.policy[1]) - h = np.maximum(0.0, self.policy[2] @ h + self.policy[3]) - return np.tanh(np.split( - self.policy[4] @ h + self.policy[5], 2)[0]) + h = np.maximum(0.0, self.policy[ + "base_network.network.0.weight"] @ h + self.policy[ + "base_network.network.0.bias"]) + h = np.maximum(0.0, self.policy[ + "base_network.network.2.weight"] @ h + self.policy[ + "base_network.network.2.bias"]) + return np.tanh(np.split(self.policy[ + "base_network.network.4.weight"] @ h + self.policy[ + "base_network.network.4.bias"], 2)[0]) # convert vectors to morphologies env = MorphingDKittyEnv(expose_design=True,