From 1e70140b994d0496460f3b4fe4b167d1fb7003e8 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 14 Sep 2022 10:10:35 +0000 Subject: [PATCH 01/10] Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. --- paddle/phi/kernels/kps/reduce_sum_kernel.cu | 85 ++++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index f3d3246854f33..751c171ee3ede 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -12,12 +12,61 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/platform/enforce.h" namespace phi { +template +void ReduceSumEigen(const KPDevice& dev_ctx, + const DenseTensor& x, + bool reduce_all, + const std::vector& dims, + DataType out_dtype, + DenseTensor* out) { + // Resize Input Tensor + auto new_x = x; + std::vector reduce_dims = + phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); + int added_dims = EigenDimSize - x.dims().size(); + std::vector new_dim(added_dims, 1); + for (int i=0; i::From(x); + + // Create Out Tensor + dev_ctx.Alloc(out); + // Resize Out Tensor + std::vector new_reduced_dim(added_dims, 1); + for (int i=0; idims().size(); i++) { + new_reduced_dim.push_back(out->dims().at(i)); + } + out->Resize(phi::make_ddim(new_reduced_dim)); + constexpr int kReduceOutRank = ReduceAll ? 1 + : EigenDimSize - ReducedDimSize; + auto eigen_out_tensor = EigenTensor::From(*out); + for (int i=0; i::From(phi::make_ddim(reduce_dims)); + // Caculate + eigen_out_tensor.device(*dev_ctx.eigen_device()) = eigen_x_tensor.sum(eigen_reduce_dim); + std::vector final_out_dim; + for (int i=added_dims; idims().size(); i++) { + final_out_dim.push_back(out->dims().at(i)); + } + out->Resize(phi::make_ddim(final_out_dim)); +} + template void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, @@ -29,10 +78,42 @@ void SumRawKernel(const Context& dev_ctx, if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { out_dtype = out->dtype(); } - phi::Reduce( + if (x.numel() > INT_MAX) { + #ifndef PADDLE_WITH_XPU_KP + std::vector reduce_dims = + phi::funcs::details::GetReduceDim(dims.GetData(), x.dims().size(), reduce_all); + + #define CALL_EIGEN_REDUCE_SUM_KERNEL(reduce_rank) \ + case reduce_rank: { \ + if (reduce_all) { \ + ReduceSumEigen(dev_ctx, x, reduce_all, \ + dims.GetData(), out_dtype, out); \ + } else { \ + ReduceSumEigen(dev_ctx, x, reduce_all, \ + dims.GetData(), out_dtype, out); \ + } \ + break; \ + } + + switch(reduce_dims.size()) { + CALL_EIGEN_REDUCE_SUM_KERNEL(1); + CALL_EIGEN_REDUCE_SUM_KERNEL(2); + CALL_EIGEN_REDUCE_SUM_KERNEL(3); + CALL_EIGEN_REDUCE_SUM_KERNEL(4); + CALL_EIGEN_REDUCE_SUM_KERNEL(5); + default: + PADDLE_THROW(phi::errors::Fatal( + "If Input.numel() > INT32_MAX, reduce_sum kernel uses EigenTensor " + "sum for reduce_sum function. As a result, its dim should be <= 5.")); + break; + } + #undef CALL_EIGEN_REDUCE_SUM_KERNEL + #endif + } else { + phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); + } } - } // namespace phi #ifdef PADDLE_WITH_XPU_KP From f4fe24fd1742cdce654f65497ef4dd082efa2917 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 21 Sep 2022 18:13:14 +0800 Subject: [PATCH 02/10] support pure bfloat16 --- paddle/fluid/eager/pylayer/py_layer_node.cc | 6 +- .../platform/device/gpu/gpu_primitives.h | 111 +++++++++--------- paddle/phi/kernels/empty_kernel.cc | 1 + paddle/phi/kernels/funcs/activation_functor.h | 8 +- paddle/phi/kernels/funcs/eigen/broadcast.cu | 1 + .../phi/kernels/gpu/activation_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/activation_kernel.cu | 11 +- paddle/phi/kernels/gpu/adam_kernel.cu | 6 +- paddle/phi/kernels/gpu/clip_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/clip_kernel.cu | 3 +- .../phi/kernels/gpu/embedding_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/embedding_kernel.cu | 3 +- paddle/phi/kernels/gpu/gelu_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/gelu_kernel.cu | 3 +- paddle/phi/kernels/gpu/pad3d_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/pad3d_kernel.cu | 1 + .../kernels/gpu/pixel_shuffle_grad_kernel.cu | 4 +- .../phi/kernels/gpu/pixel_shuffle_kernel.cu | 10 +- paddle/phi/kernels/gpu/tile_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/where_grad_kernel.cu | 5 +- paddle/phi/kernels/gpu/where_kernel.cu | 11 +- python/paddle/fluid/clip.py | 14 ++- python/paddle/fluid/dygraph/amp/auto_cast.py | 6 +- python/paddle/optimizer/adam.py | 2 +- python/paddle/tensor/stat.py | 6 +- 25 files changed, 140 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index 6fb78d20e8a8b..4c6b0f8dc64b3 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -155,7 +155,11 @@ GradNodePyLayer::operator()( if (ctx->forward_input_tensor_is_duplicable[i]) { grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj)); } else { - grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); + if (obj && obj != Py_None) { + grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); + } else { + grad_out.push_back({}); + } } } } else { diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index b99d6de5dbbb4..96eddf09237d9 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr, T value) { CudaAtomicAdd(arr + index, value); } - -#ifdef PADDLE_WITH_CUDA -/* - * One thead block deals with elementwise atomicAdd for vector of len. - * @in: [x1, x2, x3, ...] - * @out:[y1+x1, y2+x2, y3+x3, ...] - * */ -template ::value>::type * = nullptr> -__device__ __forceinline__ void VectorizedAtomicAddPerBlock( - const int64_t len, int tid, int threads_per_block, const T *in, T *out) { - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -} - -// Note: assume that len is even. If len is odd, call fastAtomicAdd directly. -template ::value>::type * = nullptr> -__device__ __forceinline__ void VectorizedAtomicAddPerBlock( - const int64_t len, int tid, int threads_per_block, const T *in, T *out) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -#else - int i = 0; - int loops = len / 2 * 2; - - bool aligned_half2 = - (reinterpret_cast(out) % sizeof(__half2) == 0); - - if (aligned_half2) { - for (i = tid * 2; i < loops; i += threads_per_block * 2) { - __half2 value2; - T value_1 = in[i]; - T value_2 = in[i + 1]; - value2.x = *reinterpret_cast<__half *>(&value_1); - value2.y = *reinterpret_cast<__half *>(&value_2); - atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); - } - for (; i < len; i += threads_per_block) { - fastAtomicAdd(out, i, len, in[i]); - } - } else { - for (int i = tid; i < len; i += threads_per_block) { - fastAtomicAdd(out, i, len, in[i]); - } - } -#endif -} -#endif #endif // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. @@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { } #endif +#ifdef PADDLE_CUDA_FP16 +#ifdef PADDLE_WITH_CUDA +/* + * One thead block deals with elementwise atomicAdd for vector of len. + * @in: [x1, x2, x3, ...] + * @out:[y1+x1, y2+x2, y3+x3, ...] + * */ +template ::value>::type * = nullptr> +__device__ __forceinline__ void VectorizedAtomicAddPerBlock( + const int64_t len, int tid, int threads_per_block, const T *in, T *out) { + for (int i = tid; i < len; i += threads_per_block) { + CudaAtomicAdd(&out[i], in[i]); + } +} + +// Note: assume that len is even. If len is odd, call fastAtomicAdd directly. +template ::value>::type * = nullptr> +__device__ __forceinline__ void VectorizedAtomicAddPerBlock( + const int64_t len, int tid, int threads_per_block, const T *in, T *out) { +#if ((CUDA_VERSION < 10000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + for (int i = tid; i < len; i += threads_per_block) { + CudaAtomicAdd(&out[i], in[i]); + } +#else + int i = 0; + int loops = len / 2 * 2; + + bool aligned_half2 = + (reinterpret_cast(out) % sizeof(__half2) == 0); + + if (aligned_half2) { + for (i = tid * 2; i < loops; i += threads_per_block * 2) { + __half2 value2; + T value_1 = in[i]; + T value_2 = in[i + 1]; + value2.x = *reinterpret_cast<__half *>(&value_1); + value2.y = *reinterpret_cast<__half *>(&value_2); + atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); + } + for (; i < len; i += threads_per_block) { + fastAtomicAdd(out, i, len, in[i]); + } + } else { + for (int i = tid; i < len; i += threads_per_block) { + fastAtomicAdd(out, i, len, in[i]); + } + } +#endif +} +#endif +#endif } // namespace platform } // namespace paddle diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index 2c969cc43d2f1..01b07c438a527 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty, int64_t, bool, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 51420c5ecb6dc..2af106ca38c48 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor { } __device__ __forceinline__ T operator()(const T x) const { - T res = x; - if (res <= zero) { + using MT = + typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type; + MT res = static_cast(x); + if (x <= zero) { res = alpha * expf(res) - alpha; } res *= scale; - return res; + return static_cast(res); } private: diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index 0b749f5c009a5..0c5a3408872c4 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, float); +INSTANTIATION(EigenBroadcastGrad, dtype::bfloat16); INSTANTIATION(EigenBroadcastGrad, dtype::float16); INSTANTIATION(EigenBroadcastGrad, double); INSTANTIATION(EigenBroadcastGrad, dtype::complex); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 53f727ec51a39..b947c70cb89d4 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 0e9e754a99706..e57332c40756a 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow, double, int, int64_t, - phi::dtype::float16) {} -PD_REGISTER_KERNEL(selu, GPU, ALL_LAYOUT, phi::SeluKernel, float, double) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(selu, + GPU, + ALL_LAYOUT, + phi::SeluKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index b20e8610fefaf..b85e1dcdb8259 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -373,7 +373,8 @@ PD_REGISTER_KERNEL(adam, phi::AdamDenseKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); @@ -386,7 +387,8 @@ PD_REGISTER_KERNEL(merged_adam, phi::MergedAdamKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu index 4566e8468ec16..60d311a2555a0 100644 --- a/paddle/phi/kernels/gpu/clip_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index 9e0050db7fdbf..e8d519a5d3a2b 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 6694216214c31..e10d01ce9e4a5 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -256,7 +256,8 @@ PD_REGISTER_KERNEL(embedding_grad, phi::EmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(embedding_sparse_grad, GPU, @@ -264,4 +265,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, phi::EmbeddingSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 90f3cc8d36032..bb22fea5f6493 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding, phi::EmbeddingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu index 1f33d5c901f29..b1ffa921f912b 100644 --- a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu @@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad, phi::GeluGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gelu_kernel.cu b/paddle/phi/kernels/gpu/gelu_kernel.cu index 509a5ccf4d177..e0792c387d751 100644 --- a/paddle/phi/kernels/gpu/gelu_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_kernel.cu @@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu, phi::GeluKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu index e9f820a318482..fb7f1a2325790 100644 --- a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu @@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad, phi::Pad3dGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pad3d_kernel.cu b/paddle/phi/kernels/gpu/pad3d_kernel.cu index d1b1d70667673..fa85c650bc854 100644 --- a/paddle/phi/kernels/gpu/pad3d_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_kernel.cu @@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d, ALL_LAYOUT, phi::Pad3dKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu index 1414fb9df0b41..5c88bbbf42532 100644 --- a/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad, ALL_LAYOUT, phi::PixelShuffleGradKernel, float, - double) {} + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu index e43d6f961236a..09eb0485a297f 100644 --- a/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu +++ b/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h" -PD_REGISTER_KERNEL( - pixel_shuffle, GPU, ALL_LAYOUT, phi::PixelShuffleKernel, float, double) {} +PD_REGISTER_KERNEL(pixel_shuffle, + GPU, + ALL_LAYOUT, + phi::PixelShuffleKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/tile_grad_kernel.cu b/paddle/phi/kernels/gpu/tile_grad_kernel.cu index c092609e623d3..d1e356df401a8 100644 --- a/paddle/phi/kernels/gpu/tile_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu index 709dddcb82c7e..8451e59d7c8f7 100644 --- a/paddle/phi/kernels/gpu/where_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel( int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < N; idx += blockDim.x * gridDim.x) { if (dx != nullptr) { - dx[idx] = cond[idx] ? dout[idx] : 0.; + dx[idx] = cond[idx] ? dout[idx] : static_cast(0.); } if (dy != nullptr) { - dy[idx] = cond[idx] ? 0. : dout[idx]; + dy[idx] = cond[idx] ? static_cast(0.) : dout[idx]; } } } @@ -61,6 +61,7 @@ PD_REGISTER_KERNEL(where_grad, GPU, ALL_LAYOUT, phi::WhereGradKernel, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index 441be02b99efa..a0d8272ea20e0 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -45,5 +45,12 @@ void WhereKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - where, GPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(where, + GPU, + ALL_LAYOUT, + phi::WhereKernel, + float, + double, + int, + int64_t, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index f83947bf6cd07..31f676b33e321 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -52,8 +52,9 @@ def _clip_by_global_norm_using_mp_type(*args): def _cast_to_mp_type_if_enabled(x): - if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type( - ): + if (x.dtype == core.VarDesc.VarType.FP16 + or x.dtype == core.VarDesc.VarType.BF16 + ) and _clip_by_global_norm_using_mp_type(): return x.astype(core.VarDesc.VarType.FP32) else: return x @@ -65,7 +66,8 @@ def _squared_l2_norm(x): """ x = _cast_to_mp_type_if_enabled(x) - if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16: + if core.is_compiled_with_xpu( + ) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16: square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square @@ -501,7 +503,7 @@ def _dygraph_clip(self, params_grads): merge_grad = layers.get_tensor_from_selected_rows(merge_grad) sum_square = _squared_l2_norm(merge_grad) - if sum_square.dtype == core.VarDesc.VarType.FP16: + if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16: sum_square_list_fp16.append(sum_square) elif sum_square.dtype == core.VarDesc.VarType.FP32: sum_square_list_fp32.append(sum_square) @@ -554,8 +556,8 @@ def _dygraph_clip(self, params_grads): continue # TODO(wangxi): use inplace elementwise_mul if need_clip: - clip_input = (clip_var.astype('float16') if g.dtype - == core.VarDesc.VarType.FP16 else clip_var) + clip_input = (clip_var.astype(g.dtype) + if clip_var.dtype != g.dtype else clip_var) new_grad = layers.elementwise_mul(g, clip_input) params_and_grads.append((p, new_grad)) else: diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index d70c976858b96..39d95fbbfc91b 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -91,10 +91,10 @@ } BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} -BF16_BLACK_LIST = {''} +BF16_BLACK_LIST = set() -PURE_BF16_WHITE_LIST = {''} -PURE_BF16_BLACK_LIST = {''} +PURE_BF16_WHITE_LIST = set() +PURE_BF16_BLACK_LIST = set() _g_amp_state_ = None diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 26d082690b7b7..4b34b3cc1a74d 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -275,7 +275,7 @@ def _get_accumulator(self, name, param): def _add_moments_pows(self, p): acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: + if acc_dtype == core.VarDesc.VarType.FP16 or acc_dtype == core.VarDesc.VarType.BF16: acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index b5946459d344c..144620f3c6ea4 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): u = mean(x, axis, True, name) out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name) - n = paddle.cast(paddle.numel(x), x.dtype) \ - / paddle.cast(paddle.numel(out), x.dtype) + dtype = x.dtype + n = paddle.cast(paddle.numel(x), paddle.int64) \ + / paddle.cast(paddle.numel(out), paddle.int64) + n = n.astype(dtype) if unbiased: one_const = paddle.ones([1], x.dtype) n = where(n > one_const, n - 1., one_const) From b420a328c6c8c6a2d1e05b829708eeae92b6dea3 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 21 Sep 2022 18:33:23 +0800 Subject: [PATCH 03/10] support bf16 linear --- .../fluid/operators/fused/fused_gemm_epilogue_op.cu | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 22340210b5715..07870faff73e3 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/scope_guard.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/float16.h" @@ -62,6 +63,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { if (std::is_same::value) { mat_type = CUDA_R_16F; } + if (std::is_same::value) { + mat_type = CUDA_R_16BF; + } if (std::is_same::value) { mat_type = CUDA_R_64F; scale_type = CUDA_R_64F; @@ -352,6 +356,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { if (std::is_same::value) { mat_type = CUDA_R_16F; } + if (std::is_same::value) { + mat_type = CUDA_R_16BF; + } if (std::is_same::value) { mat_type = CUDA_R_64F; scale_type = CUDA_R_64F; @@ -686,12 +693,14 @@ REGISTER_OP_CUDA_KERNEL( fused_gemm_epilogue, ops::FusedGemmEpilogueKernel, ops::FusedGemmEpilogueKernel, - ops::FusedGemmEpilogueKernel); + ops::FusedGemmEpilogueKernel, + ops::FusedGemmEpilogueKernel); REGISTER_OP_CUDA_KERNEL( fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradKernel, ops::FusedGemmEpilogueGradKernel, ops::FusedGemmEpilogueGradKernel); + paddle::platform::float16>, + ops::FusedGemmEpilogueKernel); #endif From 7ff13881509ee3188056b722ef64866f4fb37d3d Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 27 Sep 2022 16:44:04 +0800 Subject: [PATCH 04/10] update PR to pass CI --- paddle/fluid/eager/pylayer/py_layer_node.cc | 6 +----- paddle/phi/kernels/gpu/where_kernel.cu | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index 4c6b0f8dc64b3..6fb78d20e8a8b 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -155,11 +155,7 @@ GradNodePyLayer::operator()( if (ctx->forward_input_tensor_is_duplicable[i]) { grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj)); } else { - if (obj && obj != Py_None) { - grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); - } else { - grad_out.push_back({}); - } + grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); } } } else { diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index a0d8272ea20e0..09a974fbc2340 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -53,4 +53,5 @@ PD_REGISTER_KERNEL(where, double, int, int64_t, + phi::dtype::float16, phi::dtype::bfloat16) {} From b9a7c14a85bfac18fdef3e3138f217dd200edf17 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 27 Sep 2022 16:46:41 +0800 Subject: [PATCH 05/10] tiny fix where_grad_kernel.cu --- paddle/phi/kernels/gpu/where_grad_kernel.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu index 8451e59d7c8f7..4c411bfb9cd5a 100644 --- a/paddle/phi/kernels/gpu/where_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -61,6 +61,7 @@ PD_REGISTER_KERNEL(where_grad, GPU, ALL_LAYOUT, phi::WhereGradKernel, + phi::dtype::float16, phi::dtype::bfloat16, float, double, From 817d7ee069e610236ccd282ecdb3cadf374d6b49 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 28 Sep 2022 02:35:34 +0000 Subject: [PATCH 06/10] Support bfloat16 type for reducer and sharding. --- paddle/fluid/distributed/collective/reducer.cc | 8 ++++++++ paddle/phi/kernels/cpu/fill_kernel.cc | 1 + paddle/phi/kernels/gpu/fill_kernel.cu | 1 + .../sharding/group_sharded_optimizer_stage2.py | 1 + .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 6 ++++++ .../fleet/meta_parallel/sharding/group_sharded_storage.py | 2 ++ .../fleet/meta_parallel/sharding/sharding_utils.py | 1 + 7 files changed, 20 insertions(+) diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 75a16bac37130..0d46425b2e832 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -254,6 +254,10 @@ static void ConcatTensorsWithType( ConcatTensorsForAllReduce()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -281,6 +285,10 @@ static void SplitTensorsWithType(const DeviceContext &context, SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc index ee8dac7f6770c..adca39e6ab95d 100644 --- a/paddle/phi/kernels/cpu/fill_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu index 141e47b8cb109..3fedb4118ff9e 100644 --- a/paddle/phi/kernels/gpu/fill_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index ec5a832aa1c47..ce3b122be966e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -41,6 +41,7 @@ alignment = {"gpu": 256, "cpu": 4096} align = { Type.fp16.value: 2, + Type.bf16.value: 2, Type.fp32.value: 4, } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index 5c0e61cd8be0e..46dd3ae99995f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -514,6 +514,12 @@ def _rank_buffer_size(self, buffer_max_size, model_size): "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" .format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2**19)) + if Type.bf16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + logger_.info( + "====== BF16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======" + .format(rank_buffer_size[Type.bf16.value] / 2**19, + model_size / 2**19)) if Type.fp32.value in rank_buffer_size.keys(): # FP32 GradStorage and model size logger_.info( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index c44872491093e..71d51bafbf00a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -53,6 +53,8 @@ def __init__(self, size, dtype, device, convert_cpu=False): dtype=np.float16) if Type.fp16.value == dtype else np.zeros( size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if dtype == Type.bf16.value: + self.buffer = paddle.cast(self.buffer, dtype = paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index d21502bcc16b8..42f43ce537748 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -45,6 +45,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 From 44abf06a9db978efc2fe97dd6bc09a89548f85be Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 28 Sep 2022 03:53:46 +0000 Subject: [PATCH 07/10] Fix some bug. --- .../fleet/meta_parallel/sharding/group_sharded_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 8cff407363a3b..7eb7b1e8784aa 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -41,6 +41,7 @@ class Type(Enum): Type of trainable parameters """ fp16 = paddle.float16 + bf16 = paddle.bfloat16 fp32 = paddle.float32 From 384d4971ea097fd9da816c072712ca636ead9e53 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Fri, 14 Oct 2022 11:28:15 +0000 Subject: [PATCH 08/10] Polish code. --- .../fleet/meta_parallel/sharding/group_sharded_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 71d51bafbf00a..5b9ab7343f08c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -54,7 +54,7 @@ def __init__(self, size, dtype, device, convert_cpu=False): size, dtype=np.float32) self.buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) if dtype == Type.bf16.value: - self.buffer = paddle.cast(self.buffer, dtype = paddle.bfloat16) + self.buffer = paddle.cast(self.buffer, dtype=paddle.bfloat16) else: self.buffer = paddle.zeros(size, dtype=dtype) From d012390cad7935031e92a3b09e1a05ed47d616c0 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Sun, 16 Oct 2022 14:55:19 +0000 Subject: [PATCH 09/10] Polise code. --- paddle/phi/kernels/kps/reduce_sum_kernel.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index 4fef44551a0a5..c5a30a6a634a8 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/phi/kernels/reduce_sum_kernel.h" #include #include "paddle/phi/core/enforce.h" From 480c7326c5376f11395edb0a44507a81fe641212 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Mon, 17 Oct 2022 03:13:57 +0000 Subject: [PATCH 10/10] Add bfloat16 datatype in fill_grad kernels. --- paddle/phi/kernels/cpu/fill_grad_kernel.cc | 1 + paddle/phi/kernels/gpu/fill_grad_kernel.cu | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc index ee676773762ca..07448c85a57d6 100644 --- a/paddle/phi/kernels/cpu/fill_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu index 32559ba95dfbc..e18bb5c6dbb24 100644 --- a/paddle/phi/kernels/gpu/fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(fill_grad, int64_t, int, paddle::platform::float16, + paddle::platform::bfloat16, bool) {}