From 875313dea181bbf563c292acac7ae893b781fa2f Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 22 Dec 2022 22:56:53 +0900 Subject: [PATCH 1/2] feature: return dict from evaluation function (same key) --- handyrl/evaluation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 2d391bce..4ffe0192 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -106,7 +106,7 @@ def exec_match(env, agents, critic=None, show=False, game_args={}): outcome = env.outcome() if show: print('final outcome = %s' % outcome) - return outcome + return {'result': outcome} def exec_network_match(env, network_agents, critic=None, show=False, game_args={}): @@ -138,7 +138,7 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={ outcome = env.outcome() for p, agent in network_agents.items(): agent.outcome(outcome[p]) - return outcome + return {'result': outcome} def build_agent(raw, env=None): @@ -170,11 +170,11 @@ def execute(self, models, args): else: agents[p] = Agent(model) - outcome = exec_match(self.env, agents) - if outcome is None: + results = exec_match(self.env, agents) + if results is None: print('None episode in evaluation!') return None - return {'args': args, 'result': outcome, 'opponent': opponent} + return {'args': args, 'opponent': opponent, **results} def wp_func(results): @@ -196,10 +196,10 @@ def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue, print('*** Game %d ***' % g) agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)} if isinstance(list(agent_map.values())[0], NetworkAgent): - outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args) + results = exec_network_match(env, agent_map, critic, show=show, game_args=game_args) else: - outcome = exec_match(env, agent_map, critic, show=show, game_args=game_args) - out_queue.put((pat_idx, agent_ids, outcome)) + results = exec_match(env, agent_map, critic, show=show, game_args=game_args) + out_queue.put((pat_idx, agent_ids, results)) out_queue.put(None) @@ -246,7 +246,8 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g if ret is None: finished_cnt += 1 continue - pat_idx, agent_ids, outcome = ret + pat_idx, agent_ids, results = ret + outcome = results.get('outcome') if outcome is not None: for idx, p in enumerate(env.players()): agent_id = agent_ids[idx] From 50dcefdb020f607a3e1ec5a128ed98a21294ad65 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 22 Dec 2022 23:46:02 +0900 Subject: [PATCH 2/2] fix: output dict key outcome -> result --- handyrl/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 4ffe0192..45c4d225 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -247,7 +247,7 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g finished_cnt += 1 continue pat_idx, agent_ids, results = ret - outcome = results.get('outcome') + outcome = results.get('result') if outcome is not None: for idx, p in enumerate(env.players()): agent_id = agent_ids[idx]