Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified reduce_max reduce_min reduce_prod for higher_performance and fix a bug in reduce_op.cuh #32974

Merged
merged 43 commits into from
Jun 22, 2021

Conversation

AnnaTrainingG
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG commented May 18, 2021

PR types

Function optimization

PR changes

OPs

Describe

modified reduce_min reduce_max reduce_prod reduce_all reduce_any

ctest结果:

Test project /paddle_test/commit/Paddle/build
Start 709: test_max_op
100% tests passed, 0 tests failed out of 1

Total Test time (real) = 7.51 sec
Test project /paddle_test/commit/Paddle/build
Start 719: test_min_op
100% tests passed, 0 tests failed out of 1

Total Test time (real) = 6.81 sec
Test project /paddle_test/commit/Paddle/build
Start 826: test_prod_op

100% tests passed, 0 tests failed out of 1

Total Test time (real) = 7.78 sec

以max 为例进行性能比对:

axis case pytorch us paddle_old us paddle_new us 加速比 old/new 加速比pytorch/padle_new 是否为benchmark
axis=0 [512    2048] 12.442 28.272 10.821 2.61 1.15
axis=0 [128    1024] 5.595 5.181 3.711 1.40 1.51
axis=0 [30522  1024] 162.77 1767.3 152.229 11.61 1.07
axis=0 [1024   16] 4.703 2.471 3.509 0.70 1.34
axis=0 [256    12800] 18.756 81.647 17.734 4.60 1.06
axis=0 [256    10240] 15.742 59.888 15.379 3.89 1.02
axis=0 [1024   1280] 11.625 33.204 8.399 3.95 1.38
axis=0 [32768  1280] 205.95 3504.7 198.15 17.69 1.04
axis=0 [30522  10240] 1414.6 32643 1437.523 22.71 0.98
axis=0 [256    10240] 15.257 65.901 14.79 4.46 1.03
axis=0 [1024   1280] 8.265 31.31 7.158 4.37 1.15
axis=0 [32768  1280] 207.58 3501 198.297 17.66 1.05
axis=0 [30522  10240] 1415.5 32554 1438.646 22.63 0.98
axis=0 [2560   10240] 127.21 585.19 126.275 4.63 1.01
axis=0 [10240  1280] 76.668 413.34 67.667 6.11 1.13
axis=0 [32768  2560] 390.23 8323.7 383.609 21.70 1.02
axis=0 [30522  1024] 160.21 1808.7 151.341 11.95 1.06
axis=0 [16 16  1   1] 2.884 1.332 1.44 0.93 2.00

benchmark性能数据如下:

axis case pytorch paddle paddle_new_last old/new pytorch/new
axis: [2, 3] [16 2048 33 33] 171.1 199.8 164.36 1.22 1.04
axis: [1] [16 8 128] 3.285 4.234 1.322 3.20 2.48
axis: [0] [16 16 1 1] 2.884 1.568 1.44 1.09 2.00
axis: [] [30522 1024] 146.45 143.12 142.99 1.00 1.02
reduce_sum优化前后性能变化    
reduce维度 加速比 与pytorch对比情况
axis = 0 1.4 ~ 22.7 打平或者超过pytorch。
axis = -1 1.0 ~1.3 打平或者超过pytorch,17个case中有2个case差于pytorch,约为pytorch时间的2倍
axis = 1 2.44 ~24.88 打平或者超过pytorch, 17个case中有1个case差于pytorch,约为pytorch时间的2倍
axis =[] 1.0 ~1.03 打平或者超过pytorch, 17个case中有1个case差于pytorch,约为pytorch时间的2倍

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented May 27, 2021

CLA assistant check
All committers have signed the CLA.

@AnnaTrainingG AnnaTrainingG changed the title Reduce max min prod all any Reduce max min prod May 28, 2021
@AnnaTrainingG AnnaTrainingG changed the title Reduce max min prod modified reduce_max reduce_min reduce_prod for higher_performance and fix a bug in reduce_op.cuh May 28, 2021
paddle/fluid/operators/reduce_ops/reduce_functor_op.h Outdated Show resolved Hide resolved
paddle/fluid/operators/reduce_ops/reduce_max_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/reduce_ops/reduce_max_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/reduce_ops/reduce_max_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/reduce_ops/reduce_op.cuh Outdated Show resolved Hide resolved
paddle/fluid/operators/reduce_ops/reduce_op.cuh Outdated Show resolved Hide resolved
@xingfeng01
Copy link
Contributor

LGTM

1 similar comment
@ZzSean
Copy link
Contributor

ZzSean commented Jun 21, 2021

LGTM

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}
}

// module function designed for global function
Copy link
Contributor

@Xreki Xreki Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉模板可以再简化一下,一些参数没有必要通过模板传,比如ReduceType。关于ReduceTypeif判断只执行一次,并没有在循环里面,所以通过输入参数传也不会多影响性能。减少一些模板,应该能够剪短一些编译时间。

对于TransformOp,感觉LaunchReduceKernelLaunchKernel这两个函数是不需要将TransformOp作为模板的?ReduceKernelFunction看起来是需要的。另外,LaunchReduceKernelLaunchKernel函数命名缺乏辨识度,不能准确地表达函数的功能。

@@ -141,21 +174,24 @@ struct ReduceConfig {
void Run() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L170的comment:输入参数都建议改成const std::vector &类型。

@@ -523,22 +606,22 @@ static void launchKernel(const Tx* x_data, Ty* y_data,
ReduceKernelFunction<
Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, 128, kRank, kReduceRank,
ReduceType::kReduceHigherDim><<<grid, block, 0, stream>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L597 - L599 comment:直接写成CUB_REDUCE_TYPE_CASE(ReduceType::kReduceLastDim)这样?若ReduceType不作为模板,也就不需要这个swith case了。

// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
// y_data;
config.SetOutputData(y_data, x.place(), tmp);
framework::Tensor tmp;
config.SetOutputData(y_data, x.place(), &tmp);

if (config.reduce_num == 1) {
auto out_dims = y->dims();
framework::TensorCopy(x, y->place(), y);
y->Resize(out_dims);
return;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L684 - L689可以挪到L674或L677前面?

}
};

template <typename T, template <typename, typename> class ReduceOp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的实现都可能需要复用到别的算子里面(比如broadcast反向),但ReduceCudaKernel只用于reduce_xxx算子的实现,所以L749 - L771最好不要放到这个头文件里面。

int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CUDA_KERNEL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从注释来看,原来之所以加这个ifdef是因为原来的reduce采用Eigen实现,而Eigendouble的支持有问题。我们已经全部改成了cuda+cub的方式,或许这个ifdef可以去掉。

@Xreki Xreki merged commit 480b284 into PaddlePaddle:develop Jun 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants