Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add piano graph executor #36

Merged
merged 17 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/compiler/paddle2piano/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ cc_test(piano_op_registry_test SRCS piano_op_registry_test.cc DEPS piano_op_regi

cc_library(piano_op_kernel_context SRCS piano_op_kernel_context.cc DEPS piano_op_registry proto_desc piano_symbolization_builder)
cc_test(piano_op_kernel_context_test SRCS piano_op_kernel_context_test.cc DEPS piano_op_kernel_context op_registry)

cc_library(piano_graph_executor SRCS piano_graph_executor.cc DEPS piano_op_kernel_context piano_symbolization_meat_op)
cc_test(piano_graph_executor_test SRCS piano_graph_executor_test.cc DEPS piano_graph_executor node)
201 changes: 201 additions & 0 deletions paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/paddle2piano/piano_graph_executor.h"

#include <queue>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.h"
#include "paddle/fluid/compiler/paddle2piano/vartype2notetype.h"
#include "paddle/fluid/compiler/piano/symbolization/meta_op.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {

using framework::ir::Node;
using framework::proto::VarType;

namespace {
const std::unordered_map<framework::proto::VarType::Type,
note::ElementTypeProto>&
GetVarType2NoteTypeMap() {
static std::unordered_map<framework::proto::VarType::Type,
note::ElementTypeProto>
vartype2notetype = {{framework::proto::VarType::BOOL, note::B1},
{framework::proto::VarType::INT8, note::S8},
{framework::proto::VarType::INT16, note::S16},
{framework::proto::VarType::INT32, note::S32},
{framework::proto::VarType::INT64, note::S64},
{framework::proto::VarType::FP16, note::F16},
{framework::proto::VarType::FP32, note::F32},
{framework::proto::VarType::FP64, note::F64},
{framework::proto::VarType::UINT8, note::U8},
{framework::proto::VarType::SIZE_T, note::U64}};
return vartype2notetype;
}
} // namespace

note::ElementTypeProto VarType2NoteType(framework::proto::VarType::Type type) {
const auto& vartype2notetype = GetVarType2NoteTypeMap();
PADDLE_ENFORCE_NE(vartype2notetype.find(type), vartype2notetype.end(),
"Unsupported value data type (%s)",
framework::DataTypeToString(type).c_str());
return vartype2notetype.at(type);
}

VarType::Type PianoGraphExecutor::GetVarDataType(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉VarType2NoteTypeGetVarDataType的定义可以拿出来放在vartype_utils.cc中,并将vartype2notetype.h更改为vartype_utils.h,让后将VarType2NoteTypeGetVarDataType的声明放在vartype_utils.h。GetVarDataType放在类中作为static函数,但是又感觉与PianoGraphExecutor无关。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const framework::VarDesc* var) {
// non-pod type list
static std::unordered_set<VarType::Type> non_pod_types = {
VarType::LOD_TENSOR, VarType::SELECTED_ROWS, VarType::LOD_TENSOR_ARRAY};

const auto& var_type = var->GetType();
if (non_pod_types.count(var_type) != 0) {
// if the value type is non-pod type
return var->GetDataType();
} else if (GetVarType2NoteTypeMap().count(var_type) != 0) {
// if value is supported type
return var_type;
}
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported value data type (%s)",
framework::DataTypeToString(var_type).c_str()));
return framework::proto::VarType::RAW;
}

void PianoGraphExecutor::CreateInputOperand(
PianoScope* scope, symbolization::NoteBuilder* builder) const {
for (int64_t id = 0; id < cluster_inputs_.size(); ++id) {
auto* node = cluster_inputs_.at(id);
PADDLE_ENFORCE_EQ(node->IsVar(), true,
platform::errors::InvalidArgument(
"Cluster Sub-Graph Input should be var"));

const auto& var_name = node->Name();

// create operand shape
const auto& var_shape = node->Var()->GetShape();
const auto& var_type = GetVarDataType(node->Var());

// convert framework vartype to piano note type
note::ElementTypeProto element_type = VarType2NoteType(var_type);
Shape operand_shape(element_type, var_shape);

// create Operand
symbolization::Operand op =
symbolization::Parameter(builder, id, operand_shape, var_name);

// store into PianoScope
scope->SetOperand(var_name, op);
}
}

void PianoGraphExecutor::TopologicSortCluster(
GraphNodeVec* cluster_sorted) const {
std::unordered_set<Node*> cluster_set(cluster_.cbegin(), cluster_.cend());

std::unordered_map<Node*, size_t> indegree;
std::unordered_map<Node*, std::unordered_map<Node*, size_t>> adj_list;
std::queue<Node*> topo_queue;

// record all op's input op and output op
for (auto* n : cluster_) {
PADDLE_ENFORCE_EQ(n->IsOp(), true,
platform::errors::PreconditionNotMet(
"Cluster's node all should be op node"));
PADDLE_ENFORCE_EQ(PianoOpRegistry::IsPianoOp(n->Name()), true,
platform::errors::PreconditionNotMet(
"Cluster's op all should be piano op"));
// the op's input is var
for (auto* in_var : n->inputs) {
// the var's input is op
for (auto* in_op : in_var->inputs) {
if (cluster_set.find(in_op) != cluster_set.end()) {
++indegree[n];
++adj_list[in_op][n];
}
}
}
}

// find topology entrance
for (auto* n : cluster_) {
if (indegree[n] == 0) {
topo_queue.push(n);
}
}

// topological sorting
while (!topo_queue.empty()) {
auto* cur_op = topo_queue.front();
topo_queue.pop();

cluster_sorted->emplace_back(cur_op);
for (const auto& adj_pair : adj_list[cur_op]) {
// decrease output op's in-degree
indegree.at(adj_pair.first) -= adj_pair.second;

// if empty, push into queue
if (indegree.at(adj_pair.first) == 0) {
topo_queue.push(adj_pair.first);
}
}
}

PADDLE_ENFORCE_EQ(cluster_sorted->size(), cluster_.size(),
platform::errors::PreconditionNotMet(
"Cluster Sub-Graph shouldn't contain cycle."));
}

