Skip to content

Commit

Permalink
Inference support Ascend910 (#34101)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Jul 14, 2021
1 parent a4028b4 commit 4e3fb21
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 10 deletions.
57 changes: 55 additions & 2 deletions paddle/fluid/inference/api/analysis_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ PassStrategy *AnalysisConfig::pass_builder() const {
pass_builder_.reset(new GpuPassStrategy);
} else if (use_xpu_) {
pass_builder_.reset(new XpuPassStrategy);
} else if (use_npu_) {
pass_builder_.reset(new NpuPassStrategy);
} else {
LOG(INFO) << "Create CPU IR passes";
pass_builder_.reset(new CpuPassStrategy);
Expand Down Expand Up @@ -110,6 +112,18 @@ void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
Update();
}

void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
use_npu_ = true;
npu_device_id_ = device_id;
#else
LOG(ERROR) << "Please compile with npu to EnableNpu()";
use_npu_ = false;
#endif

Update();
}

AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#define CP_MEMBER(member__) member__ = other.member__;

Expand All @@ -127,7 +141,6 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_gpu_);
CP_MEMBER(use_cudnn_);
CP_MEMBER(gpu_device_id_);
CP_MEMBER(xpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);

CP_MEMBER(enable_memory_optim_);
Expand Down Expand Up @@ -167,14 +180,20 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(lite_ops_filter_);
CP_MEMBER(lite_zero_copy_);

// XPU related.
CP_MEMBER(use_xpu_);
CP_MEMBER(xpu_device_id_);
CP_MEMBER(xpu_l3_workspace_size_);
CP_MEMBER(xpu_locked_);
CP_MEMBER(xpu_autotune_);
CP_MEMBER(xpu_autotune_file_);
CP_MEMBER(xpu_precision_);
CP_MEMBER(xpu_adaptive_seqlen_);

// NPU related.
CP_MEMBER(use_npu_);
CP_MEMBER(npu_device_id_);

// profile related.
CP_MEMBER(with_profile_);

Expand Down Expand Up @@ -202,6 +221,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
} else if (use_xpu_) {
pass_builder_.reset(new XpuPassStrategy(
*static_cast<XpuPassStrategy *>(other.pass_builder())));
} else if (use_npu_) {
pass_builder_.reset(new NpuPassStrategy(
*static_cast<NpuPassStrategy *>(other.pass_builder())));
} else {
pass_builder_.reset(new CpuPassStrategy(
*static_cast<CpuPassStrategy *>(other.pass_builder())));
Expand Down Expand Up @@ -376,7 +398,9 @@ void AnalysisConfig::Update() {
if (info == serialized_info_cache_) return;

// Transfer pass_builder and copy the existing compatible passes.
if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu()))) {
if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
((use_xpu() ^ pass_builder_->use_xpu())) ||
((use_npu() ^ pass_builder_->use_npu()))) {
if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy);

Expand All @@ -390,6 +414,12 @@ void AnalysisConfig::Update() {
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
pass_builder_.reset(new XpuPassStrategy);
} else if (use_npu()) {
PADDLE_ENFORCE_EQ(
use_gpu(), false,
platform::errors::InvalidArgument(
"Only one choice can be made between GPU and NPU."));
pass_builder_.reset(new NpuPassStrategy);
} else {
pass_builder_.reset(new CpuPassStrategy);
}
Expand All @@ -405,6 +435,13 @@ void AnalysisConfig::Update() {
"Only one choice can be made between CPU and XPU."));
pass_builder_.reset(new XpuPassStrategy(
*static_cast<XpuPassStrategy *>(pass_builder_.get())));
} else if (use_npu()) {
PADDLE_ENFORCE_EQ(
use_gpu(), false,
platform::errors::InvalidArgument(
"Only one choice can be made between GPU and NPU."));
pass_builder_.reset(new NpuPassStrategy(
*static_cast<NpuPassStrategy *>(pass_builder_.get())));
} else {
pass_builder_.reset(new CpuPassStrategy(
*static_cast<CpuPassStrategy *>(pass_builder_.get())));
Expand Down Expand Up @@ -502,6 +539,19 @@ void AnalysisConfig::Update() {
#endif
}

if (use_npu_) {
#ifdef PADDLE_WITH_ASCEND_CL
PADDLE_ENFORCE_EQ(use_gpu_, false,
platform::errors::Unavailable(
"Currently, NPU and GPU cannot be enabled in the "
"same analysis configuration."));
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use an NPU device, but Paddle was not compiled "
"with NPU-runtime."));
#endif
}

