Skip to content

Commit

Permalink
add test_cumprod_op
Browse files Browse the repository at this point in the history
  • Loading branch information
liyagit21 committed Aug 27, 2021
1 parent bb01b12 commit c96cf6d
Show file tree
Hide file tree
Showing 10 changed files with 1,040 additions and 30 deletions.
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ repos:
entry: python ./tools/codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
exclude: |
(?x)^(
paddle/utils/.*
)$
6 changes: 6 additions & 0 deletions paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ struct complex;
namespace paddle {
namespace framework {

template <typename T>
struct IsComplex : public std::false_type {};

template <typename T>
struct IsComplex<platform::complex<T>> : public std::true_type {};

template <typename T>
struct DataTypeTrait {};

Expand Down
96 changes: 96 additions & 0 deletions paddle/fluid/operators/cumprod_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/cumprod_op.h"

namespace paddle {
namespace operators {

class CumprodOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cumprod");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "cumprod");

ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
};

class CumprodOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cumprod op.");
AddOutput("Out", "(Tensor), The output tensor of cumprod op.");
AddAttr<int>(
"dim",
"(int), The dim along which the input tensors will be cumproded");
AddComment(
R"DOC(Cumprod operator. Return the cumprod results of the input elements along the dim.
For example, if input X is a vector of size N, the output will also be a vector of size N,
with elements yi = x1 * x2 * x3 *...* xi (0<=i<=N))DOC");
}
};

template <typename T>
class CumprodGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cumprod_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};

class CumprodGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
ops::CumprodGradOpMaker<paddle::framework::OpDesc>,
ops::CumprodGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);

REGISTER_OP_CPU_KERNEL(
cumprod, ops::CumprodOpCPUKernel<float>, ops::CumprodOpCPUKernel<double>,
ops::CumprodOpCPUKernel<int>, ops::CumprodOpCPUKernel<int64_t>,
ops::CumprodOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodOpCPUKernel<paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
cumprod_grad, ops::CumprodGradOpCPUKernel<float>,
ops::CumprodGradOpCPUKernel<double>, ops::CumprodGradOpCPUKernel<int>,
ops::CumprodGradOpCPUKernel<int64_t>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<float>>,
ops::CumprodGradOpCPUKernel<paddle::platform::complex<double>>);
Loading

0 comments on commit c96cf6d

Please sign in to comment.