forked from jurra/pymurtree
-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Currently, fit reuses the persistent solver when refitting. This is okay when fitting with the same data, but with a higher node budget, since the cache can be reused.
But if the data is different, then the cache is no longer valid, hence the solver should be re-initialized.
Example to reproduce the error:
import pymurtree
import numpy
from sklearn.tree import DecisionTreeClassifier
# Create training data
x1 = numpy.array([[0, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0]])
x2 = numpy.array([[1, 0, 1, 0], [0, 1, 1, 0], [0, 0, 1, 1]])
y = numpy.array([5, 5, 4]) # labels
model = pymurtree.OptimalDecisionTreeClassifier(max_depth=4, verbose=False)
#model = DecisionTreeClassifier(max_depth=4) # As a reference, when using the sklearn DT classifier, this does work
# fit on first data set
model.fit(x1, y)
labels = model.predict(x1)
assert(all(labels == y))
print("First tree computed perfectly")
# fit on second data set
model.fit(x2, y)
# Predict labels for a new set of features
labels = model.predict(x2)
assert(all(labels == y))
print("Second tree computed perfectly")
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working