Skip to content

Commit

Permalink
[Fix] Fix weight dtype and bug in symbolic (PaddlePaddle#709)
Browse files Browse the repository at this point in the history
* fix dtype of weight to avoid unnecessary type promotion

* fix bug in symbolic
  • Loading branch information
HydrogenSulfate committed Dec 21, 2023
1 parent 77873b2 commit d6382c2
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/hpinns/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def init_lambda(output_dict: Dict[str, paddle.Tensor], bound: int):
"""
global lambda_re, lambda_im, loss_weight
x, y = output_dict["x"], output_dict["y"]
lambda_re = np.zeros((len(x[bound:]), 1))
lambda_im = np.zeros((len(y[bound:]), 1))
lambda_re = np.zeros((len(x[bound:]), 1), paddle.get_default_dtype())
lambda_im = np.zeros((len(y[bound:]), 1), paddle.get_default_dtype())
# loss_weight: [PDE loss 1, PDE loss 2, Lagrangian loss 1, Lagrangian loss 2, objective loss]
if train_mode == "aug_lag":
loss_weight = [0.5 * mu] * 2 + [1.0, 1.0] + [1.0]
Expand Down
5 changes: 4 additions & 1 deletion ppsci/data/dataset/array_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(
self.input_keys = tuple(input.keys())
self.label_keys = tuple(self.label.keys())
self.weight = (
{key: paddle.to_tensor(value) for key, value in weight.items()}
{
key: paddle.to_tensor(value, paddle.get_default_dtype())
for key, value in weight.items()
}
if weight is not None
else None
)
Expand Down
6 changes: 4 additions & 2 deletions ppsci/data/dataset/trphysx_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __getitem__(self, i):

weight_shape = [1] * len(data_item.shape)
weight_item = {
key: np.full(weight_shape, value) for key, value in self.weight_dict.items()
key: np.full(weight_shape, value, paddle.get_default_dtype())
for key, value in self.weight_dict.items()
}
return (input_item, label_item, weight_item)

Expand Down Expand Up @@ -307,6 +308,7 @@ def __getitem__(self, i):
label_item[self.label_keys[1]] = data_item[1:, :]
weight_shape = [1] * len(data_item.shape)
weight_item = {
key: np.full(weight_shape, value) for key, value in self.weight_dict.items()
key: np.full(weight_shape, value, paddle.get_default_dtype())
for key, value in self.weight_dict.items()
}
return (input_item, label_item, weight_item)
8 changes: 4 additions & 4 deletions ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def _minimum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT:
)
for i in range(2, len(self.childs)):
data_dict[self.key] = paddle.minimum(
data_dict[data_dict[self.key]],
data_dict[data_dict[self.childs[i]]],
data_dict[self.key],
data_dict[self.childs[i]],
)
return data_dict

Expand All @@ -267,8 +267,8 @@ def _maximum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT:
)
for i in range(2, len(self.childs)):
data_dict[self.key] = paddle.maximum(
data_dict[data_dict[self.key]],
data_dict[data_dict[self.childs[i]]],
data_dict[self.key],
data_dict[self.childs[i]],
)
return data_dict

Expand Down

0 comments on commit d6382c2

Please sign in to comment.