Skip to content

Commit

Permalink
fix some format problem
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor committed Mar 4, 2022
1 parent 7f3613b commit 1429e04
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 63 deletions.
13 changes: 6 additions & 7 deletions paddle/phi/kernels/cpu/graph_send_recv_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <algorithm>
#include <vector>

#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
Expand Down Expand Up @@ -66,12 +65,12 @@ struct GraphSendRecvMaxFunctor {
};

template <typename T, typename IndexT, typename Functor>
void elementwise_inner_operation(const DenseTensor& src,
DenseTensor* dst,
const IndexT& src_index,
const IndexT& dst_index,
const bool& first_flag,
Functor functor) {
void ElementwiseInnerOperation(const DenseTensor& src,
DenseTensor* dst,
const IndexT& src_index,
const IndexT& dst_index,
const bool& first_flag,
Functor functor) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dst_slice = dst->Slice(dst_index, dst_index + 1);

Expand Down
62 changes: 30 additions & 32 deletions paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@
namespace phi {

template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop_grad(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) {
void GraphSendRecvCpuGradLoop(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
Expand Down Expand Up @@ -96,33 +96,31 @@ void GraphSendRecvGradOpKernelLaunchHelper(
const IndexT* d_index = dst_index.data<IndexT>();

if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type);
} else if (pool_type == "MEAN") {
const int* s_count = dst_count->data<int>();
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
nullptr,
x,
out);
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
nullptr,
x,
out);
}
}

Expand Down
47 changes: 23 additions & 24 deletions paddle/phi/kernels/cpu/graph_send_recv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
namespace phi {

template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
int* dst_count = nullptr) {
void GraphSendRecvCpuLoop(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
int* dst_count = nullptr) {
Functor functor;
if (pool_type == "SUM") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
for (int i = 0; i < index_size; ++i) {
Expand All @@ -66,11 +66,11 @@ void graph_send_recv_cpu_for_loop(const int& input_size,
const IndexT& dst_idx = d_index[i];
bool in_set = existed_dst.find(dst_idx) != existed_dst.end();
if (!in_set) {
elementwise_inner_operation<T, IndexT, Functor>(
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, true, functor);
existed_dst.emplace(dst_idx);
} else {
elementwise_inner_operation<T, IndexT, Functor>(
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
}
Expand Down Expand Up @@ -100,27 +100,26 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MIN") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMinFunctor<T>>(
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MAX") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MEAN") {
ctx.template Alloc<int>(dst_count);
int* p_dst_count = dst_count->data<int>();
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0],
index_size,
s_index,
d_index,
x,
out,
pool_type,
p_dst_count);
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
s_index,
d_index,
x,
out,
pool_type,
p_dst_count);
}
}

Expand Down

0 comments on commit 1429e04

Please sign in to comment.