diff --git a/handyrl/worker.py b/handyrl/worker.py index 0cf47b6..5c6a46c 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -3,8 +3,8 @@ # worker and gather +import base64 import random -import threading import time import functools from socket import gethostname @@ -17,7 +17,6 @@ from .environment import prepare_env, make_env from .connection import QueueCommunicator from .connection import send_recv, open_multiprocessing_connections -from .connection import connect_socket_connection, accept_socket_connections from .evaluation import Evaluator from .generation import Generator from .model import ModelWrapper, RandomModel @@ -87,8 +86,8 @@ def run(self): send_recv(self.conn, ('result', result)) -def make_worker_args(args, n_ga, gaid, base_wid, wid, conn): - return args, conn, base_wid + wid * n_ga + gaid +def make_worker_args(args, base_wid, wid, conn): + return args, conn, base_wid + wid def open_worker(args, conn, wid): @@ -97,25 +96,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): - def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + def __init__(self, args, conn, gather_id, base_worker_id, num_workers): + print('started gather %d' % gather_id) super().__init__() - self.gather_id = gaid + self.gather_id = gather_id self.server_conn = conn self.args_queue = deque([]) self.data_map = {'model': {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] - - num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) - base_wid = args['worker'].get('base_worker_id', 0) - worker_conns = open_multiprocessing_connections( - num_workers_per_gather, + num_workers, open_worker, - functools.partial(make_worker_args, args, n_ga, gaid, base_wid) + functools.partial(make_worker_args, args, base_worker_id) ) for conn in worker_conns: @@ -168,8 +162,27 @@ def run(self): self.result_send_cnt = 0 -def gather_loop(args, conn, gaid): - gather = Gather(args, conn, gaid) +def gather_loop(args, conn, gather_id): + n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga) + args['worker']['num_parallel_per_gather'] = n_pro_w + base_worker_id = 0 + + if conn is None: + # entry + port = int(args['worker'].get('server_port', 9998)) + conn = connect_websocket_connection(args['worker']['server_address'], port) + + conn.send(('entry', args['worker'])) + args = conn.recv() + print('entry finished') + + if gather_id == 0: # call once at every machine + print(args) + prepare_env(args['env']) + base_worker_id = args['worker'].get('base_worker_id', 0) + + gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w) gather.run() @@ -189,47 +202,102 @@ def run(self): self.add_connection(conn0) -class WorkerServer(QueueCommunicator): +import base64 +import queue +import socket +from websocket import create_connection +from websocket_server import WebsocketServer + + +class WebsocketConnection: + def __init__(self, conn): + self.conn = conn + + @staticmethod + def dumps(data): + return base64.b64encode(pickle.dumps(data)) + + @staticmethod + def loads(message): + return pickle.loads(base64.b64decode(message)) + + def send(self, data): + message = self.dumps(data) + self.conn.send(message) + + def recv(self): + message = self.conn.recv() + return self.loads(message) + + def close(self): + self.conn.close() + + +def connect_websocket_connection(host, port): + host = socket.gethostbyname(host) + conn = create_connection('ws://%s:%d' % (host, port)) + return WebsocketConnection(conn) + + +class WorkerServer(WebsocketServer): def __init__(self, args): - super().__init__() + port = int(args['worker'].get('server_port', 9998)) + super().__init__(port=port, host='0.0.0.0') + self.input_queue = queue.Queue(maxsize=256) + self.output_queue = queue.Queue(maxsize=256) + self.shutdown_flag = False + self.args = args self.total_worker_count = 0 - def run(self): - # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) - conn_acceptor = accept_socket_connections(port=port) - while True: - conn = next(conn_acceptor) - worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) - worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] - args = copy.deepcopy(self.args) - args['worker'] = worker_args - conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port) - while True: - conn = next(conn_acceptor) - self.add_connection(conn) - print('finished worker server') - - threading.Thread(target=entry_server, args=(9999,), daemon=True).start() - threading.Thread(target=worker_server, args=(9998,), daemon=True).start() + def connection_count(self): + return len(self.clients) - -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args + def run(self): + self.set_fn_new_client(self._new_client) + self.set_fn_message_received(self._message_received) + self.run_forever(threaded=True) + + def shutdown(self): + self.shutdown_flag = True + self.shutdown_gracefully() + + def recv(self, timeout=None): + return self.input_queue.get(timeout=timeout) + + def send(self, client, send_data): + self.output_queue.put((client, send_data)) + + @staticmethod + def _new_client(client, server): + print('New client {}:{} has joined.'.format(client['address'][0], client['address'][1])) + + @staticmethod + def _message_received(client, server, message): + message_ = WebsocketConnection.loads(message) + if message_[0] == 'entry': + worker_args = message_[1] + print('accepted connection from %s' % worker_args['address']) + args = copy.deepcopy(server.args) + worker_args['base_worker_id'] = server.total_worker_count + server.total_worker_count += worker_args['num_parallel_per_gather'] + args['worker'] = worker_args + reply_message = args + else: + while not server.shutdown_flag: + try: + server.input_queue.put((client, message_), timeout=0.3) + break + except queue.Full: + pass + while not server.shutdown_flag: + try: + client, reply_message = server.output_queue.get(timeout=0.3) + break + except queue.Empty: + continue + reply_message_ = WebsocketConnection.dumps(reply_message) + server.send_message(client, reply_message_) class RemoteWorkerCluster: @@ -241,18 +309,12 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): - conn = connect_socket_connection(self.args['server_address'], 9998) - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100) diff --git a/requirements.txt b/requirements.txt index 90443c0..5c353d4 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ numpy torch pytest psutil +websocket-server +websocket-client