Skip to content

Commit

Permalink
Merge pull request #7529 from JiayiFeng/remove_functor1
Browse files Browse the repository at this point in the history
remove `functor1` of ElementwiseGradCompute
  • Loading branch information
JiayiFeng committed Jan 15, 2018
2 parents f23691d + 6ee8a2e commit 5f44813
Show file tree
Hide file tree
Showing 5 changed files with 1 addition and 40 deletions.
18 changes: 0 additions & 18 deletions paddle/operators/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,6 @@ struct ElementwiseAddGradFunctor {
}
};

template <typename T>
struct ElementwiseAddOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e.sum();
}
}
};

template <typename T>
struct ElementwiseAddBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
Expand Down Expand Up @@ -142,7 +125,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddOneGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
}
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
}
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/operators/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
EIGEN_FUNCTOR(Div, EIGEN_DIV);

template <typename DeviceContext, typename T, typename functor,
typename functor1, typename broadcastfunctor,
typename broadcast2functor>
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;

Expand Down
18 changes: 0 additions & 18 deletions paddle/operators/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,6 @@ struct ElementwiseSubGradFunctor {
}
};

template <typename T>
struct ElementwiseSubOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) * dz_e.sum();
}
}
};

template <typename T>
struct ElementwiseSubBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
Expand Down Expand Up @@ -106,7 +89,6 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubOneGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
}
Expand Down

0 comments on commit 5f44813

Please sign in to comment.