diff --git a/config.yaml b/config.yaml index 2bf65d85..4512b9e7 100755 --- a/config.yaml +++ b/config.yaml @@ -20,7 +20,7 @@ train_args: maximum_episodes: 100000 epochs: -1 num_batchers: 2 - eval_rate: 0.1 + eval_coef: 0.85 worker: num_parallel: 6 lambda: 0.7 diff --git a/handyrl/train.py b/handyrl/train.py index 75c0a31b..0e96f4fc 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -403,8 +403,8 @@ def __init__(self, args, net=None, remote=False): random.seed(args['seed']) self.env = make_env(env_args) - eval_modify_rate = (args['update_episodes'] ** 0.85) / args['update_episodes'] - self.eval_rate = max(args['eval_rate'], eval_modify_rate) + self.eval_rate_fn = lambda n: (n ** self.args['eval_coef']) / n + self.eval_rate = self.eval_rate_fn(self.args['minimum_episodes'] + self.args['update_episodes']) self.shutdown_flag = False self.flags = set() @@ -520,11 +520,17 @@ def output_wp(name, results): std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5 print('generation stats = %.3f +- %.3f' % (mean, std)) + if self.model_epoch == 0: + self.num_episodes = 0 + self.num_results = 0 + model, steps = self.trainer.update() if model is None: model = self.model self.update_model(model, steps) + # update evaluation ratio + self.eval_rate = self.eval_rate_fn(self.args['update_episodes']) # clear flags self.flags = set()