Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the Module class. #33

Merged
merged 11 commits into from
Aug 23, 2021
2 changes: 1 addition & 1 deletion paddle/fluid/compiler/piano/note/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ cc_test(note_opcode_test SRCS opcode_test.cc DEPS note_opcode)
proto_library(note_proto SRCS note.proto)
target_compile_options(note_proto PUBLIC "-Wno-extra")

cc_library(note_ir SRCS instruction.cc function.cc DEPS note_opcode note_proto piano_data_description)
cc_library(note_ir SRCS instruction.cc function.cc module.cc DEPS note_opcode note_proto piano_data_description)
cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir)
19 changes: 17 additions & 2 deletions paddle/fluid/compiler/piano/note/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,27 @@ Function::Function(
instr->set_parent(this);
// set parameter(input) instructions field
if (instr->opcode() == OpCode::kParameter) {
PADDLE_ENFORCE_EQ(
instr->valid_parameter_index(), true,
platform::errors::PreconditionNotMet(
"The parameter instruction %s doesn't have a valid index.",
instr->name()));

param_instrs_.push_back(instr.get());
}
instr_index[instr_proto.id()] = instr.get();
inverted_index[instr.get()] = instr_proto.id();

auto instr_id = instr_proto.id();
PADDLE_ENFORCE_EQ(
instr_index.count(instr_id), 0,
platform::errors::PreconditionNotMet(
"The global id (%ld) of Instruction %s is the same as the previous "
"Instruction %s.",
instr_id, instr->name(), instr_index[instr_id]->name()));
instr_index[instr_id] = instr.get();
inverted_index[instr.get()] = instr_id;
instructions_.emplace_back(std::move(instr));
}

