Merge pull request #488 from sinhrks/pyfeaturenames
Support feature names in Python package
This commit is contained in:
commit
ae43fd7c7a
@ -1,5 +1,5 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-arguments
|
||||
# pylint: disable=too-many-arguments, too-many-branches
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
@ -23,8 +23,9 @@ class XGBoostError(Exception):
|
||||
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = str,
|
||||
unicode = str
|
||||
else:
|
||||
# pylint: disable=invalid-name
|
||||
STRING_TYPES = basestring,
|
||||
@ -131,7 +132,11 @@ class DMatrix(object):
|
||||
which is optimized for both memory efficiency and training speed.
|
||||
You can construct DMatrix from numpy.arrays
|
||||
"""
|
||||
def __init__(self, data, label=None, missing=0.0, weight=None, silent=False):
|
||||
|
||||
feature_names = None # for previous version's pickle
|
||||
|
||||
def __init__(self, data, label=None, missing=0.0,
|
||||
weight=None, silent=False, feature_names=None):
|
||||
"""
|
||||
Data matrix used in XGBoost.
|
||||
|
||||
@ -149,6 +154,8 @@ class DMatrix(object):
|
||||
Weight for each instance.
|
||||
silent : boolean, optional
|
||||
Whether print messages during construction
|
||||
feature_names : list, optional
|
||||
Labels for features.
|
||||
"""
|
||||
# force into void_p, mac need to pass things in as void_p
|
||||
if data is None:
|
||||
@ -176,6 +183,21 @@ class DMatrix(object):
|
||||
if weight is not None:
|
||||
self.set_weight(weight)
|
||||
|
||||
# validate feature name
|
||||
if not feature_names is None:
|
||||
if not isinstance(feature_names, list):
|
||||
feature_names = list(feature_names)
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
raise ValueError('feature_names must be unique')
|
||||
if len(feature_names) != self.num_col():
|
||||
msg = 'feature_names must have the same length as data'
|
||||
raise ValueError(msg)
|
||||
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
|
||||
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
||||
for f in feature_names):
|
||||
raise ValueError('all feature_names must be alphanumerics')
|
||||
self.feature_names = feature_names
|
||||
|
||||
def _init_from_csr(self, csr):
|
||||
"""
|
||||
Initialize data from a CSR matrix.
|
||||
@ -391,6 +413,18 @@ class DMatrix(object):
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_col(self):
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns : int
|
||||
"""
|
||||
ret = ctypes.c_uint()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def slice(self, rindex):
|
||||
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||
|
||||
@ -404,7 +438,7 @@ class DMatrix(object):
|
||||
res : DMatrix
|
||||
A new DMatrix containing only selected indices.
|
||||
"""
|
||||
res = DMatrix(None)
|
||||
res = DMatrix(None, feature_names=self.feature_names)
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
@ -419,6 +453,9 @@ class Booster(object):
|
||||
Booster is the model of xgboost, that contains low level routines for
|
||||
training, prediction and evaluation.
|
||||
"""
|
||||
|
||||
feature_names = None
|
||||
|
||||
def __init__(self, params=None, cache=(), model_file=None):
|
||||
# pylint: disable=invalid-name
|
||||
"""Initialize the Booster.
|
||||
@ -435,6 +472,7 @@ class Booster(object):
|
||||
for d in cache:
|
||||
if not isinstance(d, DMatrix):
|
||||
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
|
||||
self._validate_feature_names(d)
|
||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
|
||||
@ -519,6 +557,8 @@ class Booster(object):
|
||||
"""
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||
self._validate_feature_names(dtrain)
|
||||
|
||||
if fobj is None:
|
||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
|
||||
else:
|
||||
@ -543,6 +583,8 @@ class Booster(object):
|
||||
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||
self._validate_feature_names(dtrain)
|
||||
|
||||
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||
c_array(ctypes.c_float, grad),
|
||||
c_array(ctypes.c_float, hess),
|
||||
@ -572,6 +614,8 @@ class Booster(object):
|
||||
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||
if not isinstance(d[1], STRING_TYPES):
|
||||
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||
self._validate_feature_names(d[0])
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||
msg = ctypes.c_char_p()
|
||||
@ -605,6 +649,7 @@ class Booster(object):
|
||||
result: str
|
||||
Evaluation result string.
|
||||
"""
|
||||
self._validate_feature_names(data)
|
||||
return self.eval_set([(data, name)], iteration)
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||
@ -642,6 +687,9 @@ class Booster(object):
|
||||
option_mask |= 0x01
|
||||
if pred_leaf:
|
||||
option_mask |= 0x02
|
||||
|
||||
self._validate_feature_names(data)
|
||||
|
||||
length = ctypes.c_ulong()
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
||||
@ -731,16 +779,46 @@ class Booster(object):
|
||||
"""
|
||||
Returns the dump the model as a list of strings.
|
||||
"""
|
||||
|
||||
length = ctypes.c_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
if self.feature_names is not None and fmap == '':
|
||||
flen = int(len(self.feature_names))
|
||||
fname = (ctypes.c_char_p * flen)()
|
||||
ftype = (ctypes.c_char_p * flen)()
|
||||
|
||||
# supports quantitative type only
|
||||
# {'q': quantitative, 'i': indicator}
|
||||
if sys.version_info[0] == 3:
|
||||
features = [bytes(f, 'utf-8') for f in self.feature_names]
|
||||
types = [bytes('q', 'utf-8')] * flen
|
||||
else:
|
||||
features = [f.encode('utf-8') if isinstance(f, unicode) else f
|
||||
for f in self.feature_names]
|
||||
types = ['q'] * flen
|
||||
|
||||
fname[:] = features
|
||||
ftype[:] = types
|
||||
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||
flen,
|
||||
fname,
|
||||
ftype,
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
else:
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
res.append(str(sarr[i].decode('ascii')))
|
||||
try:
|
||||
res.append(str(sarr[i].decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(unicode(sarr[i].decode('utf-8')))
|
||||
return res
|
||||
|
||||
def get_fscore(self, fmap=''):
|
||||
@ -765,3 +843,17 @@ class Booster(object):
|
||||
else:
|
||||
fmap[fid] += 1
|
||||
return fmap
|
||||
|
||||
def _validate_feature_names(self, data):
|
||||
"""
|
||||
Validate Booster and data's feature_names are identical
|
||||
"""
|
||||
if self.feature_names is None:
|
||||
self.feature_names = data.feature_names
|
||||
else:
|
||||
# Booster can't accept data with different feature names
|
||||
if self.feature_names != data.feature_names:
|
||||
msg = 'feature_names mismatch: {0} {1}'
|
||||
raise ValueError(msg.format(self.feature_names,
|
||||
data.feature_names))
|
||||
|
||||
|
||||
@ -5,15 +5,3 @@ if [ ${TRAVIS_OS_NAME} != "osx" ]; then
|
||||
fi
|
||||
|
||||
brew update
|
||||
|
||||
if [ ${TASK} == "python-package" ]; then
|
||||
brew install python git graphviz
|
||||
easy_install pip
|
||||
pip install numpy scipy matplotlib nose
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "python-package3" ]; then
|
||||
brew install python3 git graphviz
|
||||
sudo pip3 install --upgrade setuptools
|
||||
pip3 install numpy scipy matplotlib nose graphviz
|
||||
fi
|
||||
|
||||
@ -33,30 +33,44 @@ if [ ${TASK} == "R-package" ]; then
|
||||
scripts/travis_R_script.sh || exit -1
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "python-package" ]; then
|
||||
sudo apt-get install graphviz
|
||||
sudo apt-get install python-numpy python-scipy python-matplotlib python-nose
|
||||
sudo python -m pip install graphviz
|
||||
make all CXX=${CXX} || exit -1
|
||||
nosetests tests/python || exit -1
|
||||
fi
|
||||
if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
|
||||
|
||||
if [ ${TASK} == "python-package3" ]; then
|
||||
sudo apt-get install graphviz
|
||||
# python3-matplotlib is unavailale on Ubuntu 12.04
|
||||
sudo apt-get install python3-dev
|
||||
sudo apt-get install python3-numpy python3-scipy python3-nose python3-setuptools
|
||||
|
||||
make all CXX=${CXX} || exit -1
|
||||
|
||||
if [ ${TRAVIS_OS_NAME} != "osx" ]; then
|
||||
sudo easy_install3 pip
|
||||
sudo easy_install3 -U distribute
|
||||
sudo pip install graphviz matplotlib
|
||||
nosetests3 tests/python || exit -1
|
||||
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||
brew install graphviz
|
||||
if [ ${TASK} == "python-package3" ]; then
|
||||
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
|
||||
else
|
||||
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-MacOSX-x86_64.sh
|
||||
fi
|
||||
else
|
||||
nosetests tests/python || exit -1
|
||||
sudo apt-get install graphviz
|
||||
if [ ${TASK} == "python-package3" ]; then
|
||||
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
else
|
||||
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh
|
||||
fi
|
||||
fi
|
||||
bash conda.sh -b -p $HOME/miniconda
|
||||
export PATH="$HOME/miniconda/bin:$PATH"
|
||||
hash -r
|
||||
conda config --set always_yes yes --set changeps1 no
|
||||
conda update -q conda
|
||||
# Useful for debugging any issues with conda
|
||||
conda info -a
|
||||
|
||||
if [ ${TASK} == "python-package3" ]; then
|
||||
conda create -n myenv python=3.4
|
||||
else
|
||||
conda create -n myenv python=2.7
|
||||
fi
|
||||
source activate myenv
|
||||
conda install numpy scipy matplotlib nose
|
||||
python -m pip install graphviz
|
||||
|
||||
make all CXX=${CXX} || exit -1
|
||||
|
||||
python -m nose tests/python || exit -1
|
||||
python --version
|
||||
fi
|
||||
|
||||
# only test java under linux for now
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
|
||||
@ -29,6 +30,30 @@ def test_basic():
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2-preds)) == 0
|
||||
|
||||
def test_feature_names():
|
||||
data = np.random.randn(100, 5)
|
||||
target = np.array([0, 1] * 50)
|
||||
|
||||
cases = [['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'],
|
||||
[u'要因1', u'要因2', u'要因3', u'要因4', u'要因5']]
|
||||
|
||||
for features in cases:
|
||||
dm = xgb.DMatrix(data, label=target,
|
||||
feature_names=features)
|
||||
assert dm.feature_names == features
|
||||
assert dm.num_row() == 100
|
||||
assert dm.num_col() == 5
|
||||
|
||||
params={'objective': 'multi:softprob',
|
||||
'eval_metric': 'mlogloss',
|
||||
'eta': 0.3,
|
||||
'num_class': 3}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=10)
|
||||
scores = bst.get_fscore()
|
||||
assert list(sorted(k for k in scores)) == features
|
||||
|
||||
|
||||
def test_plotting():
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
|
||||
@ -435,6 +435,7 @@ int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
*out_dptr = BeginPtr(vec);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
@ -442,6 +443,13 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// xgboost implementation
|
||||
int XGBoosterCreate(DMatrixHandle dmats[],
|
||||
bst_ulong len,
|
||||
@ -572,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
|
||||
featmap, with_stats != 0, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char **fname,
|
||||
const char **ftype,
|
||||
int with_stats,
|
||||
bst_ulong *len,
|
||||
const char ***out_models) {
|
||||
API_BEGIN();
|
||||
utils::FeatMap featmap;
|
||||
for (int i = 0; i < fnum; ++i) {
|
||||
featmap.PushBack(i, fname[i], ftype[i]);
|
||||
}
|
||||
*out_models = static_cast<Booster*>(handle)->GetModelDump(
|
||||
featmap, with_stats != 0, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -184,6 +184,13 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
|
||||
bst_ulong *out);
|
||||
/*!
|
||||
* \brief get number of columns
|
||||
* \param handle the handle to the DMatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
|
||||
bst_ulong *out);
|
||||
// --- start XGBoost class
|
||||
/*!
|
||||
* \brief create xgboost learner
|
||||
@ -324,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||
int with_stats,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_dump_array);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
* \param fnum number of features
|
||||
* \param fnum names of features
|
||||
* \param fnum types of features
|
||||
* \param with_stats whether to dump with statistics
|
||||
* \param out_len length of output array
|
||||
* \param out_dump_array pointer to hold representing dump of each model
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char **fname,
|
||||
const char **ftype,
|
||||
int with_stats,
|
||||
bst_ulong *len,
|
||||
const char ***out_models);
|
||||
|
||||
#endif // XGBOOST_WRAPPER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user