Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Inplace api】Add copy for inplace #54683

Merged
merged 11 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,27 @@ class {} : public egr::GradNodeBase {{
VLOG(5) << \"Running C++ API: \" << \"{}\";
// Before log info
{}
// Forward API Call

bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});

// Node Declaration
std::shared_ptr<{}> grad_node;

// Set grad_node before API Call
{}

// Forward API Call
{}
// Check NaN and Inf if needed
{}
// Get Outputs
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});

// Check Inplace if needed
{}{}
// Node Creation
// Set grad_node after API call
{}

VLOG(4) << \"Finish AD API: {}";
Expand Down Expand Up @@ -296,10 +303,8 @@ class {} : public egr::GradNodeBase {{
}}
"""

FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});

// Node Construction
{}
// Set for forward trace
Expand All @@ -310,6 +315,13 @@ class {} : public egr::GradNodeBase {{
{}
// Set TensorWrappers for Forward Inputs if needed
{}
}}
"""

FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if(require_any_grad) {{

egr::EagerUtils::PassStopGradient({});

// SetGradOutMeta & SetEdges
{}
// SetOutRank & SetHistory & SetGradInMeta
Expand Down Expand Up @@ -914,7 +926,7 @@ def GetPassStopGradientArgsList(self, forward_outputs_position_map):
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
return pass_stop_gradient_args_str

def GenerateNodeCreationCodes(self, for_backward=False):
def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
Expand All @@ -937,6 +949,7 @@ def GenerateNodeCreationCodes(self, for_backward=False):
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(self.backward_api_name)
self.grad_node_name = grad_node_name

# Helper
indent = GetIndent(2)
Expand All @@ -946,6 +959,7 @@ def GenerateNodeCreationCodes(self, for_backward=False):
# See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
# and https://github.com/MRtrix3/mrtrix3/issues/957
node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
node_assignment_str = f"{indent}grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"

# SetAttributes
set_attributes_list = []
Expand Down Expand Up @@ -973,14 +987,25 @@ def GenerateNodeCreationCodes(self, for_backward=False):
pos,
) in backward_forward_inputs_map.items():
is_optional = name in optional_inputs
is_inplace_input = (
is_inplaced and name in self.forward_inplace_map.keys()
)

if is_fwd_input:
if is_optional:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
if is_inplace_input:
set_tensor_wrappers = """{indent}if({name}) {
auto {name}_clone = paddle::experimental::assign({name});
grad_node->SetTensorWrapper{name}(*{name}_clone);}""".format_map(
{"indent": indent, "name": name}
)
else:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
else:
set_tensor_wrappers = (
f"{indent}grad_node->SetTensorWrapper{name}({name});"
)
if is_inplace_input:
set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: # Forwad's output as backward's input
if num_fwd_outputs > 1:
Expand Down Expand Up @@ -1074,18 +1099,25 @@ def GenerateNodeCreationCodes(self, for_backward=False):

node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
self.node_creation_str = ""
if not for_backward:
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str,
pass_stop_gradient_args_str,
node_construction_str,
set_attributes_str,
set_input_tensor_wrappers_str,
set_grad_out_meta_str,
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_output_tensor_wrappers_str,
self.node_creation_before_call_str = (
FORWARD_BODY_BEFORE_API_CALL_TEMPLATE.format(
node_creation_event_str,
node_assignment_str,
set_attributes_str,
set_input_tensor_wrappers_str,
)
)
self.node_creation_after_call_str = (
FORWARD_BODY_AFTER_API_CALL_TEMPLATE.format(
pass_stop_gradient_args_str,
set_grad_out_meta_str,
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_output_tensor_wrappers_str,
)
)
else:
self.node_creation_str = (
Expand Down Expand Up @@ -1615,8 +1647,10 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)

# Node Creation
self.GenerateNodeCreationCodes()
self.GenerateNodeCreationCodes(is_inplaced=is_inplaced)
node_creation_str = self.node_creation_str
node_creation_before_call_str = self.node_creation_before_call_str
node_creation_after_call_str = self.node_creation_after_call_str

dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_ad_function_name = GetDygraphForwardFunctionName(
Expand Down Expand Up @@ -1726,14 +1760,16 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_autograd_meta_str,
forward_api_name,
before_log_str,
compute_require_grad_args_str,
self.grad_node_name,
node_creation_before_call_str,
forward_call_str,
check_nan_inf_str,
get_outputs_str,
outputs_autograd_meta_str,
compute_require_grad_args_str,
check_inplace_str,
bump_inplace_version_str,
node_creation_str,
node_creation_after_call_str,
forward_api_name,
log_str,
returns_str,
Expand Down Expand Up @@ -1882,7 +1918,7 @@ def GenerateHigherOrderNodeCreationCode(self):
namespace,
)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes(True)
next_node_generator.GenerateNodeCreationCodes(for_backward=True)

next_grad_node_creation_str = next_node_generator.node_creation_str
next_grad_node_out_list = next_node_generator.grad_node_out_list
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
func : ElementwiseInferMeta
kernel :
func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad

- op : embedding
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@
kernel :
func : pow
data_type : x
inplace: (x -> out)
backward : pow_grad

- op : prelu
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
from .tensor.math import log10 # noqa: F401
from .tensor.math import multiplex # noqa: F401
from .tensor.math import pow # noqa: F401
from .tensor.math import pow_ # noqa: F401
from .tensor.math import reciprocal # noqa: F401
from .tensor.math import all # noqa: F401
from .tensor.math import any # noqa: F401
Expand Down Expand Up @@ -557,6 +558,7 @@
'abs',
'tril',
'pow',
'pow_',
'zeros_like',
'maximum',
'topk',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
from .math import log # noqa: F401
from .math import multiplex # noqa: F401
from .math import pow # noqa: F401
from .math import pow_ # noqa: F401
from .math import reciprocal # noqa: F401
from .math import reciprocal_ # noqa: F401
from .math import round # noqa: F401
Expand Down Expand Up @@ -360,6 +361,7 @@
'logsumexp',
'multiplex',
'pow',
'pow_',
'prod',
'reciprocal',
'reciprocal_',
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,22 @@ def pow(x, y, name=None):
)


