// -*- C++ -*-
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/FastJets.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/LeptonFinder.hh"
#include "Rivet/Projections/VetoedFinalState.hh"
#include "Rivet/Projections/PromptFinalState.hh"
#include "Rivet/Projections/MissingMomentum.hh"

namespace Rivet {


  /// ATLAS pTmiss+jets cross-section ratios at 13 TeV
  class ATLAS_2017_I1609448 : public Analysis {
  public:

    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(ATLAS_2017_I1609448);


    struct HistoHandler {
      Histo1DPtr histo;
      Estimate1DPtr estimate;
      unsigned int d, x, y;

      void fill(double value) {
        histo->fill(value);
      }
    };


    /// Initialize
    void init() {

      // Get options from the new option system
      _mode = 0;
      if ( getOption("LMODE") == "NU" ) _mode = 0; // using Z -> nunu channel by default
      if ( getOption("LMODE") == "MU" ) _mode = 1;
      if ( getOption("LMODE") == "EL" ) _mode = 2;

      // Prompt photons
      PromptFinalState photon_fs(Cuts::abspid == PID::PHOTON && Cuts::abseta < 4.9);
      // Prompt electrons
      PromptFinalState el_fs(Cuts::abseta < 4.9 && Cuts::abspid == PID::ELECTRON);
      // Prompt muons
      PromptFinalState mu_fs(Cuts::abseta < 4.9 && Cuts::abspid == PID::MUON);

      // Dressed leptons
      Cut lep_cuts = Cuts::pT > 7*GeV && Cuts::abseta < 2.5;
      LeptonFinder dressed_leps((_mode == 2 ? el_fs : mu_fs), photon_fs, 0.1, lep_cuts);
      declare(dressed_leps, "LeptonFinder");

      // In-acceptance leptons for lepton veto
      PromptFinalState veto_lep_fs(Cuts::abseta < 4.9 && (Cuts::abspid == PID::ELECTRON || Cuts::abspid == PID::MUON));
      veto_lep_fs.acceptTauDecays();
      veto_lep_fs.acceptMuonDecays();
      LeptonFinder veto_lep(veto_lep_fs, photon_fs, 0.1, lep_cuts);
      declare(veto_lep, "VetoLeptons");

      // MET
      VetoedFinalState met_fs(Cuts::abseta > 2.5 && Cuts::abspid == PID::MUON); // veto out-of-acceptance muons
      if (_mode) met_fs.addVetoOnThisFinalState(dressed_leps);
      declare(MissingMomentum(met_fs), "MET");

      // Jet collection
      FastJets jets(FinalState(Cuts::abseta < 4.9), JetAlg::ANTIKT, 0.4, JetMuons::NONE, JetInvisibles::NONE);
      declare(jets, "Jets");

      _h["met_mono"] = bookHandler(1, 1, 2);
      _h["met_vbf" ] = bookHandler(2, 1, 2);
      _h["mjj_vbf" ] = bookHandler(3, 1, 2);
      _h["dphijj_vbf"] = bookHandler(4, 1, 2);
    }


    HistoHandler bookHandler(unsigned int id_d, unsigned int id_x, unsigned int id_y) {
      HistoHandler dummy;
      if (_mode < 2) {  // numerator mode
        const string histName = "_" + mkAxisCode(id_d, id_x, id_y);
        book(dummy.histo, histName, refData(id_d, id_x, id_y)); // hidden auxiliary output
        book(dummy.estimate, id_d, id_x, id_y - 1); // ratio
        dummy.d = id_d;
        dummy.x = id_x;
        dummy.y = id_y;
      } else {
        book(dummy.histo, id_d, id_x, 4); // denominator mode
      }
      return dummy;
    }


    bool isBetweenJets(const Jet& probe, const Jet& boundary1, const Jet& boundary2) {
      const double y_p = probe.rapidity();
      const double y_b1 = boundary1.rapidity();
      const double y_b2 = boundary2.rapidity();
      const double y_min = std::min(y_b1, y_b2);
      const double y_max = std::max(y_b1, y_b2);
      return (y_p > y_min && y_p < y_max);
    }


    int centralJetVeto(Jets& jets) {
      if (jets.size() < 2) return 0;
      const Jet bj1 = jets.at(0);
      const Jet bj2 = jets.at(1);

      // Start loop at the 3rd hardest pT jet
      int n_between = 0;
      for (size_t i = 2; i < jets.size(); ++i) {
        const Jet j = jets.at(i);
        if (isBetweenJets(j, bj1, bj2) && j.pT() > 25*GeV)  ++n_between;
      }
      return n_between;
    }


    /// Perform the per-event analysis
    void analyze(const Event& event) {

      // Require 0 (Znunu) or 2 (Zll) dressed leptons
      bool isZll = bool(_mode);
      const DressedLeptons &vetoLeptons = apply<LeptonFinder>(event, "VetoLeptons").dressedLeptons();
      const DressedLeptons &all_leps = apply<LeptonFinder>(event, "LeptonFinder").dressedLeptons();
      if (!isZll && vetoLeptons.size())    vetoEvent;
      if ( isZll && all_leps.size() != 2)  vetoEvent;

      DressedLeptons leptons;
      bool pass_Zll = true;
      if (isZll) {
        // Sort dressed leptons by pT
        if (all_leps[0].pt() > all_leps[1].pt()) {
          leptons.push_back(all_leps[0]);
          leptons.push_back(all_leps[1]);
        } else {
          leptons.push_back(all_leps[1]);
          leptons.push_back(all_leps[0]);
        }
        // Leading lepton pT cut
        pass_Zll &= leptons[0].pT() > 80*GeV;
        // Opposite-charge requirement
        pass_Zll &= charge3(leptons[0]) + charge3(leptons[1]) == 0;
        // Z-mass requirement
        const double Zmass = (leptons[0].mom() + leptons[1].mom()).mass();
        pass_Zll &= (Zmass >= 66*GeV && Zmass <= 116*GeV);
      }
      if (!pass_Zll)  vetoEvent;


      // Get jets and remove those within dR = 0.5 of a dressed lepton
      Jets jets = apply<FastJets>(event, "Jets").jetsByPt(Cuts::pT > 25*GeV && Cuts::absrap < 4.4);
      for (const DressedLepton& lep : leptons)
        idiscard(jets, deltaRLess(lep, 0.5));

      const size_t njets = jets.size();
      if (!njets)  vetoEvent;
      const int njets_gap = centralJetVeto(jets);

      double jpt1 = jets[0].pT();
      double jeta1 = jets[0].eta();
      double mjj = 0., jpt2 = 0., dphijj = 0.;
      if (njets >= 2) {
        mjj = (jets[0].momentum() + jets[1].momentum()).mass();
        jpt2 = jets[1].pT();
        dphijj = deltaPhi(jets[0], jets[1]);
      }

      // MET
      Vector3 met_vec = apply<MissingMomentum>(event, "MET").vectorMPT();
      double met = met_vec.mod();

      // Cut on deltaPhi between MET and first 4 jets, but only if jet pT > 30 GeV
      bool dphi_fail = false;
      for (size_t i = 0; i < jets.size() && i < 4; ++i) {
        dphi_fail |= (deltaPhi(jets[i], met_vec) < 0.4 && jets[i].pT() > 30*GeV);
      }

      const bool pass_met_dphi = met > 200*GeV && !dphi_fail;
      const bool pass_vbf = pass_met_dphi && mjj > 200*GeV && jpt1 > 80*GeV && jpt2 > 50*GeV && njets >= 2 && !njets_gap;
      const bool pass_mono = pass_met_dphi && jpt1 > 120*GeV && fabs(jeta1) < 2.4;
      if (pass_mono)  _h["met_mono"].fill(met);
      if (pass_vbf) {
        _h["met_vbf"].fill(met/GeV);
        _h["mjj_vbf"].fill(mjj/GeV);
        _h["dphijj_vbf"].fill(dphijj);
      }
    }


    /// Normalise, scale and otherwise manipulate histograms here
    void finalize() {
      const double sf(crossSection() / femtobarn / sumOfWeights());
      for (auto& item : _h) {
        scale(item.second.histo, sf);
        if (_mode < 2)  constructRmiss(item.second);
      }
    }


    void constructRmiss(HistoHandler& handler) {
      // Load transfer function from reference data file
      const YODA::Estimate1D& rmiss = refData(handler.d, handler.x, handler.y);
      const YODA::Estimate1D& numer = refData(handler.d, handler.x, handler.y + 1);
      const YODA::Estimate1D& denom = refData(handler.d, handler.x, handler.y + 2);
      const YODA::Estimate1D& bsm = handler.histo->mkEstimate();
      for (size_t i = 1; i < handler.estimate->numBins()+1; ++i) {
        const auto& r = rmiss.bin(i); // SM Rmiss
        const auto& n = numer.bin(i); // SM numerator
        const auto& d = denom.bin(i); // SM denominator
        const auto& b = bsm.bin(i); // BSM
        // Rmiss central value
        const double rmiss_y = safediv(n.val() + b.val(), d.val());
        // Ratio error (Rmiss = SM_num/SM_denom + BSM/SM_denom ~ Rmiss_SM + BSM/SM_denom
        const double rmiss_p = sqrt(sqr(r.totalErrPos())  + safediv(sqr(b.val()? b.totalErrPos() : 0.), sqr(d.val())));
        const double rmiss_m = sqrt(sqr(r.totalErrNeg()) + safediv(sqr(b.val()? b.totalErrNeg() : 0.), sqr(d.val())));
        // Set new values
        handler.estimate->bin(i).set(rmiss_y, {-rmiss_m, rmiss_p});
      }
    }


  protected:

    // Analysis-mode switch
    size_t _mode;

    /// Histograms
    map<string, HistoHandler> _h;

  };


  RIVET_DECLARE_PLUGIN(ATLAS_2017_I1609448);
}
