Skip to content

Commit

Permalink
polish xpu enforce msg, test=kunlun (PaddlePaddle#45749)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored and Caozhou1995 committed Sep 8, 2022
1 parent 8752869 commit 2190212
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 33 deletions.
24 changes: 4 additions & 20 deletions paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,8 @@ void BatchNormGradKernel(const Context &dev_ctx,
C,
epsilon_data,
global_inv_std_data);
PADDLE_ENFORCE_EQ(r1,
XPU_SUCCESS,
phi::errors::External("XPU API(batch_norm_grad "
"CalculateInvVar function) "
"return wrong value[%d %s]",
r1,
XPUAPIErrorMsg[r1]));
PADDLE_ENFORCE_XDNN_SUCCESS(r1,
"batch_norm_grad CalculateInvVar function");
}

// Here is a trick, x is a const input,
Expand All @@ -209,13 +204,7 @@ void BatchNormGradKernel(const Context &dev_ctx,
C,
H * W,
x.data<T>());
PADDLE_ENFORCE_EQ(r2,
XPU_SUCCESS,
phi::errors::External("XPU API(batch_norm_grad "
"CalculateInvBNY function) "
"return wrong value[%d %s]",
r2,
XPUAPIErrorMsg[r2]));
PADDLE_ENFORCE_XDNN_SUCCESS(r2, "batch_norm_grad CalculateInvBNY function");
}

int r3;
Expand Down Expand Up @@ -263,12 +252,7 @@ void BatchNormGradKernel(const Context &dev_ctx,
bias_grad_data,
is_nchw);
}
PADDLE_ENFORCE_EQ(r3,
XPU_SUCCESS,
phi::errors::External("XPU API(batch_norm_grad) return "
"wrong value[%d %s]",
r3,
XPUAPIErrorMsg[r3]));
PADDLE_ENFORCE_XDNN_SUCCESS(r3, "batch_norm_grad");
}

} // namespace phi
Expand Down
15 changes: 2 additions & 13 deletions paddle/phi/kernels/xpu/batch_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,7 @@ void BatchNormKernel(const Context& dev_ctx,
mean_out_data,
variance_out_data,
is_nchw);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
phi::errors::External(
"The batch_norm XPU API return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm");
} else {
const auto* mean_data = mean.data<float>();
const auto* variance_data = variance.data<float>();
Expand All @@ -124,13 +119,7 @@ void BatchNormKernel(const Context& dev_ctx,
mean_data,
variance_data,
is_nchw);
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
phi::errors::External(
"The batch_norm_infer XPU API return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer");
}
}

Expand Down

0 comments on commit 2190212

Please sign in to comment.