@inplace_apis_in_dygraph_only
def pow_(x, y, name=None):
"""
Inplace version of ``pow`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_pow`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像这个标签不正确,导致没有正确引用。试试改成api_paddle_pow,或者在api_label 中加入该标签
image

"""
if isinstance(y, (int, float)):
return _C_ops.pow_(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
return _C_ops.elementwise_pow_(x, y)
else:
raise TypeError(
'y must be scalar or tensor type, but received: %s ' % (type(y))
)


OP_NAMEMAPPING = {
'elementwise_max': 'maximum',
'elementwise_min': 'minimum',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/utils/inplace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops`
# in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method
# of the original API will be called.
# NOTE(GGBond8488): Simply run the original version of the API under the static graph mode has a low
# probability that the result is inconsistent with the dynamic graph.
def _inplace_apis_in_dygraph_only_(func):
def __impl__(*args, **kwargs):
if not in_dynamic_mode():
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from test_inplace import TestDygraphInplace

import paddle
from paddle.fluid import core
Expand Down Expand Up @@ -213,5 +214,40 @@ def test_errors(self):
self.assertRaises(TypeError, paddle.pow, x, str(y))


class TestInplacePowerScalar(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose

def inplace_api_processing(self, var):
return paddle.pow_(var, 2)

def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)


class TestInplacePowerTensor(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
self.y = paddle.ones([10, 20, 1], dtype="float32") * 2

def set_np_compare_func(self):
self.np_compare = np.allclose

def inplace_api_processing(self, var):
return paddle.pow_(var, self.y)

def non_inplace_api_processing(self, var):
return paddle.pow(var, self.y)

def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar or tensor type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we check gradients of inplace paddle.pow_?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestInplacePowerScalar(Tensor) inherit from TestDygraphInplace, which contains the backward test


if __name__ == '__main__':
unittest.main()