From 2f65f596d6c31e808dac452b459e4116c7953c9f Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Tue, 4 Jan 2022 07:54:17 +0000 Subject: [PATCH 1/7] fix expand_v2 and expand_as_v2 bug --- paddle/fluid/operators/expand_as_v2_op.cc | 4 +++ paddle/fluid/operators/expand_as_v2_op.h | 33 +++++++++++++++++------ paddle/fluid/operators/expand_v2_op.cc | 6 ++++- python/paddle/tensor/manipulation.py | 2 +- 4 files changed, 35 insertions(+), 10 deletions(-) mode change 100644 => 100755 paddle/fluid/operators/expand_as_v2_op.cc mode change 100644 => 100755 paddle/fluid/operators/expand_as_v2_op.h mode change 100644 => 100755 paddle/fluid/operators/expand_v2_op.cc mode change 100644 => 100755 python/paddle/tensor/manipulation.py diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc old mode 100644 new mode 100755 index 5296a144f6247..8b3103179bbde --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -50,6 +50,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "X is the input to be expanded."); + AddInput("Y", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "Expand X according to the shape of Y.") + .AsDispensable(); AddOutput("Out", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "The rank of Output(Out) have the same with Input(X). " diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h old mode 100644 new mode 100755 index 3e8f7d15880bc..9e683a792c61f --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel { PADDLE_ENFORCE_NE(target_shape[i], 0, platform::errors::InvalidArgument( "The value of target shape cannot be zero.")); - if (vec_in_dims[i] != 1) { + if (i < diff) { + PADDLE_ENFORCE_GT( + target_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_as_v2 op.", + target_shape[i])); + repeat_times[i] = target_shape[i]; + } else if (target_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], target_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_as_v2 op.", + vec_in_dims[i], target_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = target_shape[i]; + } + } else { PADDLE_ENFORCE_EQ( - vec_in_dims[i], target_shape[i], + target_shape[i], -1, platform::errors::InvalidArgument( - "The value (%d) of the non-singleton dimension does not match" - " the corresponding value (%d) in " - "target tensor for expand_as_v2 op.", - vec_in_dims[i], target_shape[i])); + "When the value in shape is negative for expand_as_v2 op, " + "only -1 is supported, but the value received is %d.", + target_shape[i])); repeat_times[i] = 1; - } else { - repeat_times[i] = target_shape[i]; } } auto* out0 = context.Output("Out"); diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc old mode 100644 new mode 100755 index dc6da979671e5..71ac5e70e3770 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel { if (x_dims[i] == -1) { out_shape[i] = -1; } else if (expand_shape[i] == -1) { - out_shape[i] = x_dims[i]; + if(x_dims.size() >= i+1){ + out_shape[i] = x_dims[i]; + }else{ + out_shape[i] = -1; + } } else if (expand_shape[i] == -2) { // We use -2 to represent the element in expand_shape is a var. out_shape[i] = -1; diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py old mode 100644 new mode 100755 index b54c3596a26a9..a15c1af391f9f --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None): "you must set its stop_gradient to be False by " "some_var.stop_gradient = True, supporting " "some_var as the input 'x'.") - inputs = {"X": [x]} + inputs = {"X": [x], "Y": [y]} helper = LayerHelper('expand_as', **locals()) dtype = helper.input_dtype(input_param_name='x') From 6e76b29915645a93488ca0e58b5403a5a21518a2 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Tue, 4 Jan 2022 08:20:33 +0000 Subject: [PATCH 2/7] fix expand_v2 and expand_as_v2 bug --- paddle/fluid/operators/expand_v2_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 71ac5e70e3770..175e964ea6830 100755 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -65,7 +65,7 @@ class ExpandV2Op : public framework::OperatorWithKernel { if (x_dims[i] == -1) { out_shape[i] = -1; } else if (expand_shape[i] == -1) { - if(x_dims.size() >= i+1){ + if(static_cast(x_dims.size()) > i){ out_shape[i] = x_dims[i]; }else{ out_shape[i] = -1; From b2f44f0d00e54f6e7432eaebce75b3535555f9a1 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Tue, 4 Jan 2022 19:55:24 +0800 Subject: [PATCH 3/7] fix code style --- paddle/fluid/operators/expand_v2_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 175e964ea6830..3693b0220748d 100755 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -65,9 +65,9 @@ class ExpandV2Op : public framework::OperatorWithKernel { if (x_dims[i] == -1) { out_shape[i] = -1; } else if (expand_shape[i] == -1) { - if(static_cast(x_dims.size()) > i){ + if (static_cast(x_dims.size()) > i) { out_shape[i] = x_dims[i]; - }else{ + } else{ out_shape[i] = -1; } } else if (expand_shape[i] == -2) { From c688b5ae1443641d97fcd6e10356e0735ab88860 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 5 Jan 2022 10:19:42 +0800 Subject: [PATCH 4/7] fix code style --- paddle/fluid/operators/expand_v2_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 3693b0220748d..6d803c500d90f 100755 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -67,7 +67,7 @@ class ExpandV2Op : public framework::OperatorWithKernel { } else if (expand_shape[i] == -1) { if (static_cast(x_dims.size()) > i) { out_shape[i] = x_dims[i]; - } else{ + } else { out_shape[i] = -1; } } else if (expand_shape[i] == -2) { From 562f897b30871955ea60f5aed5f65e5bafb2a644 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 5 Jan 2022 16:14:23 +0800 Subject: [PATCH 5/7] add op_version --- paddle/fluid/operators/expand_as_v2_op.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 8b3103179bbde..1fb863c6aba0a 100755 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_as_v2_op.h" +#include "paddle/fluid/framework/op_version_registry.h" #include #include @@ -148,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ExpandAsV2GradKernel, ops::ExpandAsV2GradKernel); #endif + +REGISTER_OP_VERSION(expand_as_v2) + .AddCheckpoint( + R"ROC(fix expand_as_v2 and add new input [Y])ROC", + paddle::framework::compatible::OpVersionDesc() + .NewInput("Y", "Expand X according to the shape of Y")); \ No newline at end of file From 325f0725bad78938e628d6d16ab34808621d500b Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 5 Jan 2022 19:44:41 +0800 Subject: [PATCH 6/7] fix code style --- paddle/fluid/operators/expand_as_v2_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 1fb863c6aba0a..3f356f579925a 100755 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -153,5 +153,5 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_VERSION(expand_as_v2) .AddCheckpoint( R"ROC(fix expand_as_v2 and add new input [Y])ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput("Y", "Expand X according to the shape of Y")); \ No newline at end of file + paddle::framework::compatible::OpVersionDesc().NewInput( + "Y", "Expand X according to the shape of Y")); \ No newline at end of file From 37dc3420b9d6b00048e1112bb3b671a5f08ec078 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 6 Jan 2022 10:48:46 +0800 Subject: [PATCH 7/7] fix code style --- paddle/fluid/operators/expand_as_v2_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 3f356f579925a..cc293a5aaa0b2 100755 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -10,9 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_as_v2_op.h" -#include "paddle/fluid/framework/op_version_registry.h" #include #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators {