Skip to content

Commit

Permalink
fix crash when eletwise inputs are different rank
Browse files Browse the repository at this point in the history
when two INPUT are different rank, AlignPermuteVectorForElementWise()
will force align them and crash

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>
  • Loading branch information
Chen committed Dec 6, 2023
1 parent 5173979 commit 9cbf37e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
48 changes: 48 additions & 0 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
auto expanded_shape =
GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape());
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
Expand All @@ -63,6 +87,30 @@ class MultiplyLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ != tim::vx::CONSTANT &&
in_1->GetSpec().attr_ != tim::vx::CONSTANT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto rank_long = pv_long->Rank();
auto rank_short = pv_short->Rank();
auto expanded_pv = MakeShared(rank_long);
// if different size, expand short pv to long pv
for (uint32_t i = 0; i < rank_short; ++i) {
expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}
auto expanded_shape =
GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape());
short_tensor->GetSpec().SetShape(expanded_shape);

context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto multiply =
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
Expand Down
13 changes: 5 additions & 8 deletions src/tim/transform/ops/op_layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() {
std::vector<uint32_t> OpLayoutInfer::GetExpandedShape(
const std::vector<uint32_t>& ref_shape,
const std::vector<uint32_t>& origin_shape) {
std::vector<uint32_t> expanded_shape;
for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) {
if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) {
expanded_shape.push_back(origin_shape[j]);
++j;
} else {
expanded_shape.push_back(1);
}
std::vector<uint32_t> expanded_shape(origin_shape);
auto ref_rank = ref_shape.size();
auto origin_rank = origin_shape.size();
for (uint32_t i = 0; i < ref_rank; ++i) {
if (i >= origin_rank) expanded_shape.push_back(1);
}
return expanded_shape;
}
Expand Down

0 comments on commit 9cbf37e

Please sign in to comment.