Skip to content

Commit

Permalink
Support FP16 mean (#38289)
Browse files Browse the repository at this point in the history
* mean first version

* fix scalar mean

* add fp16 dtype for api
  • Loading branch information
sneaxiy committed Dec 21, 2021
1 parent c197d73 commit 643a268
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 63 deletions.
15 changes: 11 additions & 4 deletions paddle/fluid/operators/kernel_primitives/functor_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

#pragma once

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -74,16 +77,20 @@ struct IdentityFunctor {
*/
template <typename Tx, typename Ty = Tx>
struct DivideFunctor {
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<Tx>(1.0f); }
private:
using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;

public:
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }

HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {}
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv);
return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
}

private:
Tx n_inv;
MPType n_inv;
};

/**
Expand Down
53 changes: 24 additions & 29 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}

HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

private:
T n_inv;
};

template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
T data = in_data[0];
auto data = static_cast<MT>(in_data[0]);
for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = data / (static_cast<T>(N));
out_data[idx] = static_cast<T>(data / (static_cast<MT>(N)));
}
}

Expand All @@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");

output->mutable_data<T>(context.GetPlace());
auto size_prob = input->numel();
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace());
auto numel = input->numel();
auto rank = input->dims().size();
auto place = context.GetPlace();
auto stream = context.cuda_device_context().stream();

DivideFunctor<T> transformer(size_prob);
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
in_data, transformer);
size_t temp_storage_bytes = 0;
if (rank == 0) { // scalar
auto gpu_place = BOOST_GET(platform::CUDAPlace, place);
memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T),
stream);
return;
}

auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_GPU_SUCCESS(err);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
context.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_GPU_SUCCESS(err);
using MT = typename details::MPTypeTrait<T>::Type;
using Div = kernel_primitives::DivideFunctor<T, MT>;
std::vector<int> reduce_dims;
reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i);
}
TensorReduceFunctorImpl<T, T, kernel_primitives::AddFunctor, Div>(
*input, output, Div(numel), reduce_dims, stream);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_functor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct CustomSub {

template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>;
using Transformer = kps::DivideFunctor<Tx, Ty>;

inline Ty initial() { return static_cast<Ty>(0.0f); }

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
REGISTER_OP_CUDA_KERNEL(
reduce_mean,
ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::DivideFunctor>,
ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);
13 changes: 13 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,18 @@ struct MeanGradFunctor {
}
};

// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean
// and reduce_mean_grad later.
struct FP16MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = (dy->template cast<float>().broadcast(dim) /
dx->template cast<float>().constant(size))
.template cast<platform::float16>();
}
};

} // namespace operators
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;

using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
57 changes: 38 additions & 19 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/string/string_helper.h"

// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
Expand Down Expand Up @@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
}
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, reducer.initial(), stream);
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, platform::float16>::value,
void>::type
CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data,
const TransformOp& transform, int reduce_num,
const platform::Place& place, gpuStream_t stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}

template <typename Tx, typename Ty, template <typename> class ReduceOp,
typename TransformOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform,
std::vector<int> origin_reduce_dims,
const std::vector<int>& origin_reduce_dims,
gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
Expand Down Expand Up @@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}

config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.reduce_num == numel) &&
(!std::is_same<Tx, paddle::platform::float16>::value);
constexpr bool kIsTxFP16 = std::is_same<Tx, paddle::platform::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
if (use_cub_reduce) {
// launch CUB::Reduce
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
x.place());
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(),
stream);

CubTensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp>(
x_data, y_data, transform, config.reduce_num, x.place(), stream);
return;
}

Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1;
for (int i = 0; i < input->dims().size(); i++) {
for (auto i : reduce_dims) {
reduce_num *= (input->dims())[i];
}
gpuStream_t stream = context.cuda_device_context().stream();
Expand All @@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream));
} else {
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>(
*input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream);
using MPType = typename details::MPTypeTrait<T>::Type;
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
*input, output, TransformOp<T, MPType>(reduce_num), reduce_dims,
stream);
}
}
};
Expand Down
Loading

0 comments on commit 643a268

Please sign in to comment.