diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 2d391bc..45c4d22 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -106,7 +106,7 @@ def exec_match(env, agents, critic=None, show=False, game_args={}): outcome = env.outcome() if show: print('final outcome = %s' % outcome) - return outcome + return {'result': outcome} def exec_network_match(env, network_agents, critic=None, show=False, game_args={}): @@ -138,7 +138,7 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={ outcome = env.outcome() for p, agent in network_agents.items(): agent.outcome(outcome[p]) - return outcome + return {'result': outcome} def build_agent(raw, env=None): @@ -170,11 +170,11 @@ def execute(self, models, args): else: agents[p] = Agent(model) - outcome = exec_match(self.env, agents) - if outcome is None: + results = exec_match(self.env, agents) + if results is None: print('None episode in evaluation!') return None - return {'args': args, 'result': outcome, 'opponent': opponent} + return {'args': args, 'opponent': opponent, **results} def wp_func(results): @@ -196,10 +196,10 @@ def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue, print('*** Game %d ***' % g) agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)} if isinstance(list(agent_map.values())[0], NetworkAgent): - outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args) + results = exec_network_match(env, agent_map, critic, show=show, game_args=game_args) else: - outcome = exec_match(env, agent_map, critic, show=show, game_args=game_args) - out_queue.put((pat_idx, agent_ids, outcome)) + results = exec_match(env, agent_map, critic, show=show, game_args=game_args) + out_queue.put((pat_idx, agent_ids, results)) out_queue.put(None) @@ -246,7 +246,8 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g if ret is None: finished_cnt += 1 continue - pat_idx, agent_ids, outcome = ret + pat_idx, agent_ids, results = ret + outcome = results.get('result') if outcome is not None: for idx, p in enumerate(env.players()): agent_id = agent_ids[idx]