Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip evaluation of TensorFlow model if inputs are empty #45139

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions PhysicsTools/TensorFlow/interface/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ namespace tensorflow {
// version of the function above that accepts a const session
bool closeSession(const Session*& session);

bool checkEmptyInputs(const NamedTensorList& inputs);

// run the session with inputs and outputNames, store output tensors, and control the underlying
// thread pool using threadPoolOptions
// used for thread scheduling with custom thread pool options
Expand Down
17 changes: 17 additions & 0 deletions PhysicsTools/TensorFlow/src/TensorFlow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ namespace tensorflow {
return state;
}

bool checkEmptyInputs(const NamedTensorList& inputs) {
// check for empty tensors in the inputs
bool isEmpty = false;
for (const auto& input : inputs) {
// Checking using the shape
if (input.second.shape().num_elements() == 0) {
isEmpty = true;
break;
}
}
return isEmpty;
}

void run(Session* session,
const NamedTensorList& inputs,
const std::vector<std::string>& outputNames,
Expand All @@ -268,6 +281,10 @@ namespace tensorflow {
// create empty run options
RunOptions runOptions;

// Check if the inputs are empty
if (checkEmptyInputs(inputs))
return;

// run and check the status
Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
if (!status.ok()) {
Expand Down
7 changes: 7 additions & 0 deletions PhysicsTools/TensorFlow/test/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@
<use name="PhysicsTools/TensorFlow"/>
</bin>

<bin name="testTFEmptyInputs" file="testRunner.cpp,testEmptyInputs.cc">
<use name="boost_filesystem"/>
<use name="cppunit"/>
<use name="PhysicsTools/TensorFlow"/>
</bin>



<iftool name="cuda">
<bin name="testTFVisibleDevicesCUDA" file="testRunner.cpp,testVisibleDevicesCUDA.cc">
Expand Down
57 changes: 57 additions & 0 deletions PhysicsTools/TensorFlow/test/testEmptyInputs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Tests for working with empty inputs
*
*/

#include <stdexcept>
#include <cppunit/extensions/HelperMacros.h>

#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"

#include "testBase.h"

class testEmptyInputs : public testBase {
CPPUNIT_TEST_SUITE(testEmptyInputs);
CPPUNIT_TEST(test);
CPPUNIT_TEST_SUITE_END();

public:
std::string pyScript() const override;
void test() override;
};

CPPUNIT_TEST_SUITE_REGISTRATION(testEmptyInputs);

std::string testEmptyInputs::pyScript() const { return "createconstantgraph.py"; }

void testEmptyInputs::test() {
std::string pbFile = dataPath_ + "/constantgraph.pb";

std::cout << "Testing CPU backend" << std::endl;
tensorflow::Backend backend = tensorflow::Backend::cpu;

// load the graph
tensorflow::Options options{backend};
tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
CPPUNIT_ASSERT(graphDef != nullptr);

// create a new session and add the graphDef
const tensorflow::Session* session = tensorflow::createSession(graphDef, options);
CPPUNIT_ASSERT(session != nullptr);

// example evaluation with empty tensor
tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 0});
tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
scale.scalar<float>()() = 1.0;
std::vector<tensorflow::Tensor> outputs;

// run using the convenience helper
outputs.clear();
tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
CPPUNIT_ASSERT(outputs.size() == 0);

// cleanup
CPPUNIT_ASSERT(tensorflow::closeSession(session));
CPPUNIT_ASSERT(session == nullptr);
delete graphDef;
}