From 65f4447a73d4916131a007cc546ef254f56052ac Mon Sep 17 00:00:00 2001 From: hyDONG <116695878+1want2sleep@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:35:01 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PPSCI=20Doc=20No.38-40=E3=80=91=20(#82?= =?UTF-8?q?6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix --------- Co-authored-by: krp <2934631798@qq.com> --- ppsci/equation/pde/base.py | 62 +++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 41a0089de..9a0c0c53e 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -83,19 +83,71 @@ def add_equation(self, name: str, equation: Callable): self.equations.update({name: equation}) def parameters(self) -> List[paddle.Tensor]: - """Return parameters contained in PDE. + """Return learnable parameters contained in PDE. + + Args: + None Returns: - List[Tensor]: A list of parameters. + List[Tensor]: A list of learnable parameters. + + Examples: + >>> import ppsci + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> print(pde.parameters()) + [Parameter containing: + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, + -4.), Parameter containing: + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False, + 0.)] """ return self.learnable_parameters.parameters() def state_dict(self) -> Dict[str, paddle.Tensor]: - """Return named parameters in dict.""" + """Return named learnable parameters in dict. + + Args: + None + + Returns: + Dict[str, Tensor]: A dict of states(str) and learnable parameters(Tensor). + + Examples: + >>> import ppsci + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> print(pde.state_dict()) + OrderedDict([('0', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + -4.)), ('1', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.))]) + """ + return self.learnable_parameters.state_dict() - def set_state_dict(self, state_dict): - """Set state dict from dict.""" + def set_state_dict(self, state_dict: Dict[str, paddle.Tensor]): + """Set state dict from dict. + + Args: + state_dict (Dict[str, paddle.Tensor]): The state dict to be set. + + Returns: + None + + Examples: + >>> import paddle + >>> import ppsci + >>> paddle.set_default_dtype("float64") + >>> pde = ppsci.equation.Vibration(2, -4, 0) + >>> state = pde.state_dict() + >>> state['0'] = paddle.to_tensor(-3.1) + >>> pde.set_state_dict(state) + >>> print(state) + OrderedDict([('0', Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True, + -3.10000000)), ('1', Parameter containing: + Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False, + 0.))]) + """ self.learnable_parameters.set_state_dict(state_dict) def __str__(self):