From c9cbbdbde04423e1e7a761b7dac18f6ff89a2081 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 5 Jan 2023 19:11:06 +0900 Subject: [PATCH 1/2] feature: model pool in each worker --- handyrl/worker.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index 0cf47b63..ebd676ab 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -29,7 +29,7 @@ def __init__(self, args, conn, wid): self.worker_id = wid self.args = args self.conn = conn - self.latest_model = -1, None + self.model_pool = {} self.env = make_env({**args['env'], 'id': wid}) self.generator = Generator(self.env, self.args) @@ -41,27 +41,20 @@ def __del__(self): print('closed worker %d' % self.worker_id) def _gather_models(self, model_ids): - model_pool = {} for model_id in model_ids: - if model_id not in model_pool: - if model_id < 0: - model_pool[model_id] = None - elif model_id == self.latest_model[0]: - # use latest model - model_pool[model_id] = self.latest_model[1] - else: - # get model from server - model = pickle.loads(send_recv(self.conn, ('model', model_id))) - if model_id == 0: - # use random model - self.env.reset() - obs = self.env.observation(self.env.players()[0]) - model = RandomModel(model, obs) - model_pool[model_id] = ModelWrapper(model) - # update latest model - if model_id > self.latest_model[0]: - self.latest_model = model_id, model_pool[model_id] - return model_pool + if model_id is not None and model_id not in self.model_pool: + # get model from server + model = pickle.loads(send_recv(self.conn, ('model', model_id))) + if model_id == 0: + # use random model + self.env.reset() + obs = self.env.observation(self.env.players()[0]) + model = RandomModel(model, obs) + # update latest model + if len(self.model_pool) >= 1: + oldest_model_id = list(self.model_pool.keys())[0] + self.model_pool.pop(oldest_model_id) + self.model_pool[model_id] = ModelWrapper(model) def run(self): while True: @@ -73,11 +66,11 @@ def run(self): models = {} if 'model_id' in args: model_ids = list(args['model_id'].values()) - model_pool = self._gather_models(model_ids) + self._gather_models(model_ids) # make dict of models for p, model_id in args['model_id'].items(): - models[p] = model_pool[model_id] + models[p] = self.model_pool.get(model_id, None) if role == 'g': episode = self.generator.execute(models, args) From f4b74edb58a1e7619ed631fcea55b35ae2c02db2 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 9 Jan 2023 15:48:40 +0900 Subject: [PATCH 2/2] fix: for the case model_id < 0 --- handyrl/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index ebd676ab..97ba1a6c 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -42,7 +42,7 @@ def __del__(self): def _gather_models(self, model_ids): for model_id in model_ids: - if model_id is not None and model_id not in self.model_pool: + if model_id is not None and model_id >= 0 and model_id not in self.model_pool: # get model from server model = pickle.loads(send_recv(self.conn, ('model', model_id))) if model_id == 0: