Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Fix description err in aneurysm.md #619

Merged
merged 4 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/zh/examples/aneurysm.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ examples/aneurysm/aneurysm.py:149:157

### 3.5 超参数设定

接下来需要指定训练轮数和学习率,此处按实验经验,使用 1500 轮训练轮数。
接下来需要指定训练轮数和学习率,此处按实验经验,使用 1500 轮训练轮数,0.001 的初始学习率

``` py linenums="59"
--8<--
Expand All @@ -210,7 +210,7 @@ examples/aneurysm/conf/aneurysm.yaml:59:75

### 3.6 优化器构建

训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,并配合使用机器学习中常用的 OneCycle 学习率调整策略。
训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,并配合使用机器学习中常用的 ExponentialDecay 学习率调整策略。

``` py linenums="159"
--8<--
Expand Down
2 changes: 1 addition & 1 deletion ppsci/data/dataset/array_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __getitem__(self, idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return (input_item, label_item, weight_item)
Expand Down
10 changes: 8 additions & 2 deletions ppsci/data/dataset/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __getitem__(self, idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return (input_item, label_item, weight_item)
Expand Down Expand Up @@ -265,7 +265,13 @@ def num_samples(self):
return self._len

def __iter__(self):
yield self.input, self.label, self.weight
if callable(self.transforms):
input_, label_, weight_ = self.transforms(
self.input, self.label, self.weight
)
yield input_, label_, weight_
else:
yield self.input, self.label, self.weight

def __len__(self):
return 1
4 changes: 2 additions & 2 deletions ppsci/data/dataset/era5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __getitem__(self, global_idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return input_item, label_item, weight_item
Expand Down Expand Up @@ -230,7 +230,7 @@ def __getitem__(self, global_idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return input_item, label_item, weight_item
10 changes: 8 additions & 2 deletions ppsci/data/dataset/mat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __getitem__(self, idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return (input_item, label_item, weight_item)
Expand Down Expand Up @@ -265,7 +265,13 @@ def num_samples(self):
return self._len

def __iter__(self):
yield self.input, self.label, self.weight
if callable(self.transforms):
input_, label_, weight_ = self.transforms(
self.input, self.label, self.weight
)
yield input_, label_, weight_
else:
yield self.input, self.label, self.weight

def __len__(self):
return 1
10 changes: 8 additions & 2 deletions ppsci/data/dataset/npz_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __getitem__(self, idx):

if self.transforms is not None:
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
input_item, label_item, weight_item
)

return (input_item, label_item, weight_item)
Expand Down Expand Up @@ -261,7 +261,13 @@ def num_samples(self):
return self._len

def __iter__(self):
yield self.input, self.label, self.weight
if callable(self.transforms):
input_, label_, weight_ = self.transforms(
self.input, self.label, self.weight
)
yield input_, label_, weight_
else:
yield self.input, self.label, self.weight

def __len__(self):
return 1