diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 41a0089de..9a0c0c53e 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -83,19 +83,71 @@ def add_equation(self, name: str, equation: Callable): self.equations.update({name: equation}) def parameters(self) -> List[paddle.Tensor]: - """Return parameters contained in PDE. + """Return learnable parameters contained in PDE. + + Args: + None Returns: - List[Tensor]: A list of parameters. + List[Tensor]: A list of learnable parameters. + + Examples: + >>> import ppsci + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> print(pde.parameters()) + [Parameter containing: + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, + -4.), Parameter containing: + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, + 0.)] """ return self.learnable_parameters.parameters() def state_dict(self) -> Dict[str, paddle.Tensor]: - """Return named parameters in dict.""" + """Return named learnable parameters in dict. + + Args: + None + + Returns: + Dict[str, Tensor]: A dict of states(str) and learnable parameters(Tensor). + + Examples: + >>> import ppsci + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> print(pde.state_dict()) + OrderedDict([('0', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + -4.)), ('1', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.))]) + """ + return self.learnable_parameters.state_dict() - def set_state_dict(self, state_dict): - """Set state dict from dict.""" + def set_state_dict(self, state_dict: Dict[str, paddle.Tensor]): + """Set state dict from dict. + + Args: + state_dict (Dict[str, paddle.Tensor]): The state dict to be set. + + Returns: + None + + Examples: + >>> import paddle + >>> import ppsci + >>> paddle.set_default_dtype("float64") + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> state = pde.state_dict() + >>> state['0'] = paddle.to_tensor(-3.1) + >>> pde.set_state_dict(state) + >>> print(state) + OrderedDict([('0', Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True, + -3.10000000)), ('1', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.))]) + """ self.learnable_parameters.set_state_dict(state_dict) def __str__(self):