Skip to content

Commit

Permalink
feature: entry in gather_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriCat committed Jan 26, 2022
1 parent f926365 commit 103dca1
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions handyrl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def run(self):
send_recv(self.conn, ('result', result))


def make_worker_args(args, n_ga, gaid, base_wid, wid, conn):
return args, conn, base_wid + wid * n_ga + gaid
def make_worker_args(args, base_wid, wid, conn):
return args, conn, base_wid + wid


def open_worker(args, conn, wid):
Expand All @@ -94,25 +94,20 @@ def open_worker(args, conn, wid):


class Gather(QueueCommunicator):
def __init__(self, args, conn, gaid):
print('started gather %d' % gaid)
def __init__(self, args, conn, gather_id, base_worker_id, num_workers):
print('started gather %d' % gather_id)
super().__init__()
self.gather_id = gaid
self.gather_id = gather_id
self.server_conn = conn
self.args_queue = deque([])
self.data_map = {'model': {}}
self.result_send_map = {}
self.result_send_cnt = 0

n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers']

num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga)
base_wid = args['worker'].get('base_worker_id', 0)

worker_conns = open_multiprocessing_connections(
num_workers_per_gather,
num_workers,
open_worker,
functools.partial(make_worker_args, args, n_ga, gaid, base_wid)
functools.partial(make_worker_args, args, base_worker_id)
)

for conn in worker_conns:
Expand Down Expand Up @@ -162,9 +157,25 @@ def run(self):
self.result_send_cnt = 0


def gather_loop(args, conn, gaid):
def gather_loop(args, conn, gather_id):
n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers']
n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga)
args['worker']['num_parallel_per_gather'] = n_pro_w
base_worker_id = 0

if conn is None:
# entry
conn = connect_socket_connection(args['worker']['server_address'], 9998)
conn.send(args['worker'])
args = conn.recv()

if gather_id == 0: # call once at every machine
print(args)
prepare_env(args['env'])
base_worker_id = args['worker'].get('base_worker_id', 0)

try:
gather = Gather(args, conn, gaid)
gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w)
gather.run()
finally:
gather.shutdown()
Expand Down Expand Up @@ -203,7 +214,7 @@ def worker_server(port):
worker_args = conn.recv()
print('accepted connection from %s!' % worker_args['address'])
worker_args['base_worker_id'] = self.total_worker_count
self.total_worker_count += worker_args['num_parallel']
self.total_worker_count += worker_args['num_parallel_per_gather']
args = copy.deepcopy(self.args)
args['worker'] = worker_args
conn.send(args)
Expand All @@ -228,18 +239,8 @@ def run(self):
process = []
try:
for i in range(self.args['num_gathers']):
# entry
conn = connect_socket_connection(self.args['server_address'], 9998)
conn.send(self.args)
args = conn.recv()

if i == 0: # call once at every machine
print(args)
prepare_env(args['env'])

p = mp.Process(target=gather_loop, args=(args, conn, i))
p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i))
p.start()
conn.close()
process.append(p)
while True:
time.sleep(100)
Expand Down

0 comments on commit 103dca1

Please sign in to comment.