diff --git a/scripts/aux_swa.py b/scripts/make_swa_model.py similarity index 50% rename from scripts/aux_swa.py rename to scripts/make_swa_model.py index 6c10a164..d017b423 100644 --- a/scripts/aux_swa.py +++ b/scripts/make_swa_model.py @@ -1,5 +1,5 @@ -# Usage: python3 script/aux_swa.py [FINAL_EPOCH] [EPOCHS] [EPOCH_STEP] +# Usage: python3 script/make_swa_model.py [FINAL_EPOCH] [EPOCHS] [EPOCH_STEP] import os import sys @@ -44,6 +44,28 @@ def _avg_fn(averaged_model_parameter, model_parameter, num_averaged): model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) swa_model.update_parameters(model) +# Update BN +# get averaged running stats from SWA model +averaged_running_stats = {} +for model_id in model_ids: + model.load_state_dict(torch.load(os.path.join(model_dir, model_id)), strict=True) + for name, module in dict(model.named_modules()).items(): + if not issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + continue + if name not in averaged_running_stats: + averaged_running_stats[name] = { + 'running_mean': torch.zeros_like(module.running_mean), + 'running_var': torch.zeros_like(module.running_var),} + averaged_running_stats[name]['running_mean'] += module.running_mean / len(model_ids) + averaged_running_stats[name]['running_var'] += module.running_var / len(model_ids) + +# set averaged running stats into SWA model +for name, module in dict(swa_model.module.named_modules()).items(): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + module.running_mean.copy_(averaged_running_stats[name]['running_mean']) + module.running_var.copy_(averaged_running_stats[name]['running_var']) + +# Save SWA model torch.save(swa_model.module.state_dict(), saved_model_path) print('Saved %s' % saved_model_path)