Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten]Refactor the Elementwise_add Kernel #37043

Merged
merged 12 commits into from
Nov 12, 2021
22 changes: 1 addition & 21 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,14 @@ limitations under the License. */
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/nn.h"
#include "paddle/pten/include/math.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());

auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::ElementwiseAdd<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
}
};

template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ limitations under the License. */

#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/top/api dirs
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/nn.h"
#include "paddle/pten/include/math.h"

namespace paddle {
namespace operators {
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ void LaunchBroadcastElementwiseCudaKernel(
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// *_tmp for cache DenseTensor
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
Expand Down Expand Up @@ -197,7 +201,11 @@ void LaunchElementwiseCudaKernel(
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// *_tmp for cache DenseTensor
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
Expand Down
202 changes: 20 additions & 182 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"

// only can include the headers in paddle/top/api dirs
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/kernels/cpu/funcs/elementwise.h"
#include "paddle/pten/kernels/functions/cpu/elementwise.h"
#include "paddle/pten/kernels/functions/general/elementwise_base.h"

#if defined(__NVCC__) || defined(__HIPCC__)
Expand Down Expand Up @@ -136,28 +136,6 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
return axis;
}

/*
* Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped.
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *is_run_common_broadcast* is a flag to record whether to run
* common broadcast code.
*/
inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post,
int *is_run_common_broadcast) {
pten::general::get_mid_dims(x_dims, y_dims, axis, pre, n, post,
is_run_common_broadcast);
}

inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这层壳的保留是必要的吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const int *index_array) {
return pten::GetElementwiseIndex(x_dims_array, max_dim, index_array);
Expand Down Expand Up @@ -1098,154 +1076,6 @@ inline framework::DDim trim_trailing_singular_dims(
return pten::general::trim_trailing_singular_dims(dims);
}

template <typename T, typename DeviceContext>
class RowwiseTransformIterator;

template <typename T, typename DeviceContext>
class MidWiseTransformIterator;

// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}

RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}

return *this;
}

bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}

bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}

const T &operator*() { return ptr_[i_]; }

private:
const T *ptr_;
int i_;
int64_t n_;
};

template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public:
MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}

MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
}
return *this;
}

bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}

bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}

const T &operator*() { return ptr_[i_]; }

private:
const T *ptr_;
int64_t i_;
int64_t j_;
int64_t n_;
int64_t post_;
};

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
: super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access;

private:
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
};

template <typename T>
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access;

private:
unsigned int post_;
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
};
#endif

template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor {
Expand Down Expand Up @@ -1274,21 +1104,27 @@ class TransformFunctor {
platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), z_, func_);
pten::general::RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(x_, n), z_, func_);
pten::general::RowwiseTransformIterator<T, DeviceContext>(x_, n),
z_, func_);
}
}

inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
pten::general::MidWiseTransformIterator<T, DeviceContext>(y_, n,
post),
z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(x_, n, post), z_, func_);
pten::general::MidWiseTransformIterator<T, DeviceContext>(x_, n,
post),
z_, func_);
}
}

Expand Down Expand Up @@ -1617,13 +1453,13 @@ void ElemwiseGradComputeWithBroadcast(
if (is_xsize_larger) {
auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
pten::general::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n,
&post, &is_run_common_broadcast);
} else {
auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
pten::general::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n,
&post, &is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
Expand Down Expand Up @@ -2020,7 +1856,8 @@ void FusedElemwiseAndActComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis;

int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
Copy link
Contributor

@chenwhql chenwhql Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数命名建议按照code style统一,采用驼峰式命名,这里原先就不太规范,可在后续PR更改

&is_run_common_broadcast);
if (post == 1) {
int h = pre;
int w = n;
Expand Down Expand Up @@ -2567,7 +2404,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis;

int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
&is_run_common_broadcast);
const T *x_data = nullptr;
const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>();
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License. */
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/kernels/cuda/funcs/elementwise/elementwise.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise.h"

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
Expand All @@ -45,7 +45,11 @@ void LaunchSameDimsElementwiseCudaKernel(
std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// *_tmp for cache DenseTensor
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ add_subdirectory(tests)

# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor kernel_factory kernel_context)
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu nn_cpu)
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu)
set(PTEN_DEPS ${PTEN_DEPS} nary unary binary)
if(WITH_GPU OR WITH_ROCM)
set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda nn_cuda)
set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda)
endif()
if(WITH_XPU)
set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_subdirectory(lib)

cc_library(pten_api SRCS all.cc DEPS linalg_api math_api creation_api manipulation_api nn_api)
cc_library(pten_api SRCS all.cc DEPS linalg_api math_api creation_api manipulation_api)
2 changes: 2 additions & 0 deletions paddle/pten/api/include/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ namespace experimental {
// TODO(chenweihang): move mean API into stat.h/cc
Tensor mean(const Tensor& x);

Tensor add(const Tensor& x, const Tensor& y);

} // namespace experimental
} // namespace paddle
Loading