DOC/TST: Fix Python sklearn dep
This commit is contained in:
@@ -48,11 +48,13 @@ try:
|
||||
from sklearn.cross_validation import KFold, StratifiedKFold
|
||||
SKLEARN_INSTALLED = True
|
||||
|
||||
XGBKFold = KFold
|
||||
XGBStratifiedKFold = StratifiedKFold
|
||||
XGBModelBase = BaseEstimator
|
||||
XGBRegressorBase = RegressorMixin
|
||||
XGBClassifierBase = ClassifierMixin
|
||||
|
||||
XGBKFold = KFold
|
||||
XGBStratifiedKFold = StratifiedKFold
|
||||
XGBLabelEncoder = LabelEncoder
|
||||
except ImportError:
|
||||
SKLEARN_INSTALLED = False
|
||||
|
||||
@@ -60,5 +62,7 @@ except ImportError:
|
||||
XGBModelBase = object
|
||||
XGBClassifierBase = object
|
||||
XGBRegressorBase = object
|
||||
|
||||
XGBKFold = None
|
||||
XGBStratifiedKFold = None
|
||||
XGBLabelEncoder = None
|
||||
|
||||
@@ -7,8 +7,10 @@ import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .training import train
|
||||
|
||||
# Do not use class names on scikit-learn directly.
|
||||
# Re-define the classes on .compat to guarantee the behavior without scikit-learn
|
||||
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||
XGBClassifierBase, XGBRegressorBase, LabelEncoder)
|
||||
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
||||
|
||||
|
||||
def _objective_decorator(func):
|
||||
@@ -398,7 +400,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
self._features_count = X.shape[1]
|
||||
|
||||
self._le = LabelEncoder().fit(y)
|
||||
self._le = XGBLabelEncoder().fit(y)
|
||||
training_labels = self._le.transform(y)
|
||||
|
||||
if sample_weight is not None:
|
||||
|
||||
22
python-package/xgboost/testing.py
Normal file
22
python-package/xgboost/testing.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# coding: utf-8
|
||||
|
||||
import nose
|
||||
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
|
||||
|
||||
def _skip_if_no_sklearn():
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_pandas():
|
||||
if not PANDAS_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_matplotlib():
|
||||
try:
|
||||
import matplotlib.pyplot as plt # noqa
|
||||
except ImportError:
|
||||
raise nose.SkipTest()
|
||||
Reference in New Issue
Block a user