From 400a9c1163a5191d3d46b04112a365052e783aeb Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 30 Nov 2022 23:45:06 +0900 Subject: [PATCH] feature: remove resolve_agent and use build_agent to build trained model agent --- handyrl/evaluation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 2d391bce..c83e2a17 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -147,6 +147,12 @@ def build_agent(raw, env=None): elif raw.startswith('rulebase'): key = raw.split('-')[1] if '-' in raw else None return RuleBasedAgent(key) + + if env is not None: + model = load_model(model_path, env.net()) + agent = Agent(model) + return agent + return None @@ -383,14 +389,7 @@ def eval_main(args, argv): num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 - def resolve_agent(model_path): - agent = build_agent(model_path, env) - if agent is None: - model = load_model(model_path, env.net()) - agent = Agent(model) - return agent - - main_agent = resolve_agent(model_paths[0]) + main_agent = build_agent(model_paths[0], env) critic = None print('%d process, %d games' % (num_process, num_games)) @@ -399,7 +398,7 @@ def resolve_agent(model_path): print('seed = %d' % seed) opponent = model_paths[1] if len(model_paths) > 1 else 'random' - agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)] + agents = [main_agent] + [build_agent(opponent, env) for _ in range(len(env.players()) - 1)] evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed)