Skip to content

fit should recreate the solver for new data #34

@kjgm

Description

@kjgm

Currently, fit reuses the persistent solver when refitting. This is okay when fitting with the same data, but with a higher node budget, since the cache can be reused.
But if the data is different, then the cache is no longer valid, hence the solver should be re-initialized.

Example to reproduce the error:

import pymurtree
import numpy
from sklearn.tree import DecisionTreeClassifier

# Create training data
x1 = numpy.array([[0, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0]]) 
x2 = numpy.array([[1, 0, 1, 0], [0, 1, 1, 0], [0, 0, 1, 1]]) 
y = numpy.array([5, 5, 4]) # labels

model = pymurtree.OptimalDecisionTreeClassifier(max_depth=4, verbose=False)
#model = DecisionTreeClassifier(max_depth=4) # As a reference, when using the sklearn DT classifier, this does work

# fit on first data set
model.fit(x1, y)
labels = model.predict(x1)
assert(all(labels == y))
print("First tree computed perfectly")

# fit on second data set
model.fit(x2, y)

# Predict labels for a new set of features
labels = model.predict(x2)
assert(all(labels == y))
print("Second tree computed perfectly")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions