Skip to content

Commit

Permalink
Merge pull request #45139 from valsdav/tf_empty_inputs
Browse files Browse the repository at this point in the history
Skip evaluation of TensorFlow model if inputs are empty
  • Loading branch information
cmsbuild committed Jun 6, 2024
2 parents db28858 + e5ab42f commit 581c010
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
2 changes: 2 additions & 0 deletions PhysicsTools/TensorFlow/interface/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ namespace tensorflow {
// version of the function above that accepts a const session
bool closeSession(const Session*& session);

bool checkEmptyInputs(const NamedTensorList& inputs);

// run the session with inputs and outputNames, store output tensors, and control the underlying
// thread pool using threadPoolOptions
// used for thread scheduling with custom thread pool options
Expand Down
17 changes: 17 additions & 0 deletions PhysicsTools/TensorFlow/src/TensorFlow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,19 @@ namespace tensorflow {
return state;
}

bool checkEmptyInputs(const NamedTensorList& inputs) {
// check for empty tensors in the inputs
bool isEmpty = false;
for (const auto& input : inputs) {
// Checking using the shape
if (input.second.shape().num_elements() == 0) {
isEmpty = true;
break;
}
}
return isEmpty;
}

void run(Session* session,
const NamedTensorList& inputs,
const std::vector<std::string>& outputNames,
Expand All @@ -277,6 +290,10 @@ namespace tensorflow {
// create empty run options
RunOptions runOptions;

// Check if the inputs are empty
if (checkEmptyInputs(inputs))
return;

// run and check the status
Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
if (!status.ok()) {
Expand Down
7 changes: 7 additions & 0 deletions PhysicsTools/TensorFlow/test/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@
<use name="PhysicsTools/TensorFlow"/>
</bin>

<bin name="testTFEmptyInputs" file="testRunner.cpp,testEmptyInputs.cc">
<use name="boost_filesystem"/>
<use name="cppunit"/>
<use name="PhysicsTools/TensorFlow"/>
</bin>



<iftool name="tf_cuda_support">
<bin name="testTFVisibleDevicesCUDA" file="testRunner.cpp,testVisibleDevicesCUDA.cc">
Expand Down
57 changes: 57 additions & 0 deletions PhysicsTools/TensorFlow/test/testEmptyInputs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Tests for working with empty inputs
*
*/

#include <stdexcept>
#include <cppunit/extensions/HelperMacros.h>

#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"

#include "testBase.h"

class testEmptyInputs : public testBase {
CPPUNIT_TEST_SUITE(testEmptyInputs);
CPPUNIT_TEST(test);
CPPUNIT_TEST_SUITE_END();

public:
std::string pyScript() const override;
void test() override;
};

CPPUNIT_TEST_SUITE_REGISTRATION(testEmptyInputs);

std::string testEmptyInputs::pyScript() const { return "createconstantgraph.py"; }

void testEmptyInputs::test() {
std::string pbFile = dataPath_ + "/constantgraph.pb";

std::cout << "Testing CPU backend" << std::endl;
tensorflow::Backend backend = tensorflow::Backend::cpu;

// load the graph
tensorflow::Options options{backend};
tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
CPPUNIT_ASSERT(graphDef != nullptr);

// create a new session and add the graphDef
const tensorflow::Session* session = tensorflow::createSession(graphDef, options);
CPPUNIT_ASSERT(session != nullptr);

// example evaluation with empty tensor
tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 0});
tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
scale.scalar<float>()() = 1.0;
std::vector<tensorflow::Tensor> outputs;

// run using the convenience helper
outputs.clear();
tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
CPPUNIT_ASSERT(outputs.size() == 0);

// cleanup
CPPUNIT_ASSERT(tensorflow::closeSession(session));
CPPUNIT_ASSERT(session == nullptr);
delete graphDef;
}

0 comments on commit 581c010

Please sign in to comment.