Skip to content

Commit

Permalink
Merge pull request #1 from PaddlePaddle/develop
Browse files Browse the repository at this point in the history
update
  • Loading branch information
AnnaTrainingG committed Mar 25, 2021
2 parents 1eb927f + bf09dcb commit 7d58b91
Show file tree
Hide file tree
Showing 26 changed files with 322 additions and 474 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
"merge_lod_tensor",
"equal",
"sequence_pool",
"recurrent",
"lod_reset"};
for (auto* tmp : node->inputs) {
CHECK(tmp->IsOp());
Expand Down
284 changes: 1 addition & 283 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,276 +10,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using float16 = paddle::platform::float16;

template <typename T>
struct CudaVecType {
using type = T;
static constexpr int vecsize = 1;
};

template <>
struct CudaVecType<platform::float16> {
using type = __half2;
static constexpr int vecsize = 2;
};

template <>
struct CudaVecType<float> {
using type = float4;
static constexpr int vecsize = 4;
};

template <typename T>
class BaseGPUFunctor {
public:
using ELEMENT_TYPE = T;
};

/* ========================================================================== */

/* =========================== relu forward ============================ */
template <typename T>
class ReluGPUFuctor : public BaseGPUFunctor<T> {
private:
T zero_;

public:
ReluGPUFuctor() { zero_ = static_cast<T>(0.0f); }

// for relu forward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* x);

// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T x) {
return x > zero_ ? x : zero_;
}
};

template <>
__device__ __forceinline__ CudaVecType<double>::type
ReluGPUFuctor<double>::Compute(const CudaVecType<double>::type* x) {
// relu forward : out = max(x, 0)
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
return __ldg(x) > zero_ ? __ldg(x) : zero_;
#else
return (*x) > zero_ ? (*x) : zero_;
#endif
}

template <>
__device__ __forceinline__ CudaVecType<float>::type
ReluGPUFuctor<float>::Compute(const CudaVecType<float>::type* xx) {
// relu forward : out = max(xx, 0)
return make_float4((xx->x > zero_) * (xx->x), (xx->y > zero_) * (xx->y),
(xx->z > zero_) * (xx->z), (xx->w > zero_) * (xx->w));
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
ReluGPUFuctor<float16>::Compute(const CudaVecType<float16>::type* in) {
// relu forward : out = max(in, 0)
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(__ldg(in), kzero), __ldg(in));
#else
const float2 xx = __half22float2(*in);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(xx.x),
(xx.y > 0.0f) * static_cast<float>(xx.y));
#endif
}
/* ========================================================================== */

/* =========================== relu backward ============================
*/

template <typename T>
class ReluGradGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;

public:
ReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }

// for relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* out,
const typename CudaVecType<T>::type* dout);

// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T out, const T dout) {
// relu backward : dx = out > 0 ? dout : 0;
return out > zero_ ? dout : zero_;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <>
__device__ __forceinline__ CudaVecType<double>::type
ReluGradGPUFunctor<double>::Compute(const CudaVecType<double>::type* out,
const CudaVecType<double>::type* dout) {
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
return __ldg(out) > zero_ ? __ldg(dout) : zero_;
#else
return (*out) > zero_ ? (*dout) : zero_;
#endif
}

template <>
__device__ __forceinline__ CudaVecType<float>::type
ReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type* out,
const CudaVecType<float>::type* dout) {
// relu backward : dx = out > 0 ? dout : 0;
return make_float4((out->x > zero_) * (dout->x), (out->y > zero_) * (dout->y),
(out->z > zero_) * (dout->z),
(out->w > zero_) * (dout->w));
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* out,
const CudaVecType<float16>::type* dout) {
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 || CUDA_VERSION >= 300
const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(__ldg(out), kzero), __ldg(dout));
#else
const float2 xx = __half22float2(*out);
const float2 yy = __half22float2(*dout);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(yy.x),
(xx.y > 0.0f) * static_cast<float>(yy.y));
#endif
}

/* ========================================================================== */

template <typename T, typename Functor>
__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout,
T* dx, int num, Functor functor) {
using VecType = typename CudaVecType<T>::type;
constexpr int vecsize = CudaVecType<T>::vecsize;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int loop = num / vecsize;
int tail = num % vecsize;
const VecType* in_forward = reinterpret_cast<const VecType*>(forward_data);
const VecType* in_dout = reinterpret_cast<const VecType*>(dout);
VecType* out = reinterpret_cast<VecType*>(dx);

for (int i = idx; i < loop; i += stride) {
out[i] = functor.Compute((in_forward + i), (in_dout + i));
}

while (idx == loop && tail) {
dx[num - tail] =
functor.ComputeRemainder(forward_data[num - tail], dout[num - tail]);
--tail;
}
}

template <typename T, typename Functor>
__global__ void ActivationkernelVec(const T* src, T* dst, int num,
Functor functor) {
constexpr int vecsize = CudaVecType<T>::vecsize;
using VecType = typename CudaVecType<T>::type;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int loop = num / vecsize;
int tail = num % vecsize;
const VecType* in = reinterpret_cast<const VecType*>(src);
VecType* out = reinterpret_cast<VecType*>(dst);

for (int i = idx; i < loop; i += stride) {
out[i] = functor.Compute((in + i));
}

while (idx == loop && tail) {
dst[num - tail] = functor.ComputeRemainder(src[num - tail]);
--tail;
}
}

template <typename DeviceContext, typename Functor>
class ActivationGPUKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = nullptr;
framework::Tensor* out = nullptr;
ExtractActivationTensor(context, &in_x, &out);
auto& dev_ctx = context.template device_context<DeviceContext>();

int num = in_x->numel();
const T* input_data = in_x->data<T>();
T* output_data = out->mutable_data<T>(dev_ctx.GetPlace(),
static_cast<size_t>(num * sizeof(T)));

int block = 512;
#ifdef __HIPCC__
block = 256;
#endif
Functor functor;
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((num / vecsize + block - 1) / block, 1);
ActivationkernelVec<T, Functor><<<grid, block>>>(input_data, output_data,
num, functor);
}
};

template <typename DeviceContext, typename Functor>
class ActivationGradGPUKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor *x, *out, *d_out;
framework::Tensor* d_x = nullptr;
x = out = d_out = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &x, &out, &d_out,
&d_x);
int numel = d_out->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* dx_data = d_x->mutable_data<T>(
dev_ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto* dout_data = d_out->data<T>();

auto* forward_data = dout_data;
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
// Only need forward output Out
forward_data = out->data<T>();
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) {
// Only need forward input X
forward_data = x->data<T>();
}

int block = 512;
#ifdef __HIPCC__
block = 256;
#endif
Functor functor;
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((numel / vecsize + block - 1) / block, 1);
ActivationGradKernelVec<T, Functor><<<grid, block>>>(
forward_data, dout_data, dx_data, numel, functor);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

Expand Down Expand Up @@ -328,21 +60,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */

/* =========================== relu register ============================ */
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFuctor<float>>,
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFuctor<double>>,
ops::ActivationGPUKernel<plat::CUDADeviceContext,
ops::ReluGPUFuctor<plat::float16>>);

REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<float>>,
ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<double>>,
ops::ActivationGradGPUKernel<plat::CUDADeviceContext,
ops::ReluGradGPUFunctor<plat::float16>>);
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor);

REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detection/polygon_box_transform_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class PolygonBoxTransformOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
platform::errors::InvalidArgument(
"The polygon_box_transform operator needs to be executed on GPU."));
auto* in = ctx.Input<Tensor>("Input");
auto in_dims = in->dims();
const T* in_data = in->data<T>();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/dot_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class MatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_,
platform::errors::InvalidArgument(
"Input X's width should be equal to the Y's height, "
"but received X's shape: [%s],"
"but received X's shape: [%s], "
"Y's shape: [%s].",
dim_x, dim_y));
#endif
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
return reduced;
}

static const std::vector<int> GetDimsForKey(
const std::vector<const Tensor*>& inputs) {
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
dims_key.push_back((*it)->dims()[0]);
}
return dims_key;
}

template <typename T>
class ConcatPrimitiveFactory {
public:
Expand Down Expand Up @@ -134,6 +143,8 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
Expand All @@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());

ConcatPrimitiveFactory<T> prim_creator;
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
std::string key = platform::CreateKey(
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
std::string key =
platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
multi_input.size(), ctx.OutputName("Out"), dt);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);

const std::string key_prim = key + "@concat_p";
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/nll_loss_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ static void nll_loss_1D(T* out_data, T* total_weight_data, const T* x_data,
}
PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true,
platform::errors::InvalidArgument(
"label should not be out of bounds."));
"Label value is out of range. "
"Expected label value in range of [0, %d), but "
"received value is %d.",
n_classes, cur_label));

const auto cur_weight =
weight_data ? weight_data[cur_label] : static_cast<T>(1);
Expand Down
Loading

0 comments on commit 7d58b91

Please sign in to comment.