Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(To be discussed) (Idea) feature: multi dimensional reward #225

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
6 changes: 0 additions & 6 deletions handyrl/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@ def terminal(self):
def reward(self):
return {}

#
# Should be defined in all games
#
def outcome(self):
raise NotImplementedError()

#
# Should be defined in all games
#
Expand Down
24 changes: 14 additions & 10 deletions handyrl/envs/geister.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def forward(self, x, hidden):
h_p = torch.cat([h_p_move, h_p_set], -1)
h_v = self.head_v(h)
h_r = self.head_r(h)
h_v = torch.cat([torch.tanh(h_v), h_r], -1)

return {'policy': h_p, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden}
return {'policy': h_p, 'value': h_v, 'hidden': hidden}


class Environment(BaseEnvironment):
Expand Down Expand Up @@ -431,17 +432,17 @@ def terminal(self):
return self.win_color is not None

def reward(self):
# return immediate rewards
return {p: -0.01 for p in self.players()}
# immediate rewards
turn_rewards = [-0.01, -0.01]

def outcome(self):
# return terminal outcomes
outcomes = [0, 0]
# terminal reward
terminal_rewards = [0, 0]
if self.win_color == self.BLACK:
outcomes = [1, -1]
terminal_rewards = [1, -1]
elif self.win_color == self.WHITE:
outcomes = [-1, 1]
return {p: outcomes[idx] for idx, p in enumerate(self.players())}
terminal_rewards = [-1, 1]

return {p: [terminal_rewards[idx], turn_rewards[idx]] for idx, p in enumerate(self.players())}

def legal(self, action):
if self.turn_count < 0:
Expand Down Expand Up @@ -541,11 +542,14 @@ def net(self):
if __name__ == '__main__':
e = Environment()
for _ in range(100):
total_rewards = {}
e.reset()
while not e.terminal():
print(e)
actions = e.legal_actions()
print([e.action2str(a, e.turn()) for a in actions])
e.play(random.choice(actions))
for p, r in e.reward().items():
total_rewards[p] = total_rewards.get(p, 0) + np.array(r)
print(e)
print(e.outcome())
print(total_rewards)
5 changes: 4 additions & 1 deletion handyrl/envs/parallel_tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def turns(self):
if __name__ == '__main__':
e = Environment()
for _ in range(100):
total_rewards = {}
e.reset()
while not e.terminal():
print(e)
Expand All @@ -70,5 +71,7 @@ def turns(self):
print([e.action2str(a) for a in actions])
action_map[p] = random.choice(actions)
e.step(action_map)
for p, r in e.reward().items():
total_rewards[p] = total_rewards.get(p, 0) + r
print(e)
print(e.outcome())
print(total_rewards)
16 changes: 9 additions & 7 deletions handyrl/envs/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@ def terminal(self):
# check whether the state is terminal
return self.win_color != 0 or len(self.record) == 3 * 3

def outcome(self):
# terminal outcome
outcomes = [0, 0]
def reward(self):
rewards = [0, 0]
if self.win_color > 0:
outcomes = [1, -1]
rewards = [1, -1]
if self.win_color < 0:
outcomes = [-1, 1]
return {p: outcomes[idx] for idx, p in enumerate(self.players())}
rewards = [-1, 1]
return {p: rewards[idx] for idx, p in enumerate(self.players())}

def legal_actions(self, _=None):
# legal action list
Expand All @@ -171,11 +170,14 @@ def observation(self, player=None):
if __name__ == '__main__':
e = Environment()
for _ in range(100):
total_rewards = {}
e.reset()
while not e.terminal():
print(e)
actions = e.legal_actions()
print([e.action2str(a) for a in actions])
e.play(random.choice(actions))
for p, r in e.reward().items():
total_rewards[p] = total_rewards.get(p, 0) + r
print(e)
print(e.outcome())
print(total_rewards)
32 changes: 20 additions & 12 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import time
import multiprocessing as mp

import numpy as np

