Skip to content

Commit

Permalink
【PPSCI Doc No.38-40】 (#826)
Browse files Browse the repository at this point in the history
* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix

* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix

---------

Co-authored-by: krp <2934631798@qq.com>
  • Loading branch information
1want2sleep and UnityLiker committed Mar 29, 2024
1 parent 166caf5 commit 65f4447
Showing 1 changed file with 57 additions and 5 deletions.
62 changes: 57 additions & 5 deletions ppsci/equation/pde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,71 @@ def add_equation(self, name: str, equation: Callable):
self.equations.update({name: equation})

def parameters(self) -> List[paddle.Tensor]:
"""Return parameters contained in PDE.
"""Return learnable parameters contained in PDE.
Args:
None
Returns:
List[Tensor]: A list of parameters.
List[Tensor]: A list of learnable parameters.
Examples:
>>> import ppsci
>>> pde = ppsci.equation.Vibration(2, -4, 0)
>>> print(pde.parameters())
[Parameter containing:
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
-4.), Parameter containing:
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
0.)]
"""
return self.learnable_parameters.parameters()

def state_dict(self) -> Dict[str, paddle.Tensor]:
"""Return named parameters in dict."""
"""Return named learnable parameters in dict.
Args:
None
Returns:
Dict[str, Tensor]: A dict of states(str) and learnable parameters(Tensor).
Examples:
>>> import ppsci
>>> pde = ppsci.equation.Vibration(2, -4, 0)
>>> print(pde.state_dict())
OrderedDict([('0', Parameter containing:
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
-4.)), ('1', Parameter containing:
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
0.))])
"""

return self.learnable_parameters.state_dict()

def set_state_dict(self, state_dict):
"""Set state dict from dict."""
def set_state_dict(self, state_dict: Dict[str, paddle.Tensor]):
"""Set state dict from dict.
Args:
state_dict (Dict[str, paddle.Tensor]): The state dict to be set.
Returns:
None
Examples:
>>> import paddle
>>> import ppsci
>>> paddle.set_default_dtype("float64")
>>> pde = ppsci.equation.Vibration(2, -4, 0)
>>> state = pde.state_dict()
>>> state['0'] = paddle.to_tensor(-3.1)
>>> pde.set_state_dict(state)
>>> print(state)
OrderedDict([('0', Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
-3.10000000)), ('1', Parameter containing:
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
0.))])
"""
self.learnable_parameters.set_state_dict(state_dict)

def __str__(self):
Expand Down

0 comments on commit 65f4447

Please sign in to comment.