Skip to content

Commit

Permalink
Merge pull request #43100 from cms-tau-pog/CMSSW_13_3_X_tau-pog_pnetT…
Browse files Browse the repository at this point in the history
…auAtHLT

Enable HLT tau filtering based on jet-tags (ParticleNet)
  • Loading branch information
cmsbuild committed Oct 27, 2023
2 parents 3df8fcb + 0cffcda commit 666097a
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 52 deletions.
2 changes: 1 addition & 1 deletion RecoTauTag/HLTProducers/src/L2TauTagFilter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class L2TauTagFilter : public HLTFilter {
}
for (size_t l1_idx = 0; l1_idx < l1Taus.size(); l1_idx++) {
if (L2Outcomes[l1_idx] >= discrWP_ || l1Taus[l1_idx]->pt() > l1PtTh_) {
filterproduct.addObject(nTauPassed, l1Taus[l1_idx]);
filterproduct.addObject(trigger::TriggerL1Tau, l1Taus[l1_idx]);
nTauPassed++;
}
}
Expand Down
171 changes: 171 additions & 0 deletions RecoTauTag/HLTProducers/src/TauTagFilter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* \class TauTagFilter
*
* Filter tau candidates based on tagger scores.
*
* \author Konstantin Androsov, EPFL and ETHZ
*/

#include "DataFormats/BTauReco/interface/JetTag.h"
#include "DataFormats/HLTReco/interface/TriggerTypeDefs.h"
#include "FWCore/Framework/interface/Event.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Framework/interface/Frameworkfwd.h"
#include "FWCore/Utilities/interface/InputTag.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "HLTrigger/HLTcore/interface/HLTFilter.h"
#include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"

class TauTagFilter : public HLTFilter {
public:
using TauCollection = reco::PFJetCollection;
using TauTagCollection = reco::JetTagCollection;
using TauRef = reco::PFJetRef;
using Selector = tau::TauWPThreshold;
using LorentzVectorM = math::PtEtaPhiMLorentzVector;

explicit TauTagFilter(const edm::ParameterSet& cfg)
: HLTFilter(cfg),
nExpected_(cfg.getParameter<int>("nExpected")),
tausSrc_(cfg.getParameter<edm::InputTag>("taus")),
tausToken_(consumes<TauCollection>(tausSrc_)),
tauTagsToken_(consumes<TauTagCollection>(cfg.getParameter<edm::InputTag>("tauTags"))),
tauPtCorrToken_(mayConsume<TauTagCollection>(cfg.getParameter<edm::InputTag>("tauPtCorr"))),
seedsSrc_(mayConsume<trigger::TriggerFilterObjectWithRefs>(cfg.getParameter<edm::InputTag>("seeds"))),
seedTypes_(cfg.getParameter<std::vector<int>>("seedTypes")),
selector_(cfg.getParameter<std::string>("selection")),
minPt_(cfg.getParameter<double>("minPt")),
maxEta_(cfg.getParameter<double>("maxEta")),
usePtCorr_(cfg.getParameter<bool>("usePtCorr")),
matchWithSeeds_(cfg.getParameter<bool>("matchWithSeeds") && cfg.getParameter<double>("matchingdR") >= 0),
matchingdR2_(std::pow(cfg.getParameter<double>("matchingdR"), 2)) {
if (cfg.getParameter<bool>("matchWithSeeds") && cfg.getParameter<double>("matchingdR") < 0)
edm::LogWarning("TauTagFilter") << "Matching with seeds is disabled because matchingdR < 0";

extractMomenta(); // checking that all seed types are supported
}

static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
makeHLTFilterDescription(desc);
desc.add<int>("nExpected", 2)->setComment("number of expected taus per event");
desc.add<edm::InputTag>("taus", edm::InputTag(""))->setComment("input collection of taus");
desc.add<edm::InputTag>("tauTags", edm::InputTag(""))->setComment("input collection of tau tagger scores");
desc.add<edm::InputTag>("tauPtCorr", edm::InputTag(""))
->setComment("input collection of multiplicative tau pt corrections");
desc.add<edm::InputTag>("seeds", edm::InputTag(""))->setComment("input collection of seeds");
desc.add<std::vector<int>>("seedTypes",
{trigger::TriggerL1Tau, trigger::TriggerL1Jet, trigger::TriggerTau, trigger::TriggerJet})
->setComment("list of seed object types");
desc.add<std::string>("selection", "0")->setComment("selection formula");
desc.add<double>("minPt", 20)->setComment("minimal tau pt");
desc.add<double>("maxEta", 2.5)->setComment("maximal tau abs(eta)");
desc.add<bool>("usePtCorr", false)->setComment("use multiplicative tau pt corrections");
desc.add<bool>("matchWithSeeds", false)->setComment("apply match with seeds");
desc.add<double>("matchingdR", 0.5)->setComment("deltaR for matching with seeds");
descriptions.addWithDefaultLabel(desc);
}

bool hltFilter(edm::Event& event,
const edm::EventSetup& eventsetup,
trigger::TriggerFilterObjectWithRefs& filterproduct) const override {
if (saveTags())
filterproduct.addCollectionTag(tausSrc_);

int nTauPassed = 0;

const auto tausHandle = event.getHandle(tausToken_);
const auto& taus = *tausHandle;

const std::vector<LorentzVectorM> seed_p4s = extractMomenta(&event);
auto hasMatch = [&](const LorentzVectorM& p4) {
for (const auto& seed_p4 : seed_p4s) {
if (reco::deltaR2(p4, seed_p4) < matchingdR2_)
return true;
}
return false;
};

const auto& tauTags = event.get(tauTagsToken_);
const TauTagCollection* tauPtCorr = nullptr;
if (usePtCorr_)
tauPtCorr = &event.get(tauPtCorrToken_);

if (taus.size() != tauTags.size())
throw cms::Exception("Inconsistent Data", "TauTagFilter::hltFilter") << "taus.size() != tauTags.size()";
if (usePtCorr_ && taus.size() != tauPtCorr->size())
throw cms::Exception("Inconsistent Data", "TauTagFilter::hltFilter") << "taus.size() != tauPtCorr.size()";

for (size_t tau_idx = 0; tau_idx < taus.size(); ++tau_idx) {
const auto& tau = taus[tau_idx];
double pt = tau.pt();
if (usePtCorr_)
pt *= (*tauPtCorr)[tau_idx].second;
const double eta = std::abs(tau.eta());
if (pt > minPt_ && eta < maxEta_ && (!matchWithSeeds_ || hasMatch(tau.polarP4()))) {
const double tag = tauTags[tau_idx].second;
const double tag_thr = selector_(tau);
if (tag > tag_thr) {
filterproduct.addObject(trigger::TriggerTau, TauRef(tausHandle, tau_idx));
nTauPassed++;
}
}
}

return nTauPassed >= nExpected_;
}

private:
std::vector<LorentzVectorM> extractMomenta(const edm::Event* event = nullptr) const {
std::vector<LorentzVectorM> seed_p4s;
if (matchWithSeeds_) {
const trigger::TriggerFilterObjectWithRefs* seeds = nullptr;
if (event)
seeds = &event->get(seedsSrc_);
for (const int seedType : seedTypes_) {
if (seedType == trigger::TriggerL1Tau) {
extractMomenta<l1t::TauVectorRef>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerL1Jet) {
extractMomenta<l1t::JetVectorRef>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerTau) {
extractMomenta<std::vector<reco::PFTauRef>>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerJet) {
extractMomenta<std::vector<reco::PFJetRef>>(seeds, seedType, seed_p4s);
} else
throw cms::Exception("Invalid seed type", "TauTagFilter::extractMomenta")
<< "Unsupported seed type: " << seedType;
}
}
return seed_p4s;
}

template <typename Collection>
static void extractMomenta(const trigger::TriggerRefsCollections* triggerObjects,
int objType,
std::vector<LorentzVectorM>& p4s) {
if (triggerObjects) {
Collection objects;
triggerObjects->getObjects(objType, objects);
for (const auto& obj : objects)
p4s.push_back(obj->polarP4());
}
}

private:
const int nExpected_;
const edm::InputTag tausSrc_;
const edm::EDGetTokenT<TauCollection> tausToken_;
const edm::EDGetTokenT<TauTagCollection> tauTagsToken_, tauPtCorrToken_;
const edm::EDGetTokenT<trigger::TriggerFilterObjectWithRefs> seedsSrc_;
const std::vector<int> seedTypes_;
const Selector selector_;
const double minPt_, maxEta_;
const bool usePtCorr_;
const bool matchWithSeeds_;
const double matchingdR2_;
};

//define this as a plug-in
#include "FWCore/Framework/interface/MakerMacros.h"
DEFINE_FWK_MODULE(TauTagFilter);
64 changes: 64 additions & 0 deletions RecoTauTag/RecoTau/interface/TauWPThreshold.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef RecoTauTag_RecoTau_TauWPThreshold_h
#define RecoTauTag_RecoTau_TauWPThreshold_h

