Python code for tree ensemble interpretation proposed in the following paper.
- S. Hara, K. Hayashi, Making Tree Ensembles Interpretable: A Bayesian Model Selection Approach. In Proceedings of the 21th International Conference on Artificial Intelligence and Statistics (AISTATS'18), pages 77--85, 2018.
To use defragTrees:
- Python3.x
- Numpy
- Pandas
To run example codes in example directory:
- Python: XGBoost, Scikit-learn
- R: randomForest
To replicate paper results in paper directory:
- Python: Scikit-learn, Matplotlib, pylab
- R: randomForest, inTrees, nodeHarvest
Prepare data:
- Input
X: feature matrix, numpy array of size (num, dim). - Output
y: output array, numpy array of size (num,).- For regression,
yis real value. - For classification,
yis class index (i.e., 0, 1, 2, ..., C-1, for C classes).
- For regression,
- Splitter
splitter: thresholds of tree ensembles, numpy array of size (# of split rules, 2).- Each row of
splitteris (feature index, threshold). Suppose the split rule issecond feature < 0.5, the row ofsplitteris then (1, 0.5).
- Each row of
Import the class:
from defragTrees import DefragModelFit the simplified model:
Kmax = 10 # uppder-bound number of rules to be fitted
mdl = DefragModel(modeltype='regression') # change to 'classification' if necessary.
mdl.fit(X, y, splitter, Kmax)
#mdl.fit(X, y, splitter, Kmax, fittype='EM') # use this when one wants exactly Kmax rules to be fittedCheck the learned rules:
print(mdl)For further deitals, see defragTrees.py.
In IPython, one can check:
import defragTrees
defragTrees?See example directory.
See paper directory.