from .environment import prepare_env, make_env
from .connection import send_recv, accept_socket_connections, connect_socket_connection
from .agent import RandomAgent, RuleBasedAgent, Agent, EnsembleAgent, SoftAgent
Expand Down Expand Up @@ -43,8 +45,8 @@ def run(self):
break
if command == 'quit':
break
elif command == 'outcome':
print('outcome = %f' % args[0])
elif command == 'reward':
print('reward = %s' % args[0])
elif hasattr(self.agent, command):
if command == 'action' or command == 'observe':
view(self.env)
Expand Down Expand Up @@ -84,8 +86,10 @@ def exec_match(env, agents, critic=None, show=False, game_args={}):
''' match with shared game environment '''
if env.reset(game_args):
return None
for agent in agents.values():
total_rewards = {}
for p, agent in agents.items():
agent.reset(env, show=show)
total_rewards[p] = 0
while not env.terminal():
if show:
view(env)
Expand All @@ -103,19 +107,22 @@ def exec_match(env, agents, critic=None, show=False, game_args={}):
return None
if show:
view_transition(env)
outcome = env.outcome()
for p, reward in env.reward().items():
total_rewards[p] += np.array(reward).reshape(-1)
if show:
print('final outcome = %s' % outcome)
return outcome
print('total rewards = %s' % total_rewards)
return total_rewards


def exec_network_match(env, network_agents, critic=None, show=False, game_args={}):
''' match with divided game environment '''
if env.reset(game_args):
return None
total_rewards = {}
for p, agent in network_agents.items():
info = env.diff_info(p)
agent.update(info, True)
total_rewards[p] = 0
while not env.terminal():
if show:
view(env)
Expand All @@ -132,13 +139,14 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
agent.observe(p)
if env.step(actions):
return None
for p, reward in env.reward().items():
total_rewards[p] += np.array(reward).reshape(-1)
for p, agent in network_agents.items():
info = env.diff_info(p)
agent.update(info, False)
outcome = env.outcome()
for p, agent in network_agents.items():
agent.outcome(outcome[p])
return outcome
agent.reward(total_rewards[p])
return total_rewards


def build_agent(raw, env=None):
Expand Down Expand Up @@ -170,11 +178,11 @@ def execute(self, models, args):
else:
agents[p] = Agent(model)

outcome = exec_match(self.env, agents)
if outcome is None:
total_rewards = exec_match(self.env, agents)
if total_rewards is None:
print('None episode in evaluation!')
return None
return {'args': args, 'result': outcome, 'opponent': opponent}
return {'args': args, 'total_reward': total_rewards, 'opponent': opponent}


def wp_func(results):
Expand Down
7 changes: 5 additions & 2 deletions handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def generate(self, models, args):
# episode generation
moments = []
hidden = {}
total_rewards = {}
for player in self.env.players():
hidden[player] = models[player].init_hidden()
total_rewards[player] = 0

err = self.env.reset()
if err:
Expand Down Expand Up @@ -68,6 +70,7 @@ def generate(self, models, args):
reward = self.env.reward()
for player in self.env.players():
moment['reward'][player] = reward.get(player, None)
total_rewards[player] += np.array(reward.get(player, 0)).reshape(-1)

moment['turn'] = turn_players
moments.append(moment)
Expand All @@ -78,12 +81,12 @@ def generate(self, models, args):
for player in self.env.players():
ret = 0
for i, m in reversed(list(enumerate(moments))):
ret = (m['reward'][player] or 0) + self.args['gamma'] * ret
ret = np.array(m['reward'][player] or 0) + np.array(self.args['gamma']) * ret
moments[i]['return'][player] = ret

episode = {
'args': args, 'steps': len(moments),
'outcome': self.env.outcome(),
'total_reward': total_rewards,
'moment': [
bz2.compress(pickle.dumps(moments[i:i+self.args['compress_steps']]))
for i in range(0, len(moments), self.args['compress_steps'])
Expand Down
Loading