diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 2d391bce..c83e2a17 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -147,6 +147,12 @@ def build_agent(raw, env=None): elif raw.startswith('rulebase'): key = raw.split('-')[1] if '-' in raw else None return RuleBasedAgent(key) + + if env is not None: + model = load_model(model_path, env.net()) + agent = Agent(model) + return agent + return None @@ -383,14 +389,7 @@ def eval_main(args, argv): num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 - def resolve_agent(model_path): - agent = build_agent(model_path, env) - if agent is None: - model = load_model(model_path, env.net()) - agent = Agent(model) - return agent - - main_agent = resolve_agent(model_paths[0]) + main_agent = build_agent(model_paths[0], env) critic = None print('%d process, %d games' % (num_process, num_games)) @@ -399,7 +398,7 @@ def resolve_agent(model_path): print('seed = %d' % seed) opponent = model_paths[1] if len(model_paths) > 1 else 'random' - agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)] + agents = [main_agent] + [build_agent(opponent, env) for _ in range(len(env.players()) - 1)] evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed)