Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different data type between input and output #32823

Merged
merged 2 commits into from
May 10, 2021

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented May 10, 2021

PR types

Performance optimization

PR changes

OPs

Describe

elementwise 实现模版支持输入和输出的数据类型不同,并适配 abs 验证,性能数据如下:

config dt paddle old paddle new pro pytorch pro
[16, 128, 257, 257] fp16 852.23us 668.53us 27.48% 750.56us 12.27%
[16, 128, 257, 257] fp32 1.3771ms 1.3109ms 5.05% 1.3115ms 0.05%
优化后,性能已经打平甚至优于 pytorch

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = tid < size ? 1 : 0;
ScalarKernelImpl(data, func, tid, remain);
}

template <ElementwiseType ET, typename T, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以设置默认值OutT = InT吗?

int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size =
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>()));
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个循环的写法其实可以简写成如下:

for (auto in : ins) {
  vec_size =
        std::min<int>(vec_size, GetVectorizedSizeImpl(in->data<InT>()));
}

@Xreki Xreki merged commit 3419de5 into PaddlePaddle:develop May 10, 2021
@ZzSean ZzSean deleted the support_two_dt branch May 18, 2021 05:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants