Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Support fp16 in flatten op to avoid error in recompute #35588

Merged
merged 1 commit into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/flatten_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down Expand Up @@ -50,6 +51,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
Expand All @@ -63,6 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _split_activation(tensor):

# use inplace operation to save memory
data = tensor.flatten_()

part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _update_list(self):
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported.
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
Expand Down