diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 2d391bce..6ffca08e 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -3,9 +3,13 @@ # evaluation of policies or planning algorithms +import functools +import importlib import random +import re import time import multiprocessing as mp +from pathlib import Path from .environment import prepare_env, make_env from .connection import send_recv, accept_socket_connections, connect_socket_connection @@ -14,6 +18,11 @@ network_match_port = 9876 +_agent_aliases = { + 'random': RandomAgent, + 'rulebase': RuleBasedAgent, +} + def view(env, player=None): if hasattr(env, 'view'): @@ -141,13 +150,103 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={ return outcome +def split_except_quoted_field(string, delimiter=','): + """Split a string ignoring quoted fields. + + Example: + >>> agent_config = 'transformer, temperature=1.0, hoge="hoge, fuga"' + >>> split_except_quoted_field(agent_config, ',') + ['transformer', + 'temperature=1.0', + 'hoge="hoge, fuga"'] + + .. python - Split by comma and how to exclude comma from quotes in split + https://stackoverflow.com/a/64333329 + """ + delimiter_matcher = re.compile(fr"\s*{delimiter}(?=(?:[^\"']*[\"'][^\"']*[\"'])*[^\"']*$)\s*") + return delimiter_matcher.split(string) + + +def parse_args(arg_list): + def _convert_dtype(value): + try: + value = eval(value) + except: + pass + return value + + args = [] + kwargs = {} + args_end_flag = False + for arg in arg_list: + arg_split = split_except_quoted_field(arg, '=') + if len(arg_split) == 1: # args + if args_end_flag: + raise SyntaxError('positional argument follows keyword argument') + args.append(_convert_dtype(arg)) + elif len(arg_split) == 2: # kwargs + args_end_flag = True + key, value = arg_split + kwargs[key] = _convert_dtype(value) + else: + raise SyntaxError('invalid syntax') + + return args, kwargs + + +def register_agent(alias=None): + """Register a custom agent alias. + + Example: + >>> @register_agent(alias="transformer") + ... class TransformerAgent: + ... pass + >>> env = Environment({}) + >>> agent = build_agent("transformer", env) + + Note: + The agent class must be defined in the current context. + For example, if you define a class in `src/agent.py`, you need to add the following line to `main.py`. + + ```python + import src.agent + ``` + """ + + def _registered_class(cls, alias=None): + if alias is None: + alias = cls.__name__ + + # Register agent class + _agent_aliases[str(alias).lower()] = cls + + return cls + + registered_class = functools.partial(_registered_class, alias=alias) + return registered_class + + def build_agent(raw, env=None): - if raw == 'random': - return RandomAgent() - elif raw.startswith('rulebase'): - key = raw.split('-')[1] if '-' in raw else None - return RuleBasedAgent(key) - return None + agent_name, *arg_list = split_except_quoted_field(raw, ',') + args, kwargs = parse_args(arg_list) + + if agent_name.lower() in _agent_aliases: + AgentClass = _agent_aliases[agent_name.lower()] + agent = AgentClass(*args, **kwargs) + elif Path(agent_name).exists(): + # model path e.g. models/latst.pth, models/latest.pth.onnx + model = load_model(agent_name, env.net()) + agent = Agent(model, *args, **kwargs) + elif "." not in agent_name: + raise ValueError(f"Unknown agent: {agent_name}") + else: + # custom agnet e.g. agents.custom_agent.CustomAgent + module_path, model_name = agent_name.rsplit('.', maxsplit=1) + module = importlib.import_module(module_path) + CustomAgent = getattr(module, model_name) + agent = CustomAgent(*args, **kwargs) + + return agent class Evaluator: @@ -368,9 +467,6 @@ def load_model(model_path, model=None): def client_mp_child(env_args, model_path, conn): env = make_env(env_args) agent = build_agent(model_path, env) - if agent is None: - model = load_model(model_path, env.net()) - agent = Agent(model) NetworkAgentClient(agent, env, conn).run() @@ -379,18 +475,11 @@ def eval_main(args, argv): prepare_env(env_args) env = make_env(env_args) - model_paths = argv[0].split(':') if len(argv) >= 1 else ['models/latest.pth'] + model_paths = split_except_quoted_field(argv[0], ':') if len(argv) >= 1 else ['models/latest.pth'] num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 - def resolve_agent(model_path): - agent = build_agent(model_path, env) - if agent is None: - model = load_model(model_path, env.net()) - agent = Agent(model) - return agent - - main_agent = resolve_agent(model_paths[0]) + main_agent = build_agent(model_paths[0]) critic = None print('%d process, %d games' % (num_process, num_games)) @@ -399,7 +488,7 @@ def resolve_agent(model_path): print('seed = %d' % seed) opponent = model_paths[1] if len(model_paths) > 1 else 'random' - agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)] + agents = [main_agent] + [build_agent(opponent) for _ in range(len(env.players()) - 1)] evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed)