diff --git a/handyrl/train.py b/handyrl/train.py index 75c0a31b..59c9b82c 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -451,7 +451,8 @@ def feed_episodes(self, episodes): if episode is None: continue for p in episode['args']['player']: - model_id = episode['args']['model_id'][p] + #model_id = episode['args']['model_id'][p] + model_id = self.model_epoch outcome = episode['outcome'][p] n, r, r2 = self.generation_results.get(model_id, (0, 0, 0)) self.generation_results[model_id] = n + 1, r + outcome, r2 + outcome ** 2 @@ -479,7 +480,8 @@ def feed_results(self, results): if result is None: continue for p in result['args']['player']: - model_id = result['args']['model_id'][p] + #model_id = result['args']['model_id'][p] + model_id = self.model_epoch res = result['result'][p] n, r, r2 = self.results.get(model_id, (0, 0, 0)) self.results[model_id] = n + 1, r + res, r2 + res ** 2