diff --git a/handyrl/worker.py b/handyrl/worker.py index 0cf47b6..1b8bb85 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -197,36 +197,31 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) - conn_acceptor = accept_socket_connections(port=port) - while True: - conn = next(conn_acceptor) - worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) - worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] - args = copy.deepcopy(self.args) - args['worker'] = worker_args - conn.send(args) - conn.close() - print('finished entry server') - def worker_server(port): print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port) while True: conn = next(conn_acceptor) - self.add_connection(conn) + type, worker_args = conn.recv() + if type == 'entry': + print('accepted connection from %s!' % worker_args['address']) + worker_args['base_worker_id'] = self.total_worker_count + self.total_worker_count += worker_args['num_parallel'] + args = copy.deepcopy(self.args) + args['worker'] = worker_args + conn.send(args) + conn.close() + else: + self.add_connection(conn) print('finished worker server') - threading.Thread(target=entry_server, args=(9999,), daemon=True).start() + # use thread list of super class threading.Thread(target=worker_server, args=(9998,), daemon=True).start() def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) + conn = connect_socket_connection(worker_args['server_address'], 9998) + conn.send(('entry', worker_args)) args = conn.recv() conn.close() return args @@ -250,6 +245,7 @@ def run(self): try: for i in range(self.args['num_gathers']): conn = connect_socket_connection(self.args['server_address'], 9998) + conn.send(('worker', None)) p = mp.Process(target=gather_loop, args=(args, conn, i)) p.start() conn.close()