Skip to content

Commit

Permalink
update keys for policy dict
Browse files Browse the repository at this point in the history
  • Loading branch information
brandontrabucco committed Jun 7, 2021
1 parent 075f2da commit cfcab39
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
13 changes: 9 additions & 4 deletions design_bench/oracles/exact/ant_morphology_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,15 @@ def protected_predict(self, x):

# create a policy forward pass in numpy
def mlp_policy(h):
h = np.maximum(0.0, self.policy[0] @ h + self.policy[1])
h = np.maximum(0.0, self.policy[2] @ h + self.policy[3])
return np.tanh(np.split(
self.policy[4] @ h + self.policy[5], 2)[0])
h = np.maximum(0.0, self.policy[
"base_network.network.0.weight"] @ h + self.policy[
"base_network.network.0.bias"])
h = np.maximum(0.0, self.policy[
"base_network.network.2.weight"] @ h + self.policy[
"base_network.network.2.bias"])
return np.tanh(np.split(self.policy[
"base_network.network.4.weight"] @ h + self.policy[
"base_network.network.4.bias"], 2)[0])

# convert vectors to morphologies
env = MorphingAntEnv(expose_design=True,
Expand Down
13 changes: 9 additions & 4 deletions design_bench/oracles/exact/dkitty_morphology_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,15 @@ def protected_predict(self, x):

# create a policy forward pass in numpy
def mlp_policy(h):
h = np.maximum(0.0, self.policy[0] @ h + self.policy[1])
h = np.maximum(0.0, self.policy[2] @ h + self.policy[3])
return np.tanh(np.split(
self.policy[4] @ h + self.policy[5], 2)[0])
h = np.maximum(0.0, self.policy[
"base_network.network.0.weight"] @ h + self.policy[
"base_network.network.0.bias"])
h = np.maximum(0.0, self.policy[
"base_network.network.2.weight"] @ h + self.policy[
"base_network.network.2.bias"])
return np.tanh(np.split(self.policy[
"base_network.network.4.weight"] @ h + self.policy[
"base_network.network.4.bias"], 2)[0])

# convert vectors to morphologies
env = MorphingDKittyEnv(expose_design=True,
Expand Down

0 comments on commit cfcab39

Please sign in to comment.