Skip to content

Commit

Permalink
[Phi] move ops: maxout/take_along_axis/put_along_axis (#39959)
Browse files Browse the repository at this point in the history
* [Phi] move put_along_axis/take_along_axis/maxout

* use phi::Copy
  • Loading branch information
m3ngyang committed Mar 8, 2022
1 parent 00566ea commit 48b4366
Show file tree
Hide file tree
Showing 36 changed files with 1,191 additions and 710 deletions.
151 changes: 76 additions & 75 deletions paddle/fluid/operators/math/maxouting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,106 +14,107 @@ limitations under the License. */

#include "paddle/fluid/operators/math/maxouting.h"

#include "paddle/phi/backends/cpu/cpu_context.h"

namespace paddle {
namespace operators {
namespace math {

// All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
if (axis == 1) {
input_idx =
(new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* output,
const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
input_idx = (new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
output_idx = new_bindex + f * output_channels + c;
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
output_data[output_idx] = ele;
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
} else {
output_idx = new_bindex + f * output_channels + c;
}
output_data[output_idx] = ele;
}
}
}
};
}

template <class T>
class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
template <typename DeviceContext, typename T>
void MaxOutGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad, const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0, output_idx;
bool continue_match = true;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0, output_idx;
bool continue_match = true;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
}
}
}
}
};
}

template class MaxOutGradFunctor<platform::CPUDeviceContext, float>;
template class MaxOutGradFunctor<platform::CPUDeviceContext, double>;
template class MaxOutFunctor<platform::CPUDeviceContext, float>;
template class MaxOutFunctor<platform::CPUDeviceContext, double>;

template class MaxOutGradFunctor<phi::CPUContext, float>;
template class MaxOutGradFunctor<phi::CPUContext, double>;
template class MaxOutFunctor<phi::CPUContext, float>;
template class MaxOutFunctor<phi::CPUContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
107 changes: 55 additions & 52 deletions paddle/fluid/operators/math/maxouting.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -95,68 +96,70 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
/*
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
axis, output_data);
}
};
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* output,
const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
axis, output_data);
}

/*
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups, axis);
}
};
template <typename DeviceContext, typename T>
void MaxOutGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad, const framework::Tensor& output,
const framework::Tensor& output_grad, const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];

const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups, axis);
}

template class MaxOutGradFunctor<platform::CUDADeviceContext, float>;
template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;

template class MaxOutFunctor<platform::CUDADeviceContext, float>;
template class MaxOutFunctor<platform::CUDADeviceContext, double>;

template class MaxOutGradFunctor<phi::GPUContext, float>;
template class MaxOutGradFunctor<phi::GPUContext, double>;

template class MaxOutFunctor<phi::GPUContext, float>;
template class MaxOutFunctor<phi::GPUContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/operators/math/maxouting.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MaxOutFunctor {
const int axis = 1);
};

template <typename DeviceContext, class T>
template <typename DeviceContext, typename T>
class MaxOutGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
Expand Down
13 changes: 3 additions & 10 deletions paddle/fluid/operators/maxout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License. */

#include "paddle/fluid/operators/maxout_op.h"
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {

using framework::Tensor;

class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -130,10 +130,3 @@ REGISTER_OPERATOR(
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(
maxout, ops::MaxOutKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, double>);
24 changes: 0 additions & 24 deletions paddle/fluid/operators/maxout_op.cu.cc

This file was deleted.

Loading

0 comments on commit 48b4366

Please sign in to comment.