Merge pull request #488 from sinhrks/pyfeaturenames

Support feature names in Python package
This commit is contained in:
Tianqi Chen 2015-09-15 09:56:55 -07:00
commit ae43fd7c7a
6 changed files with 214 additions and 43 deletions

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments, too-many-branches
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import from __future__ import absolute_import
@ -23,8 +23,9 @@ class XGBoostError(Exception):
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
# pylint: disable=invalid-name # pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str, STRING_TYPES = str,
unicode = str
else: else:
# pylint: disable=invalid-name # pylint: disable=invalid-name
STRING_TYPES = basestring, STRING_TYPES = basestring,
@ -131,7 +132,11 @@ class DMatrix(object):
which is optimized for both memory efficiency and training speed. which is optimized for both memory efficiency and training speed.
You can construct DMatrix from numpy.arrays You can construct DMatrix from numpy.arrays
""" """
def __init__(self, data, label=None, missing=0.0, weight=None, silent=False):
feature_names = None # for previous version's pickle
def __init__(self, data, label=None, missing=0.0,
weight=None, silent=False, feature_names=None):
""" """
Data matrix used in XGBoost. Data matrix used in XGBoost.
@ -149,6 +154,8 @@ class DMatrix(object):
Weight for each instance. Weight for each instance.
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
feature_names : list, optional
Labels for features.
""" """
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data is None: if data is None:
@ -176,6 +183,21 @@ class DMatrix(object):
if weight is not None: if weight is not None:
self.set_weight(weight) self.set_weight(weight)
# validate feature name
if not feature_names is None:
if not isinstance(feature_names, list):
feature_names = list(feature_names)
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
for f in feature_names):
raise ValueError('all feature_names must be alphanumerics')
self.feature_names = feature_names
def _init_from_csr(self, csr): def _init_from_csr(self, csr):
""" """
Initialize data from a CSR matrix. Initialize data from a CSR matrix.
@ -391,6 +413,18 @@ class DMatrix(object):
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
def num_col(self):
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns : int
"""
ret = ctypes.c_uint()
_check_call(_LIB.XGDMatrixNumCol(self.handle,
ctypes.byref(ret)))
return ret.value
def slice(self, rindex): def slice(self, rindex):
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`. """Slice the DMatrix and return a new DMatrix that only contains `rindex`.
@ -404,7 +438,7 @@ class DMatrix(object):
res : DMatrix res : DMatrix
A new DMatrix containing only selected indices. A new DMatrix containing only selected indices.
""" """
res = DMatrix(None) res = DMatrix(None, feature_names=self.feature_names)
res.handle = ctypes.c_void_p() res.handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle, _check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
c_array(ctypes.c_int, rindex), c_array(ctypes.c_int, rindex),
@ -419,6 +453,9 @@ class Booster(object):
Booster is the model of xgboost, that contains low level routines for Booster is the model of xgboost, that contains low level routines for
training, prediction and evaluation. training, prediction and evaluation.
""" """
feature_names = None
def __init__(self, params=None, cache=(), model_file=None): def __init__(self, params=None, cache=(), model_file=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Initialize the Booster. """Initialize the Booster.
@ -435,6 +472,7 @@ class Booster(object):
for d in cache: for d in cache:
if not isinstance(d, DMatrix): if not isinstance(d, DMatrix):
raise TypeError('invalid cache item: {}'.format(type(d).__name__)) raise TypeError('invalid cache item: {}'.format(type(d).__name__))
self._validate_feature_names(d)
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle))) _check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
@ -519,6 +557,8 @@ class Booster(object):
""" """
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
self._validate_feature_names(dtrain)
if fobj is None: if fobj is None:
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle)) _check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
else: else:
@ -543,6 +583,8 @@ class Booster(object):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
self._validate_feature_names(dtrain)
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle, _check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
c_array(ctypes.c_float, grad), c_array(ctypes.c_float, grad),
c_array(ctypes.c_float, hess), c_array(ctypes.c_float, hess),
@ -572,6 +614,8 @@ class Booster(object):
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__)) raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
if not isinstance(d[1], STRING_TYPES): if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format(type(d[1]).__name__)) raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
self._validate_feature_names(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
msg = ctypes.c_char_p() msg = ctypes.c_char_p()
@ -605,6 +649,7 @@ class Booster(object):
result: str result: str
Evaluation result string. Evaluation result string.
""" """
self._validate_feature_names(data)
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
@ -642,6 +687,9 @@ class Booster(object):
option_mask |= 0x01 option_mask |= 0x01
if pred_leaf: if pred_leaf:
option_mask |= 0x02 option_mask |= 0x02
self._validate_feature_names(data)
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = ctypes.POINTER(ctypes.c_float)() preds = ctypes.POINTER(ctypes.c_float)()
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle, _check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
@ -731,16 +779,46 @@ class Booster(object):
""" """
Returns the dump the model as a list of strings. Returns the dump the model as a list of strings.
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '':
flen = int(len(self.feature_names))
fname = (ctypes.c_char_p * flen)()
ftype = (ctypes.c_char_p * flen)()
# supports quantitative type only
# {'q': quantitative, 'i': indicator}
if sys.version_info[0] == 3:
features = [bytes(f, 'utf-8') for f in self.feature_names]
types = [bytes('q', 'utf-8')] * flen
else:
features = [f.encode('utf-8') if isinstance(f, unicode) else f
for f in self.feature_names]
types = ['q'] * flen
fname[:] = features
ftype[:] = types
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
flen,
fname,
ftype,
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
else:
_check_call(_LIB.XGBoosterDumpModel(self.handle, _check_call(_LIB.XGBoosterDumpModel(self.handle,
c_str(fmap), c_str(fmap),
int(with_stats), int(with_stats),
ctypes.byref(length), ctypes.byref(length),
ctypes.byref(sarr))) ctypes.byref(sarr)))
res = [] res = []
for i in range(length.value): for i in range(length.value):
try:
res.append(str(sarr[i].decode('ascii'))) res.append(str(sarr[i].decode('ascii')))
except UnicodeDecodeError:
res.append(unicode(sarr[i].decode('utf-8')))
return res return res
def get_fscore(self, fmap=''): def get_fscore(self, fmap=''):
@ -765,3 +843,17 @@ class Booster(object):
else: else:
fmap[fid] += 1 fmap[fid] += 1
return fmap return fmap
def _validate_feature_names(self, data):
"""
Validate Booster and data's feature_names are identical
"""
if self.feature_names is None:
self.feature_names = data.feature_names
else:
# Booster can't accept data with different feature names
if self.feature_names != data.feature_names:
msg = 'feature_names mismatch: {0} {1}'
raise ValueError(msg.format(self.feature_names,
data.feature_names))

View File

@ -5,15 +5,3 @@ if [ ${TRAVIS_OS_NAME} != "osx" ]; then
fi fi
brew update brew update
if [ ${TASK} == "python-package" ]; then
brew install python git graphviz
easy_install pip
pip install numpy scipy matplotlib nose
fi
if [ ${TASK} == "python-package3" ]; then
brew install python3 git graphviz
sudo pip3 install --upgrade setuptools
pip3 install numpy scipy matplotlib nose graphviz
fi

View File

@ -33,30 +33,44 @@ if [ ${TASK} == "R-package" ]; then
scripts/travis_R_script.sh || exit -1 scripts/travis_R_script.sh || exit -1
fi fi
if [ ${TASK} == "python-package" ]; then if [ ${TASK} == "python-package" -o ${TASK} == "python-package3" ]; then
sudo apt-get install graphviz
sudo apt-get install python-numpy python-scipy python-matplotlib python-nose if [ ${TRAVIS_OS_NAME} == "osx" ]; then
sudo python -m pip install graphviz brew install graphviz
make all CXX=${CXX} || exit -1 if [ ${TASK} == "python-package3" ]; then
nosetests tests/python || exit -1 wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
else
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-MacOSX-x86_64.sh
fi fi
else
sudo apt-get install graphviz
if [ ${TASK} == "python-package3" ]; then
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
else
wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh
fi
fi
bash conda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
hash -r
conda config --set always_yes yes --set changeps1 no
conda update -q conda
# Useful for debugging any issues with conda
conda info -a
if [ ${TASK} == "python-package3" ]; then if [ ${TASK} == "python-package3" ]; then
sudo apt-get install graphviz conda create -n myenv python=3.4
# python3-matplotlib is unavailale on Ubuntu 12.04 else
sudo apt-get install python3-dev conda create -n myenv python=2.7
sudo apt-get install python3-numpy python3-scipy python3-nose python3-setuptools fi
source activate myenv
conda install numpy scipy matplotlib nose
python -m pip install graphviz
make all CXX=${CXX} || exit -1 make all CXX=${CXX} || exit -1
if [ ${TRAVIS_OS_NAME} != "osx" ]; then python -m nose tests/python || exit -1
sudo easy_install3 pip python --version
sudo easy_install3 -U distribute
sudo pip install graphviz matplotlib
nosetests3 tests/python || exit -1
else
nosetests tests/python || exit -1
fi
fi fi
# only test java under linux for now # only test java under linux for now

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
@ -29,6 +30,30 @@ def test_basic():
# assert they are the same # assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0 assert np.sum(np.abs(preds2-preds)) == 0
def test_feature_names():
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
cases = [['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'],
[u'要因1', u'要因2', u'要因3', u'要因4', u'要因5']]
for features in cases:
dm = xgb.DMatrix(data, label=target,
feature_names=features)
assert dm.feature_names == features
assert dm.num_row() == 100
assert dm.num_col() == 5
params={'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
scores = bst.get_fscore()
assert list(sorted(k for k in scores)) == features
def test_plotting(): def test_plotting():
bst2 = xgb.Booster(model_file='xgb.model') bst2 = xgb.Booster(model_file='xgb.model')
# plotting # plotting

View File

@ -435,6 +435,7 @@ int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
*out_dptr = BeginPtr(vec); *out_dptr = BeginPtr(vec);
API_END(); API_END();
} }
int XGDMatrixNumRow(const DMatrixHandle handle, int XGDMatrixNumRow(const DMatrixHandle handle,
bst_ulong *out) { bst_ulong *out) {
API_BEGIN(); API_BEGIN();
@ -442,6 +443,13 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
API_END(); API_END();
} }
int XGDMatrixNumCol(const DMatrixHandle handle,
bst_ulong *out) {
API_BEGIN();
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
API_END();
}
// xgboost implementation // xgboost implementation
int XGBoosterCreate(DMatrixHandle dmats[], int XGBoosterCreate(DMatrixHandle dmats[],
bst_ulong len, bst_ulong len,
@ -572,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
featmap, with_stats != 0, len); featmap, with_stats != 0, len);
API_END(); API_END();
} }
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models) {
API_BEGIN();
utils::FeatMap featmap;
for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]);
}
*out_models = static_cast<Booster*>(handle)->GetModelDump(
featmap, with_stats != 0, len);
API_END();
}

View File

@ -184,6 +184,13 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
*/ */
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
bst_ulong *out); bst_ulong *out);
/*!
* \brief get number of columns
* \param handle the handle to the DMatrix
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
bst_ulong *out);
// --- start XGBoost class // --- start XGBoost class
/*! /*!
* \brief create xgboost learner * \brief create xgboost learner
@ -324,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
int with_stats, int with_stats,
bst_ulong *out_len, bst_ulong *out_len,
const char ***out_dump_array); const char ***out_dump_array);
/*!
* \brief dump model, return array of strings representing model dump
* \param handle handle
* \param fnum number of features
* \param fnum names of features
* \param fnum types of features
* \param with_stats whether to dump with statistics
* \param out_len length of output array
* \param out_dump_array pointer to hold representing dump of each model
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models);
#endif // XGBOOST_WRAPPER_H_ #endif // XGBOOST_WRAPPER_H_