-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Paddle compiler #35
base: paddle_compiler
Are you sure you want to change the base?
Paddle compiler #35
Changes from all commits
c4a9ca8
6f10137
772f98d
8cd4153
fa97ce3
bb6e63d
9c94942
2d4a99f
a8a2426
a385fb8
5651f2b
13af23d
46288df
3aaa036
0e04f5a
0c02918
c75a57f
08106e2
82dd306
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/compiler/piano/pass.h" | ||
#include "paddle/fluid/compiler/piano/all_passes.h" | ||
#include "glog/logging.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
int verify_all_passes() { | ||
int count = 0; | ||
#define VAR(pass) _##pass | ||
#define CHECK_PASS(pass) \ | ||
auto VAR(pass) = PASS_CTOR(pass); \ | ||
count++; \ | ||
LOG(INFO) << "Check pass: " << VAR(pass).name(); \ | ||
{ | ||
// Expand the pass list | ||
PASS_ALL(CHECK_PASS) | ||
} | ||
return count; | ||
#undef CHECK_PASS | ||
#undef VAR | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/compiler/piano/pass.h" | ||
#include "paddle/fluid/compiler/piano/note/instruction.h" | ||
#include "paddle/fluid/compiler/piano/note/function.h" | ||
#include "paddle/fluid/compiler/piano/note/note.pb.h" | ||
#include <type_traits> | ||
|
||
// include all pass class definition headers here | ||
|
||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
class ATestPass : public Pass { | ||
using Function = note::Function; | ||
using Instruction = note::Instruction; | ||
using OpCode = note::OpCode; | ||
public: | ||
ATestPass() : Pass() {} | ||
~ATestPass() override = default; | ||
bool run(void *fn) override { | ||
bool changed = false; | ||
auto* ir = static_cast<Function*>(fn); | ||
auto dead_ins = std::vector<Instruction*>(); | ||
for (auto& instruction : ir->instructions()) { | ||
if (instruction.ctrl_predecessors().empty() && | ||
instruction.ctrl_successors().empty() && | ||
instruction.opcode() != OpCode::kParameter) | ||
dead_ins.push_back(&instruction); | ||
} | ||
// (TODO) remove dead instructions from function | ||
changed = !dead_ins.empty(); | ||
return changed; | ||
} | ||
std::string name() const override { | ||
return "a_test_pass"; | ||
} | ||
}; | ||
|
||
// Put all the piano optimization passes here so that they can be hooked | ||
// with the make_pass function. | ||
#define PASS_ALL(__macro) \ | ||
__macro(ATest) | ||
|
||
// Pass id enum is used as key for dispatching pass classes | ||
enum class PassId { | ||
#define ID(pass) pass, | ||
PASS_NA, | ||
PASS_ALL(ID) | ||
#undef ID | ||
}; | ||
|
||
#define INC(name) +1 | ||
constexpr int Total_Num_Passes = PASS_ALL(INC); | ||
#undef INC | ||
|
||
// Following are basic utilities for constructing pass objects | ||
|
||
template<typename P> | ||
static P *do_make_pass() { | ||
static_assert(std::is_base_of<Pass, P>::value); | ||
return new P(); | ||
} | ||
|
||
template<PassId T> | ||
struct PassClass {}; | ||
|
||
#define PASS_ID(pass) PassId::pass | ||
#define PASS_CLASS(pass) pass##Pass | ||
|
||
#define SPECIALIZE_PASSCLASS(pass) \ | ||
template<> \ | ||
struct PassClass<PASS_ID(pass)> { \ | ||
using type = PASS_CLASS(pass); \ | ||
}; | ||
PASS_ALL(SPECIALIZE_PASSCLASS) | ||
|
||
// Use this macro as the public interface for constructing heap allocated | ||
// pass object. | ||
// Code example: | ||
// { | ||
// auto* dce_pass = make_pass(ModuleDCE); | ||
// dce->run(module_ir); | ||
// } | ||
#define make_pass(pass) \ | ||
do_make_pass<PassClass<PASS_ID(pass)>::type>(); | ||
|
||
// use this macro as the public interface for constructing stack allocated | ||
// pass object. | ||
// Code example: | ||
// { | ||
// auto dce_pass = PASS_CTOR(ModuleDCE); | ||
// dce.run(module_ir); | ||
// } | ||
#define PASS_CTOR(pass) \ | ||
PassClass<PASS_ID(pass)>::type(); | ||
|
||
int verify_all_passes(); | ||
|
||
#undef DEF_PASS_MAKER | ||
#undef SPECIALIZE_PASSCLASS | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <string> | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
class Pass { | ||
public: | ||
Pass() {} | ||
virtual ~Pass() {}; | ||
virtual bool run(void *ir) = 0; | ||
virtual std::string name() const = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不大明白为什么要把
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello,构造参数就要为每个对象分配 name 空间,虚函数的用意是一个类只需要一个常量 string,constexpr 也可以? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不管有没有 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 另外,constexpr应该不行,但 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块我想过,比较下来还是虚函数最好,name严格说都不算pass的本质属性,另外name()不频繁调用,对象只包含一个虚函数表指针。 |
||
}; | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/compiler/piano/pass.h" | ||
#include "paddle/fluid/compiler/piano/all_passes.h" | ||
#include "paddle/fluid/compiler/piano/note/instruction.h" | ||
#include "paddle/fluid/compiler/piano/note/function.h" | ||
#include "paddle/fluid/compiler/piano/note/note.pb.h" | ||
#include "glog/logging.h" | ||
#include "gtest/gtest.h" | ||
|
||
namespace paddle { | ||
namespace piano { | ||
|
||
using Function = note::Function; | ||
using Instruction = note::Instruction; | ||
using OpCode = note::OpCode; | ||
using FunctionProto = note::FunctionProto; | ||
using SignatureProto = note::SignatureProto; | ||
using InstructionProto = note::InstructionProto; | ||
using AttrValueProto = note::AttrValueProto; | ||
using ProtoMapType = note::ProtoMapType; | ||
|
||
class BTestPass : Pass { | ||
public: | ||
BTestPass() : Pass() {} | ||
~BTestPass() override = default; | ||
bool run(void *fn) override { | ||
bool changed = false; | ||
auto* ir = static_cast<Function*>(fn); | ||
auto dead_ins = std::vector<Instruction*>(); | ||
for (auto& instruction : ir->instructions()) { | ||
if (instruction.ctrl_predecessors().empty() && | ||
instruction.ctrl_successors().empty() && | ||
instruction.opcode() != OpCode::kParameter) | ||
dead_ins.push_back(&instruction); | ||
} | ||
// (TODO) remove dead instructions from function | ||
changed = !dead_ins.empty(); | ||
return changed; | ||
} | ||
std::string name() const override { | ||
return "b_test_pass"; | ||
} | ||
}; | ||
|
||
class PassClassTest : public ::testing::Test { | ||
virtual void SetUp() { | ||
// input shapes | ||
const Shape arg1_shape(note::F32, {3, 6}); | ||
const Shape arg2_shape(note::F32, {3, 6}); | ||
// output shape | ||
const Shape result_shape(note::F32, {3, 6}); | ||
// function signature | ||
const Signature signature({arg1_shape, arg2_shape}, {"arg1.1", "arg2.2"}, | ||
result_shape); | ||
SignatureProto signature_proto = signature.ToProto(); | ||
signature_proto_.Swap(&signature_proto); | ||
|
||
// set instr1_proto_ | ||
instr1_proto_.set_name("arg1.1"); | ||
instr1_proto_.set_opcode(GetOpName(OpCode::kParameter)); | ||
instr1_proto_.set_id(1); | ||
instr1_proto_.set_parameter_index(0); | ||
*instr1_proto_.mutable_shape() = arg1_shape.ToProto(); | ||
auto* attrs1_map = instr1_proto_.mutable_attrs(); | ||
AttrValueProto val1_proto; | ||
val1_proto.set_d(3.141); | ||
attrs1_map->insert(ProtoMapType::value_type("test_double", val1_proto)); | ||
auto* strings = val1_proto.mutable_strings()->mutable_value(); | ||
*strings->Add() = "hello"; | ||
*strings->Add() = "world"; | ||
attrs1_map->insert(ProtoMapType::value_type("test_strings", val1_proto)); | ||
auto* bools = val1_proto.mutable_bools()->mutable_value(); | ||
bools->Add(true); | ||
bools->Add(false); | ||
attrs1_map->insert(ProtoMapType::value_type("test_bools", val1_proto)); | ||
auto* ints = val1_proto.mutable_ints()->mutable_value(); | ||
ints->Add(8); | ||
ints->Add(26); | ||
attrs1_map->insert(ProtoMapType::value_type("test_ints", val1_proto)); | ||
|
||
// set instr2_proto_ | ||
instr2_proto_.set_name("arg2.2"); | ||
instr2_proto_.set_opcode(GetOpName(OpCode::kParameter)); | ||
instr2_proto_.set_id(2); | ||
instr2_proto_.set_parameter_index(1); | ||
*instr2_proto_.mutable_shape() = arg2_shape.ToProto(); | ||
auto* attrs2_map = instr2_proto_.mutable_attrs(); | ||
AttrValueProto val2_proto; | ||
val2_proto.set_b(true); | ||
attrs2_map->insert(ProtoMapType::value_type("test_bool", val2_proto)); | ||
auto* longs = val2_proto.mutable_longs()->mutable_value(); | ||
longs->Add(8l); | ||
longs->Add(16l); | ||
attrs2_map->insert(ProtoMapType::value_type("test_longs", val2_proto)); | ||
auto* floats = val2_proto.mutable_floats()->mutable_value(); | ||
floats->Add(8.6f); | ||
floats->Add(7.6f); | ||
attrs2_map->insert(ProtoMapType::value_type("test_floats", val2_proto)); | ||
auto* doubles = val2_proto.mutable_doubles()->mutable_value(); | ||
doubles->Add(5.66); | ||
doubles->Add(6.66); | ||
attrs2_map->insert(ProtoMapType::value_type("test_doubles", val2_proto)); | ||
|
||
// set instr3_proto_ | ||
instr3_proto_.set_name("add.3"); | ||
instr3_proto_.set_opcode(GetOpName(OpCode::kAdd)); | ||
instr3_proto_.set_id(3); | ||
*instr3_proto_.mutable_shape() = result_shape.ToProto(); | ||
instr3_proto_.add_operand_ids(1); | ||
instr3_proto_.add_operand_ids(2); | ||
auto* attrs3_map = instr3_proto_.mutable_attrs(); | ||
AttrValueProto val3_proto; | ||
val3_proto.set_s("Add"); | ||
attrs3_map->insert(ProtoMapType::value_type("test_string", val3_proto)); | ||
val3_proto.set_i(-1); | ||
attrs3_map->insert(ProtoMapType::value_type("test_int", val3_proto)); | ||
val3_proto.set_l(-100l); | ||
attrs3_map->insert(ProtoMapType::value_type("test_long", val3_proto)); | ||
val3_proto.set_f(-1.414f); | ||
attrs3_map->insert(ProtoMapType::value_type("test_float", val3_proto)); | ||
|
||
// set func_proto_ | ||
func_proto_.set_name(func_name_); | ||
*func_proto_.mutable_signature() = signature_proto_; | ||
func_proto_.set_return_id(instr3_proto_.id()); | ||
function_id_ = instr3_proto_.id() + 1; | ||
func_proto_.set_id(function_id_); | ||
*func_proto_.add_instructions() = instr1_proto_; | ||
*func_proto_.add_instructions() = instr2_proto_; | ||
*func_proto_.add_instructions() = instr3_proto_; | ||
|
||
} | ||
protected: | ||
std::string func_name_{"union_12510013719728903619"}; | ||
FunctionProto func_proto_; | ||
std::int64_t function_id_; | ||
SignatureProto signature_proto_; | ||
InstructionProto instr1_proto_; | ||
InstructionProto instr2_proto_; | ||
InstructionProto instr3_proto_; | ||
}; | ||
|
||
TEST_F(PassClassTest, VerifyPasses) { | ||
int num_passes = verify_all_passes(); | ||
// Control reaches here if compilation succeeds | ||
EXPECT_EQ(num_passes, Total_Num_Passes); | ||
} | ||
|
||
TEST_F(PassClassTest, SimpleFunctionPass) { | ||
std::unordered_map<std::int64_t, Function*> func_index; | ||
Function func(func_proto_, func_index); | ||
auto* a_pass = make_pass(ATest); | ||
EXPECT_EQ(a_pass->run(&func), true); | ||
LOG(INFO) << "A simple function pass detecting dead instructions."; | ||
|
||
} | ||
|
||
|
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该是
\
而不是+1
吧?另外这个name
好像没有用到唉?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 是 +1,展开之后相当于 +1 +1 +1 ... = number of pass classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
受教了,还没这么用过