Migrate pylint check to Python 3 (#4381)
* Migrate lint to Python 3 * Fix lint errors * Use Miniconda3 to use Python 3.7 * Use latest pylint and astroid
This commit is contained in:
parent
5e97de6a41
commit
bbe0dbd7ec
4
Makefile
4
Makefile
@ -173,10 +173,10 @@ xgboost: $(CLI_OBJ) $(ALL_DEP)
|
|||||||
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
||||||
|
|
||||||
rcpplint:
|
rcpplint:
|
||||||
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
|
python3 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
|
||||||
|
|
||||||
lint: rcpplint
|
lint: rcpplint
|
||||||
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package
|
python3 dmlc-core/scripts/lint.py --pylint-rc ${PWD}/python-package/.pylintrc xgboost ${LINT_LANG} include src plugin python-package
|
||||||
|
|
||||||
pylint:
|
pylint:
|
||||||
flake8 --ignore E501 python-package
|
flake8 --ignore E501 python-package
|
||||||
|
|||||||
@ -4,7 +4,7 @@ ignore=tests
|
|||||||
|
|
||||||
extension-pkg-whitelist=numpy
|
extension-pkg-whitelist=numpy
|
||||||
|
|
||||||
disiable=unexpected-special-method-signature,too-many-nested-blocks
|
disable=unexpected-special-method-signature,too-many-nested-blocks,useless-object-inheritance
|
||||||
|
|
||||||
dummy-variables-rgx=(unused|)_.*
|
dummy-variables-rgx=(unused|)_.*
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable= invalid-name
|
# pylint: disable=invalid-name, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
@ -20,13 +20,11 @@ def _fmt_metric(value, show_stdv=True):
|
|||||||
"""format metric string"""
|
"""format metric string"""
|
||||||
if len(value) == 2:
|
if len(value) == 2:
|
||||||
return '%s:%g' % (value[0], value[1])
|
return '%s:%g' % (value[0], value[1])
|
||||||
elif len(value) == 3:
|
if len(value) == 3:
|
||||||
if show_stdv:
|
if show_stdv:
|
||||||
return '%s:%g+%g' % (value[0], value[1], value[2])
|
return '%s:%g+%g' % (value[0], value[1], value[2])
|
||||||
else:
|
return '%s:%g' % (value[0], value[1])
|
||||||
return '%s:%g' % (value[0], value[1])
|
raise ValueError("wrong metric value")
|
||||||
else:
|
|
||||||
raise ValueError("wrong metric value")
|
|
||||||
|
|
||||||
|
|
||||||
def print_evaluation(period=1, show_stdv=True):
|
def print_evaluation(period=1, show_stdv=True):
|
||||||
@ -50,10 +48,10 @@ def print_evaluation(period=1, show_stdv=True):
|
|||||||
"""
|
"""
|
||||||
def callback(env):
|
def callback(env):
|
||||||
"""internal function"""
|
"""internal function"""
|
||||||
if env.rank != 0 or len(env.evaluation_result_list) == 0 or period is False or period == 0:
|
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
|
||||||
return
|
return
|
||||||
i = env.iteration
|
i = env.iteration
|
||||||
if (i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration):
|
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
||||||
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
||||||
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
|
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
|
||||||
return callback
|
return callback
|
||||||
@ -89,7 +87,7 @@ def record_evaluation(eval_result):
|
|||||||
|
|
||||||
def callback(env):
|
def callback(env):
|
||||||
"""internal function"""
|
"""internal function"""
|
||||||
if len(eval_result) == 0:
|
if not eval_result:
|
||||||
init(env)
|
init(env)
|
||||||
for k, v in env.evaluation_result_list:
|
for k, v in env.evaluation_result_list:
|
||||||
pos = k.index('-')
|
pos = k.index('-')
|
||||||
@ -182,14 +180,14 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|||||||
"""internal function"""
|
"""internal function"""
|
||||||
bst = env.model
|
bst = env.model
|
||||||
|
|
||||||
if len(env.evaluation_result_list) == 0:
|
if not env.evaluation_result_list:
|
||||||
raise ValueError('For early stopping you need at least one set in evals.')
|
raise ValueError('For early stopping you need at least one set in evals.')
|
||||||
if len(env.evaluation_result_list) > 1 and verbose:
|
if len(env.evaluation_result_list) > 1 and verbose:
|
||||||
msg = ("Multiple eval metrics have been passed: "
|
msg = ("Multiple eval metrics have been passed: "
|
||||||
"'{0}' will be used for early stopping.\n\n")
|
"'{0}' will be used for early stopping.\n\n")
|
||||||
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
||||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
|
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
|
||||||
maximize_at_n_metrics = ('auc@', 'aucpr@' 'map@', 'ndcg@')
|
maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
|
||||||
maximize_score = maximize
|
maximize_score = maximize
|
||||||
metric_label = env.evaluation_result_list[-1][0]
|
metric_label = env.evaluation_result_list[-1][0]
|
||||||
metric = metric_label.split('-', 1)[-1]
|
metric = metric_label.split('-', 1)[-1]
|
||||||
@ -225,7 +223,7 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|||||||
def callback(env):
|
def callback(env):
|
||||||
"""internal function"""
|
"""internal function"""
|
||||||
score = env.evaluation_result_list[-1][1]
|
score = env.evaluation_result_list[-1][1]
|
||||||
if len(state) == 0:
|
if not state:
|
||||||
init(env)
|
init(env)
|
||||||
best_score = state['best_score']
|
best_score = state['best_score']
|
||||||
best_iteration = state['best_iteration']
|
best_iteration = state['best_iteration']
|
||||||
|
|||||||
@ -11,14 +11,13 @@ PY3 = (sys.version_info[0] == 3)
|
|||||||
|
|
||||||
if PY3:
|
if PY3:
|
||||||
# pylint: disable=invalid-name, redefined-builtin
|
# pylint: disable=invalid-name, redefined-builtin
|
||||||
STRING_TYPES = str,
|
STRING_TYPES = (str,)
|
||||||
|
|
||||||
def py_str(x):
|
def py_str(x):
|
||||||
"""convert c string back to python string"""
|
"""convert c string back to python string"""
|
||||||
return x.decode('utf-8')
|
return x.decode('utf-8')
|
||||||
else:
|
else:
|
||||||
# pylint: disable=invalid-name
|
STRING_TYPES = (basestring,) # pylint: disable=undefined-variable
|
||||||
STRING_TYPES = basestring,
|
|
||||||
|
|
||||||
def py_str(x):
|
def py_str(x):
|
||||||
"""convert c string back to python string"""
|
"""convert c string back to python string"""
|
||||||
@ -37,13 +36,13 @@ try:
|
|||||||
PANDAS_INSTALLED = True
|
PANDAS_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
class MultiIndex(object):
|
class MultiIndex(object):
|
||||||
""" dummy for pandas.MultiIndex """
|
""" dummy for pandas.MultiIndex """
|
||||||
pass
|
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
class DataFrame(object):
|
class DataFrame(object):
|
||||||
""" dummy for pandas.DataFrame """
|
""" dummy for pandas.DataFrame """
|
||||||
pass
|
|
||||||
|
|
||||||
PANDAS_INSTALLED = False
|
PANDAS_INSTALLED = False
|
||||||
|
|
||||||
@ -57,9 +56,9 @@ try:
|
|||||||
DT_INSTALLED = True
|
DT_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
class DataTable(object):
|
class DataTable(object):
|
||||||
""" dummy for datatable.DataTable """
|
""" dummy for datatable.DataTable """
|
||||||
pass
|
|
||||||
|
|
||||||
DT_INSTALLED = False
|
DT_INSTALLED = False
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
# pylint: disable=too-many-branches, too-many-lines, W0141
|
# pylint: disable=too-many-branches, too-many-lines, too-many-locals
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import collections
|
import collections
|
||||||
@ -30,7 +30,6 @@ c_bst_ulong = ctypes.c_uint64
|
|||||||
|
|
||||||
class XGBoostError(Exception):
|
class XGBoostError(Exception):
|
||||||
"""Error thrown by xgboost trainer."""
|
"""Error thrown by xgboost trainer."""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopException(Exception):
|
class EarlyStopException(Exception):
|
||||||
@ -67,18 +66,16 @@ def from_pystr_to_cstr(data):
|
|||||||
list of str
|
list of str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
pointers = (ctypes.c_char_p * len(data))()
|
|
||||||
if PY3:
|
|
||||||
data = [bytes(d, 'utf-8') for d in data]
|
|
||||||
else:
|
|
||||||
data = [d.encode('utf-8') if isinstance(d, unicode) else d
|
|
||||||
for d in data]
|
|
||||||
pointers[:] = data
|
|
||||||
return pointers
|
|
||||||
else:
|
|
||||||
# copy from above when we actually use it
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
pointers = (ctypes.c_char_p * len(data))()
|
||||||
|
if PY3:
|
||||||
|
data = [bytes(d, 'utf-8') for d in data]
|
||||||
|
else:
|
||||||
|
data = [d.encode('utf-8') if isinstance(d, unicode) else d # pylint: disable=undefined-variable
|
||||||
|
for d in data]
|
||||||
|
pointers[:] = data
|
||||||
|
return pointers
|
||||||
|
|
||||||
|
|
||||||
def from_cstr_to_pystr(data, length):
|
def from_cstr_to_pystr(data, length):
|
||||||
@ -104,6 +101,7 @@ def from_cstr_to_pystr(data, length):
|
|||||||
try:
|
try:
|
||||||
res.append(str(data[i].decode('ascii')))
|
res.append(str(data[i].decode('ascii')))
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
|
# pylint: disable=undefined-variable
|
||||||
res.append(unicode(data[i].decode('utf-8')))
|
res.append(unicode(data[i].decode('utf-8')))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -123,7 +121,7 @@ def _get_log_callback_func():
|
|||||||
def _load_lib():
|
def _load_lib():
|
||||||
"""Load xgboost Library."""
|
"""Load xgboost Library."""
|
||||||
lib_paths = find_lib_path()
|
lib_paths = find_lib_path()
|
||||||
if len(lib_paths) == 0:
|
if not lib_paths:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
pathBackup = os.environ['PATH'].split(os.pathsep)
|
pathBackup = os.environ['PATH'].split(os.pathsep)
|
||||||
@ -243,7 +241,7 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
|||||||
if feature_names is None:
|
if feature_names is None:
|
||||||
if isinstance(data.columns, MultiIndex):
|
if isinstance(data.columns, MultiIndex):
|
||||||
feature_names = [
|
feature_names = [
|
||||||
' '.join(map(str, i))
|
' '.join([str(x) for x in i])
|
||||||
for i in data.columns
|
for i in data.columns
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@ -267,8 +265,7 @@ def _maybe_pandas_label(label):
|
|||||||
label_dtypes = label.dtypes
|
label_dtypes = label.dtypes
|
||||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
|
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
|
||||||
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
|
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
|
||||||
else:
|
label = label.values.astype('float')
|
||||||
label = label.values.astype('float')
|
|
||||||
# pd.Series can be passed to xgb as it is
|
# pd.Series can be passed to xgb as it is
|
||||||
|
|
||||||
return label
|
return label
|
||||||
@ -301,8 +298,7 @@ def _maybe_dt_data(data, feature_names, feature_types):
|
|||||||
# always return stypes for dt ingestion
|
# always return stypes for dt ingestion
|
||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
raise ValueError('DataTable has own feature types, cannot pass them in')
|
raise ValueError('DataTable has own feature types, cannot pass them in')
|
||||||
else:
|
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
||||||
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
@ -512,7 +508,7 @@ class DMatrix(object):
|
|||||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||||
else:
|
else:
|
||||||
# datatable<=0.8.0
|
# datatable<=0.8.0
|
||||||
from datatable.internal import frame_column_data_r
|
from datatable.internal import frame_column_data_r # pylint: disable=no-name-in-module,import-error
|
||||||
for icol in range(data.ncols):
|
for icol in range(data.ncols):
|
||||||
ptrs[icol] = frame_column_data_r(data, icol)
|
ptrs[icol] = frame_column_data_r(data, icol)
|
||||||
|
|
||||||
@ -1039,8 +1035,7 @@ class Booster(object):
|
|||||||
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
||||||
if success.value != 0:
|
if success.value != 0:
|
||||||
return py_str(ret.value)
|
return py_str(ret.value)
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
def attributes(self):
|
def attributes(self):
|
||||||
"""Get attributes stored in the Booster as a dictionary.
|
"""Get attributes stored in the Booster as a dictionary.
|
||||||
@ -1056,8 +1051,7 @@ class Booster(object):
|
|||||||
ctypes.byref(length),
|
ctypes.byref(length),
|
||||||
ctypes.byref(sarr)))
|
ctypes.byref(sarr)))
|
||||||
attr_names = from_cstr_to_pystr(sarr, length)
|
attr_names = from_cstr_to_pystr(sarr, length)
|
||||||
res = dict([(n, self.attr(n)) for n in attr_names])
|
return {n: self.attr(n) for n in attr_names}
|
||||||
return res
|
|
||||||
|
|
||||||
def set_attr(self, **kwargs):
|
def set_attr(self, **kwargs):
|
||||||
"""Set the attribute of the Booster.
|
"""Set the attribute of the Booster.
|
||||||
@ -1399,13 +1393,13 @@ class Booster(object):
|
|||||||
ret = self.get_dump(fmap, with_stats, dump_format)
|
ret = self.get_dump(fmap, with_stats, dump_format)
|
||||||
if dump_format == 'json':
|
if dump_format == 'json':
|
||||||
fout.write('[\n')
|
fout.write('[\n')
|
||||||
for i in range(len(ret)):
|
for i, _ in enumerate(ret):
|
||||||
fout.write(ret[i])
|
fout.write(ret[i])
|
||||||
if i < len(ret) - 1:
|
if i < len(ret) - 1:
|
||||||
fout.write(",\n")
|
fout.write(",\n")
|
||||||
fout.write('\n]')
|
fout.write('\n]')
|
||||||
else:
|
else:
|
||||||
for i in range(len(ret)):
|
for i, _ in enumerate(ret):
|
||||||
fout.write('booster[{}]:\n'.format(i))
|
fout.write('booster[{}]:\n'.format(i))
|
||||||
fout.write(ret[i])
|
fout.write(ret[i])
|
||||||
if need_close:
|
if need_close:
|
||||||
@ -1538,51 +1532,50 @@ class Booster(object):
|
|||||||
|
|
||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
else:
|
average_over_splits = True
|
||||||
average_over_splits = True
|
if importance_type == 'total_gain':
|
||||||
if importance_type == 'total_gain':
|
importance_type = 'gain'
|
||||||
importance_type = 'gain'
|
average_over_splits = False
|
||||||
average_over_splits = False
|
elif importance_type == 'total_cover':
|
||||||
elif importance_type == 'total_cover':
|
importance_type = 'cover'
|
||||||
importance_type = 'cover'
|
average_over_splits = False
|
||||||
average_over_splits = False
|
|
||||||
|
|
||||||
trees = self.get_dump(fmap, with_stats=True)
|
trees = self.get_dump(fmap, with_stats=True)
|
||||||
|
|
||||||
importance_type += '='
|
importance_type += '='
|
||||||
fmap = {}
|
fmap = {}
|
||||||
gmap = {}
|
gmap = {}
|
||||||
for tree in trees:
|
for tree in trees:
|
||||||
for line in tree.split('\n'):
|
for line in tree.split('\n'):
|
||||||
# look for the opening square bracket
|
# look for the opening square bracket
|
||||||
arr = line.split('[')
|
arr = line.split('[')
|
||||||
# if no opening bracket (leaf node), ignore this line
|
# if no opening bracket (leaf node), ignore this line
|
||||||
if len(arr) == 1:
|
if len(arr) == 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# look for the closing bracket, extract only info within that bracket
|
# look for the closing bracket, extract only info within that bracket
|
||||||
fid = arr[1].split(']')
|
fid = arr[1].split(']')
|
||||||
|
|
||||||
# extract gain or cover from string after closing bracket
|
# extract gain or cover from string after closing bracket
|
||||||
g = float(fid[1].split(importance_type)[1].split(',')[0])
|
g = float(fid[1].split(importance_type)[1].split(',')[0])
|
||||||
|
|
||||||
# extract feature name from string before closing bracket
|
# extract feature name from string before closing bracket
|
||||||
fid = fid[0].split('<')[0]
|
fid = fid[0].split('<')[0]
|
||||||
|
|
||||||
if fid not in fmap:
|
if fid not in fmap:
|
||||||
# if the feature hasn't been seen yet
|
# if the feature hasn't been seen yet
|
||||||
fmap[fid] = 1
|
fmap[fid] = 1
|
||||||
gmap[fid] = g
|
gmap[fid] = g
|
||||||
else:
|
else:
|
||||||
fmap[fid] += 1
|
fmap[fid] += 1
|
||||||
gmap[fid] += g
|
gmap[fid] += g
|
||||||
|
|
||||||
# calculate average value (gain/cover) for each feature
|
# calculate average value (gain/cover) for each feature
|
||||||
if average_over_splits:
|
if average_over_splits:
|
||||||
for fid in gmap:
|
for fid in gmap:
|
||||||
gmap[fid] = gmap[fid] / fmap[fid]
|
gmap[fid] = gmap[fid] / fmap[fid]
|
||||||
|
|
||||||
return gmap
|
return gmap
|
||||||
|
|
||||||
def trees_to_dataframe(self, fmap=''):
|
def trees_to_dataframe(self, fmap=''):
|
||||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||||
@ -1721,9 +1714,9 @@ class Booster(object):
|
|||||||
xgdump = self.get_dump(fmap=fmap)
|
xgdump = self.get_dump(fmap=fmap)
|
||||||
values = []
|
values = []
|
||||||
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
|
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||||
for i in range(len(xgdump)):
|
for i, _ in enumerate(xgdump):
|
||||||
m = re.findall(regexp, xgdump[i])
|
m = re.findall(regexp, xgdump[i])
|
||||||
values.extend(map(float, m))
|
values.extend([float(x) for x in m])
|
||||||
|
|
||||||
n_unique = len(np.unique(values))
|
n_unique = len(np.unique(values))
|
||||||
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||||
@ -1734,9 +1727,7 @@ class Booster(object):
|
|||||||
|
|
||||||
if as_pandas and PANDAS_INSTALLED:
|
if as_pandas and PANDAS_INSTALLED:
|
||||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||||
elif as_pandas and not PANDAS_INSTALLED:
|
if as_pandas and not PANDAS_INSTALLED:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||||
return nph
|
return nph
|
||||||
else:
|
|
||||||
return nph
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import sys
|
|||||||
|
|
||||||
class XGBoostLibraryNotFound(Exception):
|
class XGBoostLibraryNotFound(Exception):
|
||||||
"""Error thrown by when xgboost is not found"""
|
"""Error thrown by when xgboost is not found"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def find_lib_path():
|
def find_lib_path():
|
||||||
|
|||||||
@ -55,7 +55,6 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
-------
|
-------
|
||||||
ax : matplotlib Axes
|
ax : matplotlib Axes
|
||||||
"""
|
"""
|
||||||
# TODO: move this to compat.py
|
|
||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -70,11 +69,12 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
else:
|
else:
|
||||||
raise ValueError('tree must be Booster, XGBModel or dict instance')
|
raise ValueError('tree must be Booster, XGBModel or dict instance')
|
||||||
|
|
||||||
if len(importance) == 0:
|
if not importance:
|
||||||
raise ValueError('Booster.get_score() results in empty')
|
raise ValueError('Booster.get_score() results in empty')
|
||||||
|
|
||||||
tuples = [(k, importance[k]) for k in importance]
|
tuples = [(k, importance[k]) for k in importance]
|
||||||
if max_num_features is not None:
|
if max_num_features is not None:
|
||||||
|
# pylint: disable=invalid-unary-operand-type
|
||||||
tuples = sorted(tuples, key=lambda x: x[1])[-max_num_features:]
|
tuples = sorted(tuples, key=lambda x: x[1])[-max_num_features:]
|
||||||
else:
|
else:
|
||||||
tuples = sorted(tuples, key=lambda x: x[1])
|
tuples = sorted(tuples, key=lambda x: x[1])
|
||||||
|
|||||||
@ -3,9 +3,9 @@
|
|||||||
"""Scikit-Learn Wrapper interface for XGBoost."""
|
"""Scikit-Learn Wrapper interface for XGBoost."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
from .training import train
|
from .training import train
|
||||||
|
|
||||||
@ -107,15 +107,15 @@ class XGBModel(XGBModelBase):
|
|||||||
importance_type: string, default "gain"
|
importance_type: string, default "gain"
|
||||||
The feature importance type for the feature_importances_ property: either "gain",
|
The feature importance type for the feature_importances_ property: either "gain",
|
||||||
"weight", "cover", "total_gain" or "total_cover".
|
"weight", "cover", "total_gain" or "total_cover".
|
||||||
\*\*kwargs : dict, optional
|
\\*\\*kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||||
Attempting to set a parameter via the constructor args and \*\*kwargs dict simultaneously
|
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict simultaneously
|
||||||
will result in a TypeError.
|
will result in a TypeError.
|
||||||
|
|
||||||
.. note:: \*\*kwargs unsupported by scikit-learn
|
.. note:: \\*\\*kwargs unsupported by scikit-learn
|
||||||
|
|
||||||
\*\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
||||||
passed via this argument will interact properly with scikit-learn.
|
passed via this argument will interact properly with scikit-learn.
|
||||||
|
|
||||||
Note
|
Note
|
||||||
@ -597,7 +597,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
|
|
||||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
|
||||||
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
@ -834,10 +834,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
validate_features=validate_features)
|
validate_features=validate_features)
|
||||||
if self.objective == "multi:softprob":
|
if self.objective == "multi:softprob":
|
||||||
return class_probs
|
return class_probs
|
||||||
else:
|
classone_probs = class_probs
|
||||||
classone_probs = class_probs
|
classzero_probs = 1.0 - classone_probs
|
||||||
classzero_probs = 1.0 - classone_probs
|
return np.vstack((classzero_probs, classone_probs)).transpose()
|
||||||
return np.vstack((classzero_probs, classone_probs)).transpose()
|
|
||||||
|
|
||||||
def evals_result(self):
|
def evals_result(self):
|
||||||
"""Return the evaluation results.
|
"""Return the evaluation results.
|
||||||
@ -1008,15 +1007,15 @@ class XGBRanker(XGBModel):
|
|||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
\*\*kwargs : dict, optional
|
\\*\\*kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||||
Attempting to set a parameter via the constructor args and \*\*kwargs dict
|
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict
|
||||||
simultaneously will result in a TypeError.
|
simultaneously will result in a TypeError.
|
||||||
|
|
||||||
.. note:: \*\*kwargs unsupported by scikit-learn
|
.. note:: \\*\\*kwargs unsupported by scikit-learn
|
||||||
|
|
||||||
\*\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
||||||
passed via this argument will interact properly with scikit-learn.
|
passed via this argument will interact properly with scikit-learn.
|
||||||
|
|
||||||
Note
|
Note
|
||||||
@ -1073,7 +1072,7 @@ class XGBRanker(XGBModel):
|
|||||||
random_state=random_state, seed=seed, missing=missing, **kwargs)
|
random_state=random_state, seed=seed, missing=missing, **kwargs)
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
raise ValueError("custom objective function not supported by XGBRanker")
|
raise ValueError("custom objective function not supported by XGBRanker")
|
||||||
elif "rank:" not in self.objective:
|
if "rank:" not in self.objective:
|
||||||
raise ValueError("please use XGBRanker for ranking task")
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
||||||
@ -1158,9 +1157,9 @@ class XGBRanker(XGBModel):
|
|||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
if eval_group is None:
|
if eval_group is None:
|
||||||
raise ValueError("eval_group is required if eval_set is not None")
|
raise ValueError("eval_group is required if eval_set is not None")
|
||||||
elif len(eval_group) != len(eval_set):
|
if len(eval_group) != len(eval_set):
|
||||||
raise ValueError("length of eval_group should match that of eval_set")
|
raise ValueError("length of eval_group should match that of eval_set")
|
||||||
elif any(group is None for group in eval_group):
|
if any(group is None for group in eval_group):
|
||||||
raise ValueError("group is required for all eval datasets for ranking task")
|
raise ValueError("group is required for all eval datasets for ranking task")
|
||||||
|
|
||||||
def _dmat_init(group, **params):
|
def _dmat_init(group, **params):
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def _train_internal(params, dtrain,
|
|||||||
|
|
||||||
# Distributed code: Load the checkpoint from rabit.
|
# Distributed code: Load the checkpoint from rabit.
|
||||||
version = bst.load_rabit_checkpoint()
|
version = bst.load_rabit_checkpoint()
|
||||||
assert(rabit.get_world_size() != 1 or version == 0)
|
assert rabit.get_world_size() != 1 or version == 0
|
||||||
rank = rabit.get_rank()
|
rank = rabit.get_rank()
|
||||||
start_iteration = int(version / 2)
|
start_iteration = int(version / 2)
|
||||||
nboost += start_iteration
|
nboost += start_iteration
|
||||||
@ -75,12 +75,12 @@ def _train_internal(params, dtrain,
|
|||||||
bst.save_rabit_checkpoint()
|
bst.save_rabit_checkpoint()
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
||||||
|
|
||||||
nboost += 1
|
nboost += 1
|
||||||
evaluation_result_list = []
|
evaluation_result_list = []
|
||||||
# check evaluation result.
|
# check evaluation result.
|
||||||
if len(evals) != 0:
|
if evals:
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
msg = bst_eval_set
|
msg = bst_eval_set
|
||||||
@ -402,7 +402,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
else:
|
else:
|
||||||
params = dict((k, v) for k, v in params.items())
|
params = dict((k, v) for k, v in params.items())
|
||||||
|
|
||||||
if len(metrics) == 0 and 'eval_metric' in params:
|
if (not metrics) and 'eval_metric' in params:
|
||||||
if isinstance(params['eval_metric'], list):
|
if isinstance(params['eval_metric'], list):
|
||||||
metrics = params['eval_metric']
|
metrics = params['eval_metric']
|
||||||
else:
|
else:
|
||||||
@ -462,7 +462,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
rank=0,
|
rank=0,
|
||||||
evaluation_result_list=res))
|
evaluation_result_list=res))
|
||||||
except EarlyStopException as e:
|
except EarlyStopException as e:
|
||||||
for k in results.keys():
|
for k in results:
|
||||||
results[k] = results[k][:(e.best_iteration + 1)]
|
results[k] = results[k][:(e.best_iteration + 1)]
|
||||||
break
|
break
|
||||||
if as_pandas:
|
if as_pandas:
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
if [ ${TASK} == "lint" ]; then
|
if [ ${TASK} == "lint" ]; then
|
||||||
|
source activate python3
|
||||||
|
conda install numpy scipy
|
||||||
|
python -m pip install cpplint pylint astroid
|
||||||
make lint || exit -1
|
make lint || exit -1
|
||||||
echo "Check documentations..."
|
echo "Check documentations..."
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,19 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
if [ ${TASK} == "lint" ]; then
|
if [ ${TASK} == "lint" ]; then
|
||||||
pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6'
|
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||||
|
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
|
||||||
|
else
|
||||||
|
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
fi
|
||||||
|
bash conda.sh -b -p $HOME/miniconda
|
||||||
|
export PATH="$HOME/miniconda/bin:$PATH"
|
||||||
|
hash -r
|
||||||
|
conda config --set always_yes yes --set changeps1 no
|
||||||
|
conda update -q conda
|
||||||
|
# Useful for debugging any issues with conda
|
||||||
|
conda info -a
|
||||||
|
conda create -n python3 python=3.7
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +30,6 @@ if [ ${TASK} == "python_test" ] || [ ${TASK} == "python_lightweight_test" ] || [
|
|||||||
conda update -q conda
|
conda update -q conda
|
||||||
# Useful for debugging any issues with conda
|
# Useful for debugging any issues with conda
|
||||||
conda info -a
|
conda info -a
|
||||||
conda create -n python3 python=3.5
|
conda create -n python3 python=3.7
|
||||||
conda create -n python2 python=2.7
|
conda create -n python2 python=2.7
|
||||||
fi
|
fi
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user