Skip to content

Commit

Permalink
[PIR] Add op_trait and type_util (PaddlePaddle#57580)
Browse files Browse the repository at this point in the history
* op_trait and type_util

* add unit test

* add expect_throw for ci converage

* fix for win ci
  • Loading branch information
zhangbopd authored and jiahy0825 committed Oct 16, 2023
1 parent 9fc980e commit 34e5a14
Show file tree
Hide file tree
Showing 10 changed files with 1,328 additions and 12 deletions.
196 changes: 196 additions & 0 deletions paddle/pir/core/op_trait.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/core/op_trait.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/type_util.h"

namespace pir::detail {

void VerifySameOperandsShapeTrait(Operation *op) {
VLOG(4) << "Verify SameOperandsShapeTrait for : " << op->name();

IR_ENFORCE(op->num_operands() > 0,
"Op %s with SameOperandsShapeTrait requires at least 1 operands, "
"but got %u operands.",
op->name(),
op->num_operands());

std::vector<pir::OpOperand> operands = op->operands();
std::vector<pir::Type> types;
std::for_each(operands.begin(), operands.end(), [&types](pir::OpOperand op) {
types.push_back(op.type());
});

IR_ENFORCE(VerifyCompatibleShapes(types),
"Op %s with SameOperandsShapeTrait requires the same shape for "
"all operands.",
op->name());
}

void VerifySameOperandsAndResultShapeTrait(Operation *op) {
VLOG(4) << "Verify SameOperandsAndResultShapeTrait for : " << op->name();

IR_ENFORCE(op->num_operands() > 0,
"Op %s with SameOperandsAndResultShapeTrait requires at least 1 "
"operands, but got %u operands.",
op->name(),
op->num_operands());

IR_ENFORCE(op->num_results() > 0,
"Op %s with SameOperandsAndResultShapeTrait requires at least 1 "
"results, but got %u results.",
op->name(),
op->num_results());

std::vector<pir::OpOperand> operands = op->operands();
std::vector<pir::OpResult> results = op->results();

std::vector<pir::Type> types;

std::for_each(operands.begin(), operands.end(), [&types](pir::OpOperand op) {
types.push_back(op.type());
});

std::for_each(results.begin(), results.end(), [&types](pir::OpResult op) {
types.push_back(op.type());
});

IR_ENFORCE(VerifyCompatibleShapes(types),
"Op %s with SameOperandsAndResultShapeTrait requires compatible "
"shapes for operands and results.",
op->name());
}

void VerifySameOperandsElementTypeTrait(Operation *op) {
VLOG(4) << "Verify SameOperandsElementTypeTrait for : " << op->name();

IR_ENFORCE(op->num_operands() > 0,
"Op %s with SameOperandsElementTypeTrait requires at least 1 "
"operands, but got %u operands.",
op->name(),
op->num_operands());

auto elementType = GetElementTypeOrSelf(op->result(0).type());
for (auto operand : op->operands()) {
IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType,
"Op %s with SameOperandsElementTypeTrait requires the same "
"element type for all operands.",
op->name());
}
}

void VerifySameOperandsAndResultElementTypeTrait(Operation *op) {
VLOG(4) << "Verify SameOperandsAndResultElementTypeTrait for : "
<< op->name();

IR_ENFORCE(op->num_operands() > 0,
"Op %s with SameOperandsAndResultElementTypeTrait requires at "
"least 1 operands, but got %u operands.",
op->name(),
op->num_operands());

IR_ENFORCE(op->num_results() > 0,
"Op %s with SameOperandsAndResultElementTypeTrait requires at "
"least 1 results, but got %u results.",
op->name(),
op->num_results());

auto elementType = GetElementTypeOrSelf(op->result(0).type());

// Verify result element type matches first result's element type.
for (auto result : op->results()) {
IR_ENFORCE(GetElementTypeOrSelf(result.type()) == elementType,
"Op %s with SameOperandsAndResultElementTypeTrait requires the "
"same element type for all operands and results.",
op->name());
}

// Verify operand's element type matches first result's element type.
for (auto operand : op->operands()) {
IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType,
"Op %s with SameOperandsAndResultElementTypeTrait requires the "
"same element type for all operands and results.",
op->name());
}
}

void VerifySameOperandsAndResultTypeTrait(Operation *op) {
VLOG(4) << "Verify SameOperandsAndResultTypeTrait for : " << op->name();

IR_ENFORCE(op->num_operands() > 0,
"Op %s with SameOperandsAndResultTypeTrait requires at least 1 "
"operands, but got %u operands.",
op->name(),
op->num_operands());

IR_ENFORCE(op->num_results() > 0,
"Op %s with SameOperandsAndResultTypeTrait requires at least 1 "
"results, but got %u results.",
op->name(),
op->num_results());

auto type = op->result(0).type();
auto elementType = GetElementTypeOrSelf(type);

for (auto result : op->results()) {
IR_ENFORCE(GetElementTypeOrSelf(result.type()) == elementType,
"Op %s with SameOperandsAndResultTypeTrait requires the same "
"type for all operands and results.",
op->name());

IR_ENFORCE(VerifyCompatibleShape(result.type(), type),
"Op %s with SameOperandsAndResultTypeTrait requires the same "
"type for all operands and results.",
op->name());
}

for (auto operand : op->operands()) {
IR_ENFORCE(GetElementTypeOrSelf(operand.type()) == elementType,
"Op %s with SameOperandsAndResultTypeTrait requires the same "
"type for all operands and results.",
op->name());

IR_ENFORCE(VerifyCompatibleShape(operand.type(), type),
"Op %s with SameOperandsAndResultTypeTrait requires the same "
"type for all operands and results.",
op->name());
}
}

