Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set more attrs in ReplaceScaleLossGradOp #44576

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/scale_loss_grad_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {

~ScaleLossGradOpHandle() final;

proto::VarType::Type DType() const { return out_dtype_; }

std::string Name() const override;

platform::Place GetPlace() const { return place_; }
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ cc_library(
cc_library(
graph_helper
SRCS graph_helper.cc
DEPS graph)
DEPS graph scale_loss_grad_op_handle)
cc_library(
pass
SRCS pass.cc
Expand Down
18 changes: 16 additions & 2 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <stack>

#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/op_proto_maker.h"

DECLARE_bool(convert_all_blocks);
Expand Down Expand Up @@ -448,11 +449,23 @@ std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {

static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetType("fill_constant");
desc->SetAttr("shape", std::vector<int64_t>({1}));
desc->SetAttr("value", 1.0f);

if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
desc->SetAttr(
"dtype",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
}

desc->SetAttr("force_cpu", false);
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
desc->SetAttr("shape", std::vector<int64_t>({1}));
// TODO(Ruibiao) : Set OpDeviceAttrName when needed

std::vector<std::string> output_names;
for (auto out : node.outputs) {
output_names.emplace_back(out->Name());
Expand Down Expand Up @@ -482,6 +495,7 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,

// create fill_constant op
if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE)

# Compiling shared library will cost some time, but running process is very fast.
set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_custom_relu_op_setup
PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=1)
set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
Expand Down