if (ir_debug_) {
pass_builder()->TurnOnDebug();
}
Expand Down Expand Up @@ -566,6 +616,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << xpu_precision_;
ss << xpu_adaptive_seqlen_;

ss << use_npu_;
ss << npu_device_id_;

ss << thread_local_stream_;

return ss.str();
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ bool AnalysisPredictor::CreateExecutor() {
"with WITH_XPU."));
#endif // PADDLE_WITH_XPU
}
} else if (config_.use_npu()) {
#ifdef PADDLE_WITH_ASCEND_CL
place_ = paddle::platform::NPUPlace(config_.npu_device_id());
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use NPU forward propagation, but Paddle was not compiled "
"with WITH_ASCEND_CL."));
#endif
} else {
place_ = paddle::platform::CPUPlace();
}
Expand Down Expand Up @@ -847,6 +855,9 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
auto xpu_place = BOOST_GET_CONST(platform::XPUPlace, place_);
res->SetPlace(PaddlePlace::kXPU, xpu_place.GetDeviceId());
}
} else if (platform::is_npu_place(place_)) {
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place_);
res->SetPlace(PaddlePlace::kNPU, npu_place.GetDeviceId());
} else {
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place_);
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
Expand Down Expand Up @@ -879,6 +890,9 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
auto xpu_place = BOOST_GET_CONST(platform::XPUPlace, place_);
res->SetPlace(PaddlePlace::kXPU, xpu_place.GetDeviceId());
}
} else if (platform::is_npu_place(place_)) {
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place_);
res->SetPlace(PaddlePlace::kNPU, npu_place.GetDeviceId());
} else {
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place_);
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
Expand Down
19 changes: 18 additions & 1 deletion paddle/fluid/inference/api/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"

DEFINE_bool(profile, false, "Turn on profiler for fluid");
Expand Down Expand Up @@ -78,6 +79,8 @@ bool NativePaddlePredictor::Init(
place_ = paddle::platform::CUDAPlace(config_.device);
} else if (config_.use_xpu) {
place_ = paddle::platform::XPUPlace(config_.device);
} else if (config_.use_npu) {
place_ = paddle::platform::NPUPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
Expand Down Expand Up @@ -255,7 +258,7 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with CUDA, should not reach here."));
#endif
} else {
} else if (platform::is_xpu_place(place_)) {
#ifdef PADDLE_WITH_XPU
auto dst_xpu_place = BOOST_GET_CONST(platform::XPUPlace, place_);
memory::Copy(dst_xpu_place, static_cast<void *>(input_ptr),
Expand All @@ -264,6 +267,20 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with XPU, should not reach here."));
#endif
} else {
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx =
static_cast<const platform::NPUDeviceContext *>(pool.Get(place_));
auto dst_npu_place = BOOST_GET_CONST(platform::NPUPlace, place_);
memory::Copy(dst_npu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), inputs[i].data.data(),
inputs[i].data.length(), dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with NPU, should not reach here."));
#endif
}

Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/inference/api/api_impl_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ void MainWord2Vec(const paddle::PaddlePlace& place) {
auto predictor = CreatePaddlePredictor<NativeConfig>(config);
config.use_gpu = paddle::gpu_place_used(place);
config.use_xpu = paddle::xpu_place_used(place);
config.use_npu = paddle::npu_place_used(place);

framework::LoDTensor first_word, second_word, third_word, fourth_word;
framework::LoD lod{{0, 1}};
Expand Down Expand Up @@ -119,6 +120,7 @@ void MainImageClassification(const paddle::PaddlePlace& place) {
NativeConfig config = GetConfig();
config.use_gpu = paddle::gpu_place_used(place);
config.use_xpu = paddle::xpu_place_used(place);
config.use_npu = paddle::npu_place_used(place);
config.model_dir =
FLAGS_book_dirname + "/image_classification_resnet.inference.model";

Expand Down Expand Up @@ -163,6 +165,7 @@ void MainThreadsWord2Vec(const paddle::PaddlePlace& place) {
NativeConfig config = GetConfig();
config.use_gpu = paddle::gpu_place_used(place);
config.use_xpu = paddle::xpu_place_used(place);
config.use_npu = paddle::npu_place_used(place);
auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);

// prepare inputs data and reference results
Expand Down Expand Up @@ -227,6 +230,7 @@ void MainThreadsImageClassification(const paddle::PaddlePlace& place) {
NativeConfig config = GetConfig();
config.use_gpu = paddle::gpu_place_used(place);
config.use_xpu = paddle::xpu_place_used(place);
config.use_npu = paddle::npu_place_used(place);
config.model_dir =
FLAGS_book_dirname + "/image_classification_resnet.inference.model";

Expand Down Expand Up @@ -297,6 +301,15 @@ TEST(inference_api_native, image_classification_xpu) {
}
#endif

#ifdef PADDLE_WITH_ASCEND_CL
TEST(inference_api_native, word2vec_npu) {
MainWord2Vec(paddle::PaddlePlace::kNPU);
}
// TEST(inference_api_native, image_classification_npu) {
// MainImageClassification(paddle::PaddlePlace::kNPU);
// }
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(inference_api_native, word2vec_gpu) {
MainWord2Vec(paddle::PaddlePlace::kGPU);
Expand Down
37 changes: 35 additions & 2 deletions paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -150,10 +151,26 @@ void Tensor::CopyFromCpu(const T *data) {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with XPU place because paddle is not compiled "
"with XPU."));
#endif
} else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::NPUPlace npu_place(device_);
auto *t_data = tensor->mutable_data<T>(npu_place);
auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
pool.Get(npu_place));
paddle::memory::Copy(npu_place, static_cast<void *>(t_data),
paddle::platform::CPUPlace(), data, ele_size,
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with NPU place because paddle is not compiled "
"with NPU."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The analysis predictor supports CPU, GPU and XPU now."));
"The analysis predictor supports CPU, GPU, NPU and XPU now."));
}
}