PADDLE_ENFORCE_EQ(
proto.return_id() >= 0 && instr_index.count(proto.return_id()), true,
platform::errors::PreconditionNotMet(
Expand Down
87 changes: 62 additions & 25 deletions paddle/fluid/compiler/piano/note/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "boost/range/iterator_range.hpp"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/note/type_traits.h"
#include "paddle/fluid/compiler/piano/shape.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace piano {
namespace note {

class Instruction;
// class Module;
class Module;

class Function {
public:
Expand All @@ -46,47 +49,82 @@ class Function {
const std::string &name() const { return name_; }

// return instructions owned by this function
std::vector<Instruction *> instructions() const {
std::vector<Instruction *> instrs;
instrs.reserve(instructions_.size());
std::transform(
instructions_.cbegin(), instructions_.cend(),
std::back_inserter(instrs),
[](const std::unique_ptr<Instruction> &instr) { return instr.get(); });
return instrs;
// for(Instruction &instr : function->instructions()){...}
auto instructions() const {
using IteratorT = decltype(instructions_.cbegin());
return boost::make_iterator_range(
UnboxingIterator<IteratorT>{instructions_.cbegin()},
UnboxingIterator<IteratorT>{instructions_.cend()});
}

const Instruction *instruction(std::int64_t idx) const {
return instructions_.at(idx).get();
// return an instruction included in this function by the given index
Instruction *instruction(std::int64_t idx) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回的instruction调用方有可能会修改吗

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是暂时不确定,我这么改之后const对象和非const对象返回的instruction在外部都是可以修改的,给了比较大的权限。你觉得需要限制吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要修改啊

PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(instructions_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, instructions_.size()));
PADDLE_ENFORCE_NOT_NULL(
instructions_[idx].get(),
platform::errors::PreconditionNotMet(
"The instruction %ld should not be null.", idx));
return instructions_[idx].get();
}

Instruction *mutable_instruction(std::int64_t idx) {
return instructions_.at(idx).get();
}

// return the function signature
// return the immutable function signature
const Signature &signature() const { return signature_; }

// return the mutable function signature
Signature *mutable_signature() { return &signature_; }

// return the globally unique id of this function
std::int64_t global_id() const { return global_id_; }

// return the returned instruction of this function
const Instruction *return_instr() const { return return_instr_; }
const Instruction &return_instr() const {
PADDLE_ENFORCE_NOT_NULL(return_instr_,
platform::errors::PreconditionNotMet(
"The return instruction should not be null."));
return *return_instr_;
}

// const Module *parent() const { return parent_; }
// return the immutable module which includes this function
const Module &parent() const {
PADDLE_ENFORCE_NOT_NULL(parent_, platform::errors::PreconditionNotMet(
"The parent_(Module) of this function "
"is null, please set it first."));
return *parent_;
}

// Module *mutable_parent() { return parent_; }
// return the mutable module which includes this function
Module *mutable_parent() {
PADDLE_ENFORCE_NOT_NULL(parent_, platform::errors::PreconditionNotMet(
"The parent_(Module) of this function "
"is null, please set it first."));
return parent_;
}

// void set_parent(Module *module) { parent_ = module; }
// set the module in which this function resides
void set_parent(Module *mod) { parent_ = mod; }

const std::vector<Instruction *> &param_instrs() const {
return param_instrs_;
}

// return parameter instructions of this function
const Instruction *param_instr(std::int64_t idx) const {
return param_instrs_.at(idx);
const Instruction &param_instr(std::int64_t idx) const {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(param_instrs_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, param_instrs_.size()));
PADDLE_ENFORCE_NOT_NULL(
param_instrs_[idx],
platform::errors::PreconditionNotMet(
"The parameter instruction %ld should not be null.", idx));
return *param_instrs_[idx];
}

// return the parameter(input) number of this function
Expand All @@ -104,9 +142,8 @@ class Function {
// the returned instruction of this function
Instruction *return_instr_;

// TODO(wzzju): Add Module class.
// the module where this function is contained
// Module *parent_{nullptr};
Module *parent_{nullptr};

// parameter instructions of this function,
// which denote input parameters
Expand Down
21 changes: 8 additions & 13 deletions paddle/fluid/compiler/piano/note/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Instruction::Instruction(

// add control dependency
for (auto id : proto.control_predecessor_ids()) {
PADDLE_ENFORCE_EQ(instr_index.at(id)->parent(), parent(),
PADDLE_ENFORCE_EQ(instr_index.at(id)->mutable_parent(), mutable_parent(),
platform::errors::PreconditionNotMet(
"The instruction and its dependent instruction are "
"not in the same function."));
Expand All @@ -68,16 +68,9 @@ Instruction::Instruction(
}
}

// set parameter number
if (proto.has_parameter_number()) {
PADDLE_ENFORCE_EQ(proto.parameter_number(), operands_.size(),
platform::errors::PreconditionNotMet(
"The number of operands(%ld) is not equal to the "
"parameter_number(%zu) in proto.",
proto.parameter_number(), operands_.size()));
parameter_number_ = proto.parameter_number();
} else {
parameter_number_ = static_cast<std::int64_t>(operands_.size());
// set parameter index
if (proto.has_parameter_index()) {
parameter_index_ = proto.parameter_index();
}

// set attrs
Expand All @@ -94,7 +87,9 @@ InstructionProto Instruction::ToProto() const {
proto.set_name(name_);
proto.set_opcode(GetOpName(opcode_));
proto.set_id(global_id_);
proto.set_parameter_number(parameter_number_);
if (valid_parameter_index()) {
proto.set_parameter_index(parameter_index_);
}

// serialize shape info
*proto.mutable_shape() = shape_.ToProto();
Expand Down Expand Up @@ -166,7 +161,7 @@ std::string Instruction::ToString() const {
string::join_strings(attr_strs, ", ").c_str());
}

void Instruction::Accept(backends::NoteVisitorBase* visitor) {
void Instruction::Accept(backends::NoteVisitorBase* visitor) const {
switch (opcode_) {
#define HANDLE_VISIT(enum_id, op_name, ...) \
case OpCode::k##enum_id: \
Expand Down
63 changes: 49 additions & 14 deletions paddle/fluid/compiler/piano/note/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,48 +50,79 @@ class Instruction {

std::string ToString() const;

void Accept(backends::NoteVisitorBase *visitor);

void Accept(backends::NoteVisitorBase *visitor) const {
return const_cast<Instruction *>(this)->Accept(visitor);
}
void Accept(backends::NoteVisitorBase *visitor) const;

// return the name of this instruction
const std::string &name() const { return name_; }

// return the opcode of this instruction
OpCode opcode() const { return opcode_; }

// return the immutable result shape of this instruction
const Shape &shape() const { return shape_; }

// return the mutable result shape of this instruction
Shape *mutable_shape() { return &shape_; }

const Function *parent() const { return parent_; }
// return the immutable function which includes this instruction
const Function &parent() const {
PADDLE_ENFORCE_NOT_NULL(parent_,
platform::errors::PreconditionNotMet(
"The parent_(Function) of this instruction is "
"null, please set it first."));
return *parent_;
}

Function *mutable_parent() { return parent_; }
// return the mutable function which includes this instruction
Function *mutable_parent() {
PADDLE_ENFORCE_NOT_NULL(parent_,
platform::errors::PreconditionNotMet(
"The parent_(Function) of this instruction is "
"null, please set it first."));
return parent_;
}

// set the function in which this instruction resides
void set_parent(Function *func) { parent_ = func; }

// return the globally unique id of this instruction
std::int64_t global_id() const { return global_id_; }

// return instruction operands
const std::vector<Instruction *> &operands() const { return operands_; }

const Instruction *operand(std::int64_t idx) const {
return operands_.at(idx);
const Instruction &operand(std::int64_t idx) const {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(operands_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, operands_.size()));
PADDLE_ENFORCE_NOT_NULL(operands_[idx],
platform::errors::PreconditionNotMet(
"operand %ld should not be null.", idx));
return *operands_[idx];
}

Instruction *mutable_operand(std::int64_t idx) const {
Instruction *mutable_operand(std::int64_t idx) {
PADDLE_ENFORCE_EQ(
idx >= 0 && idx < static_cast<std::int64_t>(operands_.size()), true,
platform::errors::PreconditionNotMet("Invalid index value %ld. Its "
"value should between 0(include) "
"and %zu(exclude).",
idx, operands_.size()));
PADDLE_ENFORCE_NOT_NULL(operands_[idx],
platform::errors::PreconditionNotMet(
"operand %ld should not be null.", idx));
return operands_.at(idx);
return operands_[idx];
}

// return the control predecessors of this instruction
const std::vector<Instruction *> &ctrl_predecessors() const {
return ctrl_predecessors_;
}

// return the control successors of this instruction
const std::vector<Instruction *> &ctrl_successors() const {
return ctrl_successors_;
}
Expand All @@ -101,7 +132,11 @@ class Instruction {
return call_functions_;
}

std::int64_t parameter_number() const { return parameter_number_; }
// return the input index of this instruction
std::int64_t parameter_index() const { return parameter_index_; }

// only the Parameter instruction has a valid parameter index
bool valid_parameter_index() const { return parameter_index_ != -1; }

// return attributes of this instruction
const MapType &attrs() const { return attrs_; }
Expand Down Expand Up @@ -148,8 +183,8 @@ class Instruction {
std::vector<Instruction *> ctrl_successors_;
// functions called directly by this instruction
std::vector<Function *> call_functions_;
// the parameter number of this instruction
std::int64_t parameter_number_;
// the input index of this instruction
std::int64_t parameter_index_{-1};
// attributes belongs to this instruction
MapType attrs_;
// the function where this instruction is contained
Expand Down
Loading