diff --git a/src/tim/transform/ops/elementwise_layout_inference.h b/src/tim/transform/ops/elementwise_layout_inference.h index 102609d5d..e4a96eeb5 100644 --- a/src/tim/transform/ops/elementwise_layout_inference.h +++ b/src/tim/transform/ops/elementwise_layout_inference.h @@ -42,6 +42,30 @@ class ElementWiseLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + auto expanded_shape = + GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape()); + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto elementwise = context_->infer_graph_->CreateOperation(); for (const auto& i_src : op_->impl()->InputsTensor()) { @@ -63,6 +87,30 @@ class MultiplyLayoutInfer : public OpLayoutInfer { void OnInputs( std::vector>& next_tensors) override { + auto in_0 = op_->impl()->InputsTensor()[0]; + auto in_1 = op_->impl()->InputsTensor()[1]; + std::shared_ptr short_tensor = + in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0; + std::shared_ptr long_tensor = + in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0; + if (in_0->GetSpec().attr_ != tim::vx::CONSTANT && + in_1->GetSpec().attr_ != tim::vx::CONSTANT && + in_0->GetShape().size() != in_1->GetShape().size()) { + auto pv_long = context_->GetPermuteVector(long_tensor); + auto pv_short = context_->GetPermuteVector(short_tensor); + auto rank_long = pv_long->Rank(); + auto rank_short = pv_short->Rank(); + auto expanded_pv = MakeShared(rank_long); + // if different size, expand short pv to long pv + for (uint32_t i = 0; i < rank_short; ++i) { + expanded_pv->At(i) = pv_short->At(i); // replace low dims with short pv + } + auto expanded_shape = + GetExpandedShape(long_tensor->GetShape(), short_tensor->GetShape()); + short_tensor->GetSpec().SetShape(expanded_shape); + + context_->SetPermuteVector(short_tensor, expanded_pv); // set new expand pv + } auto required_pv = AlignPermuteVectorForElementWise(); auto multiply = context_->infer_graph_->CreateOperation( diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 1d1f0e83a..6efa405f1 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -297,14 +297,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() { std::vector OpLayoutInfer::GetExpandedShape( const std::vector& ref_shape, const std::vector& origin_shape) { - std::vector expanded_shape; - for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) { - if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) { - expanded_shape.push_back(origin_shape[j]); - ++j; - } else { - expanded_shape.push_back(1); - } + std::vector expanded_shape(origin_shape); + auto ref_rank = ref_shape.size(); + auto origin_rank = origin_shape.size(); + for (uint32_t i = 0; i < ref_rank; ++i) { + if (i >= origin_rank) expanded_shape.push_back(1); } return expanded_shape; }