Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Specify opponent by path #337

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 108 additions & 19 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

# evaluation of policies or planning algorithms

import functools
import importlib
import random
import re
import time
import multiprocessing as mp
from pathlib import Path

from .environment import prepare_env, make_env
from .connection import send_recv, accept_socket_connections, connect_socket_connection
Expand All @@ -14,6 +18,11 @@

network_match_port = 9876

_agent_aliases = {
'random': RandomAgent,
'rulebase': RuleBasedAgent,
}


def view(env, player=None):
if hasattr(env, 'view'):
Expand Down Expand Up @@ -141,13 +150,103 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
return outcome


def split_except_quoted_field(string, delimiter=','):
"""Split a string ignoring quoted fields.

Example:
>>> agent_config = 'transformer, temperature=1.0, hoge="hoge, fuga"'
>>> split_except_quoted_field(agent_config, ',')
['transformer',
'temperature=1.0',
'hoge="hoge, fuga"']

.. python - Split by comma and how to exclude comma from quotes in split
https://stackoverflow.com/a/64333329
"""
delimiter_matcher = re.compile(fr"\s*{delimiter}(?=(?:[^\"']*[\"'][^\"']*[\"'])*[^\"']*$)\s*")
return delimiter_matcher.split(string)


def parse_args(arg_list):
def _convert_dtype(value):
try:
value = eval(value)
except:
pass
return value

args = []
kwargs = {}
args_end_flag = False
for arg in arg_list:
arg_split = split_except_quoted_field(arg, '=')
if len(arg_split) == 1: # args
if args_end_flag:
raise SyntaxError('positional argument follows keyword argument')
args.append(_convert_dtype(arg))
elif len(arg_split) == 2: # kwargs
args_end_flag = True
key, value = arg_split
kwargs[key] = _convert_dtype(value)
else:
raise SyntaxError('invalid syntax')

return args, kwargs


def register_agent(alias=None):
"""Register a custom agent alias.

Example:
>>> @register_agent(alias="transformer")
... class TransformerAgent:
... pass
>>> env = Environment({})
>>> agent = build_agent("transformer", env)

Note:
The agent class must be defined in the current context.
For example, if you define a class in `src/agent.py`, you need to add the following line to `main.py`.

```python
import src.agent
```
"""

def _registered_class(cls, alias=None):
if alias is None:
alias = cls.__name__

# Register agent class
_agent_aliases[str(alias).lower()] = cls

return cls

registered_class = functools.partial(_registered_class, alias=alias)
return registered_class


def build_agent(raw, env=None):
if raw == 'random':
return RandomAgent()
elif raw.startswith('rulebase'):
key = raw.split('-')[1] if '-' in raw else None
return RuleBasedAgent(key)
return None
agent_name, *arg_list = split_except_quoted_field(raw, ',')
args, kwargs = parse_args(arg_list)

if agent_name.lower() in _agent_aliases:
AgentClass = _agent_aliases[agent_name.lower()]
agent = AgentClass(*args, **kwargs)
elif Path(agent_name).exists():
# model path e.g. models/latst.pth, models/latest.pth.onnx
model = load_model(agent_name, env.net())
agent = Agent(model, *args, **kwargs)
elif "." not in agent_name:
raise ValueError(f"Unknown agent: {agent_name}")
else:
# custom agnet e.g. agents.custom_agent.CustomAgent
module_path, model_name = agent_name.rsplit('.', maxsplit=1)
module = importlib.import_module(module_path)
CustomAgent = getattr(module, model_name)
agent = CustomAgent(*args, **kwargs)

return agent


class Evaluator:
Expand Down Expand Up @@ -368,9 +467,6 @@ def load_model(model_path, model=None):
def client_mp_child(env_args, model_path, conn):
env = make_env(env_args)
agent = build_agent(model_path, env)
if agent is None:
model = load_model(model_path, env.net())
agent = Agent(model)
NetworkAgentClient(agent, env, conn).run()


Expand All @@ -379,18 +475,11 @@ def eval_main(args, argv):
prepare_env(env_args)
env = make_env(env_args)

model_paths = argv[0].split(':') if len(argv) >= 1 else ['models/latest.pth']
model_paths = split_except_quoted_field(argv[0], ':') if len(argv) >= 1 else ['models/latest.pth']
num_games = int(argv[1]) if len(argv) >= 2 else 100
num_process = int(argv[2]) if len(argv) >= 3 else 1

def resolve_agent(model_path):
agent = build_agent(model_path, env)
if agent is None:
model = load_model(model_path, env.net())
agent = Agent(model)
return agent

main_agent = resolve_agent(model_paths[0])
main_agent = build_agent(model_paths[0])
critic = None

print('%d process, %d games' % (num_process, num_games))
Expand All @@ -399,7 +488,7 @@ def resolve_agent(model_path):
print('seed = %d' % seed)

opponent = model_paths[1] if len(model_paths) > 1 else 'random'
agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)]
agents = [main_agent] + [build_agent(opponent) for _ in range(len(env.players()) - 1)]

evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed)

Expand Down