Skip to content

Commit

Permalink
initial meta op interface and build implement (PaddlePaddle#14)
Browse files Browse the repository at this point in the history
* initial meta op interface and build implement

* clear comment code

* add kBrocast api and implement

* add symbolization subdirectory and move releated files

* meta op implement and basic shape inference

* add unit test for meta_op and shape_inference

* add symbolization namespace under paddle::piano

* update the usage of NoteBuilder and Operand on paddle2piano

* fix compile error

* remove remainder temporarily

* fix compile

* fix compile

* enhance annotations for critical codes
  • Loading branch information
CtfGo committed Sep 1, 2021
1 parent 963ef8f commit 5cd1723
Show file tree
Hide file tree
Showing 20 changed files with 638 additions and 35 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/compiler/paddle2piano/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ cc_test(piano_compile_pass_test SRCS piano_compile_pass_tester.cc DEPS piano_com
cc_library(piano_op_registry SRCS piano_op_registry.cc DEPS framework_proto op_registry note_proto piano_data_description)
cc_test(piano_op_registry_test SRCS piano_op_registry_test.cc DEPS piano_op_registry operator op_registry)

cc_library(piano_op_kernel_context SRCS piano_op_kernel_context.cc DEPS piano_op_registry proto_desc note_builder)
cc_library(piano_op_kernel_context SRCS piano_op_kernel_context.cc DEPS piano_op_registry proto_desc piano_symbolization_builder)
cc_test(piano_op_kernel_context_test SRCS piano_op_kernel_context_test.cc DEPS piano_op_kernel_context op_registry)
5 changes: 3 additions & 2 deletions paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ framework::Attribute PianoOpKernelContext::GetAttr(
return op_->GetAttr(name);
}

Operand PianoOpKernelContext::GetInput(const std::string& name) const {
symbolization::Operand PianoOpKernelContext::GetInput(
const std::string& name) const {
PADDLE_ENFORCE_EQ(
HasInput(name), true,
platform::errors::NotFound("Input %s is not found in op %s.",
Expand All @@ -53,7 +54,7 @@ Operand PianoOpKernelContext::GetInput(const std::string& name) const {
}

void PianoOpKernelContext::SetOutput(const std::string& name,
const Operand& op) const {
const symbolization::Operand& op) const {
PADDLE_ENFORCE_EQ(
op_->HasOutput(name), true,
platform::errors::NotFound("Output %s is not found in op %s.",
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */

#include "paddle/fluid/compiler/paddle2piano/piano_op_registry.h"
#include "paddle/fluid/compiler/paddle2piano/piano_scope.h"
#include "paddle/fluid/compiler/piano/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/framework/op_desc.h"

namespace paddle {
Expand All @@ -30,28 +30,29 @@ namespace piano {
// lower level of piano note IR.
// "OpDesc" is the operator information.
// "PianoScope" is an association of a name to operand.
// "builder" is the operand's NoteBuilder.
// "builder" is the operand's symbolization::NoteBuilder.
class PianoOpKernelContext {
public:
PianoOpKernelContext(const framework::OpDesc* op_desc, PianoScope* scope,
NoteBuilder* builder)
symbolization::NoteBuilder* builder)
: op_(op_desc), scope_(scope), builder_(builder) {}

// cannot returning reference to temporary
std::string Type() const { return op_->Type(); }

NoteBuilder* Builder() const { return builder_; }
symbolization::NoteBuilder* Builder() const { return builder_; }

bool HasInput(const std::string& name) const {
return op_->Inputs().find(name) != op_->Inputs().end();
}

Operand GetInput(const std::string& name) const;
symbolization::Operand GetInput(const std::string& name) const;

// Map the outputs's operand into scope, the operand is created by
// NoteBuilder, and be careful the output name must existed in op's
// outputs.
void SetOutput(const std::string& name, const Operand& op) const;
// symbolization::NoteBuilder, and be careful the output
// name must existed in op's outputs.
void SetOutput(const std::string& name,
const symbolization::Operand& op) const;

const std::unordered_set<note::ElementTypeProto>& DataTypes() const {
return PianoOpRegistry::PianoOpDataTypes(Type());
Expand All @@ -72,7 +73,7 @@ class PianoOpKernelContext {
private:
const framework::OpDesc* op_;
mutable PianoScope* scope_;
NoteBuilder* builder_;
symbolization::NoteBuilder* builder_;
};

} // namespace piano
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/compiler/paddle2piano/piano_op_kernel.h"
#include "paddle/fluid/compiler/paddle2piano/piano_op_registry.h"
#include "paddle/fluid/compiler/paddle2piano/piano_scope.h"
#include "paddle/fluid/compiler/piano/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
Expand Down Expand Up @@ -114,7 +114,7 @@ TEST(PianoContextTest, scope) {

// add operand
std::string name1 = "op1";
Operand op1;
symbolization::Operand op1;
scope.SetOperand(name1, op1);

ASSERT_TRUE(scope.HasLocalOperand(name1));
Expand All @@ -130,7 +130,7 @@ TEST(PianoContextTest, scope) {
ASSERT_FALSE(scope.HasKid(tmp_scope.get()));

std::string name2 = "op2";
Operand op2;
symbolization::Operand op2;
tmp_scope->SetOperand(name2, op2);

ASSERT_FALSE(scope.HasLocalOperand(name2));
Expand Down Expand Up @@ -165,10 +165,10 @@ TEST(PianoContextTest, basic) {

// create scope and NoteBuilder
PianoScope scope;
Operand op_x;
symbolization::Operand op_x;
scope.SetOperand("X", op_x);

NoteBuilder builder("test_expand");
symbolization::NoteBuilder builder("test_expand");

// create PianoOpKernelContext
PianoOpKernelContext context(op, &scope, &builder);
Expand All @@ -186,7 +186,7 @@ TEST(PianoContextTest, basic) {
ASSERT_ANY_THROW(ctx.GetInput("Y"));

// test output
Operand op_out;
symbolization::Operand op_out;
ASSERT_NO_THROW(ctx.SetOutput("Out", op_out));
ASSERT_ANY_THROW(ctx.SetOutput("Y", op_out));

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/compiler/paddle2piano/piano_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
Expand All @@ -41,7 +41,7 @@ class PianoOpRegistry final {
// `name` is the backend name. `supported_types` is data type list,
// this backend can only accept the data type in list. `filter_func` is
// a function, return false if the backend refuse this op.
using BackendFilterFunc = bool (*)(Operand*);
using BackendFilterFunc = bool (*)(symbolization::Operand*);
static void RegisterBackend(
const std::string& backend_name,
const std::unordered_set<note::ElementTypeProto>& supported_types,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ std::unordered_set<note::ElementTypeProto> TestDatatypes() {
return supported_types;
}

bool TestFilterFunc(Operand* op) {
bool TestFilterFunc(symbolization::Operand* op) {
// TODO(jiangcheng05) : fill some change of Operand
return true;
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/compiler/paddle2piano/piano_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>

#include "paddle/fluid/compiler/piano/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
Expand Down Expand Up @@ -84,7 +84,7 @@ class PianoScope {
}

// return the operand in local scope if founded.
Operand GetLocalOperand(const std::string& name) const {
symbolization::Operand GetLocalOperand(const std::string& name) const {
PADDLE_ENFORCE_EQ(
HasLocalOperand(name), true,
platform::errors::NotFound("Operand %s not founded in scope %p",
Expand All @@ -93,7 +93,7 @@ class PianoScope {
}

// return the operand in local scope or its ancestor scope if founded
Operand GetOperand(const std::string& name) const {
symbolization::Operand GetOperand(const std::string& name) const {
if (HasLocalOperand(name)) {
return operands_.at(name);
}
Expand All @@ -104,7 +104,7 @@ class PianoScope {
}

// insert the operand into local scope
void SetOperand(const std::string& name, const Operand& op) {
void SetOperand(const std::string& name, const symbolization::Operand& op) {
PADDLE_ENFORCE_EQ(HasOperand(name), false,
platform::errors::AlreadyExists(
"Operand %s already existed in scope %p.",
Expand All @@ -125,7 +125,7 @@ class PianoScope {

DISABLE_COPY_AND_ASSIGN(PianoScope);

std::unordered_map<std::string, Operand> operands_;
std::unordered_map<std::string, symbolization::Operand> operands_;

const PianoScope* parent_;
mutable std::vector<std::unique_ptr<PianoScope>> kids_;
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/compiler/piano/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
add_subdirectory(backends)
add_subdirectory(note)
add_subdirectory(symbolization)

cc_library(piano_data_description SRCS layout.cc shape.cc DEPS string_helper note_proto)
cc_test(piano_layout_test SRCS layout_test.cc DEPS piano_data_description)
cc_test(piano_shape_test SRCS shape_test.cc DEPS piano_data_description)

cc_library(note_builder SRCS note_builder.cc DEPS string_helper note_opcode piano_data_description)
cc_test(note_builder_test SRCS note_builder_test.cc DEPS note_builder)
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ target_link_libraries(nvptx_ir_emitter_test ${LLVM_LIBS})
cc_library(llvm_compiler SRCS llvm_compiler.cc DEPS nvptx_ir_emitter nvptx_primitive_ir_emitter note_ir llvm)
cc_library(nvptx_compiler SRCS nvptx_compiler.cc DEPS llvm_compiler llvm dynload_cuda)
cc_test(nvptx_compiler_test SRCS nvptx_compiler_test.cc
DEPS nvptx_compiler note_ir note_builder)
DEPS nvptx_compiler note_ir piano_symbolization_builder)
target_link_libraries(nvptx_compiler_test ${LLVM_LIBS})
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/note/opcode.h"
#include "paddle/fluid/compiler/piano/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
Expand Down Expand Up @@ -95,7 +95,7 @@ TEST(NvptxCompiler, Apply) {
platform::SetDeviceId(0);

// note builder
NoteBuilder note_builder("test_note_builder");
symbolization::NoteBuilder note_builder("test_note_builder");
{
note::InstructionProto a_proto, b_proto, c_proto;
a_proto.set_name("A");
Expand All @@ -120,7 +120,7 @@ TEST(NvptxCompiler, Apply) {
c_shape->add_dimensions(32);

// build note module
std::vector<Operand> ops;
std::vector<symbolization::Operand> ops;
ops.push_back(note_builder.AppendInstruction(std::move(a_proto),
note::OpCode::kConstant, {}));
ops.push_back(note_builder.AppendInstruction(std::move(b_proto),
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/compiler/piano/symbolization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cc_library(piano_symbolization_builder SRCS note_builder.cc DEPS string_helper note_opcode piano_data_description)
cc_test(symbolization_note_builder_test SRCS note_builder_test.cc DEPS piano_symbolization_builder)

cc_library(piano_symbolization_meat_op SRCS meta_op.cc shape_inference.cc DEPS note_opcode piano_symbolization_builder note_template_util)
cc_test(symbolization_meta_op_test SRCS meta_op_test.cc DEPS piano_symbolization_meat_op)
cc_test(symbolization_shape_inference_test SRCS shape_inference_test.cc DEPS piano_symbolization_meat_op)
139 changes: 139 additions & 0 deletions paddle/fluid/compiler/piano/symbolization/meta_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/symbolization/meta_op.h"
#include <algorithm>
#include <numeric>
#include <utility>
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/symbolization/shape_inference.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {
namespace symbolization {

Operand Parameter(NoteBuilder* builder, int64_t parameter_index,
const Shape& shape, const std::string& name) {
PADDLE_ENFORCE_GE(parameter_index, 0, platform::errors::InvalidArgument(
"Parameter_index should be >= 0"));
note::InstructionProto instr;
instr.set_parameter_index(parameter_index);
instr.set_name(name);
*instr.mutable_shape() = shape.ToProto();
return builder->AppendInstruction(std::move(instr), note::OpCode::kParameter,
{});
}

Operand Broadcast(Operand x, const std::vector<int64_t>& out_dimensions,
const std::vector<int64_t>& dimensions_alignment) {
// generate a default alignment for numpy's like broadcast operation
std::vector<int64_t> to_right_alignment;
if (dimensions_alignment.empty()) {
PADDLE_ENFORCE_LE(x.Shape().Rank(), out_dimensions.size(),
platform::errors::InvalidArgument(
"Rank of operand should be less than output"));
to_right_alignment.resize(x.Shape().Rank());
std::iota(to_right_alignment.begin(), to_right_alignment.end(), 0);
auto gap_len = out_dimensions.size() - x.Shape().Rank();
// original operand is aligned to the rightmost of out_dimensions
std::transform(to_right_alignment.begin(), to_right_alignment.end(),
to_right_alignment.begin(),
[gap_len](const auto& x) { return x + gap_len; });
}

const auto& alignment_array =
dimensions_alignment.empty() ? to_right_alignment : dimensions_alignment;
auto&& result_shape =
InferBroadcastShape(x.Shape(), out_dimensions, alignment_array);

note::InstructionProto instr;
*instr.mutable_shape() = result_shape.ToProto();
// fill the alignment array to kBroadcast attribute
auto* attrs_map = instr.mutable_attrs();
note::AttrValueProto attr_value;
note::PopulateAttrValueProto(alignment_array, &attr_value);
(*attrs_map)[note::kBroadcastAlignment] = attr_value;
return x.Builder()->AppendInstruction(std::move(instr),
note::OpCode::kBroadcast, {x});
}

Operand UnaryOp(note::OpCode unop, Operand x) {
note::InstructionProto instr;
auto&& shape = InferUnaryOpShape(unop, x.Shape());
*instr.mutable_shape() = shape.ToProto();
return x.Builder()->AppendInstruction(std::move(instr), unop, {x});
}

Operand operator-(Operand x) { return Neg(x); }
Operand operator~(Operand x) { return Not(x); }
Operand Neg(Operand x) { return UnaryOp(note::OpCode::kNegative, x); }
Operand Not(Operand x) { return UnaryOp(note::OpCode::kNot, x); }

Operand BinaryOp(note::OpCode binop, Operand x, Operand y) {
// add broadcast if shape of the operands are not same
x = x.Shape().Rank() < y.Shape().Rank() ? Broadcast(x, y.Shape().dimensions())
: x;
y = y.Shape().Rank() < x.Shape().Rank() ? Broadcast(y, x.Shape().dimensions())
: y;
// ensure shape are equal
PADDLE_ENFORCE_EQ(x.Shape(), y.Shape(),
platform::errors::InvalidArgument(
"Shape of operands should be euqal on Binary Op"));

note::InstructionProto instr;
auto&& shape = InferBinaryOpShape(binop, x.Shape(), y.Shape());
*instr.mutable_shape() = shape.ToProto();
return x.Builder()->AppendInstruction(std::move(instr), binop, {x, y});
}

Operand operator+(Operand x, Operand y) { return Add(x, y); }
Operand operator-(Operand x, Operand y) { return Sub(x, y); }
Operand operator*(Operand x, Operand y) { return Mul(x, y); }
Operand operator/(Operand x, Operand y) { return Div(x, y); }
Operand operator&(Operand x, Operand y) { return And(x, y); }
Operand operator|(Operand x, Operand y) { return Or(x, y); }
Operand operator^(Operand x, Operand y) { return Xor(x, y); }

Operand Add(Operand x, Operand y) { return BinaryOp(note::OpCode::kAdd, x, y); }

Operand Sub(Operand x, Operand y) {
return BinaryOp(note::OpCode::kSubtract, x, y);
}

Operand Mul(Operand x, Operand y) {
return BinaryOp(note::OpCode::kMultiply, x, y);
}

Operand Div(Operand x, Operand y) {
return BinaryOp(note::OpCode::kDivide, x, y);
}

Operand Max(Operand x, Operand y) {
return BinaryOp(note::OpCode::kMaximum, x, y);
}

Operand Min(Operand x, Operand y) {
return BinaryOp(note::OpCode::kMinimum, x, y);
}

Operand And(Operand x, Operand y) { return BinaryOp(note::OpCode::kAnd, x, y); }

Operand Or(Operand x, Operand y) { return BinaryOp(note::OpCode::kOr, x, y); }

Operand Xor(Operand x, Operand y) { return BinaryOp(note::OpCode::kXor, x, y); }

} // namespace symbolization
} // namespace piano
} // namespace paddle
Loading

0 comments on commit 5cd1723

Please sign in to comment.