Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paddle compiler #35

Open
wants to merge 19 commits into
base: paddle_compiler
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/compiler/piano/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@ cc_library(piano_data_description SRCS layout.cc shape.cc DEPS string_helper not
cc_test(piano_layout_test SRCS layout_test.cc DEPS piano_data_description)
cc_test(piano_shape_test SRCS shape_test.cc DEPS piano_data_description)

cc_library(piano_pass SRCS all_passes.cc DEPS note_opcode note_ir note_proto)
cc_test(piano_pass_test SRCS pass_test.cc DEPS piano_pass)

cc_library(note_builder SRCS note_builder.cc DEPS string_helper note_opcode piano_data_description)
cc_test(note_builder_test SRCS note_builder_test.cc DEPS note_builder)
39 changes: 39 additions & 0 deletions paddle/fluid/compiler/piano/all_passes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/pass.h"
#include "paddle/fluid/compiler/piano/all_passes.h"
#include "glog/logging.h"

namespace paddle {
namespace piano {

int verify_all_passes() {
int count = 0;
#define VAR(pass) _##pass
#define CHECK_PASS(pass) \
auto VAR(pass) = PASS_CTOR(pass); \
count++; \
LOG(INFO) << "Check pass: " << VAR(pass).name(); \
{
// Expand the pass list
PASS_ALL(CHECK_PASS)
}
return count;
#undef CHECK_PASS
#undef VAR
}

}
}
119 changes: 119 additions & 0 deletions paddle/fluid/compiler/piano/all_passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/compiler/piano/pass.h"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/function.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include <type_traits>

// include all pass class definition headers here


namespace paddle {
namespace piano {

class ATestPass : public Pass {
using Function = note::Function;
using Instruction = note::Instruction;
using OpCode = note::OpCode;
public:
ATestPass() : Pass() {}
~ATestPass() override = default;
bool run(void *fn) override {
bool changed = false;
auto* ir = static_cast<Function*>(fn);
auto dead_ins = std::vector<Instruction*>();
for (auto& instruction : ir->instructions()) {
if (instruction.ctrl_predecessors().empty() &&
instruction.ctrl_successors().empty() &&
instruction.opcode() != OpCode::kParameter)
dead_ins.push_back(&instruction);
}
// (TODO) remove dead instructions from function
changed = !dead_ins.empty();
return changed;
}
std::string name() const override {
return "a_test_pass";
}
};

// Put all the piano optimization passes here so that they can be hooked
// with the make_pass function.
#define PASS_ALL(__macro) \
__macro(ATest)

// Pass id enum is used as key for dispatching pass classes
enum class PassId {
#define ID(pass) pass,
PASS_NA,
PASS_ALL(ID)
#undef ID
};

#define INC(name) +1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是\而不是+1吧?另外这个name好像没有用到唉?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 是 +1,展开之后相当于 +1 +1 +1 ... = number of pass classes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

受教了,还没这么用过

constexpr int Total_Num_Passes = PASS_ALL(INC);
#undef INC

// Following are basic utilities for constructing pass objects

template<typename P>
static P *do_make_pass() {
static_assert(std::is_base_of<Pass, P>::value);
return new P();
}

template<PassId T>
struct PassClass {};

#define PASS_ID(pass) PassId::pass
#define PASS_CLASS(pass) pass##Pass

#define SPECIALIZE_PASSCLASS(pass) \
template<> \
struct PassClass<PASS_ID(pass)> { \
using type = PASS_CLASS(pass); \
};
PASS_ALL(SPECIALIZE_PASSCLASS)

// Use this macro as the public interface for constructing heap allocated
// pass object.
// Code example:
// {
// auto* dce_pass = make_pass(ModuleDCE);
// dce->run(module_ir);
// }
#define make_pass(pass) \
do_make_pass<PassClass<PASS_ID(pass)>::type>();

// use this macro as the public interface for constructing stack allocated
// pass object.
// Code example:
// {
// auto dce_pass = PASS_CTOR(ModuleDCE);
// dce.run(module_ir);
// }
#define PASS_CTOR(pass) \
PassClass<PASS_ID(pass)>::type();

int verify_all_passes();

#undef DEF_PASS_MAKER
#undef SPECIALIZE_PASSCLASS

}
}
30 changes: 30 additions & 0 deletions paddle/fluid/compiler/piano/pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>

namespace paddle {
namespace piano {

class Pass {
public:
Pass() {}
virtual ~Pass() {};
virtual bool run(void *ir) = 0;
virtual std::string name() const = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不大明白为什么要把name设计为一个纯虚函数?作为构造函数参数不更好么?

class Pass {
 public:
  Pass(const std::string& name) : name_(name) {}
  virtual ~Pass() =default;
  const std::string& name() { return name_;}
 private:
  std::string name_;
};

class SubPass : public  Pass {
 public:
  SubPass() : Pass("SubPass") {}
};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello,构造参数就要为每个对象分配 name 空间,虚函数的用意是一个类只需要一个常量 string,constexpr 也可以?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不管有没有name,虚函数意味着总是得给对象分配空间的,多个name感觉没啥负担

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,constexpr应该不行,但const std::string& name应该可以,只要保证name只会在构造函数中被赋值且不会被修改。引用的负担也不大

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块我想过,比较下来还是虚函数最好,name严格说都不算pass的本质属性,另外name()不频繁调用,对象只包含一个虚函数表指针。

};

}
}
173 changes: 173 additions & 0 deletions paddle/fluid/compiler/piano/pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/pass.h"
#include "paddle/fluid/compiler/piano/all_passes.h"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/function.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "glog/logging.h"
#include "gtest/gtest.h"

namespace paddle {
namespace piano {

using Function = note::Function;
using Instruction = note::Instruction;
using OpCode = note::OpCode;
using FunctionProto = note::FunctionProto;
using SignatureProto = note::SignatureProto;
using InstructionProto = note::InstructionProto;
using AttrValueProto = note::AttrValueProto;
using ProtoMapType = note::ProtoMapType;

class BTestPass : Pass {
public:
BTestPass() : Pass() {}
~BTestPass() override = default;
bool run(void *fn) override {
bool changed = false;
auto* ir = static_cast<Function*>(fn);
auto dead_ins = std::vector<Instruction*>();
for (auto& instruction : ir->instructions()) {
if (instruction.ctrl_predecessors().empty() &&
instruction.ctrl_successors().empty() &&
instruction.opcode() != OpCode::kParameter)
dead_ins.push_back(&instruction);
}
// (TODO) remove dead instructions from function
changed = !dead_ins.empty();
return changed;
}
std::string name() const override {
return "b_test_pass";
}
};

class PassClassTest : public ::testing::Test {
virtual void SetUp() {
// input shapes
const Shape arg1_shape(note::F32, {3, 6});
const Shape arg2_shape(note::F32, {3, 6});
// output shape
const Shape result_shape(note::F32, {3, 6});
// function signature
const Signature signature({arg1_shape, arg2_shape}, {"arg1.1", "arg2.2"},
result_shape);
SignatureProto signature_proto = signature.ToProto();
signature_proto_.Swap(&signature_proto);

// set instr1_proto_
instr1_proto_.set_name("arg1.1");
instr1_proto_.set_opcode(GetOpName(OpCode::kParameter));
instr1_proto_.set_id(1);
instr1_proto_.set_parameter_index(0);
*instr1_proto_.mutable_shape() = arg1_shape.ToProto();
auto* attrs1_map = instr1_proto_.mutable_attrs();
AttrValueProto val1_proto;
val1_proto.set_d(3.141);
attrs1_map->insert(ProtoMapType::value_type("test_double", val1_proto));
auto* strings = val1_proto.mutable_strings()->mutable_value();
*strings->Add() = "hello";
*strings->Add() = "world";
attrs1_map->insert(ProtoMapType::value_type("test_strings", val1_proto));
auto* bools = val1_proto.mutable_bools()->mutable_value();
bools->Add(true);
bools->Add(false);
attrs1_map->insert(ProtoMapType::value_type("test_bools", val1_proto));
auto* ints = val1_proto.mutable_ints()->mutable_value();
ints->Add(8);
ints->Add(26);
attrs1_map->insert(ProtoMapType::value_type("test_ints", val1_proto));

// set instr2_proto_
instr2_proto_.set_name("arg2.2");
instr2_proto_.set_opcode(GetOpName(OpCode::kParameter));
instr2_proto_.set_id(2);
instr2_proto_.set_parameter_index(1);
*instr2_proto_.mutable_shape() = arg2_shape.ToProto();
auto* attrs2_map = instr2_proto_.mutable_attrs();
AttrValueProto val2_proto;
val2_proto.set_b(true);
attrs2_map->insert(ProtoMapType::value_type("test_bool", val2_proto));
auto* longs = val2_proto.mutable_longs()->mutable_value();
longs->Add(8l);
longs->Add(16l);
attrs2_map->insert(ProtoMapType::value_type("test_longs", val2_proto));
auto* floats = val2_proto.mutable_floats()->mutable_value();
floats->Add(8.6f);
floats->Add(7.6f);
attrs2_map->insert(ProtoMapType::value_type("test_floats", val2_proto));
auto* doubles = val2_proto.mutable_doubles()->mutable_value();
doubles->Add(5.66);
doubles->Add(6.66);
attrs2_map->insert(ProtoMapType::value_type("test_doubles", val2_proto));

// set instr3_proto_
instr3_proto_.set_name("add.3");
instr3_proto_.set_opcode(GetOpName(OpCode::kAdd));
instr3_proto_.set_id(3);
*instr3_proto_.mutable_shape() = result_shape.ToProto();
instr3_proto_.add_operand_ids(1);
instr3_proto_.add_operand_ids(2);
auto* attrs3_map = instr3_proto_.mutable_attrs();
AttrValueProto val3_proto;
val3_proto.set_s("Add");
attrs3_map->insert(ProtoMapType::value_type("test_string", val3_proto));
val3_proto.set_i(-1);
attrs3_map->insert(ProtoMapType::value_type("test_int", val3_proto));
val3_proto.set_l(-100l);
attrs3_map->insert(ProtoMapType::value_type("test_long", val3_proto));
val3_proto.set_f(-1.414f);
attrs3_map->insert(ProtoMapType::value_type("test_float", val3_proto));

// set func_proto_
func_proto_.set_name(func_name_);
*func_proto_.mutable_signature() = signature_proto_;
func_proto_.set_return_id(instr3_proto_.id());
function_id_ = instr3_proto_.id() + 1;
func_proto_.set_id(function_id_);
*func_proto_.add_instructions() = instr1_proto_;
*func_proto_.add_instructions() = instr2_proto_;
*func_proto_.add_instructions() = instr3_proto_;

}
protected:
std::string func_name_{"union_12510013719728903619"};
FunctionProto func_proto_;
std::int64_t function_id_;
SignatureProto signature_proto_;
InstructionProto instr1_proto_;
InstructionProto instr2_proto_;
InstructionProto instr3_proto_;
};

TEST_F(PassClassTest, VerifyPasses) {
int num_passes = verify_all_passes();
// Control reaches here if compilation succeeds
EXPECT_EQ(num_passes, Total_Num_Passes);
}

TEST_F(PassClassTest, SimpleFunctionPass) {
std::unordered_map<std::int64_t, Function*> func_index;
Function func(func_proto_, func_index);
auto* a_pass = make_pass(ATest);
EXPECT_EQ(a_pass->run(&func), true);
LOG(INFO) << "A simple function pass detecting dead instructions.";

}


}
}