From 18d8266e06607b1317d0b296039fc606a51d0837 Mon Sep 17 00:00:00 2001 From: Feiyue Chen Date: Wed, 15 May 2024 08:23:53 +0000 Subject: [PATCH] Added ifdef marco for some later added ops Type: Code Improvement Signed-off-by: Feiyue Chen --- include/tim/vx/ops/scatternd_onnx_v16.h | 3 +++ include/tim/vx/ops/simple_operations.h | 27 ++++++++++++++----- src/tim/transform/layout_inference.cc | 12 ++++++++- .../ops/simple_ops_layout_inference.h | 26 ++++++++++++++---- src/tim/vx/ops/scatternd_onnx_v16.cc | 4 ++- src/tim/vx/ops/scatternd_onnx_v16_test.cc | 2 ++ src/tim/vx/ops/simple_operations.cc | 26 ++++++++++++++---- src/tim/vx/ops/simple_operations_test.cc | 6 ++++- 8 files changed, 87 insertions(+), 19 deletions(-) diff --git a/include/tim/vx/ops/scatternd_onnx_v16.h b/include/tim/vx/ops/scatternd_onnx_v16.h index 9698e9a0c..024c2e265 100644 --- a/include/tim/vx/ops/scatternd_onnx_v16.h +++ b/include/tim/vx/ops/scatternd_onnx_v16.h @@ -25,6 +25,8 @@ #define TIM_VX_OPS_SCATTERND_ONNX_V16_H_ #include "tim/vx/builtin_op.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE + namespace tim { namespace vx { namespace ops { @@ -57,4 +59,5 @@ class ScatterND_ONNX_V16 : public BuiltinOp { } // namespace vx } // namespace tim +#endif #endif /* TIM_VX_OPS_SCATTERND_ONNX_V16_H_ */ diff --git a/include/tim/vx/ops/simple_operations.h b/include/tim/vx/ops/simple_operations.h index ea26dcfbf..2f4d52611 100644 --- a/include/tim/vx/ops/simple_operations.h +++ b/include/tim/vx/ops/simple_operations.h @@ -65,7 +65,7 @@ namespace ops { * ## ATan * * ATan(x) : arctan(x) - * + * * ## ACosh * * ACosh(x) : arccosh(x) @@ -119,11 +119,6 @@ DECLARE_SIMPLE_OP(DataConvert) DECLARE_SIMPLE_OP(Neg) DECLARE_SIMPLE_OP(Abs) DECLARE_SIMPLE_OP(Sin) -DECLARE_SIMPLE_OP(Cos) -DECLARE_SIMPLE_OP(Tan) -DECLARE_SIMPLE_OP(ATan) -DECLARE_SIMPLE_OP(ATanh) -DECLARE_SIMPLE_OP(ACosh) DECLARE_SIMPLE_OP(Exp) DECLARE_SIMPLE_OP(Log) DECLARE_SIMPLE_OP(Sqrt) @@ -136,6 +131,26 @@ DECLARE_SIMPLE_OP(Round) DECLARE_SIMPLE_OP(Cast) DECLARE_SIMPLE_OP(Rcp) +#ifdef VSI_FEAT_OP_COS +DECLARE_SIMPLE_OP(Cos) +#endif + +#ifdef VSI_FEAT_OP_TAN +DECLARE_SIMPLE_OP(Tan) +#endif + +#ifdef VSI_FEAT_OP_ATAN +DECLARE_SIMPLE_OP(ATan) +#endif + +#ifdef VSI_FEAT_OP_ATANH +DECLARE_SIMPLE_OP(ATanh) +#endif + +#ifdef VSI_FEAT_OP_ACOSH +DECLARE_SIMPLE_OP(ACosh) +#endif + #undef DECLARE_SIMPLE_OP } // namespace ops diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index a464c4d29..7479470f7 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -262,12 +262,22 @@ std::vector> HandleLayoutInfer( REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_NEG, Neg); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ABS, Abs); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SIN, Sin); + REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh); +#ifdef VSI_FEAT_OP_COS REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_COS, Cos); +#endif +#ifdef VSI_FEAT_OP_TAN REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TAN, Tan); - REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_TANH, Tanh); +#endif +#ifdef VSI_FEAT_OP_ATAN REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATAN, ATan); +#endif +#ifdef VSI_FEAT_OP_ATANH REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ATANH, ATanh); +#endif +#ifdef VSI_FEAT_OP_ACOSH REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_ACOSH, ACosh); +#endif REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_EXP, Exp); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_LOG, Log); REGISTER_LAYOUT_INFERENCE(VSI_NN_OP_SQRT, Sqrt); diff --git a/src/tim/transform/ops/simple_ops_layout_inference.h b/src/tim/transform/ops/simple_ops_layout_inference.h index 7100fc6d0..01b313afa 100644 --- a/src/tim/transform/ops/simple_ops_layout_inference.h +++ b/src/tim/transform/ops/simple_ops_layout_inference.h @@ -60,11 +60,7 @@ using DataConvertLayoutInfer = SimpleOpsLayoutInfer; using NegLayoutInfer = SimpleOpsLayoutInfer; using AbsLayoutInfer = SimpleOpsLayoutInfer; using SinLayoutInfer = SimpleOpsLayoutInfer; -using CosLayoutInfer = SimpleOpsLayoutInfer; -using TanLayoutInfer = SimpleOpsLayoutInfer; -using ATanLayoutInfer = SimpleOpsLayoutInfer; -using ATanhLayoutInfer = SimpleOpsLayoutInfer; -using ACoshLayoutInfer = SimpleOpsLayoutInfer; + using ExpLayoutInfer = SimpleOpsLayoutInfer; using LogLayoutInfer = SimpleOpsLayoutInfer; using SqrtLayoutInfer = SimpleOpsLayoutInfer; @@ -72,6 +68,26 @@ using RsqrtLayoutInfer = SimpleOpsLayoutInfer; using SquareLayoutInfer = SimpleOpsLayoutInfer; using LogicalNotLayoutInfer = SimpleOpsLayoutInfer; +#ifdef VSI_FEAT_OP_COS +using CosLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_TAN +using TanLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ATAN +using ATanLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ATANH +using ATanhLayoutInfer = SimpleOpsLayoutInfer; +#endif + +#ifdef VSI_FEAT_OP_ACOSH +using ACoshLayoutInfer = SimpleOpsLayoutInfer; +#endif + } // namespace transform } // namespace tim diff --git a/src/tim/vx/ops/scatternd_onnx_v16.cc b/src/tim/vx/ops/scatternd_onnx_v16.cc index 7e9d7375c..8f64d54d6 100644 --- a/src/tim/vx/ops/scatternd_onnx_v16.cc +++ b/src/tim/vx/ops/scatternd_onnx_v16.cc @@ -22,6 +22,7 @@ * *****************************************************************************/ #include "tim/vx/ops/scatternd_onnx_v16.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE #include "builtin_op_impl.h" #include "vsi_nn_pub.h" @@ -60,4 +61,5 @@ std::shared_ptr ScatterND_ONNX_V16::Clone(std::shared_ptr& gra } // namespace ops } // namespace vx -} // namespace tim \ No newline at end of file +} // namespace tim +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/scatternd_onnx_v16_test.cc b/src/tim/vx/ops/scatternd_onnx_v16_test.cc index ef8c28e84..dff3c8522 100644 --- a/src/tim/vx/ops/scatternd_onnx_v16_test.cc +++ b/src/tim/vx/ops/scatternd_onnx_v16_test.cc @@ -24,6 +24,7 @@ #include "tim/vx/context.h" #include "tim/vx/graph.h" #include "tim/vx/ops/scatternd_onnx_v16.h" +#ifdef VSI_FEAT_OP_SCATTER_ND_UPDATE #include "gtest/gtest.h" @@ -71,3 +72,4 @@ TEST(ScatterND_ONNX_V16, shape_8) { EXPECT_EQ(golden, output); } +#endif \ No newline at end of file diff --git a/src/tim/vx/ops/simple_operations.cc b/src/tim/vx/ops/simple_operations.cc index d6b881aae..ff8a27473 100644 --- a/src/tim/vx/ops/simple_operations.cc +++ b/src/tim/vx/ops/simple_operations.cc @@ -40,11 +40,6 @@ DEFINE_SIMPLE_OP(DataConvert, VSI_NN_OP_DATACONVERT) DEFINE_SIMPLE_OP(Neg, VSI_NN_OP_NEG) DEFINE_SIMPLE_OP(Abs, VSI_NN_OP_ABS) DEFINE_SIMPLE_OP(Sin, VSI_NN_OP_SIN) -DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS) -DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN) -DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN) -DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH) -DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH) DEFINE_SIMPLE_OP(Exp, VSI_NN_OP_EXP) DEFINE_SIMPLE_OP(Log, VSI_NN_OP_LOG) DEFINE_SIMPLE_OP(Sqrt, VSI_NN_OP_SQRT) @@ -57,6 +52,27 @@ DEFINE_SIMPLE_OP(Round, VSI_NN_OP_ROUND) DEFINE_SIMPLE_OP(Cast, VSI_NN_OP_CAST) DEFINE_SIMPLE_OP(Rcp, VSI_NN_OP_RCP) +#ifdef VSI_FEAT_OP_COS +DEFINE_SIMPLE_OP(Cos, VSI_NN_OP_COS) +#endif + +#ifdef VSI_FEAT_OP_TAN +DEFINE_SIMPLE_OP(Tan, VSI_NN_OP_TAN) +#endif + +#ifdef VSI_FEAT_OP_ATAN +DEFINE_SIMPLE_OP(ATan, VSI_NN_OP_ATAN) +#endif + +#ifdef VSI_FEAT_OP_ATANH +DEFINE_SIMPLE_OP(ATanh, VSI_NN_OP_ATANH) +#endif + +#ifdef VSI_FEAT_OP_ACOSH +DEFINE_SIMPLE_OP(ACosh, VSI_NN_OP_ACOSH) +#endif + + #undef DEFINE_SIMPLE_OP } // namespace ops diff --git a/src/tim/vx/ops/simple_operations_test.cc b/src/tim/vx/ops/simple_operations_test.cc index 81cb19a56..7db4ce18c 100644 --- a/src/tim/vx/ops/simple_operations_test.cc +++ b/src/tim/vx/ops/simple_operations_test.cc @@ -263,6 +263,7 @@ TEST(Rcp, shape_5_1_fp32) { EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#ifdef VSI_FEAT_OP_COS TEST(Cos, shape_5_1_fp32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -290,7 +291,9 @@ TEST(Cos, shape_5_1_fp32) { EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f)); } +#endif +#ifdef VSI_FEAT_OP_TAN TEST(Tan, shape_5_1_fp32) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -317,4 +320,5 @@ TEST(Tan, shape_5_1_fp32) { std::vector output(5); EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data())); EXPECT_TRUE(ArraysMatch(golden, output, 1e-4f)); -} \ No newline at end of file +} +#endif \ No newline at end of file