diff --git a/quantgov/estimator/evaluation.py b/quantgov/estimator/evaluation.py index c264c74..03b0914 100644 --- a/quantgov/estimator/evaluation.py +++ b/quantgov/estimator/evaluation.py @@ -1,9 +1,14 @@ import configparser import logging -import sklearn.model_selection import pandas as pd +try: + from sklearn.model_selection import KFold, GridSearchCV +except ImportError: # sklearn 0.17 + from sklearn.cross_validation import KFold + from sklearn.grid_search import GridSearchCV + from . import utils as eutils log = logging.getLogger(name=__name__) @@ -25,13 +30,13 @@ def evaluate_model(model, X, y, folds, scoring): """ log.info('Evaluating {}'.format(model.name)) if hasattr(y[0], '__getitem__'): - cv = sklearn.model_selection.KFold(folds, shuffle=True) + cv = KFold(folds, shuffle=True) if '_' not in scoring: log.warning("No averaging method specified, assuming macro") scoring += '_macro' else: - cv = sklearn.model_selection.KFold(folds, shuffle=True) - gs = sklearn.model_selection.GridSearchCV( + cv = KFold(folds, shuffle=True) + gs = GridSearchCV( estimator=model.model, param_grid=model.parameters, cv=cv,