diff --git a/handyrl/worker.py b/handyrl/worker.py index 796196df..f791f4c4 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -191,8 +191,8 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) + def worker_server(port): + print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) while not self.shutdown_flag: conn = next(conn_acceptor) @@ -202,33 +202,14 @@ def entry_server(port): args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: - conn = next(conn_acceptor) - if conn is not None: self.add_connection(conn) print('finished worker server') # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() self.threads[-1].start() -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args - - class RemoteWorkerCluster: def __init__(self, args): args['address'] = gethostname() @@ -238,15 +219,19 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): + # entry conn = connect_socket_connection(self.args['server_address'], 9998) + conn.send(self.args) + args = conn.recv() + + if i == 0: # call once at every machine + print(args) + prepare_env(args['env']) + p = mp.Process(target=gather_loop, args=(args, conn, i)) p.start() conn.close()