void VerifySameTypeOperandsTrait(Operation *op) {
VLOG(4) << "Verify SameTypeOperandsTrait for : " << op->name();

// For zero or only one operand.
unsigned operand_nums = op->num_operands();
if (operand_nums < 2) return;

auto type = op->operand(0).type();

for (auto operand : op->operands()) {
IR_ENFORCE(operand.type() == type,
"Op %s with SameTypeOperandsTrait requires all operands to have "
"the same type.",
op->name());
}
}

} // namespace pir::detail

IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait)
121 changes: 121 additions & 0 deletions paddle/pir/core/op_trait.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/core/op_base.h"

namespace pir {

namespace detail {
void VerifySameOperandsShapeTrait(Operation *op);
void VerifySameOperandsAndResultShapeTrait(Operation *op);
void VerifySameOperandsElementTypeTrait(Operation *op);
void VerifySameOperandsAndResultElementTypeTrait(Operation *op);
void VerifySameOperandsAndResultTypeTrait(Operation *op);
void VerifySameTypeOperandsTrait(Operation *op);
} // namespace detail

///
/// \brief Provides verification for ops that are known to have the
/// same operand shape.
///
class SameOperandsShapeTrait : public pir::OpTraitBase<SameOperandsShapeTrait> {
public:
explicit SameOperandsShapeTrait(pir::Operation *op)
: pir::OpTraitBase<SameOperandsShapeTrait>(op) {}
static void Verify(Operation *op) {
return detail::VerifySameOperandsShapeTrait(op);
}
};

///
/// \brief Provides verification for ops that are known to have the
/// same operand and result shape.
///
class SameOperandsAndResultShapeTrait
: public pir::OpTraitBase<SameOperandsAndResultShapeTrait> {
public:
explicit SameOperandsAndResultShapeTrait(pir::Operation *op)
: pir::OpTraitBase<SameOperandsAndResultShapeTrait>(op) {}
static void Verify(Operation *op) {
return detail::VerifySameOperandsAndResultShapeTrait(op);
}
};

///
/// \brief Provides verification for ops that are known to have the
/// same operand element type (or the type itself if it is scalar).
///
class SameOperandsElementTypeTrait
: public pir::OpTraitBase<SameOperandsElementTypeTrait> {
public:
explicit SameOperandsElementTypeTrait(pir::Operation *op)
: pir::OpTraitBase<SameOperandsElementTypeTrait>(op) {}
static void Verify(Operation *op) {
return detail::VerifySameOperandsElementTypeTrait(op);
}
};

///
/// \brief Provides verification for ops that are known to have the
/// same operand and result element type (or the type itself if it is scalar).
///
class SameOperandsAndResultElementTypeTrait
: public pir::OpTraitBase<SameOperandsAndResultElementTypeTrait> {
public:
explicit SameOperandsAndResultElementTypeTrait(pir::Operation *op)
: pir::OpTraitBase<SameOperandsAndResultElementTypeTrait>(op) {}
static void Verify(Operation *op) {
return detail::VerifySameOperandsAndResultElementTypeTrait(op);
}
};

///
/// \brief Provides verification for ops that are known to have the
/// same operand and result type. It Subsumes both
/// SameOperandsAndResultShapeTrait and SameOperandsAndResultElementTypeTrait
///
class SameOperandsAndResultTypeTrait
: public pir::OpTraitBase<SameOperandsAndResultTypeTrait> {
public:
explicit SameOperandsAndResultTypeTrait(pir::Operation *op)
: pir::OpTraitBase<SameOperandsAndResultTypeTrait>(op) {}

static void Verify(Operation *op) {
return detail::VerifySameOperandsAndResultTypeTrait(op);
}
};

///
/// \brief Provides verification that all operands of the specified op have the
/// same type.
///
class SameTypeOperandsTrait : public pir::OpTraitBase<SameTypeOperandsTrait> {
public:
explicit SameTypeOperandsTrait(pir::Operation *op)
: pir::OpTraitBase<SameTypeOperandsTrait>(op) {}
static void Verify(Operation *op) {
return detail::VerifySameTypeOperandsTrait(op);
}
};

} // namespace pir

IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsShapeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultShapeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsElementTypeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultElementTypeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameOperandsAndResultTypeTrait)
IR_DECLARE_EXPLICIT_TYPE_ID(pir::SameTypeOperandsTrait)
Loading

0 comments on commit 34e5a14

Please sign in to comment.