Logo ROOT   6.10/00
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
DecisionTree.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
34 
35 //////////////////////////////////////////////////////////////////////////
36 // //
37 // DecisionTree //
38 // //
39 // Implementation of a Decision Tree //
40 // //
41 //////////////////////////////////////////////////////////////////////////
42 
43 #include "TH2.h"
44 
45 #include "TMVA/Types.h"
46 #include "TMVA/DecisionTreeNode.h"
47 #include "TMVA/BinaryTree.h"
48 #include "TMVA/BinarySearchTree.h"
49 #include "TMVA/SeparationBase.h"
51 #include "TMVA/DataSetInfo.h"
52 
53 class TRandom3;
54 
55 namespace TMVA {
56 
57  class Event;
58 
59  class DecisionTree : public BinaryTree {
60 
61  private:
62 
63  static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
64 
65  public:
66 
67  typedef std::vector<TMVA::Event*> EventList;
68  typedef std::vector<const TMVA::Event*> EventConstList;
69 
70  // the constructur needed for the "reading" of the decision tree from weight files
71  DecisionTree( void );
72 
73  // the constructur needed for constructing the decision tree via training with events
74  DecisionTree( SeparationBase *sepType, Float_t minSize,
75  Int_t nCuts, DataSetInfo* = NULL,
76  UInt_t cls =0,
77  Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
78  UInt_t nMaxDepth=9999999,
79  Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
80  Int_t treeID = 0);
81 
82  // copy constructor
83  DecisionTree (const DecisionTree &d);
84 
85  virtual ~DecisionTree( void );
86 
87  // Retrieves the address of the root node
88  virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
89  virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
90  virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
91  static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
92  virtual const char* ClassName() const { return "DecisionTree"; }
93 
94  // building of a tree by recursivly splitting the nodes
95 
96  // UInt_t BuildTree( const EventList & eventSample,
97  // DecisionTreeNode *node = NULL);
98  UInt_t BuildTree( const EventConstList & eventSample,
99  DecisionTreeNode *node = NULL);
100  // determine the way how a node is split (which variable, which cut value)
101 
102  Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
103  Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
104  Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
105  void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
106  std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
107 
108  // fill at tree with a given structure already (just see how many signa/bkgr
109  // events end up in each node
110 
111  void FillTree( const EventList & eventSample);
112 
113  // fill the existing the decision tree structure by filling event
114  // in from the top node and see where they happen to end up
115  void FillEvent( const TMVA::Event & event,
116  TMVA::DecisionTreeNode *node );
117 
118  // returns: 1 = Signal (right), -1 = Bkg (left)
119 
120  Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
122 
123  // return the individual relative variable importance
124  std::vector< Double_t > GetVariableImportance();
125 
127 
128  // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
129 
130  void ClearTree();
131 
132  // set pruning method
135 
136  // recursive pruning of the tree, validation sample required for automatic pruning
137  Double_t PruneTree( const EventConstList* validationSample = NULL );
138 
139  // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
142 
143  // apply pruning validation sample to a decision tree
144  void ApplyValidationSample( const EventConstList* validationSample ) const;
145 
146  // return the misclassification rate of a pruned tree
147  Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
148 
149  // pass a single validation event throught a pruned decision tree
150  void CheckEventWithPrunedTree( const TMVA::Event* ) const;
151 
152  // calculate the normalization factor for a pruning validation sample
153  Double_t GetSumWeights( const EventConstList* validationSample ) const;
154 
157 
158  void DescendTree( Node *n = NULL );
159  void SetParentTreeInNodes( Node *n = NULL );
160 
161  // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
162  // is coded as a sequence of left-right moves starting from the root, coded as
163  // 0-1 bit patterns stored in the "long-integer" together with the depth
164  Node* GetNode( ULong_t sequence, UInt_t depth );
165 
167 
168  void PruneNode(TMVA::DecisionTreeNode *node);
169 
170  // prune a node from the tree without deleting its descendants; allows one to
171  // effectively prune a tree many times without making deep copies
173 
175 
176 
178 
179  void SetTreeID(Int_t treeID){fTreeID = treeID;};
180  Int_t GetTreeID(){return fTreeID;};
181 
188  inline void SetNVars(Int_t n){fNvars = n;}
189 
190 
191  private:
192  // utility functions
193 
194  // calculate the Purity out of the number of sig and bkg events collected
195  // from individual samples.
196 
197  // calculates the purity S/(S+B) of a given event sample
198  Double_t SamplePurity(EventList eventSample);
199 
200  UInt_t fNvars; // number of variables used to separate S and B
201  Int_t fNCuts; // number of grid point in variable cut scans
202  Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
203  Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
204  Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
205 
206  SeparationBase *fSepType; // the separation crition
207  RegressionVariance *fRegType; // the separation crition used in Regression
208 
209  Double_t fMinSize; // min number of events in node
210  Double_t fMinNodeSize; // min fraction of training events in node
211  Double_t fMinSepGain; // min number of separation gain to perform node splitting
212 
213  Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
214  Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
215 
216  EPruneMethod fPruneMethod; // method used for prunig
217  Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
218 
219  Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
220 
221  Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
222  Int_t fUseNvars; // the number of variables used in randomised trees;
223  Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
224 
225  TRandom3 *fMyTrandom; // random number generator for randomised trees
226 
227  std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
228 
229  UInt_t fMaxDepth; // max depth
230  UInt_t fSigClass; // class which is treated as signal when building the tree
231  static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
232  Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
233 
234  Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
235 
237 
238 
239  ClassDef(DecisionTree,0); // implementation of a Decision Tree
240  };
241 
242 } // namespace TMVA
243 
244 #endif
void SetPruneMethod(EPruneMethod m=kCostComplexityPruning)
Definition: DecisionTree.h:134
virtual DecisionTreeNode * CreateNode(UInt_t) const
Definition: DecisionTree.h:89
DataSetInfo * fDataSetInfo
Definition: DecisionTree.h:236
Random number generator class based on M.
Definition: TRandom3.h:27
#define TMVA_VERSION_CODE
Definition: Version.h:47
float Float_t
Definition: RtypesCore.h:53
EPruneMethod fPruneMethod
Definition: DecisionTree.h:216
Double_t GetNodePurityLimit() const
Definition: DecisionTree.h:156
EAnalysisType
Definition: Types.h:125
Double_t GetPruneStrength() const
Definition: DecisionTree.h:141
Types::EAnalysisType GetAnalysisType(void)
Definition: DecisionTree.h:184
Calculate the &quot;SeparationGain&quot; for Regression analysis separation criteria used in various training a...
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in...
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
void SetUseExclusiveVars(Bool_t t=kTRUE)
Definition: DecisionTree.h:187
Double_t fNodePurityLimit
Definition: DecisionTree.h:219
virtual ~DecisionTree(void)
destructor
#define NULL
Definition: RtypesCore.h:88
void SetNodePurityLimit(Double_t p)
Definition: DecisionTree.h:155
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
void SetAnalysisType(Types::EAnalysisType t)
Definition: DecisionTree.h:183
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:88
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:68
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
#define ClassDef(name, id)
Definition: Rtypes.h:297
static const Int_t fgRandomSeed
Definition: DecisionTree.h:63
void FillTree(const EventList &eventSample)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
std::vector< Double_t > fVariableImportance
Definition: DecisionTree.h:227
Class that contains all the data information.
Definition: DataSetInfo.h:60
void SetTreeID(Int_t treeID)
Definition: DecisionTree.h:179
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
Double_t fPruneStrength
Definition: DecisionTree.h:214
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
Double_t fMinLinCorrForFisher
Definition: DecisionTree.h:203
void SetNVars(Int_t n)
Definition: DecisionTree.h:188
void SetMinLinCorrForFisher(Double_t min)
Definition: DecisionTree.h:186
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i.e.
Int_t GetNNodesBeforePruning()
Definition: DecisionTree.h:174
void SetPruneStrength(Double_t p)
Definition: DecisionTree.h:140
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
Implementation of a Decision Tree.
Definition: DecisionTree.h:59
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
train a node by finding the single optimal cut for a single variable that best separates signal and b...
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time...
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event through a pruned decision tree on the way down the tree...
An interface to calculate the &quot;SeparationGain&quot; for different separation criteria used in various trai...
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporarily (without actually deleting its descendants which allows testing the pruned t...
std::vector< TMVA::Event * > EventList
Definition: DecisionTree.h:67
void SetUseFisherCuts(Bool_t t=kTRUE)
Definition: DecisionTree.h:185
virtual BinaryTree * CreateTree() const
Definition: DecisionTree.h:90
const Bool_t kFALSE
Definition: RtypesCore.h:92
TRandom3 * fMyTrandom
Definition: DecisionTree.h:225
double Double_t
Definition: RtypesCore.h:55
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
static const Int_t fgDebugLevel
Definition: DecisionTree.h:231
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree ...
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:234
unsigned long ULong_t
Definition: RtypesCore.h:51
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
UInt_t GetNNodes() const
Definition: BinaryTree.h:86
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
virtual const char * ClassName() const
Definition: DecisionTree.h:92
RegressionVariance * fRegType
Definition: DecisionTree.h:207
SeparationBase * fSepType
Definition: DecisionTree.h:206
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining several different pruning m...
Bool_t DoRegression() const
Definition: DecisionTree.h:182
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a &quot;pruned tree&quot; may have set the variable &quot;IsTermi...
DecisionTree(void)
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
const Bool_t kTRUE
Definition: RtypesCore.h:91
const Int_t n
Definition: legend1.C:16
Double_t TrainNode(const EventConstList &eventSample, DecisionTreeNode *node)
Definition: DecisionTree.h:102
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node