Skip to content

Commit

Permalink
simplify ir::Node,ir::Graph in popart_canonicalization_pass (PaddlePa…
Browse files Browse the repository at this point in the history
…ddle#105)

* simplify ir::Node,ir::graph
  • Loading branch information
gglin001 committed Aug 30, 2021
1 parent 8ceadf0 commit aea7440
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,37 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *activation_op_handler(ir::Graph *graph, ir::Node *node,
const std::string &type) {
Node *activation_op_handler(Graph *graph, Node *node, const std::string &type) {
auto new_node =
CreateBaseOp(graph, node, type, {GetInputNode("X", node)}, node->outputs);
return new_node;
}

ir::Node *relu_handler(ir::Graph *graph, ir::Node *node) {
Node *relu_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_relu");
}

ir::Node *tanh_handler(ir::Graph *graph, ir::Node *node) {
Node *tanh_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_tanh");
}

ir::Node *log_handler(ir::Graph *graph, ir::Node *node) {
Node *log_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_log");
}

ir::Node *sigmoid_handler(ir::Graph *graph, ir::Node *node) {
Node *sigmoid_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_sigmoid");
}

ir::Node *sqrt_handler(ir::Graph *graph, ir::Node *node) {
Node *sqrt_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_sqrt");
}

ir::Node *gelu_handler(ir::Graph *graph, ir::Node *node) {
Node *gelu_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_gelu");
}

ir::Node *log_softmax_handler(ir::Graph *graph, ir::Node *node) {
Node *log_softmax_handler(Graph *graph, Node *node) {
auto axis_ = BOOST_GET_CONST(int, node->Op()->GetAttr("axis"));
return CreateBaseOp(graph, node, "popart_logsoftmax", node->inputs,
node->outputs, {{"axis", int64_t{axis_}}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ SymbolHandler GetHandler(const std::string &kind) {
return {};
}

void ConnectNodes(ir::Node *first_node, ir::Node *next_node) {
void ConnectNodes(Node *first_node, Node *next_node) {
first_node->outputs.push_back(next_node);
next_node->inputs.push_back(first_node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,22 @@ namespace paddle {
namespace framework {
namespace ipu {

// TODO(alleng) remove ir::
using ir::Graph;
using ir::Node;

#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::framework::ipu::RegisterHandler(#name, func)

using SymbolHandler = std::function<ir::Node *(ir::Graph *, ir::Node *)>;
using SymbolHandler = std::function<Node *(Graph *, Node *)>;

std::unordered_map<std::string, SymbolHandler> &SymbolHandlers();

bool RegisterHandler(const std::string &, const SymbolHandler &);

SymbolHandler GetHandler(const std::string &);

void ConnectNodes(ir::Node *first_node, ir::Node *next_node);
void ConnectNodes(Node *first_node, Node *next_node);
void DisConnectNodes(Node *first_node, Node *next_node);
void ClearNode(Node *node);
void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *elementwise_op_handler(ir::Graph *graph, ir::Node *node,
const std::string &type) {
Node *elementwise_op_handler(Graph *graph, Node *node,
const std::string &type) {
auto *op = node->Op();
auto x_shape = op->Block()->FindVar(op->Input("X").front())->GetShape();
int64_t x_rank = x_shape.size();
Expand Down Expand Up @@ -59,35 +59,35 @@ ir::Node *elementwise_op_handler(ir::Graph *graph, ir::Node *node,
}
}

ir::Node *elementwise_add_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_add_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_add");
}

ir::Node *elementwise_sub_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_sub_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_sub");
}

ir::Node *elementwise_div_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_div_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_div");
}

ir::Node *elementwise_mul_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_mul_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_mul");
}

ir::Node *elementwise_min_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_min_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_min");
}

ir::Node *elementwise_max_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_max_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_max");
}

ir::Node *elementwise_pow_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_pow_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_pow");
}

ir::Node *elementwise_mod_handler(ir::Graph *graph, ir::Node *node) {
Node *elementwise_mod_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_mod");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *equal_handler(ir::Graph *graph, ir::Node *node) {
Node *equal_handler(Graph *graph, Node *node) {
auto new_node = CreateBaseOp(
graph, node, "popart_equal",
{GetInputNode("X", node), GetInputNode("Y", node)}, node->outputs);
Expand Down
20 changes: 10 additions & 10 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
Node *reduce_mean_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto attrs = AttributeMap{};
auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all"));
Expand All @@ -37,15 +37,15 @@ ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
node->outputs, attrs);
}

ir::Node *mean_handler(ir::Graph *graph, ir::Node *node) {
Node *mean_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_reducemean",
{GetInputNode("X", node)}, {GetOutputNode("Out", node)},
{
{"keepdims", int64_t{0}},
});
}

ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
Node *pow_handler(Graph *graph, Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto *op = node->Op();
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
Expand All @@ -57,7 +57,7 @@ ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
node->outputs);
}

ir::Node *mul_handler(ir::Graph *graph, ir::Node *node) {
Node *mul_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto x_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("x_num_col_dims"));
auto y_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("y_num_col_dims"));
Expand Down Expand Up @@ -92,7 +92,7 @@ ir::Node *mul_handler(ir::Graph *graph, ir::Node *node) {
node->outputs, {});
}

ir::Node *matmul_handler(ir::Graph *graph, ir::Node *node) {
Node *matmul_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y"));
Expand Down Expand Up @@ -141,11 +141,11 @@ ir::Node *matmul_handler(ir::Graph *graph, ir::Node *node) {
}
}

ir::Node *sum_handler(ir::Graph *graph, ir::Node *node) {
Node *sum_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_sum", node->inputs, node->outputs);
}

ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
Node *softmax_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis_ = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto axis = int64_t{axis_};
Expand All @@ -155,7 +155,7 @@ ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
});
}