Expand Down Expand Up @@ -212,10 +229,26 @@ void Tensor::CopyToCpu(T *data) {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with XPU place because paddle is not compiled "
"with XPU."));
#endif
} else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto npu_place = BOOST_GET_CONST(paddle::platform::NPUPlace, t_place);
auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
pool.Get(npu_place));
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), npu_place, t_data,
ele_num * sizeof(T), dev_ctx->stream());
aclrtSynchronizeStream(dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with NPU place because paddle is not compiled "
"with NPU."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The analysis predictor supports CPU, GPU and XPU now."));
"The analysis predictor supports CPU, GPU, NPU and XPU now."));
}
}
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
Expand Down
27 changes: 25 additions & 2 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ struct PD_INFER_DECL AnalysisConfig {
const std::string& precision = "int16",
bool adaptive_seqlen = false);
///
/// \brief Turn on NPU.
///
/// \param device_id device_id the NPU card to use (default is 0).
///
void EnableNpu(int device_id = 0);
///
/// \brief A boolean state telling whether the GPU is turned on.
///
/// \return bool Whether the GPU is turned on.
Expand All @@ -215,6 +221,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool use_xpu() const { return use_xpu_; }
///
/// \brief A boolean state telling whether the NPU is turned on.
///
/// \return bool Whether the NPU is turned on.
///
bool use_npu() const { return use_npu_; }
///
/// \brief Get the GPU device id.
///
/// \return int The GPU device id.
Expand All @@ -227,6 +239,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
int xpu_device_id() const { return xpu_device_id_; }
///
/// \brief Get the NPU device id.
///
/// \return int The NPU device id.
///
int npu_device_id() const { return npu_device_id_; }
///
/// \brief Get the initial size in MB of the GPU memory pool.
///
/// \return int The initial size in MB of the GPU memory pool.
Expand Down Expand Up @@ -619,11 +637,15 @@ struct PD_INFER_DECL AnalysisConfig {
// GPU related.
bool use_gpu_{false};
int gpu_device_id_{0};
int xpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool thread_local_stream_{false};

bool use_cudnn_{false};

// NPU related
bool use_npu_{false};
int npu_device_id_{0};

// Padding related
bool use_fc_padding_{true};

Expand Down Expand Up @@ -689,8 +711,9 @@ struct PD_INFER_DECL AnalysisConfig {
Precision lite_precision_mode_;
bool lite_zero_copy_;

bool thread_local_stream_{false};
// XPU related.
bool use_xpu_{false};
int xpu_device_id_{0};
int xpu_l3_workspace_size_;
bool xpu_locked_;
bool xpu_autotune_;
Expand Down
Loading

0 comments on commit 4e3fb21

Please sign in to comment.