From a4d3a1ce0c18e1d1b31a9cc0b45beba290ee114c Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Sat, 27 Jul 2024 15:58:36 -0700 Subject: [PATCH] pick changes from https://github.com/onnx/onnx/pull/6195 to fix heap-buffer-overflow in onnx::convPoolShapeInference (#21507) ### Description onnx 1.16.2 is not available before ort 1.19.0 code freeze. Thus pick the needed change as patch --- cmake/patches/onnx/onnx.patch | 383 ++++++++++++++++++ .../providers/cpu/generator/random_test.cc | 8 +- .../core/graph/training_op_defs.cc | 104 +++-- 3 files changed, 447 insertions(+), 48 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 162d33581a5c..6ac3555eeecf 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -86,3 +86,386 @@ index 0aab3e26..398ac2d6 100644 +#endif + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc +index c315a2a7..58963154 100644 +--- a/onnx/defs/math/defs.cc ++++ b/onnx/defs/math/defs.cc +@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA( + } + + auto& input_shape = getInputShape(ctx, 0); ++ if (input_shape.dim_size() < 2) { ++ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), "."); ++ } + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index be6a851d..fad595d0 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -126,6 +126,9 @@ void convPoolShapeInference( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA( + auto w_type = ctx.getInputType(3); + if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType || + w_type->value_case() != TypeProto::kTensorType) { +- fail_type_inference("inputs are expected to have tensor type."); ++ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), "."); + } + + auto x_zero_point_type = ctx.getInputType(2); + if (nullptr == x_zero_point_type || + x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) { +- fail_type_inference("input and zero_point pair is expected to have be same type."); ++ fail_type_inference( ++ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), "."); + } + + auto w_zero_point_type = ctx.getInputType(5); + if (nullptr == w_zero_point_type || + w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) { +- fail_type_inference("weight and zero_point pair is expected to have same type."); ++ fail_type_inference( ++ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromInputToOutput(ctx, 7, 0); +@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA( + if (!hasNInputShapes(ctx, 1)) { + return; + } +- auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); ++ ++ auto& input_shape = getInputShape(ctx, 0); + int64_t input_ndim = input_shape.dim_size(); + int64_t axis = -1; + auto axis_proto = ctx.getAttribute("axis"); +@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA( + // positive value. + axis += input_ndim; + } +- ++ if (axis < 0) { ++ fail_shape_inference( ++ "Unexpected axis value (", ++ axis, ++ ") rank of first input is ", ++ input_ndim, ++ " in ", ++ ctx.getDisplayName(), ++ "."); ++ } + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index 57f8e2a4..8b2dc07f 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19( + residual -= stride; + } + } ++ if (i >= static_cast(effective_kernel_shape.size())) { ++ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), "."); ++ } + int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual; + if (total_pad < 0) + total_pad = 0; +diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h +index a80473b3..d1bcd401 100644 +--- a/onnx/defs/shape_inference.h ++++ b/onnx/defs/shape_inference.h +@@ -105,6 +105,10 @@ struct InferenceContext { + virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; + // Gets the shape inputs computed by partial data propagation. + virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; ++ // To display a name the user can use to narrow its search. ++ virtual std::string getDisplayName() const { ++ return ""; ++ } + }; + + // We use data propagation to perform partial evaluation of the model, to compute statically +@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput( + } else { + // This is not expected to happen + fail_type_inference( +- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case); ++ "Output ", ++ outputIndex, ++ " expected to have: ", ++ expected_value_case, ++ " or UNDEFINED. Got: ", ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr + const auto attr_type = attr->type(); + if (attr_type == AttributeProto::TENSOR) { + if (attr->t().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim tensor"); ++ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->t().data_type(); + expected_value_case = TypeProto::kTensorType; + } else if (attr_type == AttributeProto::SPARSE_TENSOR) { + if (attr->sparse_tensor().dims().size() != 1) { +- fail_type_inference("Attribute expected to have a one-dim sparse tensor"); ++ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), "."); + } + data_type = attr->sparse_tensor().values().data_type(); + expected_value_case = TypeProto::kSparseTensorType; + } else { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), "."); + } + + propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); +@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t + const auto* input_type = ctx.getInputType(n); + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); ++ } ++ if (!hasShape(*input_type)) { ++ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return input_type->tensor_type().shape(); +@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size + + const auto value_case = input_type->value_case(); + if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { +- fail_type_inference("Attribute expected to have tensor or sparse tensor type"); ++ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), "."); + } + if (value_case == TypeProto::kTensorType) { + return &input_type->tensor_type().shape(); +@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + " does not match type of output: ", + outputIndex, + "type: ", +- output_value_case); ++ output_value_case, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + if (TypeProto::kTensorType == input_value_case) { + auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); +@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType( + *dim = input_type->sparse_tensor_type().shape().dim(static_cast(fromDimIndex)); + } else { + fail_type_inference( +- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type"); ++ "Input ", ++ inputIndex, ++ " and Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType + setTensorElementType(elemType, expected_type, *output_type); + } else { + // This is not expected to happen +- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type); ++ fail_type_inference( ++ "Output ", ++ outputIndex, ++ " expected to have tensor or sparse tensor type: ", ++ expected_type, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + +@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput( + updateOutputElemType(ctx, outputIndex, default_value, expected_type); + return; + } else { +- fail_type_inference("Value of attribute ", attributeName, " not specified"); ++ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), "."); + } + } + if (!attr_proto->has_i()) { +- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type."); ++ fail_type_inference( ++ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), "."); + } + auto attr_value = attr_proto->i(); + auto elem_type = static_cast(attr_value); + if (!TensorProto_DataType_IsValid(elem_type)) { +- fail_type_inference("Attribute ", attributeName, " does not specify a valid type."); ++ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), "."); + } + updateOutputElemType(ctx, outputIndex, elem_type, expected_type); + } +@@ -497,7 +529,7 @@ inline TensorShapeProto* + getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { + auto output_type = ctx.getOutputType(n); + if (output_type == nullptr) { +- fail_type_inference("Output ", n, " expected to have tensor or sparse type"); ++ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), "."); + } + const auto output_value_case = output_type->value_case(); + if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { +@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ + } else if (output_value_case == TypeProto::VALUE_NOT_SET) { + return getTensorMutableShape(default_type, *output_type); + } else { +- fail_type_inference("Output ", n, " expected to have tensor type"); ++ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), "."); + } + } + +@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput( + auto attr_proto = ctx.getAttribute(attributeName); + if ((nullptr == attr_proto) || (!attr_proto->has_type()) || + (attr_proto->type() != AttributeProto_AttributeType_INTS)) { +- fail_shape_inference("Attribute ", attributeName, " should specify a shape"); ++ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), "."); + } + auto& int_list = attr_proto->ints(); + TensorShapeProto shape; + for (auto dim_size : int_list) { + if (dim_size < 0) { +- fail_shape_inference("Negative values are not allowed in a shape specification"); ++ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), "."); + } + shape.add_dim()->set_dim_value(dim_size); + } +@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect + if (hasInputShape(ctx, input_index)) { + auto rank = getInputShape(ctx, input_index).dim_size(); + if (rank != expected_rank) { +- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank); ++ fail_shape_inference( ++ "Input ", ++ input_index, ++ " expected to have rank ", ++ expected_rank, ++ " but has rank ", ++ rank, ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + } + } +@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind + // This shape is expected to have rank > dim_index: + if (input_shape.dim_size() <= dim_index) { + fail_shape_inference( +- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size()); ++ "Input ", ++ input_index, ++ " expected to have rank >", ++ dim_index, ++ " but has rank ", ++ input_shape.dim_size(), ++ " in ", ++ ctx.getDisplayName(), ++ "."); + } + const Dim& input_dim = input_shape.dim(dim_index); + // Now, unify dim and input_dim: +diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc +index 8723dcd4..8249fc59 100644 +--- a/onnx/shape_inference/implementation.cc ++++ b/onnx/shape_inference/implementation.cc +@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext { + const std::vector& input_types, + const std::vector& attributes, + const ShapeInferenceOptions& options) +- : input_types_(input_types), options_(options) { ++ : input_types_(input_types), options_(options), func_proto_(&func_proto) { + for (const auto& attr : attributes) { + attributesByName_[attr.name()] = &attr; + } +@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext { + return std::move(output_types_); + } + ++ std::string getDisplayName() const override { ++ if (func_proto_ == nullptr) ++ return ""; ++ if (func_proto_->domain().empty()) { ++ if (func_proto_->name().empty()) ++ return ""; ++ return MakeString("function ", func_proto_->name()); ++ } ++ if (func_proto_->name().empty()) ++ return MakeString("function [", func_proto_->domain(), "]"); ++ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]"); ++ } ++ + private: + const std::vector& input_types_; + std::vector output_types_; + std::unordered_map attributesByName_; + ShapeInferenceOptions options_; ++ const FunctionProto* func_proto_; + }; + + std::vector InferFunctionOutputTypes( +diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h +index 2c63c910..b0e4c32d 100644 +--- a/onnx/shape_inference/implementation.h ++++ b/onnx/shape_inference/implementation.h +@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext { + const ShapeInferenceOptions& options, + DataValueMap* generatedShapeData = nullptr, + GraphInferenceContext* graphInferenceContext = nullptr) +- : graphInferenceContext_{graphInferenceContext}, options_(options) { ++ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) { + for (auto& attr : *n.mutable_attribute()) { + attributesByName_[attr.name()] = &attr; + if (attr.has_g()) { +@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext { + return inferencer; + } + ++ std::string getDisplayName() const override { ++ if (node_ == nullptr) ++ return ""; ++ if (node_->domain().empty()) { ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type()); ++ return MakeString("node ", node_->op_type(), " (", node_->name(), ")"); ++ } ++ if (node_->name().empty()) ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]"); ++ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")"); ++ } ++ + std::vector allInputData_; + std::vector allInputSparseData_; + std::vector allShapeInputData_; +@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext { + // mutable as internal cache of GraphInferencer instances + mutable std::unordered_map> graphAttributeInferencers_; + ShapeInferenceOptions options_; ++ NodeProto* node_; + }; + + struct DataPropagationContextImpl : public DataPropagationContext { diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index ec9b1614488a..f42f32d63d1f 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -178,7 +178,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormal) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -194,7 +194,7 @@ TEST(Random, InvalidDType) { test.AddAttribute("shape", dims); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniform) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -210,7 +210,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomNormalLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } { @@ -226,7 +226,7 @@ TEST(Random, InvalidDType) { test.AddInput("X", dims, input); test.AddOutput("Y", dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectFailure, "Attribute dtype does not specify a valid type."); + test.Run(OpTester::ExpectResult::kExpectFailure, "Node (node1) Op (RandomUniformLike) [TypeInferenceError] Attribute dtype does not specify a valid type in ."); } } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 2a8d2de982e7..92f803030ada 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -181,6 +181,64 @@ static void propagateRecvOutputTensorElemTypes( } } +void SendShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() < 3) { + fail_shape_inference("Send must have at least three inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Send must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Send must be a scalar."); + } + } + + checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); + } + + if (ctx.getNumOutputs() != 1) { + fail_shape_inference("Send must have one output."); + } + + auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_element_type->set_elem_type(TensorProto::BOOL); + ONNX_NAMESPACE::TensorShapeProto output_shape; + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); +} + +void RecvShapeInfer(ONNX_NAMESPACE::InferenceContext& ctx) { + if (ctx.getNumInputs() != 2) { + fail_shape_inference("Recv must have two inputs."); + } else { + if (hasInputShape(ctx, 0)) { + auto& signal_input_shape = getInputShape(ctx, 0); + if (static_cast(signal_input_shape.dim_size()) != 0) { + fail_shape_inference("InputSignal of Recv must be a scalar."); + } + } + if (hasInputShape(ctx, 1)) { + auto& remote_input_shape = getInputShape(ctx, 1); + if (static_cast(remote_input_shape.dim_size()) != 0) { + fail_shape_inference("Remote of Recv must be a scalar."); + } + } + } + + if (ctx.getNumOutputs() < 2) { + fail_shape_inference("Recv must have at least two outputs."); + } + + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); +} + TensorProto ToDimensionOneFloatTensor(float value) { auto t = ToTensor(std::vector({value})); t.add_dims(1); @@ -3388,30 +3446,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() < 3) { - fail_shape_inference("Send must have at least three inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Send must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Send must be a scalar."); - } - - checkSendInputTensorElemTypes(ctx, "element_types", ctx.getNumInputs() - 2); - } - - if (ctx.getNumOutputs() != 1) { - fail_shape_inference("Send must have one output."); - } - - auto output_element_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_element_type->set_elem_type(TensorProto::BOOL); - ONNX_NAMESPACE::TensorShapeProto output_shape; - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + SendShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(Recv) @@ -3437,26 +3472,7 @@ Return true if all elements are true and false otherwise. "Constrain types to boolean tensors.") .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - if (ctx.getNumInputs() != 2) { - fail_shape_inference("Recv must have two inputs."); - } else { - auto& signal_input_shape = getInputShape(ctx, 0); - if (static_cast(signal_input_shape.dim_size()) != 0) { - fail_shape_inference("InputSignal of Recv must be a scalar."); - } - auto& remote_input_shape = getInputShape(ctx, 1); - if (static_cast(remote_input_shape.dim_size()) != 0) { - fail_shape_inference("Remote of Recv must be a scalar."); - } - } - - if (ctx.getNumOutputs() < 2) { - fail_shape_inference("Recv must have at least two outputs."); - } - - updateOutputShape(ctx, 0, {}); - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - propagateRecvOutputTensorElemTypes(ctx, "element_types", ctx.getNumOutputs() - 1); + RecvShapeInfer(ctx); }); ONNX_CONTRIB_OPERATOR_SCHEMA(MegatronF)