void PianoGraphExecutor::RunCompile(const GraphNodeVec& cluster,
PianoScope* scope,
symbolization::NoteBuilder* builder) const {
for (auto* n : cluster) {
const auto& op_name = n->Name();
const auto* op_desc = n->Op();

const auto& op_kernel_map = PianoOpRegistry::AllPianoOpKernels(op_name);
// TODO(jiangcheng05): how to distinguish library's kernel, like cudnn?
op_kernel_map.at("PLAIN")(PianoOpKernelContext(op_desc, scope, builder));
}
}

note::ModuleProto PianoGraphExecutor::operator()() const {
// Step1: create unique NoteBuilder
std::string builder_name = "NoteBuilderOfGraph_";
builder_name.append(std::to_string(graph_id_));

symbolization::NoteBuilder builder(builder_name);

// Step2: create graph's input operand
PianoScope scope;
CreateInputOperand(&scope, &builder);
Copy link
Owner

@wzzju wzzju Sep 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1)将CreateInputOperand声明为PianoScope CreateInputOperand(symbolization::NoteBuilder* builder),scope作为返回值。
2)TopologicSortCluster改名为SortInternalCluster并将签名换为GraphNodeVec SortInternalCluster()
以上两点建议是不是更好呢?

详见:https://zh-google-styleguide.readthedocs.io/en/latest/google-cpp-styleguide/classes/
29b2f3b069f080b8dd34b41c59d114c8

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PianoScopeDISABLE_COPY_AND_ASSIGN的,所以不能直接返回,改为返回std::unique_ptr<PianoScope>。另外SortInternalCluster改为返回GraphNodeVec并在调用处通过auto&&避免拷贝。


// Step3: topo sort graph
GraphNodeVec cluster_sorted;
TopologicSortCluster(&cluster_sorted);

// Step4: get PianoOpKernel and run compile
RunCompile(cluster_sorted, &scope, &builder);

// Step5: build and return module
return builder.Build();
}

} // namespace piano
} // namespace paddle
98 changes: 98 additions & 0 deletions paddle/fluid/compiler/paddle2piano/piano_graph_executor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <vector>

#include "paddle/fluid/compiler/paddle2piano/piano_scope.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/framework/ir/node.h"

namespace paddle {
namespace piano {

// An executor accept sub-graph which is generated by PianoCompilePass,
// run each op's PianoOpKernel, finally return the graph's ModuleProto.
//
// Parameter:
// 1. graph_id: the unique graph id, used for generating unique notebuilder name
// 2. cluster: a vector which contains all graph op, non-topological-sorting.
// 3. cluster_inputs: a vector which contains all graph's input var, the var's
// input are outside op, the output are inside op
// 4. cluster_outputs: a vector which contains all graph's output var, the var's
// input are inside op, the output are outside op
// 5. cluster_internals: a vector which contains all graph's internal var, the
// var's input and output are inside op
//
// Example:
// -------------------------> op3 -> var4 ->
// / /
// -> var1 -> op1 -> var2 -> op2 -> var3
//
// cluster: [op1, op2, op3]
// cluster_inputs: [var1]
// cluster_outputs: [var4]
// cluster_internals: [var2, var3]
//
// Describe:
// The executor consisted by the following step:
// 1. create a NoteBuilder, it's name is unique for each graph
// 2. create PianoScope, initially, scope only consist graph's input var and its
// operand
// 3. topological sorting graph
// 4. create PianoOpKernelContext and run each op's PianoOpKernel
// 5. run NoteBuilder's Build function to generate graph's ModuleProto
class PianoGraphExecutor {
public:
using GraphNodeVec = std::vector<framework::ir::Node*>;

PianoGraphExecutor(int64_t graph_id, const GraphNodeVec& cluster,
const GraphNodeVec& cluster_inputs,
const GraphNodeVec& cluster_outputs,
const GraphNodeVec& cluster_internals)
: graph_id_(graph_id),
cluster_(cluster),
cluster_inputs_(cluster_inputs),
cluster_outputs_(cluster_outputs),
cluster_internals_(cluster_internals) {}

note::ModuleProto operator()() const;

private:
const int64_t graph_id_;
const GraphNodeVec& cluster_;
const GraphNodeVec& cluster_inputs_;
const GraphNodeVec& cluster_outputs_;
const GraphNodeVec& cluster_internals_;

// get var's data type
static framework::proto::VarType::Type GetVarDataType(
const framework::VarDesc* var);

// create graph's input operand from cluster_inputs_
void CreateInputOperand(PianoScope* scope,
symbolization::NoteBuilder* builder) const;

// run PianoOpKernel's Compile
void RunCompile(const GraphNodeVec& cluster, PianoScope* scope,
symbolization::NoteBuilder* builder) const;

// topologic sorting graph node
void TopologicSortCluster(GraphNodeVec* cluster_sorted) const;
};

} // namespace piano
} // namespace paddle
Loading