Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elementwise max min #7538

Merged
merged 13 commits into from
Jan 18, 2018
Merged

Conversation

JiayiFeng
Copy link
Collaborator

@JiayiFeng JiayiFeng commented Jan 15, 2018

solve #7567

};

template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
}
};

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JiayiFeng We can reduce more codes. But we can do it later, if we have more element-wise operators to implement. Anyway, let's find a general pattern and clean code later.

template <typename T>
struct MaxGradFunctor1 {
  template <typename Device, typename Dxe, typename Xe, typename Ye, typename Dze>
  void operator()(Device d, Dxe dx_e, Xe, x_e, Ye, y_e, Dze, dz_e) {
    dx_e.device(d) = (x_e > y_e).template cast<T>() * dz_e;
  }
};

template <typename T>
struct MinGradFunctor1 {
  template <typename Device, typename Dxe, typename Xe, typename Ye, typename Dze>
  void operator()(Device d, Dxe dx_e, Xe, x_e, Ye, y_e, Dze, dz_e) {
    dx_e.device(d) = (x_e < y_e).template cast<T>() * dz_e;
  }
};

template <typename T, typename Functor1, typename Functor2>
struct ElementwiseMinMaxGradFunctor {
  template <typename Device, typename X, typename Y, typename Z, typename dX,
            typename dY, typename dZ>
  void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
    auto x_e = framework::EigenVector<T>::Flatten(*x);
    auto y_e = framework::EigenVector<T>::Flatten(*y);
    auto dz_e = framework::EigenVector<T>::Flatten(*dz);

    if (dx) {
      auto dx_e = framework::EigenVector<T>::Flatten(*dx);
      Functor1 functor;
      functor(d, dx_e, x_e, y_e, dz_e);
    }
    if (dy) {
      auto dy_e = framework::EigenVector<T>::Flatten(*dy);
      Functor2 functor;
      functor(d, dy_e, x_e, y_e, dz_e);
    }
  }
};

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jan 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@QiJune
I agree. Finding a general pattern in the backward pass is not as easy as that in forwarding pass, for backward op takes Z, dZ, X, Y while forwarding op only takes X and Y.

QiJune
QiJune previously approved these changes Jan 17, 2018
Copy link
Member

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JiayiFeng JiayiFeng merged commit 37a9437 into PaddlePaddle:develop Jan 18, 2018
@JiayiFeng JiayiFeng deleted the dev_elementwise_max_min branch January 18, 2018 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants