From e6c4e0dd1b46e399fc973f5de1151042ccf20341 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 2 Aug 2022 19:51:46 +0900 Subject: [PATCH 1/2] feature: cnfig eval_rate -> eval_coef --- config.yaml | 2 +- handyrl/train.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index 2bf65d85..4512b9e7 100755 --- a/config.yaml +++ b/config.yaml @@ -20,7 +20,7 @@ train_args: maximum_episodes: 100000 epochs: -1 num_batchers: 2 - eval_rate: 0.1 + eval_coef: 0.85 worker: num_parallel: 6 lambda: 0.7 diff --git a/handyrl/train.py b/handyrl/train.py index 75c0a31b..45cfddfb 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -403,8 +403,8 @@ def __init__(self, args, net=None, remote=False): random.seed(args['seed']) self.env = make_env(env_args) - eval_modify_rate = (args['update_episodes'] ** 0.85) / args['update_episodes'] - self.eval_rate = max(args['eval_rate'], eval_modify_rate) + self.eval_rate_fn = lambda n: (n ** self.args['eval_coef']) / n + self.eval_rate = self.eval_rate_fn(self.args['minimum_episodes'] + self.args['update_episodes']) self.shutdown_flag = False self.flags = set() @@ -525,6 +525,8 @@ def output_wp(name, results): model = self.model self.update_model(model, steps) + # update evaluation ratio + self.eval_rate = self.eval_rate_fn(self.args['update_episodes']) # clear flags self.flags = set() From 79243d22ad483f3a6ad18347dc7507daa9ffb585 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 3 Aug 2022 03:43:03 +0900 Subject: [PATCH 2/2] fix: reset num_results and num_episodes after epoch 0 --- handyrl/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/handyrl/train.py b/handyrl/train.py index 45cfddfb..0e96f4fc 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -520,6 +520,10 @@ def output_wp(name, results): std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5 print('generation stats = %.3f +- %.3f' % (mean, std)) + if self.model_epoch == 0: + self.num_episodes = 0 + self.num_results = 0 + model, steps = self.trainer.update() if model is None: model = self.model