Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dy2st]Add ProgramHelper to polish build program logic in autoparallel.Engine #44513

Merged
merged 2 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 15 additions & 73 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext

from .hepler import ProgramHelper
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner
Expand Down Expand Up @@ -141,87 +142,28 @@ def _prepare_single_mode(self, mode):
self._mode_init_states[mode] = True

def _build(self, mode):

if _non_static_mode() or self._dygraph_mode:
paddle.disable_static()
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")

program_helper = ProgramHelper(self.model, self._loss,
self._metrics, self.inputs_spec,
self.labels_spec)
# build forward main program
self.static_model = to_static(self.model,
input_spec=self.inputs_spec)
inputs = self.static_model.forward.inputs
outputs = self.static_model.forward.outputs
forward_main_prog = self.static_model.forward.main_program
forward_startup_prog = self.static_model.forward.concrete_program.startup_program
self.concrete_program = self.static_model.forward.concrete_program

# build loss main program
outputs_spec = []
outputs_name = []
for out in outputs:
outputs_spec.append(InputSpec(out.shape, out.dtype, out.name))
outputs_name.append(out.name)
if isinstance(self._loss, paddle.nn.Layer):
self.static_loss = to_static(self._loss.forward,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program
elif callable(self._loss):
self.static_loss = to_static(self._loss,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program

# build startup program
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=forward_startup_prog.global_block())
program_helper.build_program(mode)

paddle.enable_static()
self.concrete_program = program_helper.concrete_program
serial_main_prog = program_helper.main_program
serial_startup_prog = program_helper.startup_program

# NOTE: pure program will loss dist_attr
# feeded_var_names = [var.name for var in inputs]
# main_prog_0 = main_prog_0._prune_with_input(
# feeded_var_names=feeded_var_names, targets=outputs)

labels = []
losses = []
metrics = []
# concat forward and loss prog
if mode != 'predict' and self._loss:
forward_block = forward_main_prog.global_block()
loss_block = loss_main_prog.global_block()
for idx, op in enumerate(loss_block.ops):
op_desc = forward_block.desc.append_op()
op_desc.copy_from(op.desc)
for in_name in op.input_arg_names:
if in_name in outputs_name:
continue
in_var = forward_block._clone_variable(
loss_block.vars[in_name], force_persistable=False)
if loss_block.vars[in_name].is_data:
labels.append(in_var)
for out_name in op.output_arg_names:
out_var = forward_block._clone_variable(
loss_block.vars[out_name], force_persistable=False)
if idx == len(loss_block.ops) - 1:
losses.append(out_var)
forward_block._sync_with_cpp()
serial_main_prog = forward_main_prog
serial_startup_prog = forward_startup_prog
# update metrics op in program
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
inputs = program_helper.input_vars
outputs = program_helper.output_vars
labels = program_helper.label_vars
losses = program_helper.loss_vars
metrics = program_helper.metric_vars

paddle.enable_static()
else:
# build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None)
Expand Down
244 changes: 244 additions & 0 deletions python/paddle/distributed/auto_parallel/hepler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from paddle.nn import Layer
from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode

from .utils import to_list


class ProxyLayer(Layer):
"""
ProxyLayer implements all logic for converting dygraph model into
static Program IR. Meanwhile, it provides conviential interfaces for
auto parallel to visit feed/fetch/loss/metric variables.
"""

def __init__(self, layer, loss_func, metrics):
super(ProxyLayer, self).__init__()
# NOTE: All verify logics are finished in Engine.Prepare
self.inner_layer = layer
self.loss_func = loss_func
self.metrics = metrics
# train / eval / predict
self.mode = None

# generated program vars
self.input_vars = []
self.label_vars = []
self.output_vars = []
self.loss_vars = []
self.metric_vars = []

def _train(self, inputs, labels):
"""
Train process of inner_layer with forward/loss/metric logic.
"""
# step 1. save feed variables of Program
self.input_vars = inputs
self.label_vars = labels

# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)

# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self.loss_vars = self.call_loss(new_inputs)

# step 4. calculate metrics if needed
self.metric_vars = self.call_metrics(new_inputs)

def _eval(self, inputs, labels):
"""
Evaluate process of inner_layer with forward/loss/metric logic.
"""
# TODO(dev): we can reuse codes with self._train after making
# sure if they can.

# step 1. save feed variables of Program
self.input_vars = inputs
self.label_vars = labels

# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)

# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self.loss_vars = self.call_loss(new_inputs)

# step 4. calculate metrics if needed
self.metric_vars = self.call_metrics(new_inputs)

def _predict(self, inputs):
"""
Predict process of inner_layer with forward logic.
"""
# step 1. save feed variables of Program
self.input_vars = inputs

# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)

@not_to_static
def _prepare(self, outputs, labels):
"""
Concat outputs and labels as a single list

NOTE(dev): We use @not_to_static to avoid AST Analysis.
"""
return to_list(outputs) + to_list(labels)

def call_loss(self, inputs):
"""
Apply Loss Function on outputs and labels.

Args:
inputs: List[Variable]

Returns: List[Variable]
"""
res = []
if self.loss_func is not None:
res = self.loss_func(*inputs)
return res

def call_metrics(self, inputs):
"""
Apply Metrics Function on outputs and labels.

Args:
inputs: List[Variable]

Returns: List[Variable]
"""
outs = []
for metric in self.metrics:
outs.extend(metric.compute(*inputs))

return outs

def set_mode(self, mode):
self.mode = mode
self.training = mode == 'train'


class BuildInfo:

def __init__(self, mode=None, state=False):
self.mode = mode
self.state = state

def has_cache(self, mode):
return self.mode == mode and self.state is True


class ProgramHelper(object):
"""
A Helper class for Engine to provides different Program IR according specified 'mode'.
"""

def __init__(self, layer, loss_func, metrics, inputs_spec, labels_spec):
# original model config information
# TODO(Aurelius84): Implenet append_backward and optimizer in ProxyLayer
# after distribute engine satisify basic condition.
self.proxy_layer = ProxyLayer(layer, loss_func, metrics)
self.inputs_spec = inputs_spec
self.labels_spec = labels_spec

self.build_info = BuildInfo()
self._logger = get_logger(logging.INFO)

def build_program(self, mode):
"""
Convert dygraph model into static Program IR.
"""
assert mode in ['train', 'eval', 'predict']
# skip if we has already built program.
if self.build_info.has_cache(mode):
self._logger.info(
"Already build program with mode = %s, use cached program." %
mode)
return

self._logger.info("start to build program for mode = %s." % mode)
self.proxy_layer.mode = mode
input_spec = [self.inputs_spec, self.labels_spec
] if mode != 'predict' else [self.inputs_spec]
static_func = to_static(self.static_func(), input_spec=input_spec)

func_name = '_' + mode
setattr(self.proxy_layer, func_name, static_func)

# NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger
# generating Program IR immediately.
getattr(self.proxy_layer, func_name).concrete_program

def _build_startup_program(self):
"""
Create and Sync parameters into startup program.
"""
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=self.startup_program.global_block())

def static_func(self):
"""
Return target mode function.
"""
assert self.proxy_layer.mode in [
'train', 'eval', 'predict'
], "Please call build_program(mode) firstly."
func_name = '_' + self.proxy_layer.mode
return getattr(self.proxy_layer, func_name)

@property
def concrete_program(self):
return self.static_func().concrete_program

@property
def main_program(self):
return self.concrete_program.main_program

@property
def startup_program(self):
return self.concrete_program.startup_program

@property
def input_vars(self):
return to_list(self.proxy_layer.input_vars)

@property
def output_vars(self):
return to_list(self.proxy_layer.output_vars)

@property
def label_vars(self):
return to_list(self.proxy_layer.label_vars)

@property
def loss_vars(self):
return to_list(self.proxy_layer.loss_vars)

@property
def metric_vars(self):
return to_list(self.proxy_layer.metric_vars)