#include "DataFormats/TauReco/interface/BaseTau.h"
#include "DataFormats/PatCandidates/interface/Tau.h"
#include <TF1.h>

namespace tau {
class TauWPThreshold {
public:
explicit TauWPThreshold(const std::string& cut_str) {
bool simple_value = false;
try {
size_t pos = 0;
value_ = std::stod(cut_str, &pos);
simple_value = (pos == cut_str.size());
} catch (std::invalid_argument&) {
} catch (std::out_of_range&) {
}
if (!simple_value) {
static const std::string prefix =
"[&](double *x, double *p) { const int decayMode = p[0];"
"const double pt = p[1]; const double eta = p[2];";
static const int n_params = 3;
static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};

std::string fn_str = prefix;
if (cut_str.find("return") == std::string::npos)
fn_str += " return " + cut_str + ";}";
else
fn_str += cut_str + "}";
auto old_handler = SetErrorHandler(handler);
fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
SetErrorHandler(old_handler);
if (!fn_->IsValid())
throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
}
}

double operator()(int dm, double pt, double eta) const {
if (!fn_)
return value_;

fn_->SetParameter(0, dm);
fn_->SetParameter(1, pt);
fn_->SetParameter(2, eta);
return fn_->Eval(0);
}

double operator()(const reco::BaseTau& tau, bool isPFTau) const {
const int dm =
isPFTau ? dynamic_cast<const reco::PFTau&>(tau).decayMode() : dynamic_cast<const pat::Tau&>(tau).decayMode();
return (*this)(dm, tau.pt(), tau.eta());
}

double operator()(const reco::Candidate& tau) const { return (*this)(-1, tau.pt(), tau.eta()); }

private:
std::unique_ptr<TF1> fn_;
double value_;
};
} // namespace tau

#endif
53 changes: 2 additions & 51 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
#include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
#include "CommonTools/Utils/interface/StringObjectFunction.h"
#include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
#include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
#include "DataFormats/Common/interface/View.h"
#include "DataFormats/Common/interface/RefToBase.h"
#include "DataFormats/Provenance/interface/ProductProvenance.h"
#include "DataFormats/Provenance/interface/ProcessHistoryID.h"
#include "FWCore/Common/interface/Provenance.h"
#include <TF1.h>
#include <map>
#include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
#include "FWCore/Utilities/interface/isFinite.h"
Expand All @@ -47,55 +47,6 @@ namespace deep_tau {
PUcorrPtSum
};

class TauWPThreshold {
public:
explicit TauWPThreshold(const std::string& cut_str) {
bool simple_value = false;
try {
size_t pos = 0;
value_ = std::stod(cut_str, &pos);
simple_value = (pos == cut_str.size());
} catch (std::invalid_argument&) {
} catch (std::out_of_range&) {
}
if (!simple_value) {
static const std::string prefix =
"[&](double *x, double *p) { const int decayMode = p[0];"
"const double pt = p[1]; const double eta = p[2];";
static const int n_params = 3;
static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};

std::string fn_str = prefix;
if (cut_str.find("return") == std::string::npos)
fn_str += " return " + cut_str + ";}";
else
fn_str += cut_str + "}";
auto old_handler = SetErrorHandler(handler);
fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
SetErrorHandler(old_handler);
if (!fn_->IsValid())
throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
}
}
double operator()(const reco::BaseTau& tau, bool isPFTau) const {
if (!fn_) {
return value_;
}

if (isPFTau)
fn_->SetParameter(0, dynamic_cast<const reco::PFTau&>(tau).decayMode());
else
fn_->SetParameter(0, dynamic_cast<const pat::Tau&>(tau).decayMode());
fn_->SetParameter(1, tau.pt());
fn_->SetParameter(2, tau.eta());
return fn_->Eval(0);
}

private:
std::unique_ptr<TF1> fn_;
double value_;
};

class DeepTauCache {
public:
using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
Expand Down Expand Up @@ -951,7 +902,7 @@ class DeepTauId : public edm::stream::EDProducer<edm::GlobalCache<deep_tau::Deep
using ElectronCollection = pat::ElectronCollection;
using MuonCollection = pat::MuonCollection;
using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
using Cutter = deep_tau::TauWPThreshold;
using Cutter = tau::TauWPThreshold;
using CutterPtr = std::unique_ptr<Cutter>;
using WPList = std::vector<CutterPtr>;

Expand Down

0 comments on commit 666097a

Please sign in to comment.