ENH: allow python to handle feature names
This commit is contained in:
parent
dd3126735b
commit
6506a1c490
@ -4,6 +4,7 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
@ -131,7 +132,11 @@ class DMatrix(object):
|
|||||||
which is optimized for both memory efficiency and training speed.
|
which is optimized for both memory efficiency and training speed.
|
||||||
You can construct DMatrix from numpy.arrays
|
You can construct DMatrix from numpy.arrays
|
||||||
"""
|
"""
|
||||||
def __init__(self, data, label=None, missing=0.0, weight=None, silent=False):
|
|
||||||
|
feature_names = None # for previous version's pickle
|
||||||
|
|
||||||
|
def __init__(self, data, label=None, missing=0.0,
|
||||||
|
weight=None, silent=False, feature_names=None):
|
||||||
"""
|
"""
|
||||||
Data matrix used in XGBoost.
|
Data matrix used in XGBoost.
|
||||||
|
|
||||||
@ -149,6 +154,8 @@ class DMatrix(object):
|
|||||||
Weight for each instance.
|
Weight for each instance.
|
||||||
silent : boolean, optional
|
silent : boolean, optional
|
||||||
Whether print messages during construction
|
Whether print messages during construction
|
||||||
|
feature_names : list, optional
|
||||||
|
Labels for features.
|
||||||
"""
|
"""
|
||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
@ -176,6 +183,18 @@ class DMatrix(object):
|
|||||||
if weight is not None:
|
if weight is not None:
|
||||||
self.set_weight(weight)
|
self.set_weight(weight)
|
||||||
|
|
||||||
|
# validate feature name
|
||||||
|
if not isinstance(feature_names, list):
|
||||||
|
feature_names = list(feature_names)
|
||||||
|
if len(feature_names) != len(set(feature_names)):
|
||||||
|
raise ValueError('feature_names must be unique')
|
||||||
|
if len(feature_names) != self.num_col():
|
||||||
|
raise ValueError('feature_names must have the same length as data')
|
||||||
|
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
||||||
|
for f in feature_names):
|
||||||
|
raise ValueError('all feature_names must be alphanumerics')
|
||||||
|
self.feature_names = feature_names
|
||||||
|
|
||||||
def _init_from_csr(self, csr):
|
def _init_from_csr(self, csr):
|
||||||
"""
|
"""
|
||||||
Initialize data from a CSR matrix.
|
Initialize data from a CSR matrix.
|
||||||
@ -391,6 +410,18 @@ class DMatrix(object):
|
|||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ret.value
|
return ret.value
|
||||||
|
|
||||||
|
def num_col(self):
|
||||||
|
"""Get the number of columns in the DMatrix.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
number of columns : int
|
||||||
|
"""
|
||||||
|
ret = ctypes.c_ulong()
|
||||||
|
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||||
|
ctypes.byref(ret)))
|
||||||
|
return ret.value
|
||||||
|
|
||||||
def slice(self, rindex):
|
def slice(self, rindex):
|
||||||
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||||
|
|
||||||
@ -404,7 +435,7 @@ class DMatrix(object):
|
|||||||
res : DMatrix
|
res : DMatrix
|
||||||
A new DMatrix containing only selected indices.
|
A new DMatrix containing only selected indices.
|
||||||
"""
|
"""
|
||||||
res = DMatrix(None)
|
res = DMatrix(None, feature_names=self.feature_names)
|
||||||
res.handle = ctypes.c_void_p()
|
res.handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||||
c_array(ctypes.c_int, rindex),
|
c_array(ctypes.c_int, rindex),
|
||||||
@ -419,6 +450,9 @@ class Booster(object):
|
|||||||
Booster is the model of xgboost, that contains low level routines for
|
Booster is the model of xgboost, that contains low level routines for
|
||||||
training, prediction and evaluation.
|
training, prediction and evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
feature_names = None
|
||||||
|
|
||||||
def __init__(self, params=None, cache=(), model_file=None):
|
def __init__(self, params=None, cache=(), model_file=None):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""Initialize the Booster.
|
"""Initialize the Booster.
|
||||||
@ -435,6 +469,7 @@ class Booster(object):
|
|||||||
for d in cache:
|
for d in cache:
|
||||||
if not isinstance(d, DMatrix):
|
if not isinstance(d, DMatrix):
|
||||||
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
|
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
|
||||||
|
self._validate_feature_names(d)
|
||||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||||
self.handle = ctypes.c_void_p()
|
self.handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
|
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
|
||||||
@ -519,6 +554,8 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||||
|
self._validate_feature_names(dtrain)
|
||||||
|
|
||||||
if fobj is None:
|
if fobj is None:
|
||||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
|
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
|
||||||
else:
|
else:
|
||||||
@ -543,6 +580,8 @@ class Booster(object):
|
|||||||
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||||
|
self._validate_feature_names(dtrain)
|
||||||
|
|
||||||
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||||
c_array(ctypes.c_float, grad),
|
c_array(ctypes.c_float, grad),
|
||||||
c_array(ctypes.c_float, hess),
|
c_array(ctypes.c_float, hess),
|
||||||
@ -572,6 +611,8 @@ class Booster(object):
|
|||||||
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||||
if not isinstance(d[1], STRING_TYPES):
|
if not isinstance(d[1], STRING_TYPES):
|
||||||
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||||
|
self._validate_feature_names(d)
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||||
msg = ctypes.c_char_p()
|
msg = ctypes.c_char_p()
|
||||||
@ -605,6 +646,7 @@ class Booster(object):
|
|||||||
result: str
|
result: str
|
||||||
Evaluation result string.
|
Evaluation result string.
|
||||||
"""
|
"""
|
||||||
|
self._validate_feature_names(data)
|
||||||
return self.eval_set([(data, name)], iteration)
|
return self.eval_set([(data, name)], iteration)
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||||
@ -642,6 +684,9 @@ class Booster(object):
|
|||||||
option_mask |= 0x01
|
option_mask |= 0x01
|
||||||
if pred_leaf:
|
if pred_leaf:
|
||||||
option_mask |= 0x02
|
option_mask |= 0x02
|
||||||
|
|
||||||
|
self._validate_feature_names(data)
|
||||||
|
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = ctypes.POINTER(ctypes.c_float)()
|
preds = ctypes.POINTER(ctypes.c_float)()
|
||||||
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
||||||
@ -731,6 +776,7 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
Returns the dump the model as a list of strings.
|
Returns the dump the model as a list of strings.
|
||||||
"""
|
"""
|
||||||
|
res = []
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||||
@ -738,9 +784,19 @@ class Booster(object):
|
|||||||
int(with_stats),
|
int(with_stats),
|
||||||
ctypes.byref(length),
|
ctypes.byref(length),
|
||||||
ctypes.byref(sarr)))
|
ctypes.byref(sarr)))
|
||||||
res = []
|
|
||||||
for i in range(length.value):
|
for i in range(length.value):
|
||||||
res.append(str(sarr[i].decode('ascii')))
|
res.append(str(sarr[i].decode('ascii')))
|
||||||
|
|
||||||
|
if self.feature_names is not None:
|
||||||
|
defaults = ['f{0}'.format(i) for i in
|
||||||
|
range(len(self.feature_names))]
|
||||||
|
rep = dict((re.escape(k), v) for k, v in
|
||||||
|
zip(defaults, self.feature_names))
|
||||||
|
pattern = re.compile("|".join(rep))
|
||||||
|
def _replace(expr):
|
||||||
|
""" Replace matched group to corresponding values """
|
||||||
|
return pattern.sub(lambda m: rep[re.escape(m.group(0))], expr)
|
||||||
|
res = [_replace(r) for r in res]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_fscore(self, fmap=''):
|
def get_fscore(self, fmap=''):
|
||||||
@ -765,3 +821,17 @@ class Booster(object):
|
|||||||
else:
|
else:
|
||||||
fmap[fid] += 1
|
fmap[fid] += 1
|
||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
|
def _validate_feature_names(self, data):
|
||||||
|
"""
|
||||||
|
Validate Booster and data's feature_names are identical
|
||||||
|
"""
|
||||||
|
if self.feature_names is None:
|
||||||
|
self.feature_names = data.feature_names
|
||||||
|
else:
|
||||||
|
# Booster can't accept data with different feature names
|
||||||
|
if self.feature_names != data.feature_names:
|
||||||
|
msg = 'feature_names mismatch: {0} {1}'
|
||||||
|
raise ValueError(msg.format(self.feature_names,
|
||||||
|
data.feature_names))
|
||||||
|
|
||||||
|
|||||||
@ -435,6 +435,7 @@ int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
|||||||
*out_dptr = BeginPtr(vec);
|
*out_dptr = BeginPtr(vec);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
int XGDMatrixNumRow(const DMatrixHandle handle,
|
int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||||
bst_ulong *out) {
|
bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
@ -442,6 +443,13 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||||
|
bst_ulong *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
// xgboost implementation
|
// xgboost implementation
|
||||||
int XGBoosterCreate(DMatrixHandle dmats[],
|
int XGBoosterCreate(DMatrixHandle dmats[],
|
||||||
bst_ulong len,
|
bst_ulong len,
|
||||||
|
|||||||
@ -184,6 +184,13 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
|||||||
*/
|
*/
|
||||||
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
|
XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
|
||||||
bst_ulong *out);
|
bst_ulong *out);
|
||||||
|
/*!
|
||||||
|
* \brief get number of columns
|
||||||
|
* \param handle the handle to the DMatrix
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
|
||||||
|
bst_ulong *out);
|
||||||
// --- start XGBoost class
|
// --- start XGBoost class
|
||||||
/*!
|
/*!
|
||||||
* \brief create xgboost learner
|
* \brief create xgboost learner
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user