Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pnorm] optimize p_norm for special cases #37685

Merged
merged 15 commits into from
Dec 13, 2021
Merged

Conversation

LemonNoel
Copy link
Contributor

@LemonNoel LemonNoel commented Nov 29, 2021

PR types

Performance optimization

PR changes

OPs

Describe

Optimize p_norm for two kinds of special cases:
(1) shape=[2, 1000, 1000], reduce axis=0
(2) shape=[1, 2000000, 1], reduce axis=1

The original version is paddlepaddle-gpu == 2.2.1. The Time denotes seconds per 1k steps.

  • Forward Cases
Tensor Shape OP Original Time Current Time Speedup
[2, 1k, 1k] paddle.norm(x, axis=0, p=3) 4.0680373 0.0948312 45.2
[1, 2m, 1] paddle.norm(x, axis=1, p=3) 2.5773182 0.0548460 51.6
[1m, 2] paddle.norm(x, axis=1, p=3) 4.0455515 0.0652089 57.9
[1k, 1k] paddle.norm(x, axis=0, p=3) 0.0214734 0.0451496 -
  • Backward Cases
Shape OP Original Time Current Time Speedup
[2, 1k, 1k] paddle.norm(x, axis=0, p=3).sum().backward() 4.6668179 0.3521516 13.3
[1, 2m, 1] paddle.norm(x, axis=1, p=3).sum().backward() 4.8750155 0.2606518 18.8
[1m, 2] paddle.norm(x, axis=1, p=3).sum().backward() 0.0756586 0.0661562 1.1
[1k, 1k] paddle.norm(x, axis=0, p=3).sum().backward() 0.0624404 0.0668254 -

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@LemonNoel LemonNoel marked this pull request as ready for review November 29, 2021 15:04
@LemonNoel LemonNoel marked this pull request as draft November 29, 2021 15:05
@LemonNoel LemonNoel marked this pull request as ready for review November 29, 2021 15:05
@LemonNoel LemonNoel marked this pull request as draft November 29, 2021 15:06
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Dec 6, 2021
@PaddlePaddle PaddlePaddle unlocked this conversation Dec 6, 2021
@ZHUI ZHUI marked this pull request as ready for review December 6, 2021 13:49
@LemonNoel
Copy link
Contributor Author

@Avin0323 你好,辛苦看下benchmark的CI。这个 PR 改动了 p_norm 的代码,优化了特殊 shape 的性能。从 CI 结果看 p_norm 的测试耗时都降低了。另外,还修改了 cmake 文件,辛苦 review 一下。

ZHUI
ZHUI previously approved these changes Dec 8, 2021
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么先 static_cast<double>(x) cast 为double?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里保留了原始实现

auto xdim = in_x->dims();
auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post, asvector);
std::vector<int> reduce_axis = {axis};

auto& dev_ctx = ctx.cuda_device_context();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dev_ctx no usage in the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,晚点提PR删掉

} else {
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
porder, norm);
framework::Tensor tmp_x;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

记一下todo,这里的 tmp_x 需要尽早去掉。运行时显存占用提升很多。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO 好的,已提交卡片

auto negs = dx->constant(static_cast<T>(-1.));
auto zeros = dx->constant(static_cast<T>(0.));
auto positives = (*x) > zeros;
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里反向,都是走的eigen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

计算是eigen tensor

@@ -260,32 +254,38 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
float porder = ctx.Attr<float>("porder");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
bool reduce_all = ((axis < 0) || (in_norm->numel() == 1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis < 0 是对应 reduce_all 吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post, asvector);
const std::vector<int> dims = {axis};

auto& dev_ctx = ctx.cuda_device_context();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dev_ctx 是否还有使用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

Copy link
Contributor

@Avin0323 Avin0323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for PR-CI-OP-benchmark and changes of unity_build_rule.cmake

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit 10d9ab4 into PaddlePaddle:develop Dec 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants