From 2d385bc698b196623f41431a6d8276c9134b8295 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 26 Jan 2022 12:25:08 +0900 Subject: [PATCH] feature: entry in gather_loop --- handyrl/agent.py | 17 ++++------ handyrl/environment.py | 6 ++++ handyrl/envs/geister.py | 18 ++++------- handyrl/evaluation.py | 8 +++-- handyrl/generation.py | 45 ++++++++++++++------------ handyrl/model.py | 6 +++- handyrl/train.py | 11 +++---- handyrl/worker.py | 72 +++++++++++++++++------------------------ 8 files changed, 88 insertions(+), 95 deletions(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index f26dfb87..4af84f5e 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -41,11 +41,10 @@ def print_outputs(env, prob, v): class Agent: - def __init__(self, model, observation=False, temperature=0.0): + def __init__(self, model, temperature=0.0): # model might be a neural net, or some planning algorithm such as game tree search self.model = model self.hidden = None - self.observation = observation self.temperature = temperature def reset(self, env, show=False): @@ -75,12 +74,10 @@ def action(self, env, player, show=False): return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] def observe(self, env, player, show=False): - v = None - if self.observation: - outputs = self.plan(env.observation(player)) - v = outputs.get('value', None) - if show: - print_outputs(env, None, v) + outputs = self.plan(env.observation(player)) + v = outputs.get('value', None) + if show: + print_outputs(env, None, v) return v if v is not None else [0.0] @@ -103,5 +100,5 @@ def plan(self, obs): class SoftAgent(Agent): - def __init__(self, model, observation=False): - super().__init__(model, observation=observation, temperature=1.0) + def __init__(self, model): + super().__init__(model, temperature=1.0) diff --git a/handyrl/environment.py b/handyrl/environment.py index 9bca1713..02bb836c 100755 --- a/handyrl/environment.py +++ b/handyrl/environment.py @@ -77,6 +77,12 @@ def turn(self): def turns(self): return [self.turn()] + # + # Should be defined if players except turn player also observe game states + # + def observers(self): + return [] + # # Should be defined in all games # diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 9a395659..a82bd6af 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -34,16 +34,10 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias): ) def init_hidden(self, input_size, batch_size): - if batch_size is None: # for inference - return tuple([ - np.zeros((self.hidden_dim, *input_size), dtype=np.float32), - np.zeros((self.hidden_dim, *input_size), dtype=np.float32) - ]) - else: # for training - return tuple([ - torch.zeros(*batch_size, self.hidden_dim, *input_size), - torch.zeros(*batch_size, self.hidden_dim, *input_size) - ]) + return tuple([ + torch.zeros(*batch_size, self.hidden_dim, *input_size), + torch.zeros(*batch_size, self.hidden_dim, *input_size) + ]) def forward(self, input_tensor, cur_state): h_cur, c_cur = cur_state @@ -150,7 +144,7 @@ def __init__(self): self.head_v = ScalarHead((filters * 2, 6, 6), 1, 1) self.head_r = ScalarHead((filters * 2, 6, 6), 1, 1) - def init_hidden(self, batch_size=None): + def init_hidden(self, batch_size=[]): return self.body.init_hidden(self.input_size[1:], batch_size) def forward(self, x, hidden): @@ -453,6 +447,8 @@ def legal(self, action): if self.turn_count < 0: layout = action - 4 * 6 * 6 return 0 <= layout < 70 + elif not 0 <= action < 4 * 6 * 6: + return False pos_from = self.action2from(action, self.color) pos_to = self.action2to(action, self.color) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index a6d0ddf1..b3a7cb4a 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -88,11 +88,12 @@ def exec_match(env, agents, critic, show=False, game_args={}): if show and critic is not None: print('cv = ', critic.observe(env, None, show=False)[0]) turn_players = env.turns() + observers = env.observers() actions = {} for p, agent in agents.items(): if p in turn_players: actions[p] = agent.action(env, p, show=show) - else: + elif p in observers: agent.observe(env, p, show=show) if env.step(actions): return None @@ -117,12 +118,13 @@ def exec_network_match(env, network_agents, critic, show=False, game_args={}): if show and critic is not None: print('cv = ', critic.observe(env, None, show=False)[0]) turn_players = env.turns() + observers = env.observers() actions = {} for p, agent in network_agents.items(): if p in turn_players: action = agent.action(p) actions[p] = env.str2action(action, p) - else: + elif p in observers: agent.observe(p) if env.step(actions): return None @@ -161,7 +163,7 @@ def execute(self, models, args): if model is None: agents[p] = build_agent(opponent, self.env) else: - agents[p] = Agent(model, self.args['observation']) + agents[p] = Agent(model) outcome = exec_match(self.env, agents, None) if outcome is None: diff --git a/handyrl/generation.py b/handyrl/generation.py index 63b7e553..3ad3b76e 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -33,28 +33,31 @@ def generate(self, models, args): moment = {key: {p: None for p in self.env.players()} for key in moment_keys} turn_players = self.env.turns() + observers = self.env.observers() for player in self.env.players(): - if player in turn_players or self.args['observation']: - obs = self.env.observation(player) - model = models[player] - outputs = model.inference(obs, hidden[player]) - hidden[player] = outputs.get('hidden', None) - v = outputs.get('value', None) - - moment['observation'][player] = obs - moment['value'][player] = v - - if player in turn_players: - p_ = outputs['policy'] - legal_actions = self.env.legal_actions(player) - action_mask = np.ones_like(p_) * 1e32 - action_mask[legal_actions] = 0 - p = p_ - action_mask - action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0] - - moment['policy'][player] = p - moment['action_mask'][player] = action_mask - moment['action'][player] = action + if player not in turn_players + observers: + continue + + obs = self.env.observation(player) + model = models[player] + outputs = model.inference(obs, hidden[player]) + hidden[player] = outputs.get('hidden', None) + v = outputs.get('value', None) + + moment['observation'][player] = obs + moment['value'][player] = v + + if player in turn_players: + p_ = outputs['policy'] + legal_actions = self.env.legal_actions(player) + action_mask = np.ones_like(p_) * 1e32 + action_mask[legal_actions] = 0 + p = p_ - action_mask + action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0] + + moment['policy'][player] = p + moment['action_mask'][player] = action_mask + moment['action'][player] = action err = self.env.step(moment['action']) if err: diff --git a/handyrl/model.py b/handyrl/model.py index 621d703f..9eb7b94b 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -37,7 +37,11 @@ def __init__(self, model): def init_hidden(self, batch_size=None): if hasattr(self.model, 'init_hidden'): - return self.model.init_hidden(batch_size) + if batch_size is None: # for inference + hidden = self.model.init_hidden([]) + return map_r(hidden, lambda h: h.detach().numpy() if isinstance(h, torch.Tensor) else h) + else: # for training + return self.model.init_hidden(batch_size) return None def forward(self, *args, **kwargs): diff --git a/handyrl/train.py b/handyrl/train.py index 79ae2afc..0cee1013 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -26,7 +26,6 @@ from .model import to_torch, to_gpu, ModelWrapper from .losses import compute_target from .connection import MultiProcessJobExecutor -from .connection import accept_socket_connections from .worker import WorkerCluster, WorkerServer @@ -60,7 +59,7 @@ def replace_none(a, b): obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]]) # template for padding - # data that is chainge by training configuration + # data that is changed by training configuration if args['turn_based_training'] and not args['observation']: obs = [[m['observation'][m['turn'][0]]] for m in moments] p = np.array([[m['policy'][m['turn'][0]]] for m in moments]) @@ -154,7 +153,7 @@ def forward_prediction(model, hidden, batch, args): outputs_ = model(obs, hidden_) for k, o in outputs_.items(): if k == 'hidden': - next_hidden = outputs_['hidden'] + next_hidden = o else: outputs[k] = outputs.get(k, []) + [o] next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:])) # (..., B, P or 1, ...) @@ -349,8 +348,8 @@ def shutdown(self): def train(self): if self.optimizer is None: # non-parametric model - print() - return + time.sleep(0.1) + return self.model batch_cnt, data_cnt, loss_sum = 0, 0, {} if self.gpu > 0: @@ -395,7 +394,7 @@ def run(self): if len(self.episodes) < self.args['minimum_episodes']: time.sleep(1) continue - if self.steps == 0: + if self.steps == 0 and self.optimizer is not None: self.batcher.run() print('started training') model = self.train() diff --git a/handyrl/worker.py b/handyrl/worker.py index 7048097c..a76d875b 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -84,8 +84,8 @@ def run(self): send_recv(self.conn, ('result', result)) -def make_worker_args(args, n_ga, gaid, base_wid, wid, conn): - return args, conn, base_wid + wid * n_ga + gaid +def make_worker_args(args, base_wid, wid, conn): + return args, conn, base_wid + wid def open_worker(args, conn, wid): @@ -94,25 +94,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): - def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + def __init__(self, args, conn, gather_id, base_worker_id, num_workers): + print('started gather %d' % gather_id) super().__init__() - self.gather_id = gaid + self.gather_id = gather_id self.server_conn = conn self.args_queue = deque([]) self.data_map = {'model': {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] - - num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) - base_wid = args['worker'].get('base_worker_id', 0) - worker_conns = open_multiprocessing_connections( - num_workers_per_gather, + num_workers, open_worker, - functools.partial(make_worker_args, args, n_ga, gaid, base_wid) + functools.partial(make_worker_args, args, base_worker_id) ) for conn in worker_conns: @@ -162,9 +157,25 @@ def run(self): self.result_send_cnt = 0 -def gather_loop(args, conn, gaid): +def gather_loop(args, conn, gather_id): + n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga) + args['worker']['num_parallel_per_gather'] = n_pro_w + base_worker_id = 0 + + if conn is None: + # entry + conn = connect_socket_connection(args['worker']['server_address'], 9998) + conn.send(args['worker']) + args = conn.recv() + + if gather_id == 0: # call once at every machine + print(args) + prepare_env(args['env']) + base_worker_id = args['worker'].get('base_worker_id', 0) + try: - gather = Gather(args, conn, gaid) + gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w) gather.run() finally: gather.shutdown() @@ -194,8 +205,8 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) + def worker_server(port): + print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) while not self.shutdown_flag: conn = next(conn_acceptor) @@ -203,37 +214,18 @@ def entry_server(port): worker_args = conn.recv() print('accepted connection from %s!' % worker_args['address']) worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] + self.total_worker_count += worker_args['num_parallel_per_gather'] args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - print('started worker server %d' % port) - while not self.shutdown_flag: # use super class's flag - conn = next(conn_acceptor) - if conn is not None: self.add_connection(conn) print('finished worker server') # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() self.threads[-1].start() -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args - - class RemoteWorkerCluster: def __init__(self, args): args['address'] = gethostname() @@ -243,18 +235,12 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): - conn = connect_socket_connection(self.args['server_address'], 9998) - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100)