diff --git a/src/tim/transform/ops/activation_layout_inference.h b/src/tim/transform/ops/activation_layout_inference.h index d2ea49ac7..11659bd9b 100644 --- a/src/tim/transform/ops/activation_layout_inference.h +++ b/src/tim/transform/ops/activation_layout_inference.h @@ -73,40 +73,56 @@ class PReluLayoutInfer : public OpLayoutInfer { auto slope_shape = src_slope->GetShape(); auto input_pv = context_->GetPermuteVector(src_input); std::vector boardcast_shape; - for (uint32_t i = 0; i < input_shape.size(); ++i) { - if (i < slope_shape.size()) { - boardcast_shape.push_back(slope_shape[i]); - } else { - boardcast_shape.push_back(1); + if (slope_shape.size() != 1) { // Need to be transposed along with the input + for (uint32_t i = 0; i < input_shape.size(); ++i) { + if (i < slope_shape.size()) { + boardcast_shape.push_back(slope_shape[i]); + } else { + boardcast_shape.push_back(1); + } } - } - if (src_slope->IsConstTensor()) { - std::vector dataRef(src_slope->GetSpec().GetByteSize()); - src_slope->CopyDataFromTensor(dataRef.data()); - auto infer_slope_spec = src_slope->GetSpec(); - infer_slope_spec.SetShape(boardcast_shape); - auto infer_slope = context_->infer_graph_->CreateTensor( - infer_slope_spec, (const void*)dataRef.data()); + if (src_slope->IsConstTensor()) { + std::vector dataRef(src_slope->GetSpec().GetByteSize()); + src_slope->CopyDataFromTensor(dataRef.data()); + auto infer_slope_spec = src_slope->GetSpec(); + infer_slope_spec.SetShape(boardcast_shape); + auto infer_slope = context_->infer_graph_->CreateTensor( + infer_slope_spec, (const void*)dataRef.data()); - if (!input_pv->IsAligned()) { - //The dimension of slop is already the same as input, directly use input_pv to convert - auto out_slope = PermuteConstTensor(infer_slope, input_pv); - context_->UpdateTensorMap(src_slope, out_slope); + if (!input_pv->IsAligned()) { + //The dimension of slop is already the same as input, directly use input_pv to convert + auto out_slope = PermuteConstTensor(infer_slope, input_pv); + context_->UpdateTensorMap(src_slope, out_slope); + } else { + context_->UpdateTensorMap(src_slope, infer_slope); + } } else { + auto infer_slope_spec = src_slope->GetSpec().AsTransientSpec(); + auto reshape_out = + context_->infer_graph_->CreateTensor(infer_slope_spec); + boardcast_shape = + MapMultipleAxis(input_pv->AsStdVec(), boardcast_shape); + auto reshape = + context_->infer_graph_->CreateOperation( + boardcast_shape); + (*reshape) + .BindInput(context_->GetMapedTensor(src_slope)) + .BindOutput(reshape_out); + context_->UpdateTensorMap(src_slope, reshape_out); + } + context_->SetPermuteVector(src_slope, input_pv); + } else { // 1d slope tensor need not transpose + if (src_slope->IsConstTensor()) { + std::vector dataRef(src_slope->GetSpec().GetByteSize()); + src_slope->CopyDataFromTensor(dataRef.data()); + auto infer_slope_spec = src_slope->GetSpec(); + auto infer_slope = context_->infer_graph_->CreateTensor( + infer_slope_spec, (const void*)dataRef.data()); context_->UpdateTensorMap(src_slope, infer_slope); + context_->SetPermuteVector(src_slope, MakeShared(1)); } - } else { - auto infer_slope_spec = src_slope->GetSpec().AsTransientSpec(); - auto reshape_out = context_->infer_graph_->CreateTensor(infer_slope_spec); - boardcast_shape = MapMultipleAxis(input_pv->AsStdVec(), boardcast_shape); - auto reshape = context_->infer_graph_->CreateOperation(boardcast_shape); - (*reshape) - .BindInput(context_->GetMapedTensor(src_slope)) - .BindOutput(reshape_out); - context_->UpdateTensorMap(src_slope, reshape_out); } - context_->SetPermuteVector(src_slope, input_pv); auto axis = MapAxis(input_pv->AsStdVec(), op_->impl()->node()->nn_param.prelu.axis);