From aa79699748be90941db064c2c099c4527184adb3 Mon Sep 17 00:00:00 2001 From: ikki407 Date: Fri, 18 Mar 2022 00:45:56 +0900 Subject: [PATCH 1/2] feature: apply BN update for SWA model --- scripts/aux_swa.py | 56 ------------------------- scripts/make_swa_model.py | 87 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 56 deletions(-) delete mode 100644 scripts/aux_swa.py create mode 100644 scripts/make_swa_model.py diff --git a/scripts/aux_swa.py b/scripts/aux_swa.py deleted file mode 100644 index 6c10a164..00000000 --- a/scripts/aux_swa.py +++ /dev/null @@ -1,56 +0,0 @@ - -# Usage: python3 script/aux_swa.py [FINAL_EPOCH] [EPOCHS] [EPOCH_STEP] - -import os -import sys - -sys.path.append('./') - -import yaml - -import torch -from torch.optim.swa_utils import AveragedModel - -from handyrl.environment import make_env - - -# -# SWA (running equal averaging) -# - -model_dir = 'models' -saved_model_path = os.path.join('models', 'swa.pth') - -ed, length = int(sys.argv[1]), int(sys.argv[2]) -step = 1 -if len(sys.argv) >= 4: - step = int(sys.argv[3]) - -model_ids = [str(i) + '.pth' for i in range(ed - length + 1, ed + 1, step)] - -with open('config.yaml') as f: - args = yaml.safe_load(f) - -env = make_env(args['env_args']) -model = env.net() -model.load_state_dict(torch.load(os.path.join(model_dir, model_ids[0])), strict=True) - -def _avg_fn(averaged_model_parameter, model_parameter, num_averaged): - return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) - -swa_model = AveragedModel(model, avg_fn=_avg_fn) - -for model_id in model_ids: - model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) - swa_model.update_parameters(model) - -torch.save(swa_model.module.state_dict(), saved_model_path) - -print('Saved %s' % saved_model_path) - -# -# Test (load in strict=True) -# - -model = env.net() -model.load_state_dict(torch.load(saved_model_path), strict=True) diff --git a/scripts/make_swa_model.py b/scripts/make_swa_model.py new file mode 100644 index 00000000..fff376c1 --- /dev/null +++ b/scripts/make_swa_model.py @@ -0,0 +1,87 @@ + +# Usage: python3 script/make_swa_model.py [FINAL_EPOCH] [EPOCHS] [EPOCH_STEP] + +import os +import sys + +sys.path.append('./') + +import yaml + +import torch +from torch.optim.swa_utils import AveragedModel + +from handyrl.environment import make_env + + +# +# SWA (running equal averaging) +# + +model_dir = 'models' +saved_model_path = os.path.join('models', 'swa.pth') + +ed, length = int(sys.argv[1]), int(sys.argv[2]) +step = 1 +if len(sys.argv) >= 4: + step = int(sys.argv[3]) + +model_ids = [str(i) + '.pth' for i in range(ed - length + 1, ed + 1, step)] + +with open('config.yaml') as f: + args = yaml.safe_load(f) + +env = make_env(args['env_args']) +model = env.net() +model.load_state_dict(torch.load(os.path.join(model_dir, model_ids[0])), strict=True) + +def _avg_fn(averaged_model_parameter, model_parameter, num_averaged): + return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + +swa_model = AveragedModel(model, avg_fn=_avg_fn) + +for model_id in model_ids: + model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) + swa_model.update_parameters(model) + +# Update BN +def _get_running_stats(module, running_stats): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + if module not in running_stats: + running_stats[module] = {} + running_stats[module]['running_mean'] = module.running_mean + running_stats[module]['running_var'] = module.running_var + +# get averaged running stats from SWA model +averaged_running_stats = {} +for model_id in model_ids: + running_stats = {} + model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) + model.apply(lambda module: _get_running_stats(module, running_stats)) + for name, module in dict(model.named_modules()).items(): + if not issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + continue + if name not in averaged_running_stats: + averaged_running_stats[name] = { + 'running_mean': torch.zeros_like(module.running_mean), + 'running_var': torch.zeros_like(module.running_var),} + averaged_running_stats[name]['running_mean'] += module.running_mean / len(model_ids) + averaged_running_stats[name]['running_var'] += module.running_var / len(model_ids) + +# set averaged running stats into SWA model +for name, module in dict(swa_model.module.named_modules()).items(): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + module.running_mean.copy_(averaged_running_stats[name]['running_mean']) + module.running_var.copy_(averaged_running_stats[name]['running_var']) + +# Save SWA model +torch.save(swa_model.module.state_dict(), saved_model_path) + +print('Saved %s' % saved_model_path) + +# +# Test (load in strict=True) +# + +model = env.net() +model.load_state_dict(torch.load(saved_model_path), strict=True) From 7e28cec22167efb75ea20a8f904347a9207c4067 Mon Sep 17 00:00:00 2001 From: ikki407 Date: Fri, 18 Mar 2022 01:55:04 +0900 Subject: [PATCH 2/2] fix: remove unused function --- scripts/make_swa_model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/scripts/make_swa_model.py b/scripts/make_swa_model.py index fff376c1..d017b423 100644 --- a/scripts/make_swa_model.py +++ b/scripts/make_swa_model.py @@ -45,19 +45,10 @@ def _avg_fn(averaged_model_parameter, model_parameter, num_averaged): swa_model.update_parameters(model) # Update BN -def _get_running_stats(module, running_stats): - if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): - if module not in running_stats: - running_stats[module] = {} - running_stats[module]['running_mean'] = module.running_mean - running_stats[module]['running_var'] = module.running_var - # get averaged running stats from SWA model averaged_running_stats = {} for model_id in model_ids: - running_stats = {} model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) - model.apply(lambda module: _get_running_stats(module, running_stats)) for name, module in dict(model.named_modules()).items(): if not issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): continue