Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Support for BatchNorm and AdaptiveAvgPool in HBP methods (KFAC, KFRA, KFLR) and GGNMP #322

Open
satnavpt opened this issue Feb 29, 2024 · 8 comments

Comments

@satnavpt
Copy link

Unsure if this is similar to a previous issue where there was simply a missing link, or whether there is a more fundamental reason why these, and other modules which the documentation claims are supported, are not actually supported.

@f-dangel
Copy link
Owner

f-dangel commented Mar 1, 2024

Hi, thanks for bringing up this documentation inconsistency. You are right that the docs claim too general support, although it is true that most of the second-order extensions support the mentioned layers.

If you would like to see support for a specific layer and quantity that is currently missing, please feel free to specify.

@satnavpt
Copy link
Author

satnavpt commented Mar 1, 2024

Hi, thanks for getting back quickly. I am looking to experiment with KFLR and a Conjugate Gradient optimiser (using GGNMP) on a resnet18 model. I am fine running in eval mode, but both of these extensions do not have definitions for BatchNorm and AdaptiveAvgPool right now, which throws an error. Support for these would be greatly appreciated.

@satnavpt
Copy link
Author

satnavpt commented Mar 1, 2024

I tried manually making some changes, but I see that the issue is more than just missing links to the module extensions...

@f-dangel
Copy link
Owner

f-dangel commented Mar 1, 2024

Thanks for boiling things down. I think what you are requesting requires a lot of new additions to BackPACK. They are not impossible to realize, but you would be mostly on your own to realize them.

  • For ResNets, BackPACK converts the computation graph such that all operations are nn.Modules. This means you will need to add support to BackPACK's custom SumModule. For KFLR, the required functionality is already implemented, so the only thing you would have to do is write a HBPSumModule extension that uses the SumModuleDerivatives from the core package. For GGNMP, you will also have to implement new core functionality, namely SumModuleDerivatives._jac_mat_prod, then write the associated module extension.
  • For AdaptiveAvgPool, I verified that the core functionality is already there, so all you would have to do to add support is write module extensions for GGNMP and KFLR that use the AdaptiveAvgPoolDerivatives from the core.
  • For BatchNorm2d the situation is similar to AdaptiveAvgPool in that most of the low-level functionality is already implemented and you have to write the corresponding module extensions. Since BatchNorm2d also has trainable parameters, you will have to specify how to compute KFLR for .weights and .bias yourself.

@satnavpt
Copy link
Author

satnavpt commented Mar 4, 2024

I've made the following changes to get GGNMP working with the SumModule, but it seems my gradients are vanishing:

Accumulate for ggnmp:

def accumulate_backpropagated_quantities(self, existing: Callable, other: Callable) -> Callable:
    return lambda mat: existing(mat) + other(mat)

Sum Module for GGNMP:

class GGNMPSumModule(GGNMPBase):
    """GGNMP extension for SumModule."""

    def __init__(self):
        """Initialization."""
        super().__init__(derivatives=SumModuleDerivatives())

SumModule._jac_mat_prod:

    def _jac_mat_prod(
        self,
        module: SumModule,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        subsampling: List[int] = None,
    ) -> Tensor:
        return mat

Are you able to help?

@f-dangel
Copy link
Owner

f-dangel commented Mar 4, 2024

Hi, thanks for the update.
Your changes look good to me. What is the error you're currently seeing? I don't understand what you mean by 'gradients are vanishing'.

Best,
Felix

@satnavpt
Copy link
Author

satnavpt commented Mar 4, 2024

I am using GGNMP alongside the implementation of a conjugate gradient optimiser provided as an example here. Just printing gradients (p.grad) at each optimisation step, I see that they become all zeros after a number of steps. It is possible that this is due to the modified resnet I am testing (disabled average pooling and BatchNorm for the time being as I was just testing to see if the summodule implementation was correct).

Thanks,
Pranav

@f-dangel
Copy link
Owner

f-dangel commented Mar 4, 2024

I'm not sure if debugging the correctness of GGNMP through the CG optimizer is the most direct way.
You could try comparing the GGNMP with BackPACK's hessianfree.ggn_vector_product_from_plist to see if the matix-vector product with the GGN works properly. There could be another effect that is not related to GGNMP giving you zero gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants