Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… dev/bf16_op_2
  • Loading branch information
zhangbo9674 committed Feb 9, 2022
2 parents 9e5a70b + 772be4f commit a33373e
Show file tree
Hide file tree
Showing 18 changed files with 391 additions and 57 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/framework/var_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
#include "xpu/bkcl.h"
#endif

#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif

namespace pten {
class DenseTensor;
class SelectedRows;
Expand Down Expand Up @@ -181,6 +185,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, platform::BKCLCommunicator,
#endif
#if defined(PADDLE_WITH_CNCL)
cnclCliqueId,
#endif
int, float, Vocab>;
template <typename T>
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/operators/collective/c_comm_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ limitations under the License. */
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif
#include <string>

#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif

Expand Down Expand Up @@ -56,18 +59,23 @@ class CCommInitOp : public framework::OperatorBase {
using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext;
#elif defined(PADDLE_WITH_CNCL)
using UniqueId = cnclCliqueId;
using Place = platform::MLUPlace;
using CommContext = platform::CNCLCommContext;
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU."));
"PaddlePaddle should be compiled with GPU or XPU or MLU."));
#endif

PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place) || platform::is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));
platform::is_gpu_place(place) || platform::is_xpu_place(place) ||
platform::is_mlu_place(place),
true, platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu or mlu place only."));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
Expand Down
55 changes: 55 additions & 0 deletions paddle/fluid/operators/gaussian_random_op_mlu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <random>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T>
class MLUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
tensor->mutable_data<T>(context.GetPlace());

Tensor cpu_tensor(tensor->type());
cpu_tensor.Resize(tensor->dims());
T* cpu_data = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::normal_distribution<T> dist(mean, std);

int64_t size = tensor->numel();

unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
cpu_data[i] = dist(*engine);
}
auto& dev_ctx =
context.template device_context<paddle::platform::MLUDeviceContext>();
framework::TensorCopy(cpu_tensor, context.GetPlace(), dev_ctx, tensor);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(gaussian_random, ops::MLUGaussianRandomKernel<float>);
10 changes: 2 additions & 8 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@

template <typename T>
using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;

using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;
ops::ReduceCudaGradKernel<T, kps::DivideFunctor>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<paddle::platform::float16>,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
59 changes: 55 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */

#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/pten/kernels/gpu/reduce.h"
#include "paddle/pten/kernels/gpu/reduce_grad.h"
#endif

namespace paddle {
Expand Down Expand Up @@ -620,11 +621,12 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto input_data_type =
(in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
(out_dtype >= 0)
? static_cast<framework::proto::VarType::Type>(out_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
Expand Down Expand Up @@ -730,6 +732,55 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
dev_ctx, *input, reduce_all, dims_int64, false, pt_out_dtype, output);
}
};

template <typename T, template <typename, typename> class TransformOp>
class ReduceCudaGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
auto* in_x = context.Input<Tensor>("X");
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto out_dtype = context.Attr<int>("in_dtype");
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims = GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
framework::Tensor new_d_out(d_out->type());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(paddle::framework::make_ddim(update_dims));
auto& dev_ctx = context.cuda_device_context();
if (out_dtype > 0) {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
} else {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(new_d_out);
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
if (out_dtype <= 0) {
pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx, pt_d_out.get(), pt_d_x.get(), pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
}
};
#endif

} // namespace operators
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int in_dtype = ctx.Attr<int>("in_dtype");
int in_dtype = ctx.Attr<int>("out_dtype");
if (in_dtype >= 0) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(in_dtype),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_sum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ReduceSumGradKernel : public framework::OpKernel<T> {
auto dims = context.Attr<std::vector<int>>("dim");
if (context.GetPlace().GetType() == platform::CPUPlace().GetType() &&
dims.size() == 1) {
int in_dtype = context.Attr<int>("in_dtype");
int in_dtype = context.Attr<int>("out_dtype");

if (in_dtype >= 0) {
Tensor tmp_tensor;
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

template <typename T>
using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;
ops::ReduceCudaGradKernel<T, kps::IdentityFunctor>;

REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
Expand Down
3 changes: 2 additions & 1 deletion paddle/infrt/naive/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc
)

cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
30 changes: 14 additions & 16 deletions paddle/infrt/naive/infershaped/elementwise_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"

// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
Expand All @@ -32,39 +33,36 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape();
}

static void ElementwiseAdd(const tensor::DenseHostTensor& a,
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}

// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};

void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&infershape_kernel_frame_builder);
BuildInferShapeCache(num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}

INFRT_KERNEL(ElementwiseAdd)(frame);
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
}
};

const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};

} // namespace naive
} // namespace infrt
14 changes: 14 additions & 0 deletions paddle/infrt/naive/infershaped/infershape_launchers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {

namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}

TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}

TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
Expand All @@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>());

host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
Expand Down
17 changes: 7 additions & 10 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
Expand All @@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
}

void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
}
}

bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;

bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
}
return changed;
}
Expand Down
Loading

0 comments on commit a33373e

Please sign in to comment.