Skip to content

Commit

Permalink
[Pten]Refactor elementwise_add grad / double grad / triple grad Kerne…
Browse files Browse the repository at this point in the history
…l and move them to pten (#39048)

* refactor elementwise add grad

* fix compile bugs

* fix unit test bugs

* fix file conflicts

* fix bugs when buildPtenContext
  • Loading branch information
YuanRisheng committed Jan 24, 2022
1 parent 43919d0 commit 3bf3a6e
Show file tree
Hide file tree
Showing 17 changed files with 743 additions and 396 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ static void BuildDygraphPtenKernelContext(
size_t end_idx = start_idx + outs_vector.size();

for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (outs_vector[offset] == nullptr) {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
continue;
}
auto* var = outs_vector[offset]->MutableVar();
framework::Tensor* tensor_out = nullptr;
if (var->template IsType<framework::LoDTensor>()) {
Expand Down
28 changes: 0 additions & 28 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,6 @@ class CPUDeviceContext;
namespace paddle {
namespace operators {

template <typename T>
struct SameDimsElemwiseAdd<
platform::CPUDeviceContext, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.VADD(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
}
};

template <typename T>
struct SameDimsElemwiseAdd<
platform::CPUDeviceContext, T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
eigen_z.device(place) = eigen_x + eigen_y;
}
};

class ElementwiseAddOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Add"; }
Expand Down
130 changes: 2 additions & 128 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,139 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/pten/kernels/gpu/elementwise.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
int loop = size / vec_size;
int remainder = size % vec_size;
const float4* dout_vec = reinterpret_cast<const float4*>(dout);
float4* dx_vec = reinterpret_cast<float4*>(dx);
float4* dy_vec = reinterpret_cast<float4*>(dy);
float4 tmp_loop;

for (int i = tid; i < loop; i += stride) {
tmp_loop = dout_vec[i];
dx_vec[i] = tmp_loop;
dy_vec[i] = tmp_loop;
}

if (tid == loop && remainder != 0) {
T tmp_rem;
while (remainder) {
int idx = size - remainder;
remainder--;
tmp_rem = dout[idx];
dx[idx] = tmp_rem;
dy[idx] = tmp_rem;
}
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();

// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
auto* dout_data = dout->data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) /
PREDEFINED_BLOCK_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
}

} // namespace operators
namespace operators {} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
Expand Down
Loading

0 comments on commit 3bf3a6e

Please sign in to comment.