Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MAML meta-opt #23

Merged
merged 18 commits into from
May 9, 2022
Merged

Implement MAML meta-opt #23

merged 18 commits into from
May 9, 2022

Conversation

wj-Mcat
Copy link
Contributor

@wj-Mcat wj-Mcat commented Apr 14, 2022

Description

Try to refactor MAML meta learning algriothm to make it more reusable in paddle-based applications.

Consideration

In MAML module, there are three things which is wired against the normal model application code:

  • clone module which retrain the computation graph
  • accumulate the gradient on cloned model
  • backward the gradient based on query set data

Design

class BaseLearner(ABC):
    """Abstract Base Learner Class"""
    def __init__(self, module: Layer, optimizer: Optimizer) -> None:
        """The constructor of BaseLearner

        Args:
            module (Layer): the model to be trained
        """
        super().__init__()
        self._source_module = module
        self.cloned_module = None
        self.optimizer = optimizer

    def new_cloned_model(self,) -> Layer:
        """get the cloned model and keep the computation gragh

        Returns:
            Layer: the cloned model
        """
        self.cloned_module = clone_model(self._source_module)
        return self.cloned_module

    @abstractmethod
    def adapt(self, train_loss: Tensor) -> None:
        """Adapt the model to the current training loss

        Args:
            train_loss (Tensor): the current training loss
        """
        raise NotImplementedError


    @abstractmethod
    def step(self) -> None:
        """Perform a step of training

        Args:
            loss (float): _description_

        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError

    def clear_grad(self):
        """clear the gradient in the computation graph
        """
        self.optimizer.clear_grad()
  • adapt: accumulate the gradient on cloned model
  • new_cloned_model: clone and save the model
  • step: run step on optimizer in parameters of source model
  • clear_grad: clear the gradient

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Apr 14, 2022

Try to accomplish the task : #17

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Apr 19, 2022

Changes in Learner Structure

class BaseLearner(Layer):
    """Abstract Base Learner Class"""
    def __init__(self, module: Layer) -> None:
        """The constructor of BaseLearner

        Args:
            module (Layer): the model to be trained
        """
        super().__init__()
        self.module = module

    @abstractmethod
    def adapt(self, loss: Tensor) -> None:
        """Adapt the model to the current training loss

        Args:
            loss (Tensor): the current training loss
        """
        raise NotImplementedError

-    def new_cloned_model(self,) -> Layer:
+   def clone(self: Type[Learner]) -> Learner:
        """create cloned module and keep the computation gragh

        Args:
            self (Type[Learner]): the sub-learner

        Returns:
            Learner: the cloned model
        """
        raise NotImplementedError

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

-   @abstractmethod
-   def step(self) -> None:
-        """Perform a step of training
-
-        Args: 
-           loss (float): _description_
-
-        Raises:
-            NotImplementedError: _description_
-        """
-       raise NotImplementedError
-
-   def clear_grad(self):
-        """clear the gradient in the computation graph
-        """
-        self.optimizer.clear_grad()

As the above code shown, there are mainly two changes:

  • change new_clone_model to clone which can reuse the meta algorithm.
  • remove optimzier from learner which handle the outer loop

Copy link
Owner

@tata1661 tata1661 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wj-Mcat Can you provide a detailed README.md to compare your empirical results with existing ones provided by PaddleFSL? You can put it in examples/optim. Thanks for the contribution!

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Apr 22, 2022

Oh sorry, I have relaxed myself a few days. I will run the experiments to get the empirical results in the next few days.

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented May 9, 2022

To verify the effectiveness of the optim method, I have done some experiments based on model-zoon algo and optim algo. Below are my conclusion:

Omniglot - MAML

  • model-zoo

image

  • optim

image

Omniglot - ANIL

  • model-zoon

image

  • optim

image

MiniImagenet - ANIL

  • model-zoo

image

  • optim

image

CIFAR-FS - MAML

  • model-zoo

image

  • optim

image

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented May 9, 2022

Metric Overview

Dataset Algo model zoo(first order) Optim(first order)
Omniglot MAML 97.25 ± 1.7 97.07 ± 2.4
Omniglot ANIL 93.62 ± 2.08 94.80 ± 3.7
MiniImageNet ANIL 52.56 ± 3.5 57.50 ± 3.2
CIFAR-FS MAML 46.88 ± 3.4 49.44 ± 4.7

@tata1661 tata1661 merged commit 4e0dae5 into tata1661:master May 9, 2022
@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented May 9, 2022

Thanks for merging, I will try to fix #28 with another PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants