Skip to content

Commit

Permalink
Set more attrs in ReplaceScaleLossGradOp (#44576)
Browse files Browse the repository at this point in the history
* Set more attrs in ReplaceScaleLossGradOp

* Fix typos

* Fix CI errors

* Add UT
  • Loading branch information
From00 committed Jul 26, 2022
1 parent 6198ff2 commit ab198b4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/scale_loss_grad_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {

~ScaleLossGradOpHandle() final;

proto::VarType::Type DType() const { return out_dtype_; }

std::string Name() const override;

platform::Place GetPlace() const { return place_; }
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ cc_library(
cc_library(
graph_helper
SRCS graph_helper.cc
DEPS graph)
DEPS graph scale_loss_grad_op_handle)
cc_library(
pass
SRCS pass.cc
Expand Down
18 changes: 16 additions & 2 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <stack>

#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/op_proto_maker.h"

DECLARE_bool(convert_all_blocks);
Expand Down Expand Up @@ -469,11 +470,23 @@ void RemoveControlDepInputAndOuput(OpDesc *op_desc) {

static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetType("fill_constant");
desc->SetAttr("shape", std::vector<int64_t>({1}));
desc->SetAttr("value", 1.0f);

if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
desc->SetAttr(
"dtype",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
}

desc->SetAttr("force_cpu", false);
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
desc->SetAttr("shape", std::vector<int64_t>({1}));
// TODO(Ruibiao) : Set OpDeviceAttrName when needed

std::vector<std::string> output_names;
for (auto out : node.outputs) {
output_names.emplace_back(out->Name());
Expand Down Expand Up @@ -503,6 +516,7 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,

// create fill_constant op
if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE)

# Compiling shared library will cost some time, but running process is very fast.
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_custom_relu_op_setup
PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=1)
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
Expand Down

0 comments on commit ab198b4

Please sign in to comment.