diff --git a/handyrl/worker.py b/handyrl/worker.py index 0cf47b63..3541f0f1 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -87,8 +87,8 @@ def run(self): send_recv(self.conn, ('result', result)) -def make_worker_args(args, n_ga, gaid, base_wid, wid, conn): - return args, conn, base_wid + wid * n_ga + gaid +def make_worker_args(args, base_wid, wid, conn): + return args, conn, base_wid + wid def open_worker(args, conn, wid): @@ -97,25 +97,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): - def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + def __init__(self, args, conn, gather_id, base_worker_id, num_workers): + print('started gather %d' % gather_id) super().__init__() - self.gather_id = gaid + self.gather_id = gather_id self.server_conn = conn self.args_queue = deque([]) self.data_map = {'model': {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] - - num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) - base_wid = args['worker'].get('base_worker_id', 0) - worker_conns = open_multiprocessing_connections( - num_workers_per_gather, + num_workers, open_worker, - functools.partial(make_worker_args, args, n_ga, gaid, base_wid) + functools.partial(make_worker_args, args, base_worker_id) ) for conn in worker_conns: @@ -168,8 +163,24 @@ def run(self): self.result_send_cnt = 0 -def gather_loop(args, conn, gaid): - gather = Gather(args, conn, gaid) +def gather_loop(args, conn, gather_id): + n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga) + args['worker']['num_parallel_per_gather'] = n_pro_w + base_worker_id = 0 + + if conn is None: + # entry + conn = connect_socket_connection(args['worker']['server_address'], 9998) + conn.send(args['worker']) + args = conn.recv() + + if gather_id == 0: # call once at every machine + print(args) + prepare_env(args['env']) + base_worker_id = args['worker'].get('base_worker_id', 0) + + gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w) gather.run() @@ -197,39 +208,22 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) + def worker_server(port): + print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port) while True: conn = next(conn_acceptor) worker_args = conn.recv() print('accepted connection from %s!' % worker_args['address']) worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] + self.total_worker_count += worker_args['num_parallel_per_gather'] args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port) - while True: - conn = next(conn_acceptor) self.add_connection(conn) print('finished worker server') - threading.Thread(target=entry_server, args=(9999,), daemon=True).start() - threading.Thread(target=worker_server, args=(9998,), daemon=True).start() - - -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args + threading.Thread(target=worker_server, args=(9998,)).start() class RemoteWorkerCluster: @@ -241,18 +235,12 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): - conn = connect_socket_connection(self.args['server_address'], 9998) - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100)