Skip to content

Commit

Permalink
[Phi] migrate sync_batch_norm to phi (PaddlePaddle#44369)
Browse files Browse the repository at this point in the history
  • Loading branch information
affectionlu authored and Aurelius84 committed Jul 29, 2022
1 parent b50ef42 commit 4c88215
Show file tree
Hide file tree
Showing 12 changed files with 1,027 additions and 823 deletions.
130 changes: 83 additions & 47 deletions paddle/fluid/operators/inplace_abn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/inplace_abn_op.h"
#include <iostream>
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h"
#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class InplaceABNKernel
: public paddle::operators::SyncBatchNormKernel<DeviceContext, T> {
class InplaceABNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* y = ctx.Output<Tensor>("Y");
Expand All @@ -36,29 +38,49 @@ class InplaceABNKernel
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();

auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* mean = ctx.Input<Tensor>("Mean");
auto* variance = ctx.Input<Tensor>("Variance");

auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");

auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
auto* reserve_space = ctx.Output<Tensor>("ReserveSpace");

if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormKernel<DeviceContext, T>::Compute(ctx);
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::SyncBatchNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
*scale,
*bias,
*mean,
*variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
y,
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space);
} else {
// BatchNormKernel<DeviceContext, T>::Compute(ctx);
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* mean = ctx.Input<Tensor>("Mean");
auto* variance = ctx.Input<Tensor>("Variance");

auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");

auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
auto* reserve_space = ctx.Output<Tensor>("ReserveSpace");

auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::BatchNormKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
Expand Down Expand Up @@ -92,8 +114,7 @@ class InplaceABNKernel
// Deriving the Gradient for the Backward Pass of Batch Normalization
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class InplaceABNGradKernel
: public paddle::operators::SyncBatchNormGradKernel<DeviceContext, T> {
class InplaceABNGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* y = ctx.Input<Tensor>("Y");
Expand All @@ -115,29 +136,44 @@ class InplaceABNGradKernel
InplaceABNActivation<DeviceContext, T> functor;
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);

auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* saved_mean = ctx.Input<Tensor>("SavedMean");
auto* saved_variance = ctx.Input<Tensor>("SavedVariance");

auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");

auto* scale_grad = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* bias_grad = ctx.Output<Tensor>(framework::GradVarName("Bias"));

auto* reserve_space = ctx.Input<Tensor>("ReserveSpace");
auto* mean = ctx.Input<Tensor>("ReserveSpace");
auto* variance = ctx.Input<Tensor>("ReserveSpace");

if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormGradKernel<DeviceContext, T>::Compute(ctx);
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::SyncBatchNormGradFunctor<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
nullptr,
y,
*scale,
*bias,
*saved_mean,
*saved_variance,
*d_y,
epsilon,
data_layout,
d_x,
scale_grad,
bias_grad);
} else {
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* saved_mean = ctx.Input<Tensor>("SavedMean");
auto* saved_variance = ctx.Input<Tensor>("SavedVariance");

auto momentum = ctx.Attr<float>("momentum");
auto epsilon = ctx.Attr<float>("epsilon");
auto data_layout = ctx.Attr<std::string>("data_layout");
auto is_test = ctx.Attr<bool>("is_test");
auto use_global_stats = ctx.Attr<bool>("use_global_stats");
auto trainable_statistics = ctx.Attr<bool>("trainable_statistics");
auto fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");

auto* scale_grad = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* bias_grad = ctx.Output<Tensor>(framework::GradVarName("Bias"));

auto* reserve_space = ctx.Input<Tensor>("ReserveSpace");
auto* mean = ctx.Input<Tensor>("ReserveSpace");
auto* variance = ctx.Input<Tensor>("ReserveSpace");

paddle::optional<Tensor> space_opt;
paddle::optional<Tensor> mean_opt;
paddle::optional<Tensor> variance_opt;
Expand Down
137 changes: 0 additions & 137 deletions paddle/fluid/operators/sync_batch_norm_op.cu

This file was deleted.

Loading

0 comments on commit 4c88215

Please sign in to comment.