Skip to content

Commit

Permalink
feature: assing cumulative worker index
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriCat committed Jan 24, 2022
1 parent 57a1967 commit 897ea32
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions handyrl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def run(self):
send_recv(self.conn, ('result', result))


def make_worker_args(args, n_ga, gaid, wid, conn):
return args, conn, wid * n_ga + gaid
def make_worker_args(args, n_ga, gaid, base_wid, wid, conn):
return args, conn, base_wid + wid * n_ga + gaid


def open_worker(args, conn, wid):
Expand All @@ -107,10 +107,12 @@ def __init__(self, args, conn, gaid):
n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers']

num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga)
base_wid = args['worker'].get('base_worker_id', 0)

worker_conns = open_multiprocessing_connections(
num_workers_per_gather,
open_worker,
functools.partial(make_worker_args, args, n_ga, gaid)
functools.partial(make_worker_args, args, n_ga, gaid, base_wid)
)

for conn in worker_conns:
Expand Down Expand Up @@ -188,6 +190,7 @@ class WorkerServer(QueueCommunicator):
def __init__(self, args):
super().__init__()
self.args = args
self.total_worker_count = 0

def run(self):
# prepare listening connections
Expand All @@ -199,6 +202,8 @@ def entry_server(port):
if conn is not None:
worker_args = conn.recv()
print('accepted connection from %s!' % worker_args['address'])
worker_args['base_worker_id'] = self.total_worker_count
self.total_worker_count += worker_args['num_parallel']
args = copy.deepcopy(self.args)
args['worker'] = worker_args
conn.send(args)
Expand Down

0 comments on commit 897ea32

Please sign in to comment.