44 return fROCCurves.get();
50 for(
auto &roc:fROCs) avg+=roc.second;
51 return avg/fROCs.size();
59 for(
auto &roc:fROCs) std+=
TMath::Power(roc.second-avg, 2);
69 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
71 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
73 fLogger << kINFO <<
"------------------------" <<
Endl;
74 fLogger << kINFO <<
Form(
"Average ROC-Int : %.4f",GetROCAverage()) <<
Endl;
75 fLogger << kINFO <<
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
83 fROCCurves->
Draw(
"AL");
84 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
85 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
86 Float_t adjust=1+fROCs.size()*0.01;
88 c->
SetTitle(
"Cross Validation ROC Curves");
94 fNumFolds(5),fClassifier(new TMVA::
Factory(
"CrossValidation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
107 fDataLoader->MakeKFoldDataSet(fNumFolds);
115 TString methodOptions = fMethod.GetValue<
TString>(
"MethodOptions");
116 if(methodName ==
"")
Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
120 Log() << kINFO <<
"Evaluate method: " << methodTitle <<
Endl;
125 fDataLoader->MakeKFoldDataSet(fNumFolds);
130 for(
UInt_t i=0; i<fNumFolds; ++i){
131 Log() << kDEBUG <<
"Fold (" << methodTitle <<
"): " << i <<
Endl;
133 TString foldTitle = methodTitle;
134 foldTitle +=
"_fold";
138 MethodBase* smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
150 fResults.fROCs[i] = fClassifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);
152 TGraph*
gr = fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle,
true);
156 fResults.fROCCurves->Add(gr);
173 fClassifier->DeleteAllMethods();
174 fClassifier->fMethodsMap.clear();
178 Log() << kINFO <<
"Evaluation done." <<
Endl;
183 if(fResults.fROCs.size()==0)
Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
MsgLogger & Endl(MsgLogger &ml)
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void SetTitle(const char *title="")
Set canvas title.
TCanvas * Draw(const TString name="CrossValidation") const
Float_t GetROCAverage() const
const CrossValidationResult & GetResults() const
Float_t GetROCStandardDeviation() const
A TMultiGraph is a collection of TGraph (or derived) objects.
Virtual base Class for all MVA method.
virtual void SetTitle(const char *title="")
Set graph title.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
const TString & GetMethodName() const
const char * Data() const
Types::EAnalysisType GetAnalysisType() const
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
void SetNumFolds(UInt_t i)
Base class for all machine learning algorithms.
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
void DeleteResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
delete the results stored for this particular Method instance.
char * Form(const char *fmt,...)
virtual Double_t GetSignificance() const
compute significance of mean difference
This is the main MVA steering class.
CrossValidation(DataLoader *loader)
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
ostringstream derivative to redirect and format output
virtual void Draw(Option_t *option="")
Draw a canvas.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
std::map< UInt_t, Float_t > fROCs
A Graph is a graphics object made of two arrays X and Y with npoints each.
virtual Double_t GetTrainingEfficiency(const TString &)
Double_t Sqrt(Double_t x)
static void EnableOutput()
virtual void TestClassification()
initialization
std::shared_ptr< TMultiGraph > fROCCurves