ir::Node *scale_handler(ir::Graph *graph, ir::Node *node) {
Node *scale_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto scale_ = BOOST_GET_CONST(float, op->GetAttr("scale"));
auto bias_ = BOOST_GET_CONST(float, op->GetAttr("bias"));
Expand All @@ -182,7 +182,7 @@ ir::Node *scale_handler(ir::Graph *graph, ir::Node *node) {
auto new_node_cast = CreateCast(graph, node, {GetInputNode("X", node)}, {},
static_cast<int>(proto::VarType::FP32));

ir::Node *result = nullptr;
Node *result = nullptr;
if (bias_after_scale_) {
auto new_node_mul = CreateBaseOp(
graph, node, "popart_mul",
Expand All @@ -205,7 +205,7 @@ ir::Node *scale_handler(ir::Graph *graph, ir::Node *node) {
}
}

ir::Node *cross_entropy2_handler(ir::Graph *graph, ir::Node *node) {
Node *cross_entropy2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto new_cast = CreateCast(graph, node, {GetInputNode("Label", node)}, {},
Expand Down
36 changes: 18 additions & 18 deletions paddle/fluid/framework/ipu/popart_canonicalization/nn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace framework {
namespace ipu {
namespace {

ir::Node *conv2d_handler(ir::Graph *graph, ir::Node *node) {
Node *conv2d_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto dilations_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dilations"));
auto dilations = std::vector<int64_t>{dilations_.begin(), dilations_.end()};
Expand Down Expand Up @@ -52,15 +52,15 @@ ir::Node *conv2d_handler(ir::Graph *graph, ir::Node *node) {
}
}

ir::Node *batch_norm_handler(ir::Graph *graph, ir::Node *node) {
Node *batch_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
std::vector<ir::Node *> inputs;
std::vector<Node *> inputs;
inputs.push_back(GetInputNode("X", node));
inputs.push_back(GetInputNode("Scale", node));
inputs.push_back(GetInputNode("Bias", node));
inputs.push_back(GetInputNode("Mean", node));
inputs.push_back(GetInputNode("Variance", node));
std::vector<ir::Node *> outputs;
std::vector<Node *> outputs;
outputs.push_back(GetOutputNode("Y", node));
outputs.push_back(GetOutputNode("MeanOut", node));
outputs.push_back(GetOutputNode("VarianceOut", node));
Expand All @@ -79,7 +79,7 @@ ir::Node *batch_norm_handler(ir::Graph *graph, ir::Node *node) {
});
}

ir::Node *pool2d_handler(ir::Graph *graph, ir::Node *node) {
Node *pool2d_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto global_pooling = BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
if (global_pooling) {
Expand Down Expand Up @@ -130,37 +130,37 @@ ir::Node *pool2d_handler(ir::Graph *graph, ir::Node *node) {
}
}

ir::Node *group_norm_handler(ir::Graph *graph, ir::Node *node) {
Node *group_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
auto groups_ = BOOST_GET_CONST(int, op->GetAttr("groups"));
auto groups = int64_t{groups_};
auto attrs_ = AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups}};

std::vector<ir::Node *> inputs_ = {GetInputNode("X", node),
GetInputNode("Scale", node),
GetInputNode("Bias", node)};
std::vector<ir::Node *> outputs_ = {GetOutputNode("Y", node),
GetOutputNode("Mean", node),
GetOutputNode("Variance", node)};
std::vector<Node *> inputs_ = {GetInputNode("X", node),
GetInputNode("Scale", node),
GetInputNode("Bias", node)};
std::vector<Node *> outputs_ = {GetOutputNode("Y", node),
GetOutputNode("Mean", node),
GetOutputNode("Variance", node)};
return CreateBaseOp(graph, node, "popart_groupnormalization", inputs_,
outputs_, attrs_);
}

ir::Node *instance_norm_handler(ir::Graph *graph, ir::Node *node) {
Node *instance_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
auto attrs_ = AttributeMap{{"epsilon", epsilon_}};

std::vector<ir::Node *> inputs_ = {GetInputNode("X", node),
GetInputNode("Scale", node),
GetInputNode("Bias", node)};
std::vector<ir::Node *> outputs_ = {GetOutputNode("Y", node)};
std::vector<Node *> inputs_ = {GetInputNode("X", node),
GetInputNode("Scale", node),
GetInputNode("Bias", node)};
std::vector<Node *> outputs_ = {GetOutputNode("Y", node)};
return CreateBaseOp(graph, node, "popart_instancenormalization", inputs_,
outputs_, attrs_);
}

ir::Node *layer_norm_handler(ir::Graph *graph, ir::Node *node) {
Node *layer_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto begin_norm_axis_ = BOOST_GET_CONST(int, op->GetAttr("begin_norm_axis"));
auto input_shape_ = op->Block()->FindVar(op->Input("X")[0])->GetShape();
Expand Down
Loading

0 comments on commit aea7440

Please sign in to comment.