-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add piano graph executor #36
Changes from 15 commits
55a43ad
33e7e0a
8464aae
ef9c568
3ed2400
c1d3dd3
7f0f347
9a465a9
472c0aa
e98b242
29b8be2
c1158b2
60e126a
d0ca5fe
bb37a3c
1fe3090
03b2738
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/compiler/paddle2piano/piano_graph_executor.h" | ||
|
||
#include <queue> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
#include "paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.h" | ||
#include "paddle/fluid/compiler/paddle2piano/vartype2notetype.h" | ||
#include "paddle/fluid/compiler/piano/symbolization/meta_op.h" | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
using framework::ir::Node; | ||
using framework::proto::VarType; | ||
|
||
namespace { | ||
const std::unordered_map<framework::proto::VarType::Type, | ||
note::ElementTypeProto>& | ||
GetVarType2NoteTypeMap() { | ||
static std::unordered_map<framework::proto::VarType::Type, | ||
note::ElementTypeProto> | ||
vartype2notetype = {{framework::proto::VarType::BOOL, note::B1}, | ||
{framework::proto::VarType::INT8, note::S8}, | ||
{framework::proto::VarType::INT16, note::S16}, | ||
{framework::proto::VarType::INT32, note::S32}, | ||
{framework::proto::VarType::INT64, note::S64}, | ||
{framework::proto::VarType::FP16, note::F16}, | ||
{framework::proto::VarType::FP32, note::F32}, | ||
{framework::proto::VarType::FP64, note::F64}, | ||
{framework::proto::VarType::UINT8, note::U8}, | ||
{framework::proto::VarType::SIZE_T, note::U64}}; | ||
return vartype2notetype; | ||
} | ||
} // namespace | ||
|
||
note::ElementTypeProto VarType2NoteType(framework::proto::VarType::Type type) { | ||
const auto& vartype2notetype = GetVarType2NoteTypeMap(); | ||
PADDLE_ENFORCE_NE(vartype2notetype.find(type), vartype2notetype.end(), | ||
"Unsupported value data type (%s)", | ||
framework::DataTypeToString(type).c_str()); | ||
return vartype2notetype.at(type); | ||
} | ||
|
||
VarType::Type PianoGraphExecutor::GetVarDataType( | ||
const framework::VarDesc* var) { | ||
// non-pod type list | ||
static std::unordered_set<VarType::Type> non_pod_types = { | ||
VarType::LOD_TENSOR, VarType::SELECTED_ROWS, VarType::LOD_TENSOR_ARRAY}; | ||
|
||
const auto& var_type = var->GetType(); | ||
if (non_pod_types.count(var_type) != 0) { | ||
// if the value type is non-pod type | ||
return var->GetDataType(); | ||
} else if (GetVarType2NoteTypeMap().count(var_type) != 0) { | ||
// if value is supported type | ||
return var_type; | ||
} | ||
PADDLE_THROW(platform::errors::Unimplemented( | ||
"Unsupported value data type (%s)", | ||
framework::DataTypeToString(var_type).c_str())); | ||
return framework::proto::VarType::RAW; | ||
} | ||
|
||
void PianoGraphExecutor::CreateInputOperand( | ||
PianoScope* scope, symbolization::NoteBuilder* builder) const { | ||
for (int64_t id = 0; id < cluster_inputs_.size(); ++id) { | ||
auto* node = cluster_inputs_.at(id); | ||
PADDLE_ENFORCE_EQ(node->IsVar(), true, | ||
platform::errors::InvalidArgument( | ||
"Cluster Sub-Graph Input should be var")); | ||
|
||
const auto& var_name = node->Name(); | ||
|
||
// create operand shape | ||
const auto& var_shape = node->Var()->GetShape(); | ||
const auto& var_type = GetVarDataType(node->Var()); | ||
|
||
// convert framework vartype to piano note type | ||
note::ElementTypeProto element_type = VarType2NoteType(var_type); | ||
Shape operand_shape(element_type, var_shape); | ||
|
||
// create Operand | ||
symbolization::Operand op = | ||
symbolization::Parameter(builder, id, operand_shape, var_name); | ||
|
||
// store into PianoScope | ||
scope->SetOperand(var_name, op); | ||
} | ||
} | ||
|
||
void PianoGraphExecutor::TopologicSortCluster( | ||
GraphNodeVec* cluster_sorted) const { | ||
std::unordered_set<Node*> cluster_set(cluster_.cbegin(), cluster_.cend()); | ||
|
||
std::unordered_map<Node*, size_t> indegree; | ||
std::unordered_map<Node*, std::unordered_map<Node*, size_t>> adj_list; | ||
std::queue<Node*> topo_queue; | ||
|
||
// record all op's input op and output op | ||
for (auto* n : cluster_) { | ||
PADDLE_ENFORCE_EQ(n->IsOp(), true, | ||
platform::errors::PreconditionNotMet( | ||
"Cluster's node all should be op node")); | ||
PADDLE_ENFORCE_EQ(PianoOpRegistry::IsPianoOp(n->Name()), true, | ||
platform::errors::PreconditionNotMet( | ||
"Cluster's op all should be piano op")); | ||
// the op's input is var | ||
for (auto* in_var : n->inputs) { | ||
// the var's input is op | ||
for (auto* in_op : in_var->inputs) { | ||
if (cluster_set.find(in_op) != cluster_set.end()) { | ||
++indegree[n]; | ||
++adj_list[in_op][n]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// find topology entrance | ||
for (auto* n : cluster_) { | ||
if (indegree[n] == 0) { | ||
topo_queue.push(n); | ||
} | ||
} | ||
|
||
// topological sorting | ||
while (!topo_queue.empty()) { | ||
auto* cur_op = topo_queue.front(); | ||
topo_queue.pop(); | ||
|
||
cluster_sorted->emplace_back(cur_op); | ||
for (const auto& adj_pair : adj_list[cur_op]) { | ||
// decrease output op's in-degree | ||
indegree.at(adj_pair.first) -= adj_pair.second; | ||
|
||
// if empty, push into queue | ||
if (indegree.at(adj_pair.first) == 0) { | ||
topo_queue.push(adj_pair.first); | ||
} | ||
} | ||
} | ||
|
||
PADDLE_ENFORCE_EQ(cluster_sorted->size(), cluster_.size(), | ||
platform::errors::PreconditionNotMet( | ||
"Cluster Sub-Graph shouldn't contain cycle.")); | ||
} | ||
|
||
void PianoGraphExecutor::RunCompile(const GraphNodeVec& cluster, | ||
PianoScope* scope, | ||
symbolization::NoteBuilder* builder) const { | ||
for (auto* n : cluster) { | ||
const auto& op_name = n->Name(); | ||
const auto* op_desc = n->Op(); | ||
|
||
const auto& op_kernel_map = PianoOpRegistry::AllPianoOpKernels(op_name); | ||
// TODO(jiangcheng05): how to distinguish library's kernel, like cudnn? | ||
op_kernel_map.at("PLAIN")(PianoOpKernelContext(op_desc, scope, builder)); | ||
} | ||
} | ||
|
||
note::ModuleProto PianoGraphExecutor::operator()() const { | ||
// Step1: create unique NoteBuilder | ||
std::string builder_name = "NoteBuilderOfGraph_"; | ||
builder_name.append(std::to_string(graph_id_)); | ||
|
||
symbolization::NoteBuilder builder(builder_name); | ||
|
||
// Step2: create graph's input operand | ||
PianoScope scope; | ||
CreateInputOperand(&scope, &builder); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1)将 详见:https://zh-google-styleguide.readthedocs.io/en/latest/google-cpp-styleguide/classes/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// Step3: topo sort graph | ||
GraphNodeVec cluster_sorted; | ||
TopologicSortCluster(&cluster_sorted); | ||
|
||
// Step4: get PianoOpKernel and run compile | ||
RunCompile(cluster_sorted, &scope, &builder); | ||
|
||
// Step5: build and return module | ||
return builder.Build(); | ||
} | ||
|
||
} // namespace piano | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
|
||
#include "paddle/fluid/compiler/paddle2piano/piano_scope.h" | ||
#include "paddle/fluid/compiler/piano/note/note.pb.h" | ||
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h" | ||
#include "paddle/fluid/framework/ir/node.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
// An executor accept sub-graph which is generated by PianoCompilePass, | ||
// run each op's PianoOpKernel, finally return the graph's ModuleProto. | ||
// | ||
// Parameter: | ||
// 1. graph_id: the unique graph id, used for generating unique notebuilder name | ||
// 2. cluster: a vector which contains all graph op, non-topological-sorting. | ||
// 3. cluster_inputs: a vector which contains all graph's input var, the var's | ||
// input are outside op, the output are inside op | ||
// 4. cluster_outputs: a vector which contains all graph's output var, the var's | ||
// input are inside op, the output are outside op | ||
// 5. cluster_internals: a vector which contains all graph's internal var, the | ||
// var's input and output are inside op | ||
// | ||
// Example: | ||
// -------------------------> op3 -> var4 -> | ||
// / / | ||
// -> var1 -> op1 -> var2 -> op2 -> var3 | ||
// | ||
// cluster: [op1, op2, op3] | ||
// cluster_inputs: [var1] | ||
// cluster_outputs: [var4] | ||
// cluster_internals: [var2, var3] | ||
// | ||
// Describe: | ||
// The executor consisted by the following step: | ||
// 1. create a NoteBuilder, it's name is unique for each graph | ||
// 2. create PianoScope, initially, scope only consist graph's input var and its | ||
// operand | ||
// 3. topological sorting graph | ||
// 4. create PianoOpKernelContext and run each op's PianoOpKernel | ||
// 5. run NoteBuilder's Build function to generate graph's ModuleProto | ||
class PianoGraphExecutor { | ||
public: | ||
using GraphNodeVec = std::vector<framework::ir::Node*>; | ||
|
||
PianoGraphExecutor(int64_t graph_id, const GraphNodeVec& cluster, | ||
const GraphNodeVec& cluster_inputs, | ||
const GraphNodeVec& cluster_outputs, | ||
const GraphNodeVec& cluster_internals) | ||
: graph_id_(graph_id), | ||
cluster_(cluster), | ||
cluster_inputs_(cluster_inputs), | ||
cluster_outputs_(cluster_outputs), | ||
cluster_internals_(cluster_internals) {} | ||
|
||
note::ModuleProto operator()() const; | ||
|
||
private: | ||
const int64_t graph_id_; | ||
const GraphNodeVec& cluster_; | ||
const GraphNodeVec& cluster_inputs_; | ||
const GraphNodeVec& cluster_outputs_; | ||
const GraphNodeVec& cluster_internals_; | ||
|
||
// get var's data type | ||
static framework::proto::VarType::Type GetVarDataType( | ||
const framework::VarDesc* var); | ||
|
||
// create graph's input operand from cluster_inputs_ | ||
void CreateInputOperand(PianoScope* scope, | ||
symbolization::NoteBuilder* builder) const; | ||
|
||
// run PianoOpKernel's Compile | ||
void RunCompile(const GraphNodeVec& cluster, PianoScope* scope, | ||
symbolization::NoteBuilder* builder) const; | ||
|
||
// topologic sorting graph node | ||
void TopologicSortCluster(GraphNodeVec* cluster_sorted) const; | ||
}; | ||
|
||
} // namespace piano | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉
VarType2NoteType
和GetVarDataType
的定义可以拿出来放在vartype_utils.cc中,并将vartype2notetype.h
更改为vartype_utils.h
,让后将VarType2NoteType
和GetVarDataType
的声明放在vartype_utils.h
。GetVarDataType放在类中作为static函数,但是又感觉与PianoGraphExecutor无关。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done