diff --git a/paddle/pir/core/op_trait.cc b/paddle/pir/core/op_trait.cc new file mode 100644 index 0000000000000..ccea4e3f06d9b --- /dev/null +++ b/paddle/pir/core/op_trait.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/op_trait.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/type_util.h" + +namespace pir::detail { + +void VerifySameOperandsShapeTrait(Operation *op) { + VLOG(4) << "Verify SameOperandsShapeTrait for : " << op->name(); + + IR_ENFORCE(op->num_operands() > 0, + "Op %s with SameOperandsShapeTrait requires at least 1 operands, " + "but got %u operands.", + op->name(), + op->num_operands()); + + std::vector operands = op->operands(); + std::vector types; + std::for_each(operands.begin(), operands.end(), [&types](pir::OpOperand op) { + types.push_back(op.type()); + }); + + IR_ENFORCE(VerifyCompatibleShapes(types), + "Op %s with SameOperandsShapeTrait requires the same shape for " + "all operands.", + op->name()); +} + +void VerifySameOperandsAndResultShapeTrait(Operation *op) { + VLOG(4) << "Verify SameOperandsAndResultShapeTrait for : " << op->name(); + + IR_ENFORCE(op->num_operands() > 0, + "Op %s with SameOperandsAndResultShapeTrait requires at least 1 " + "operands, but got %u operands.", + op->name(), + op->num_operands()); + + IR_ENFORCE(op->num_results() > 0, + "Op %s with SameOperandsAndResultShapeTrait requires at least 1 " + "results, but got %u results.", + op->name(), + op->num_results()); + + std::vector operands = op->operands(); + std::vector results = op->results(); + + std::vector types; + + std::for_each(operands.begin(), operands.end(), [&types](pir::OpOperand op) { + types.push_back(op.type()); + }); + + std::for_each(results.begin(), results.end(), [&types](pir::OpResult op) { + types.push_back(op.type()); + }); + + IR_ENFORCE(VerifyCompatibleShapes(types), + "Op %s with SameOperandsAndResultShapeTrait requires compatible " + "shapes for operands and results.", + op->name()); +} + +void VerifySameOperandsElementTypeTrait(Operation *op) { + VLOG(4) << "Verify SameOperandsElementTypeTrait for : " << op->name(); + + IR_ENFORCE(op->num_operands() > 0, + "Op %s with SameOperandsElementTypeTrait requires at least 1 " + "operands, but got %u operands.", + op->name(), + op->num_operands()); + + auto elementType = GetElementTypeOrSelf(op->result(0).type()); + for (auto operand : op->operands()) { + IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType, + "Op %s with SameOperandsElementTypeTrait requires the same " + "element type for all operands.", + op->name()); + } +} + +void VerifySameOperandsAndResultElementTypeTrait(Operation *op) { + VLOG(4) << "Verify SameOperandsAndResultElementTypeTrait for : " + << op->name(); + + IR_ENFORCE(op->num_operands() > 0, + "Op %s with SameOperandsAndResultElementTypeTrait requires at " + "least 1 operands, but got %u operands.", + op->name(), + op->num_operands()); + + IR_ENFORCE(op->num_results() > 0, + "Op %s with SameOperandsAndResultElementTypeTrait requires at " + "least 1 results, but got %u results.", + op->name(), + op->num_results()); + + auto elementType = GetElementTypeOrSelf(op->result(0).type()); + + // Verify result element type matches first result's element type. + for (auto result : op->results()) { + IR_ENFORCE(GetElementTypeOrSelf(result.type()) == elementType, + "Op %s with SameOperandsAndResultElementTypeTrait requires the " + "same element type for all operands and results.", + op->name()); + } + + // Verify operand's element type matches first result's element type. + for (auto operand : op->operands()) { + IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType, + "Op %s with SameOperandsAndResultElementTypeTrait requires the " + "same element type for all operands and results.", + op->name()); + } +} + +void VerifySameOperandsAndResultTypeTrait(Operation *op) { + VLOG(4) << "Verify SameOperandsAndResultTypeTrait for : " << op->name(); + + IR_ENFORCE(op->num_operands() > 0, + "Op %s with SameOperandsAndResultTypeTrait requires at least 1 " + "operands, but got %u operands.", + op->name(), + op->num_operands()); + + IR_ENFORCE(op->num_results() > 0, + "Op %s with SameOperandsAndResultTypeTrait requires at least 1 " + "results, but got %u results.", + op->name(), + op->num_results()); + + auto type = op->result(0).type(); + auto elementType = GetElementTypeOrSelf(type); + + for (auto result : op->results()) { + IR_ENFORCE(GetElementTypeOrSelf(result.type()) == elementType, + "Op %s with SameOperandsAndResultTypeTrait requires the same " + "type for all operands and results.", + op->name()); + + IR_ENFORCE(VerifyCompatibleShape(result.type(), type), + "Op %s with SameOperandsAndResultTypeTrait requires the same " + "type for all operands and results.", + op->name()); + } + + for (auto operand : op->operands()) { + IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType, + "Op %s with SameOperandsAndResultTypeTrait requires the same " + "type for all operands and results.", + op->name()); + + IR_ENFORCE(VerifyCompatibleShape(operand.type(), type), + "Op %s with SameOperandsAndResultTypeTrait requires the same " + "type for all operands and results.", + op->name()); + } +} + +void VerifySameTypeOperandsTrait(Operation *op) { + VLOG(4) << "Verify SameTypeOperandsTrait for : " << op->name(); + + // For zero or only one operand. + unsigned operand_nums = op->num_operands(); + if (operand_nums < 2) return; + + auto type = op->operand(0).type(); + + for (auto operand : op->operands()) { + IR_ENFORCE(operand.type() == type, + "Op %s with SameTypeOperandsTrait requires all operands to have " + "the same type.", + op->name()); + } +} + +} // namespace pir::detail + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) diff --git a/paddle/pir/core/op_trait.h b/paddle/pir/core/op_trait.h new file mode 100644 index 0000000000000..760799fd16165 --- /dev/null +++ b/paddle/pir/core/op_trait.h @@ -0,0 +1,121 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/op_base.h" + +namespace pir { + +namespace detail { +void VerifySameOperandsShapeTrait(Operation *op); +void VerifySameOperandsAndResultShapeTrait(Operation *op); +void VerifySameOperandsElementTypeTrait(Operation *op); +void VerifySameOperandsAndResultElementTypeTrait(Operation *op); +void VerifySameOperandsAndResultTypeTrait(Operation *op); +void VerifySameTypeOperandsTrait(Operation *op); +} // namespace detail + +/// +/// \brief Provides verification for ops that are known to have the +/// same operand shape. +/// +class SameOperandsShapeTrait : public pir::OpTraitBase { + public: + explicit SameOperandsShapeTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(Operation *op) { + return detail::VerifySameOperandsShapeTrait(op); + } +}; + +/// +/// \brief Provides verification for ops that are known to have the +/// same operand and result shape. +/// +class SameOperandsAndResultShapeTrait + : public pir::OpTraitBase { + public: + explicit SameOperandsAndResultShapeTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(Operation *op) { + return detail::VerifySameOperandsAndResultShapeTrait(op); + } +}; + +/// +/// \brief Provides verification for ops that are known to have the +/// same operand element type (or the type itself if it is scalar). +/// +class SameOperandsElementTypeTrait + : public pir::OpTraitBase { + public: + explicit SameOperandsElementTypeTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(Operation *op) { + return detail::VerifySameOperandsElementTypeTrait(op); + } +}; + +/// +/// \brief Provides verification for ops that are known to have the +/// same operand and result element type (or the type itself if it is scalar). +/// +class SameOperandsAndResultElementTypeTrait + : public pir::OpTraitBase { + public: + explicit SameOperandsAndResultElementTypeTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(Operation *op) { + return detail::VerifySameOperandsAndResultElementTypeTrait(op); + } +}; + +/// +/// \brief Provides verification for ops that are known to have the +/// same operand and result type. It Subsumes both +/// SameOperandsAndResultShapeTrait and SameOperandsAndResultElementTypeTrait +/// +class SameOperandsAndResultTypeTrait + : public pir::OpTraitBase { + public: + explicit SameOperandsAndResultTypeTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + + static void Verify(Operation *op) { + return detail::VerifySameOperandsAndResultTypeTrait(op); + } +}; + +/// +/// \brief Provides verification that all operands of the specified op have the +/// same type. +/// +class SameTypeOperandsTrait : public pir::OpTraitBase { + public: + explicit SameTypeOperandsTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(Operation *op) { + return detail::VerifySameTypeOperandsTrait(op); + } +}; + +} // namespace pir + +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait) diff --git a/paddle/pir/core/type_util.cc b/paddle/pir/core/type_util.cc new file mode 100644 index 0000000000000..0d6d137a897f0 --- /dev/null +++ b/paddle/pir/core/type_util.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/type_util.h" +#include + +namespace pir { + +Type GetElementTypeOrSelf(Type type) { + if (auto sType = type.dyn_cast()) + return sType.GetElementType(); + return type; +} + +bool VerifyCompatibleShape(const phi::DDim &lhs_shape, + const phi::DDim &rhs_shape) { + if (lhs_shape.size() != rhs_shape.size()) return false; + + for (auto dim1 : phi::vectorize(lhs_shape)) { + for (auto dim2 : phi::vectorize(rhs_shape)) { + if (!ShapedTypeInterface::IsDynamic(dim1) && + !ShapedTypeInterface::IsDynamic(dim2) && dim1 != dim2) + return false; + } + } + return true; +} + +bool VerifyCompatibleShape(Type lhs_type, Type rhs_type) { + auto lhs_shaped_type = lhs_type.dyn_cast(); + auto rhs_shaped_type = rhs_type.dyn_cast(); + + // Either both or neither type should be shaped. + if (!lhs_shaped_type) return !rhs_shaped_type; + if (!rhs_shaped_type) return false; + + if (!lhs_shaped_type.HasRank() || !rhs_shaped_type.HasRank()) return true; + + return VerifyCompatibleShape(lhs_shaped_type.GetShape(), + rhs_shaped_type.GetShape()); +} + +bool VerifyCompatibleDims(const std::vector &dims) { + if (dims.empty()) return true; + auto static_dim = std::accumulate( + dims.begin(), dims.end(), dims.front(), [](auto &fold, auto &dim) { + return ShapedTypeInterface::IsDynamic(dim) ? fold : dim; + }); + return std::all_of(dims.begin(), dims.begin(), [&](auto dim) { + return ShapedTypeInterface::IsDynamic(dim) || dim == static_dim; + }); +} + +bool VerifyCompatibleShapes(const std::vector &lhs_types, + const std::vector &rhs_types) { + if (lhs_types.size() != rhs_types.size()) return false; + + for (auto it1 : lhs_types) { + for (auto it2 : rhs_types) { + if (!VerifyCompatibleShape(it1, it2)) return false; + } + } + return true; +} + +bool VerifyCompatibleShapes(const std::vector &types) { + std::vector shaped_type_interfaces; + + std::for_each( + types.begin(), types.end(), [&shaped_type_interfaces](Type type) { + shaped_type_interfaces.push_back(type.dyn_cast()); + }); + + // Return false if some, but not all are not shaped. Return early if none + // are shaped also. + if (std::none_of(shaped_type_interfaces.begin(), + shaped_type_interfaces.end(), + [](auto t) { return t; })) + return true; + + if (!std::all_of(shaped_type_interfaces.begin(), + shaped_type_interfaces.end(), + [](auto t) { return t; })) + return false; + + // Remove all unranked shapes + std::vector shapes; + + std::for_each(shaped_type_interfaces.begin(), + shaped_type_interfaces.end(), + [&shapes](ShapedTypeInterface type) { + if (type.HasRank()) + shapes.push_back(type.dyn_cast()); + }); + if (shapes.empty()) return true; + + // All ranks should be equal + int64_t firstRank = shapes.front().GetRank(); + + if (std::any_of(shapes.begin(), shapes.end(), [&](auto shape) { + return firstRank != shape.GetRank(); + })) + return false; + + for (unsigned i = 0; i < firstRank; ++i) { + // For all ranked dimensions + std::vector dims; + std::for_each(shapes.begin(), shapes.end(), [&](ShapedTypeInterface shape) { + dims.push_back(shape.GetDimSize(i)); + }); + + if (!VerifyCompatibleDims(dims)) return false; + } + + return true; +} + +} // namespace pir diff --git a/paddle/pir/core/type_util.h b/paddle/pir/core/type_util.h new file mode 100644 index 0000000000000..5704ba2abea78 --- /dev/null +++ b/paddle/pir/core/type_util.h @@ -0,0 +1,65 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +/// +/// \brief Utility Functions +/// + +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" + +namespace pir { +/// +/// \brief Return the element type or return the type itself. +/// +Type GetElementTypeOrSelf(Type type); + +/// +/// \brief Returns true if the given two shapes are compatible. That is, they +/// have the same size and each pair of the elements are equal or one of them is +/// dynamic. +/// +bool VerifyCompatibleShape(const phi::DDim& lhs_shape, + const phi::DDim& rhs_shape); + +/// +/// \brief Returns true if the given two types have compatible shape. That +/// is, they are both scalars (not shaped), or they are both shaped types and at +/// least one is unranked or they have compatible dimensions. Dimensions are +/// compatible if at least one is dynamic or both are equal. The element type +/// does not matter. +/// +bool VerifyCompatibleShape(Type lhs_type, Type rhs_type); + +/// +/// \brief Dimensions are compatible if all non-dynamic dims are equal. +/// +bool VerifyCompatibleDims(const std::vector& dims); + +/// +/// \brief Returns true if the given two arrays have the same number of elements +/// and each pair wise entries have compatible shape. +/// +bool VerifyCompatibleShapes(const std::vector& lhs_types, + const std::vector& rhs_types); + +/// +/// \brief Returns true if all given types have compatible shapes. That is, +/// they are all scalars (not shaped), or they are all shaped types and any +/// ranked shapes have compatible dimensions. Dimensions are compatible if all +/// non-dynamic dims are equal. The element type does not matter. +/// +bool VerifyCompatibleShapes(const std::vector& types); +} // namespace pir diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index 0d65bc5b454c3..0f0ec568bb50a 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -8,14 +8,15 @@ cc_test_old( pd_op_dialect) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS pir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS pir gtest) -cc_test_old( +paddle_test( ir_op_test SRCS ir_op_test.cc DEPS pir gtest - test_dialect) + test_dialect + pd_op_dialect) cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS pir gtest) cc_test_old(ir_builder_test SRCS ir_builder_test.cc DEPS pir gtest) cc_test_old( @@ -139,3 +140,9 @@ cc_test_old( test_dialect gtest pir) + +if(WITH_ONNXRUNTIME AND WIN32) + # Copy onnxruntime for some c++ test in Windows, since the test will + # be build only in CI, so suppose the generator in Windows is Ninja. + copy_onnx(ir_op_test) +endif() diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index c512ea753e3c0..596519ba57d4c 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -15,6 +15,8 @@ #include #include +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/phi/core/tensor_meta.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" @@ -43,6 +45,27 @@ pir::AttributeMap CreateAttributeMap( return attr_map; } +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes, + const pir::Type &dtype = + pir::Float32Type::get(pir::IrContext::Instance())) { + std::vector op_inputs = {}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + TEST(op_test, region_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. pir::IrContext *ctx = pir::IrContext::Instance(); @@ -126,3 +149,367 @@ TEST(op_test, trait_and_interface) { pir::OperationArgument argument(&ctx, "test.region"); EXPECT_THROW(builder.Build(std::move(argument)), pir::IrNotMetException); } + +TEST(op_test, op_traits_test) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype = pir::Float32Type::get(ctx); + phi::DDim dims = {2, 2}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype = + pir::DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + + auto op3 = builder.Build( + op1->result(0), op2->result(0), dense_tensor_dtype); + + EXPECT_EQ(op3->HasTrait(), true); + EXPECT_EQ(op3->HasTrait(), true); + EXPECT_EQ(op3->HasTrait(), true); + EXPECT_EQ(op3->HasTrait(), true); + EXPECT_EQ(op3->HasTrait(), true); + EXPECT_EQ(op3->HasTrait(), true); +} + +TEST(op_test, same_operands_shape_trait_test1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + EXPECT_THROW(builder.Build(), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_shape_trait_test2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype1 = pir::Float32Type::get(ctx); + phi::DDim dims1 = {2, 2}; + + pir::Type dtype2 = pir::Float64Type::get(ctx); + phi::DDim dims2 = {2, 2, 2}; + + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype = + pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0), dense_tensor_dtype), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_shape_trait_test1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + EXPECT_THROW(builder.Build(), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_shape_trait_test2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype = pir::Float64Type::get(ctx); + phi::DDim dims = {2, 2, 2}; + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0)), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_shape_trait_test3) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype1 = pir::Float32Type::get(ctx); + phi::DDim dims1 = {2, 2}; + + pir::Type dtype2 = pir::Float64Type::get(ctx); + phi::DDim dims2 = {2, 2, 2}; + + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype = + pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0), dense_tensor_dtype), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_element_type_trait_test1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + EXPECT_THROW(builder.Build(), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_element_type_trait_test2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype1 = pir::Float32Type::get(ctx); + pir::Type dtype2 = pir::Float64Type::get(ctx); + + phi::DDim dims = {2, 2}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype = + pir::DenseTensorType::get(ctx, dtype1, dims, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype1); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype2); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0), dense_tensor_dtype), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_element_type_trait_test1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + EXPECT_THROW(builder.Build(), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_element_type_trait_test2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype = pir::Float32Type::get(ctx); + phi::DDim dims = {2, 2}; + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0)), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_element_type_trait_test3) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype1 = pir::Float32Type::get(ctx); + phi::DDim dims1 = {2, 2}; + + pir::Type dtype2 = pir::Float64Type::get(ctx); + phi::DDim dims2 = {2, 2, 2}; + + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype1 = + pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); + pir::DenseTensorType dense_tensor_dtype2 = + pir::DenseTensorType::get(ctx, dtype2, dims2, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype1); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype2); + + EXPECT_THROW(builder.Build( + op1->result(0), + op2->result(0), + dense_tensor_dtype1, + dense_tensor_dtype1), + pir::IrNotMetException); + EXPECT_THROW(builder.Build( + op1->result(0), + op1->result(0), + dense_tensor_dtype1, + dense_tensor_dtype2), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_type_trait_test1) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + EXPECT_THROW(builder.Build(), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_type_trait_test2) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype = pir::Float32Type::get(ctx); + phi::DDim dims = {2, 2}; + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims, {"op1_temp"}, {"op1_attr"}, dtype); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims, {"op2_temp"}, {"op2_attr"}, dtype); + + EXPECT_THROW(builder.Build( + op1->result(0), op2->result(0)), + pir::IrNotMetException); +} + +TEST(op_test, same_operands_and_result_type_trait_test3) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + auto block = program.block(); + pir::Builder builder(ctx, block); + + pir::Type dtype1 = pir::Float32Type::get(ctx); + phi::DDim dims1 = {2, 2}; + + pir::Type dtype2 = pir::Float64Type::get(ctx); + phi::DDim dims2 = {2, 2, 2}; + + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + pir::DenseTensorType dense_tensor_dtype1 = + pir::DenseTensorType::get(ctx, dtype1, dims1, data_layout, lod, offset); + + pir::DenseTensorType dense_tensor_dtype2 = + pir::DenseTensorType::get(ctx, dtype2, dims2, data_layout, lod, offset); + + pir::DenseTensorType dense_tensor_dtype3 = + pir::DenseTensorType::get(ctx, dtype1, dims2, data_layout, lod, offset); + + pir::Operation *op1 = + CreateDenseTensorOp(ctx, dims1, {"op1_temp"}, {"op1_attr"}, dtype2); + pir::Operation *op2 = + CreateDenseTensorOp(ctx, dims2, {"op2_temp"}, {"op2_attr"}, dtype1); + + EXPECT_THROW(builder.Build( + op1->result(0), + op2->result(0), + dense_tensor_dtype1, + dense_tensor_dtype2), + pir::IrNotMetException); + + EXPECT_THROW(builder.Build( + op1->result(0), + op2->result(0), + dense_tensor_dtype1, + dense_tensor_dtype3), + pir::IrNotMetException); + + EXPECT_THROW(builder.Build( + op1->result(0), + op2->result(0), + dense_tensor_dtype1, + dense_tensor_dtype1), + pir::IrNotMetException); + + EXPECT_THROW(builder.Build( + op2->result(0), + op1->result(0), + dense_tensor_dtype1, + dense_tensor_dtype1), + pir::IrNotMetException); +} diff --git a/test/cpp/pir/core/type_test.cc b/test/cpp/pir/core/type_test.cc index ada08b5f9bf1a..0f3581732784f 100644 --- a/test/cpp/pir/core/type_test.cc +++ b/test/cpp/pir/core/type_test.cc @@ -24,6 +24,7 @@ #include "paddle/pir/core/type.h" #include "paddle/pir/core/type_base.h" #include "paddle/pir/core/type_name.h" +#include "paddle/pir/core/type_util.h" #include "paddle/pir/core/utils.h" class TypeA {}; @@ -260,6 +261,36 @@ TEST(type_test, pd_op_dialect) { EXPECT_EQ(select_rows_dtype.offset(), offset); } +TEST(type_test, type_util) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DDim dims1 = {2, 2}; + phi::DDim dims2 = {2, 2, 3}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + + paddle::dialect::SelectedRowsType select_rows_dtype1 = + paddle::dialect::SelectedRowsType::get( + ctx, fp32_dtype, dims1, data_layout, lod, offset); + + paddle::dialect::SelectedRowsType select_rows_dtype2 = + paddle::dialect::SelectedRowsType::get( + ctx, fp32_dtype, dims2, data_layout, lod, offset); + + std::vector types1 = { + select_rows_dtype1, select_rows_dtype1, select_rows_dtype1}; + std::vector types2 = { + select_rows_dtype1, select_rows_dtype1, select_rows_dtype1}; + std::vector types3 = { + select_rows_dtype2, select_rows_dtype2, select_rows_dtype2}; + + EXPECT_TRUE(pir::VerifyCompatibleShapes(types1, types2)); + EXPECT_FALSE(pir::VerifyCompatibleShapes(types1, types3)); +} + namespace TestNamespace { class TestClass {}; } // namespace TestNamespace diff --git a/test/cpp/pir/tools/test_dialect.cc b/test/cpp/pir/tools/test_dialect.cc index 49fb4a6951dd7..e3000a418119b 100644 --- a/test/cpp/pir/tools/test_dialect.cc +++ b/test/cpp/pir/tools/test_dialect.cc @@ -21,7 +21,24 @@ TestDialect::TestDialect(pir::IrContext *context) initialize(); } void TestDialect::initialize() { - RegisterOps(); + RegisterOps(); } void TestDialect::PrintOperation(pir::Operation *op, diff --git a/test/cpp/pir/tools/test_op.cc b/test/cpp/pir/tools/test_op.cc index b67dd24c5dc04..6041efec0e652 100644 --- a/test/cpp/pir/tools/test_op.cc +++ b/test/cpp/pir/tools/test_op.cc @@ -21,8 +21,8 @@ void RegionOp::Build(pir::Builder &builder, pir::OperationArgument &argument) { argument.AddRegion(nullptr); } -void BranchOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, +void BranchOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT const std::vector &target_operands, pir::Block *target) { argument.AddInputs(target_operands.begin(), target_operands.end()); @@ -35,9 +35,7 @@ void BranchOp::Verify() const { IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr"); } -const char *Operation1::attributes_name[2] = { // NOLINT - "op1_attr1", - "op1_attr2"}; +const char *Operation1::attributes_name[2] = {"op1_attr1", "op1_attr2"}; void Operation1::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) { // NOLINT @@ -58,9 +56,120 @@ void Operation1::Verify() const { throw("Type of attribute: parameter_name is not right."); } } + +void TraitExampleOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type); +} + +void SameOperandsShapeTraitOp2::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type); +} + +void SameOperandsAndResultShapeTraitOp2::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); +} + +void SameOperandsAndResultShapeTraitOp3::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type); +} + +void SameOperandsElementTypeTraitOp2::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type); +} + +void SameOperandsAndResultElementTypeTraitOp2::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); +} + +void SameOperandsAndResultElementTypeTraitOp3::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type1, + pir::Type out_type2) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type1); + argument.AddOutput(out_type2); +} + +void SameOperandsAndResultTypeTraitOp2::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); +} + +void SameOperandsAndResultTypeTraitOp3::Build( + pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type1, + pir::Type out_type2) { + argument.AddInput(l_operand); + argument.AddInput(r_operand); + argument.AddOutput(out_type1); + argument.AddOutput(out_type2); +} + } // namespace test IR_DEFINE_EXPLICIT_TYPE_ID(test::RegionOp) IR_DEFINE_EXPLICIT_TYPE_ID(test::BranchOp) IR_DEFINE_EXPLICIT_TYPE_ID(test::Operation1) IR_DEFINE_EXPLICIT_TYPE_ID(test::Operation2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::TraitExampleOp) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsShapeTraitOp1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsShapeTraitOp2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp3) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsElementTypeTraitOp1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsElementTypeTraitOp2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp3) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp2) +IR_DEFINE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp3) diff --git a/test/cpp/pir/tools/test_op.h b/test/cpp/pir/tools/test_op.h index 8d4ccd49a38ed..98f01db37614d 100644 --- a/test/cpp/pir/tools/test_op.h +++ b/test/cpp/pir/tools/test_op.h @@ -17,6 +17,7 @@ #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/op_trait.h" #include "paddle/pir/core/operation_utils.h" #include "test/cpp/pir/tools/test_interface.h" #include "test/cpp/pir/tools/test_trait.h" @@ -58,7 +59,7 @@ class Operation1 : public pir::Op { using Op::Op; static const char *name() { return "test.operation1"; } static constexpr uint32_t attributes_num = 2; - static const char *attributes_name[attributes_num]; // NOLINT + static const char *attributes_name[attributes_num]; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT void Verify() const; @@ -71,16 +72,269 @@ class Operation2 using Op::Op; static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 0; - static constexpr const char **attributes_name = nullptr; // NOLINT - static void Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument) {} // NOLINT + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT void Verify() const {} static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; +// Define TraitExampleOp. +class TraitExampleOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.trait_example_op"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type); + void Verify() const {} +}; + +// Define SameOperandsShapeTraitOp1. +class SameOperandsShapeTraitOp1 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_shape_op1"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} +}; + +// Define SameOperandsShapeTraitOp2. +class SameOperandsShapeTraitOp2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_shape_op2"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type); + void Verify() const {} +}; + +// Define SameOperandsAndResultShapeTraitOp1. +class SameOperandsAndResultShapeTraitOp1 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_shape_op1"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} +}; + +// Define SameOperandsAndResultShapeTraitOp2. +class SameOperandsAndResultShapeTraitOp2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_shape_op2"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand); + void Verify() const {} +}; + +// Define SameOperandsAndResultShapeTraitOp3. +class SameOperandsAndResultShapeTraitOp3 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_shape_op3"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type); + void Verify() const {} +}; + +// Define SameOperandsElementTypeTraitOp1. +class SameOperandsElementTypeTraitOp1 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_element_type_op1"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} +}; + +// Define SameOperandsElementTypeTraitOp2. +class SameOperandsElementTypeTraitOp2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_element_type_op1"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type); + void Verify() const {} +}; + +// Define SameOperandsAndResultElementTypeTraitOp1. +class SameOperandsAndResultElementTypeTraitOp1 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_element_type_op1"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} +}; + +// Define SameOperandsAndResultElementTypeTraitOp2. +class SameOperandsAndResultElementTypeTraitOp2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_element_type_op2"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand); + void Verify() const {} +}; + +// Define SameOperandsAndResultElementTypeTraitOp3. +class SameOperandsAndResultElementTypeTraitOp3 + : public pir::Op { + public: + using Op::Op; + static const char *name() { + return "test.same_operands_and_result_element_type_op3"; + } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type1, + pir::Type out_type2); + void Verify() const {} +}; + +// Define SameOperandsAndResultTypeTraitOp1. +class SameOperandsAndResultTypeTraitOp1 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_and_result_type_op1"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} +}; + +// Define SameOperandsAndResultTypeTraitOp2. +class SameOperandsAndResultTypeTraitOp2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_and_result_type_op2"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand); + void Verify() const {} +}; + +// Define SameOperandsAndResultTypeTraitOp3. +class SameOperandsAndResultTypeTraitOp3 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.same_operands_and_result_type_op3"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value l_operand, + pir::Value r_operand, + pir::Type out_type1, + pir::Type out_type2); + + void Verify() const {} +}; + } // namespace test IR_DECLARE_EXPLICIT_TYPE_ID(test::RegionOp) IR_DECLARE_EXPLICIT_TYPE_ID(test::BranchOp) IR_DECLARE_EXPLICIT_TYPE_ID(test::Operation1) IR_DECLARE_EXPLICIT_TYPE_ID(test::Operation2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::TraitExampleOp) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsShapeTraitOp1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsShapeTraitOp2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultShapeTraitOp3) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsElementTypeTraitOp1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsElementTypeTraitOp2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultElementTypeTraitOp3) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp2) +IR_DECLARE_EXPLICIT_TYPE_ID(test::SameOperandsAndResultTypeTraitOp3)