Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add fetch_list in engine api #43312

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import Operator, Variable
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
Expand Down Expand Up @@ -256,6 +256,7 @@ def fit(self,
train_data,
batch_size=1,
epochs=1,
fetch_list=None,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True):
Expand All @@ -266,13 +267,14 @@ def fit(self,
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list

outputs = []
for epoch in range(epochs):
for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data, use_program_cache,
logs, outs = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(loss)
outputs.append(outs)
train_logs = {
"train_" + name: val
for name, val in logs.items()
Expand All @@ -283,94 +285,116 @@ def fit(self,
def evaluate(self,
eval_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list

for step, data in enumerate(eval_dataloader):
eval_logs = dict()
outs = self._eval_step(data, use_program_cache, return_numpy)
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics:
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
for name, val in logs.items():
eval_logs["eval_" + name] = val
self._logger.info(eval_logs)
return eval_logs

def predict(self,
test_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list

outputs = []
for step, data in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache,
return_numpy)
outputs.append(outs)
predict_logs = {
"predict_" + name: val
for name, val in logs.items()
}
predict_logs = {"pred_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs

def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list

loss = self._executor.run(self.main_program,
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs

def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses)
fetch_metrics = self._fetch_list(metrics)
fetch_loss, usr_fetch_list = self._fetch_list(losses)
fetch_metrics, usr_fetch_list = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics

res = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
if not res[len(fetch_loss):]:
return res[:len(fetch_loss)]
outs = self._executor.run(self.main_program,
fetch_list=fetch_list + usr_fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
usr_out = outs[len(fetch_list):]
for i, out in enumerate(usr_out):
logs[usr_fetch_list[i]] = out
outs = outs[:len(fetch_list)]
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*res[len(fetch_loss):])
return res[:len(fetch_loss)]
metric.update(*outs[len(fetch_loss):])
return logs, outs[:len(fetch_loss)]

def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list

outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs

def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
return fetch_list
usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 var 不在 bloclk 中应该直接报错提示用户,而不是跳过,否则执行器实际fetch 时就报错了

usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list

def _create_dataloader(self,
dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@ def train():
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size)
steps_per_epoch=batch_num * batch_size,
fetch_list=['label'])

# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size)
engine.evaluate(eval_dataset, batch_size, fetch_list=['label'])

# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size)
engine.predict(test_dataset, batch_size, fetch_list=['label'])

# save
engine.save('./mlp_inf', training=False, mode='predict')
Expand Down