From 3cb6f65e23474346cf2cecc9eea7d473934356e8 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Sat, 23 Oct 2021 11:37:42 +0800 Subject: [PATCH] Add transformer of paddle desc and cinn desc (#36100) * add transformer of paddle desc and cinn desc * change LOG(FATAL) to PADDLE_THROW for ci * full error imformation for ci * fix some problem as review advice * fix some bug * move vat type utils to tansform_desc header file * add if NOT WITH_CINN control whether compile * build_strategy check whether open WITH_CINN * add control WITH_CINN in cmake --- .../framework/paddle2cinn/CMakeLists.txt | 5 + .../framework/paddle2cinn/transform_desc.cc | 348 ++++++++++++++++++ .../framework/paddle2cinn/transform_desc.h | 79 ++++ .../paddle2cinn/transform_desc_test.cc | 236 ++++++++++++ 4 files changed, 668 insertions(+) create mode 100644 paddle/fluid/framework/paddle2cinn/transform_desc.cc create mode 100644 paddle/fluid/framework/paddle2cinn/transform_desc.h create mode 100644 paddle/fluid/framework/paddle2cinn/transform_desc_test.cc diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 4a65333217727..d1c17c7a70953 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -3,6 +3,11 @@ cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_met cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope) cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector) +if (WITH_CINN) + cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) + cc_test(test_transform_desc SRCS transform_desc_test.cc DEPS transform_desc) +endif() + cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc) cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object) diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.cc b/paddle/fluid/framework/paddle2cinn/transform_desc.cc new file mode 100644 index 0000000000000..52b1395c732ac --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.cc @@ -0,0 +1,348 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/paddle2cinn/transform_desc.h" + +#include +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using PbVarType = framework::proto::VarType; +namespace cpp = ::cinn::frontend::paddle::cpp; + +::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarTypeToCinn( + const ::paddle::framework::proto::VarType::Type &type) { +#define SET_TYPE_CASE_ITEM(type__) \ + case ::paddle::framework::proto::VarType::type__: \ + return ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__; \ + break; + + switch (type) { + SET_TYPE_CASE_ITEM(LOD_TENSOR); + SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY); + SET_TYPE_CASE_ITEM(LOD_RANK_TABLE); + SET_TYPE_CASE_ITEM(SELECTED_ROWS); + SET_TYPE_CASE_ITEM(FEED_MINIBATCH); + SET_TYPE_CASE_ITEM(FETCH_LIST); + SET_TYPE_CASE_ITEM(STEP_SCOPES); + SET_TYPE_CASE_ITEM(PLACE_LIST); + SET_TYPE_CASE_ITEM(READER); + default: + PADDLE_THROW(platform::errors::NotFound("Cannot found var type")); + } +#undef SET_TYPE_CASE_ITEM +} + +::paddle::framework::proto::VarType::Type TransformVarTypeFromCinn( + const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) { +#define SET_TYPE_CASE_ITEM(type__) \ + case ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__: \ + return ::paddle::framework::proto::VarType::type__; \ + break; + + switch (type) { + SET_TYPE_CASE_ITEM(LOD_TENSOR); + SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY); + SET_TYPE_CASE_ITEM(LOD_RANK_TABLE); + SET_TYPE_CASE_ITEM(SELECTED_ROWS); + SET_TYPE_CASE_ITEM(FEED_MINIBATCH); + SET_TYPE_CASE_ITEM(FETCH_LIST); + SET_TYPE_CASE_ITEM(STEP_SCOPES); + SET_TYPE_CASE_ITEM(PLACE_LIST); + SET_TYPE_CASE_ITEM(READER); + default: + PADDLE_THROW(platform::errors::NotFound("Cannot found var type")); + } +#undef SET_TYPE_CASE_ITEM +} + +::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( + const ::paddle::framework::proto::VarType::Type &type) { +#define SET_DATA_TYPE_CASE_ITEM(type__) \ + case ::paddle::framework::proto::VarType::type__: \ + return ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__; \ + break; + + switch (type) { + SET_DATA_TYPE_CASE_ITEM(BOOL); + SET_DATA_TYPE_CASE_ITEM(SIZE_T); + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP16); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + default: + PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); + } +#undef SET_DATA_TYPE_CASE_ITEM +} + +::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp( + const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) { +#define SET_DATA_TYPE_CASE_ITEM(type__) \ + case ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__: \ + return ::paddle::framework::proto::VarType::type__; \ + break; + + switch (type) { + SET_DATA_TYPE_CASE_ITEM(BOOL); + SET_DATA_TYPE_CASE_ITEM(SIZE_T); + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP16); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + default: + PADDLE_THROW(platform::errors::NotFound("Cannot found var data type")); + } +#undef SET_DATA_TYPE_CASE_ITEM +} + +void TransformVarDescToCinn(framework::VarDesc *pb_desc, + cpp::VarDesc *cpp_desc) { + cpp_desc->SetName(pb_desc->Name()); + cpp_desc->SetType(TransformVarTypeToCinn(pb_desc->GetType())); + cpp_desc->SetPersistable(pb_desc->Persistable()); + if (pb_desc->Name() != "feed" && pb_desc->Name() != "fetch") { + cpp_desc->SetDataType(TransformVarDataTypeToCinn(pb_desc->GetDataType())); + cpp_desc->SetShape(pb_desc->GetShape()); + } +} + +void TransformVarDescFromCinn(const cpp::VarDesc &cpp_desc, + framework::VarDesc *pb_desc) { + pb_desc->Proto()->Clear(); + pb_desc->SetName(cpp_desc.Name()); + pb_desc->SetType(TransformVarTypeFromCinn(cpp_desc.GetType())); + pb_desc->SetPersistable(cpp_desc.Persistable()); + if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { + pb_desc->SetShape(cpp_desc.GetShape()); + pb_desc->SetDataType(TransformVarDataTypeFromCpp(cpp_desc.GetDataType())); + } +} + +/// For OpDesc transform +void OpInputsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc->InputNames()) { + cpp_desc->SetInput(param, pb_desc->Input(param)); + } +} + +void OpInputsFromCinn(const cpp::OpDesc &cpp_desc, framework::OpDesc *pb_desc) { + pb_desc->MutableInputs()->clear(); + for (const std::string ¶m : cpp_desc.InputArgumentNames()) { + pb_desc->SetInput(param, cpp_desc.Input(param)); + } +} + +void OpOutputsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc->OutputNames()) { + cpp_desc->SetOutput(param, pb_desc->Output(param)); + } +} + +void OpOutputsFromCinn(const cpp::OpDesc &cpp_desc, + framework::OpDesc *pb_desc) { + pb_desc->MutableOutputs()->clear(); + for (const std::string ¶m : cpp_desc.OutputArgumentNames()) { + pb_desc->SetOutput(param, cpp_desc.Output(param)); + } +} + +void OpAttrsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) { + using AttrType = framework::proto::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { +#define IMPL_ONE(type__, T) \ + case AttrType::type__: \ + cpp_desc->SetAttr(name, pb_desc->GetAttrIfExists(name)); \ + break; + IMPL_ONE(INT, int32_t); + IMPL_ONE(FLOAT, float); + IMPL_ONE(STRING, std::string); + IMPL_ONE(STRINGS, std::vector); + IMPL_ONE(FLOATS, std::vector); + IMPL_ONE(INTS, std::vector); + IMPL_ONE(BOOLEAN, bool); + IMPL_ONE(LONG, int64_t); + IMPL_ONE(LONGS, std::vector); + case AttrType::BLOCK: { + auto i = pb_desc->GetAttrIfExists(name); + cpp_desc->SetAttr(name, i); + break; + } + default: + PADDLE_THROW(platform::errors::NotFound( + "Unsupported attr type %d found ", static_cast(type))); + } + }; +#undef IMPL_ONE + + for (const auto &attr_name : pb_desc->AttrNames()) { + auto type = pb_desc->GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void OpAttrsFromCinn(const cpp::OpDesc &cpp_desc, framework::OpDesc *pb_desc) { + pb_desc->MutableAttrMap()->clear(); + using AttrType = cpp::OpDescAPI::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { +#define IMPL_ONE(type__, T) \ + case AttrType::type__: \ + pb_desc->SetAttr(name, cpp_desc.GetAttr(name)); \ + break; + IMPL_ONE(INT, int32_t); + IMPL_ONE(FLOAT, float); + IMPL_ONE(STRING, std::string); + IMPL_ONE(STRINGS, std::vector); + IMPL_ONE(FLOATS, std::vector); + IMPL_ONE(INTS, std::vector); + IMPL_ONE(BOOLEAN, bool); + IMPL_ONE(LONG, int64_t); + IMPL_ONE(LONGS, std::vector); + default: + PADDLE_THROW(platform::errors::NotFound( + "Unsupported attr type %d found ", static_cast(type))); + } + }; +#undef IMPL_ONE + + for (const auto &attr_name : cpp_desc.AttrNames()) { + auto type = cpp_desc.GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void TransformOpDescToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) { + cpp_desc->SetType(pb_desc->Type()); + OpInputsToCinn(pb_desc, cpp_desc); + OpOutputsToCinn(pb_desc, cpp_desc); + OpAttrsToCinn(pb_desc, cpp_desc); +} + +void TransformOpDescFromCinn(const cpp::OpDesc &cpp_desc, + framework::OpDesc *pb_desc) { + pb_desc->Proto()->Clear(); + pb_desc->SetType(cpp_desc.Type()); + OpInputsFromCinn(cpp_desc, pb_desc); + OpOutputsFromCinn(cpp_desc, pb_desc); + OpAttrsFromCinn(cpp_desc, pb_desc); +} + +/// For BlockDesc transform +void TransformBlockDescToCinn(framework::BlockDesc *pb_desc, + cpp::BlockDesc *cpp_desc) { + cpp_desc->SetIdx(pb_desc->ID()); + cpp_desc->SetParentIdx(pb_desc->Parent()); + cpp_desc->SetForwardBlockIdx(pb_desc->ForwardBlockID()); + + cpp_desc->ClearOps(); + const auto &all_ops = pb_desc->AllOps(); + for (const auto &op : all_ops) { + auto *cpp_op_desc = cpp_desc->AddOp(); + TransformOpDescToCinn(op, cpp_op_desc); + } + + cpp_desc->ClearVars(); + const auto &all_vars = pb_desc->AllVars(); + for (const auto &var : all_vars) { + auto *cpp_var_desc = cpp_desc->AddVar(); + TransformVarDescToCinn(var, cpp_var_desc); + } +} + +void TransformBlockDescFromCinn(const cpp::BlockDesc &cpp_desc, + framework::BlockDesc *pb_desc) { + pb_desc->Proto()->Clear(); + + pb_desc->Proto()->set_idx(cpp_desc.Idx()); + pb_desc->Proto()->set_parent_idx(cpp_desc.ParentIdx()); + pb_desc->Proto()->set_forward_block_idx(cpp_desc.ForwardBlockIdx()); + + for (size_t i = 0; i < cpp_desc.OpsSize(); ++i) { + const auto &cpp_op_desc = + cpp_desc.template GetConstOp(static_cast(i)); + auto *pb_op_desc = pb_desc->AppendOp(); + TransformOpDescFromCinn(cpp_op_desc, pb_op_desc); + } + + for (size_t i = 0; i < cpp_desc.VarsSize(); ++i) { + const auto &cpp_var_desc = + cpp_desc.template GetConstVar(static_cast(i)); + auto *pb_var_desc = pb_desc->Var(cpp_var_desc.Name()); + TransformVarDescFromCinn(cpp_var_desc, pb_var_desc); + } +} + +/// For ProgramDesc transform +void TransformProgramDescToCinn(framework::ProgramDesc *pb_desc, + cpp::ProgramDesc *cpp_desc) { + if (pb_desc->Proto()->version().has_version()) { + cpp_desc->SetVersion(pb_desc->Version()); + } + + cpp_desc->ClearBlocks(); + for (size_t i = 0; i < pb_desc->Size(); ++i) { + auto *pb_block_desc = pb_desc->MutableBlock(i); + auto *cpp_block_desc = cpp_desc->AddBlock(); + TransformBlockDescToCinn(pb_block_desc, cpp_block_desc); + } +} + +void TransformProgramDescFromCinn(const cpp::ProgramDesc &cpp_desc, + framework::ProgramDesc *pb_desc) { + pb_desc->Proto()->Clear(); + + if (cpp_desc.HasVersion()) { + pb_desc->SetVersion(cpp_desc.Version()); + } + + // For paddle proto program, the only way to add block is invoke + // AppendBlock(), + // the AppendBlock need one necessary parameter: const BlockDesc &parent, + // but the only function of parent is set the block's parent_idx value. + // Meanwhile a program has at least one block, so we set block0 to all + // sub-block's parent in initial and cannot remove. + // Don't worry, it will be change in "TransformBlockDescFromCinn". + auto *block0 = pb_desc->MutableBlock(0); + + for (size_t i = 0; i < cpp_desc.BlocksSize(); ++i) { + const auto &cpp_block_desc = cpp_desc.GetConstBlock(i); + framework::BlockDesc *pb_block_desc = nullptr; + if (i < pb_desc->Size()) { + pb_block_desc = pb_desc->MutableBlock(i); + } else { + pb_block_desc = pb_desc->AppendBlock(*block0); + } + TransformBlockDescFromCinn(cpp_block_desc, pb_block_desc); + } +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc.h b/paddle/fluid/framework/paddle2cinn/transform_desc.h new file mode 100644 index 0000000000000..76a4f812730df --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_desc.h @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" + +#include "cinn/frontend/paddle/cpp/block_desc.h" +#include "cinn/frontend/paddle/cpp/desc_api.h" +#include "cinn/frontend/paddle/cpp/op_desc.h" +#include "cinn/frontend/paddle/cpp/program_desc.h" +#include "cinn/frontend/paddle/cpp/var_desc.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarTypeToCinn( + const ::paddle::framework::proto::VarType::Type& type); + +::paddle::framework::proto::VarType::Type TransformVarTypeFromCinn( + const ::cinn::frontend::paddle::cpp::VarDescAPI::Type& type); + +::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn( + const ::paddle::framework::proto::VarType::Type& type); + +::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp( + const ::cinn::frontend::paddle::cpp::VarDescAPI::Type& type); + +// Why use framework::VarDesc* rather than const framework::VarDesc& here? +// framework::VarDesc lack of many API like clear(), etc. On the other hand, +// the paddle node return framework::Desc* even if the node is const +void TransformVarDescToCinn(framework::VarDesc* pb_desc, + ::cinn::frontend::paddle::cpp::VarDesc* cpp_desc); + +void TransformVarDescFromCinn( + const ::cinn::frontend::paddle::cpp::VarDesc& cpp_desc, + framework::VarDesc* pb_desc); + +void TransformOpDescToCinn(framework::OpDesc* pb_desc, + ::cinn::frontend::paddle::cpp::OpDesc* cpp_desc); + +void TransformOpDescFromCinn( + const ::cinn::frontend::paddle::cpp::OpDesc& cpp_desc, + framework::OpDesc* pb_desc); + +void TransformBlockDescToCinn( + framework::BlockDesc* pb_desc, + ::cinn::frontend::paddle::cpp::BlockDesc* cpp_desc); + +void TransformBlockDescFromCinn( + const ::cinn::frontend::paddle::cpp::BlockDesc& cpp_desc, + framework::BlockDesc* pb_desc); + +void TransformProgramDescToCinn( + framework::ProgramDesc* pb_desc, + ::cinn::frontend::paddle::cpp::ProgramDesc* cpp_desc); + +void TransformProgramDescFromCinn( + const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc, + framework::ProgramDesc* pb_desc); + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc b/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc new file mode 100644 index 0000000000000..ba324295cad72 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_desc_test.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/paddle2cinn/transform_desc.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using PbVarType = framework::proto::VarType; +namespace cpp = ::cinn::frontend::paddle::cpp; + +// check VarDesc +cpp::VarDesc CreateCppVarDesc() { + cpp::VarDesc var("test"); + var.SetType(cpp::VarDescAPI::Type::LOD_TENSOR); + var.SetPersistable(true); + var.SetDataType(cpp::VarDescAPI::Type::FP32); + var.SetShape({100, 200, 300}); + return var; +} + +framework::VarDesc CreatePbVarDesc() { + framework::VarDesc var("test"); + var.SetType(PbVarType::LOD_TENSOR); + var.SetPersistable(true); + var.SetDataType(PbVarType::FP32); + var.SetShape({100, 200, 300}); + return var; +} + +TEST(TransformVarDesc, cpp2pb) { + auto cpp_var = CreateCppVarDesc(); + framework::VarDesc pb_var("init"); + TransformVarDescFromCinn(cpp_var, &pb_var); + + auto correct_var = CreatePbVarDesc(); + ASSERT_EQ(pb_var.Name(), correct_var.Name()); + ASSERT_EQ(pb_var.GetType(), correct_var.GetType()); + ASSERT_EQ(pb_var.Persistable(), correct_var.Persistable()); + ASSERT_EQ(pb_var.GetDataType(), correct_var.GetDataType()); + ASSERT_EQ(pb_var.GetShape(), correct_var.GetShape()); +} + +TEST(TransformVarDesc, pb2cpp) { + auto pb_var = CreatePbVarDesc(); + cpp::VarDesc cpp_var; + TransformVarDescToCinn(&pb_var, &cpp_var); + + auto correct_var = CreateCppVarDesc(); + ASSERT_EQ(cpp_var.Name(), correct_var.Name()); + ASSERT_EQ(cpp_var.GetType(), correct_var.GetType()); + ASSERT_EQ(cpp_var.Persistable(), correct_var.Persistable()); + ASSERT_EQ(cpp_var.GetDataType(), correct_var.GetDataType()); + ASSERT_EQ(cpp_var.GetShape(), correct_var.GetShape()); +} + +// check OpDesc +cpp::OpDesc CreateCppOpDesc() { + cpp::OpDesc op; + op.SetType("test"); + op.SetInput("X", {"x1"}); + op.SetInput("Y", {"y1", "y2"}); + op.SetOutput("Out", {"out1"}); + op.SetAttr("attr_f", 0.1f); + op.SetAttr("attr_str", "test_attr"); + return op; +} + +framework::OpDesc CreatePbOpDesc() { + framework::OpDesc op; + op.SetType("test"); + op.SetInput("X", {"x1"}); + op.SetInput("Y", {"y1", "y2"}); + op.SetOutput("Out", {"out1"}); + op.SetAttr("attr_f", 0.1f); + op.SetAttr("attr_str", std::string("test_attr")); + return op; +} + +TEST(TransformOpDesc, cpp2pb) { + auto cpp_op = CreateCppOpDesc(); + framework::OpDesc pb_op; + TransformOpDescFromCinn(cpp_op, &pb_op); + + auto correct_op = CreatePbOpDesc(); + ASSERT_EQ(pb_op.Type(), correct_op.Type()); + ASSERT_EQ(pb_op.Inputs(), correct_op.Inputs()); + ASSERT_EQ(pb_op.Outputs(), correct_op.Outputs()); + ASSERT_EQ(pb_op.AttrNames(), correct_op.AttrNames()); + + for (const auto &attr_name : pb_op.AttrNames()) { + ASSERT_EQ(pb_op.GetAttrType(attr_name), correct_op.GetAttrType(attr_name)); + } + ASSERT_EQ(pb_op.GetAttrIfExists("attr_f"), + correct_op.GetAttrIfExists("attr_f")); + ASSERT_EQ(pb_op.GetAttrIfExists("attr_str"), + correct_op.GetAttrIfExists("attr_str")); +} + +TEST(TransformOpDesc, pb2cpp) { + auto pb_op = CreatePbOpDesc(); + cpp::OpDesc cpp_op; + TransformOpDescToCinn(&pb_op, &cpp_op); + + auto correct_op = CreateCppOpDesc(); + ASSERT_EQ(cpp_op.Type(), correct_op.Type()); + ASSERT_EQ(cpp_op.inputs(), correct_op.inputs()); + ASSERT_EQ(cpp_op.outputs(), correct_op.outputs()); + ASSERT_EQ(cpp_op.AttrNames(), correct_op.AttrNames()); + ASSERT_EQ(cpp_op.attr_types(), correct_op.attr_types()); + + ASSERT_EQ(cpp_op.GetAttr("attr_f"), + correct_op.GetAttr("attr_f")); + ASSERT_EQ(cpp_op.GetAttr("attr_str"), + correct_op.GetAttr("attr_str")); +} + +// check BlockDesc +// framework::BlockDesc is DISABLE_COPY_AND_ASSIGN, so can not return +void CreateCppBlockDesc(cpp::BlockDesc *block) { + block->SetIdx(42); + block->SetParentIdx(4); + block->SetForwardBlockIdx(32); + + auto *op = block->AddOp(); + *op = CreateCppOpDesc(); + + auto *var = block->AddVar(); + *var = CreateCppVarDesc(); +} + +void CreatePbBlockDesc(framework::BlockDesc *block) { + block->Proto()->set_idx(42); + block->Proto()->set_parent_idx(4); + block->Proto()->set_forward_block_idx(32); + + auto *op = block->AppendOp(); + *op = CreatePbOpDesc(); + + auto *var = block->Var("init"); + *var = CreatePbVarDesc(); +} + +TEST(TransformBlockDesc, cpp2pb) { + cpp::BlockDesc cpp_block; + CreateCppBlockDesc(&cpp_block); + + framework::ProgramDesc pb_prog; + auto *pb_block = pb_prog.MutableBlock(0); + TransformBlockDescFromCinn(cpp_block, pb_block); + + framework::ProgramDesc correct_prog; + auto *correct_block = correct_prog.MutableBlock(0); + CreatePbBlockDesc(correct_block); + ASSERT_EQ(pb_block->ID(), correct_block->ID()); + ASSERT_EQ(pb_block->Parent(), correct_block->Parent()); + ASSERT_EQ(pb_block->ForwardBlockID(), correct_block->ForwardBlockID()); + ASSERT_EQ(pb_block->OpSize(), correct_block->OpSize()); + ASSERT_EQ(pb_block->AllVars().size(), correct_block->AllVars().size()); +} + +TEST(TransformBlockDesc, pb2cpp) { + framework::ProgramDesc pb_prog; + auto *pb_block = pb_prog.MutableBlock(0); + CreatePbBlockDesc(pb_block); + + cpp::BlockDesc cpp_block; + TransformBlockDescToCinn(pb_block, &cpp_block); + + cpp::BlockDesc correct_block; + CreateCppBlockDesc(&correct_block); + ASSERT_EQ(cpp_block.Idx(), correct_block.Idx()); + ASSERT_EQ(cpp_block.ParentIdx(), correct_block.ParentIdx()); + ASSERT_EQ(cpp_block.ForwardBlockIdx(), correct_block.ForwardBlockIdx()); + ASSERT_EQ(cpp_block.OpsSize(), correct_block.OpsSize()); + ASSERT_EQ(cpp_block.VarsSize(), correct_block.VarsSize()); +} + +// check ProgramDesc +cpp::ProgramDesc CreateCppProgramDesc() { + cpp::ProgramDesc prog; + prog.SetVersion(22); + + auto *block = prog.AddBlock(); + CreateCppBlockDesc(block); + + return prog; +} + +framework::ProgramDesc CreatePbProgramDesc() { + framework::ProgramDesc prog; + prog.SetVersion(22); + + auto *block = prog.MutableBlock(0); + CreatePbBlockDesc(block); + return prog; +} + +TEST(TransformProgramDesc, cpp2pb) { + auto cpp_prog = CreateCppProgramDesc(); + framework::ProgramDesc pb_prog; + TransformProgramDescFromCinn(cpp_prog, &pb_prog); + + auto correct_prog = CreatePbProgramDesc(); + ASSERT_EQ(pb_prog.Version(), correct_prog.Version()); + ASSERT_EQ(pb_prog.Size(), correct_prog.Size()); +} + +TEST(TransformProgramDesc, pb2cpp) { + auto pb_prog = CreatePbProgramDesc(); + cpp::ProgramDesc cpp_prog; + TransformProgramDescToCinn(&pb_prog, &cpp_prog); + + auto correct_prog = CreateCppProgramDesc(); + ASSERT_EQ(cpp_prog.Version(), correct_prog.Version()); + ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize()); +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle