33 : fROCAVG(0.0), fROCCurves(std::make_shared<
TMultiGraph>())
44 return fROCCurves.get();
52 MsgLogger fLogger(
"HyperParameterOptimisation");
54 for(
UInt_t j=0; j<fFoldParameters.size(); ++j) {
55 fLogger<<kHEADER<<
"===========================================================" <<
Endl;
56 fLogger<<kINFO<<
"Optimisation for " << fMethodName <<
" fold " << j+1 <<
Endl;
58 for(
auto &it : fFoldParameters.at(j)) {
59 fLogger<<kINFO<< it.first <<
" " << it.second <<
Endl;
68 fFomType(
"Separation"),
72 fClassifier(new TMVA::
Factory(
"HyperParameterOptimisation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
85 fDataLoader->MakeKFoldDataSet(fNumFolds);
93 TString methodOptions = fMethod.GetValue<
TString>(
"MethodOptions");
97 fDataLoader->MakeKFoldDataSet(fNumFolds);
100 fResults.fMethodName = methodName;
102 for(
UInt_t i = 0; i < fNumFolds; ++i) {
104 TString foldTitle = methodTitle;
111 auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
113 auto params=smethod->OptimizeTuningParameters(fFomType,fFitType);
114 fResults.fFoldParameters.push_back(params);
118 fClassifier->DeleteAllMethods();
120 fClassifier->fMethodsMap.clear();
HyperParameterOptimisationResult()
MsgLogger & Endl(MsgLogger &ml)
A TMultiGraph is a collection of TGraph (or derived) objects.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Base class for all machine learning algorithms.
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
void SetNumFolds(UInt_t folds)
HyperParameterOptimisation(DataLoader *dataloader)
~HyperParameterOptimisation()
This is the main MVA steering class.
ostringstream derivative to redirect and format output
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
~HyperParameterOptimisationResult()
static void EnableOutput()