Skip to content

Commit

Permalink
[Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine…
Browse files Browse the repository at this point in the history
… API (#46633)

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py
  • Loading branch information
aoyulong committed Oct 10, 2022
1 parent 21612be commit 0ce5554
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 170 deletions.
24 changes: 18 additions & 6 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,24 @@ def _restore_serial_feed_vars(self):
def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list

def _restore_serial_info(self, mode="to_backup"):
Expand Down
241 changes: 139 additions & 102 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,6 @@ def _prepare_feed(self, user_feeds=None, mode="train"):
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
feeds = {}
# TODO: add inputs and labels feed dict
for name, var in get_collection(CollectionNames.FEEDS):
assert name is not None, "No name defined for feed var"
feeds[name] = var
if user_feeds is not None:
for name, var in user_feeds.items():
feeds[name] = var
Expand All @@ -227,42 +224,120 @@ def _prepare_fetch(self, user_fetches=None, mode="train"):
assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
fetch_names = []
fetch_new_names = []
fetch_sections = {}
cnt = 0
fetch_indices = []

def _process_section(section_name, var_list):
nonlocal cnt
section_start = cnt
def _process_fetch_group(group_name, var_list):
group_indices = []
for var in var_list:
new_name = None
# Rename the loss
if section_name == "loss":
new_name = "loss"
if isinstance(var, tuple):
assert len(var) == 2, "Length of tuple {} must be 2".format(
var)
new_name, var = var
if self._is_local_var(var) and var.name not in fetch_names:
fetch_names.append(var.name)
fetch_new_names.append(var.name)
cnt += 1
if self._is_local_var(var) and new_name is not None:
fetch_new_names[fetch_names.index(var.name)] = new_name
section_end = cnt
fetch_sections[section_name] = (section_start, section_end)

for name, var_list in self._fetch_vars[mode].items():
if name == "loss" and mode != "predict":
_process_section("loss", var_list)
if name == "metrics" and mode != "predict":
_process_section("metrics", var_list)
if name == "outputs" and mode == "predict":
_process_section("metrics", var_list)
var_list = (get_collection(CollectionNames.FETCHES)
or []) + (user_fetches or [])
_process_section("user_fetches", var_list)
return fetch_names, fetch_new_names, fetch_sections
# Remove duplicate var_names
if self._is_local_var(var):
var_name = _to_name_str(var)
if var_name not in fetch_names:
fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name))
fetch_indices.append(group_indices)

if mode != "predict":
_process_fetch_group("loss", self._fetch_vars[mode]["loss"])
if mode != "predict":
metrics = self._fetch_vars[mode]["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
var_list = (user_fetches_collection or []) + (user_fetches or [])
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices

def _prepare_logger(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_names=None,
fetch_indices=None,
profiler_log=""):
logs = "[{}] ".format(mode)
if epoch is not None:
logs += "epoch: {:d} ".format(epoch)
if step is not None:
logs += "step: {:d} ".format(step)
if lr is not None:
logs += "lr: {:5e} ".format(lr)
group_idx = 0
# logging loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
for idx in loss_indices:
logs += "loss: {:8f} ".format(outs[idx][0])
group_idx += 1
# logging metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs += "{}: {:8f} ".format(metric.name()[i], res)
group_idx += 1
# Skip logging outputs
if mode == "predict":
group_idx += 1
# logging user fetches
fetches_logging = get_collection(CollectionNames.LOGGING)
for name, var in fetches_logging:
if var.name in fetch_names:
idx = fetch_names.index(var.name)
# Use the user defined name for logging
logs += "{}: {} ".format(name, outs[idx])
self._logger.info(logs)

def _prepare_history(self, outs, mode="train", fetch_indices=None):
history = {}
group_idx = 0
# store loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
loss_values = []
for idx in loss_indices:
loss_values.append(outs[idx][0])
history["loss"] = loss_values
group_idx += 1
# store metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
history[tuple(metric.name())] = to_list(results)
group_idx += 1
# store outputs
if mode == "predict":
outputs_indices = fetch_indices[group_idx]
outputs_values = []
for idx in outputs_indices:
outputs_values.append(outs[idx])
history["outputs"] = outputs_values
group_idx += 1
# store user fetches
fetches_indices = fetch_indices[group_idx]
fetches_values = []
for idx in fetches_indices:
fetches_values.append(outs[idx])
history["fetches"] = fetches_values
return history

def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
Expand Down Expand Up @@ -311,7 +386,7 @@ def _build(self, mode):

if mode != "predict":
for metric in self._metrics:
metrics.extend(
metrics.append(
to_list(metric.compute(*(outputs + labels))))

default_ctx = get_default_distributed_context()
Expand Down Expand Up @@ -547,58 +622,20 @@ def __call__(self,
fetches=None,
mode="train"):
feed_dict = self._prepare_feed(feeds, mode)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetches, mode)
fetch_names, fetch_indices = self._prepare_fetch(fetches, mode)
try:
outs = self._executor.run(
self.main_program,
feed=feed_dict,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
pass
self._print_log(outs, self.mode, None, None, None, fetch_new_names,
fetch_sections)
return outs

# TODO: need a better to print the log
def _print_log(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_new_names=None,
fetch_sections=None,
profiler_log=""):
prefix = "[{}] ".format(mode)
logs = {}
if epoch is not None:
logs["epoch: {:d} "] = epoch
if step is not None:
logs["step: {:d} "] = step
if lr is not None:
logs["lr: {:5e} "] = lr
if fetch_sections is not None:
assert fetch_new_names is not None
for section_name, section in fetch_sections.items():
section_start, section_end = section
if section_name == "metrics" and section_start < section_end:
metric_out = outs[section_start:section_end]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs[metric.name()[i] + ": {:8f} "] = res
elif section_name == "loss" and section_start < section_end:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
else:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys())) + profiler_log
self._logger.info(string.format(*list(logs.values())))
self._prepare_logger(outs, self.mode, None, None, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
return history

def fit(self,
train_data,
Expand Down Expand Up @@ -692,8 +729,7 @@ def fit(self,
epochs, steps_per_epoch,
collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program)

with profiler.Profiler(timer_only=True) as prof:
Expand All @@ -702,7 +738,7 @@ def fit(self,
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
Expand All @@ -713,17 +749,19 @@ def fit(self,

prof.step()

self._print_log(outs, self.mode, epoch, step, lr,
fetch_new_names, fetch_sections,
prof.step_info())
self._prepare_logger(outs, self.mode, epoch, step, lr,
fetch_names, fetch_indices,
prof.step_info())
history = self._prepare_history(outs, self.mode,
fetch_indices)

if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outs
return history

def evaluate(self,
valid_data,
Expand Down Expand Up @@ -793,23 +831,22 @@ def evaluate(self,
steps_per_epoch=steps,
collate_fn=collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)

outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader):
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
self._reset_metrics()
return outputs
return history

def predict(self,
test_data,
Expand Down Expand Up @@ -876,22 +913,22 @@ def predict(self,
steps_per_epoch=steps,
collate_fn=collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)

for step, _ in enumerate(test_dataloader):
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)

return outs
return history

def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train'
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def call_metrics(self, inputs):
"""
outs = []
for metric in self.metrics:
outs.extend(metric.compute(*inputs))
outs.append(to_list(metric.compute(*inputs)))

return outs

Expand Down
Loading

0 comments on commit 0ce5554

Please sign in to comment.