Feature weights (#5962)

This commit is contained in:
Jiaming Yuan 2020-08-18 19:55:41 +08:00 committed by GitHub
parent a418278064
commit 4d99c58a5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 509 additions and 104 deletions

View File

@ -68,6 +68,7 @@
#include "../src/learner.cc" #include "../src/learner.cc"
#include "../src/logging.cc" #include "../src/logging.cc"
#include "../src/common/common.cc" #include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc" #include "../src/common/charconv.cc"
#include "../src/common/timer.cc" #include "../src/common/timer.cc"
#include "../src/common/quantile.cc" #include "../src/common/quantile.cc"

View File

@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.
.. versionadded:: 1.3.0
'''
import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse
def main(args):
rng = np.random.RandomState(1994)
kRows = 1000
kCols = 10
X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)
dtrain = xgboost.DMatrix(X, y)
dtrain.set_info(feature_weights=fw)
bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')
if args.plot:
xgboost.plot_importance(bst)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)

View File

@ -94,7 +94,7 @@ class Tree:
class Model: class Model:
'''Gradient boosted tree model.''' '''Gradient boosted tree model.'''
def __init__(self, m: dict): def __init__(self, model: dict):
'''Construct the Model from JSON object. '''Construct the Model from JSON object.
parameters parameters

View File

@ -107,6 +107,10 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at 'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split. each split.
On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling. There's a
similar parameter for ``fit`` method in sklearn interface.
* ``lambda`` [default=1, alias: ``reg_lambda``] * ``lambda`` [default=1, alias: ``reg_lambda``]
- L2 regularization term on weights. Increasing this value will make model more conservative. - L2 regularization term on weights. Increasing this value will make model more conservative.
@ -224,7 +228,7 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other. list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information See tutorial for more information
Additional parameters for ``hist`` and ```gpu_hist`` tree method Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================ ================================================================
* ``single_precision_histogram``, [default=``false``] * ``single_precision_histogram``, [default=``false``]

View File

@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size, bst_ulong *size,
const char ***out_features); const char ***out_features);
/*!
* \brief Set meta info from dense matrix. Valid field names are:
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);
/*! /*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix * \param handle a instance of data matrix

View File

@ -88,34 +88,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed. * \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/ */
HostDeviceVector<FeatureType> feature_types; HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;
/*! \brief default constructor */ /*! \brief default constructor */
MetaInfo() = default; MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default; MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default; MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) { MetaInfo& operator=(MetaInfo const& that) = delete;
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;
this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);
this->group_ptr_ = that.group_ptr_;
this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);
this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);
this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);
this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
/*! /*!
* \brief Validate all metainfo. * \brief Validate all metainfo.

View File

@ -455,7 +455,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
label_lower_bound=None, label_lower_bound=None,
label_upper_bound=None, label_upper_bound=None,
feature_names=None, feature_names=None,
feature_types=None): feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.''' '''Set meta info for DMatrix.'''
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
@ -473,6 +474,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.feature_names = feature_names self.feature_names = feature_names
if feature_types is not None: if feature_types is not None:
self.feature_types = feature_types self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')
def get_float_info(self, field): def get_float_info(self, field):
"""Get float property from the DMatrix. """Get float property from the DMatrix.

View File

@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data))) raise TypeError('Not supported type for data.' + str(type(data)))
def _to_data_type(dtype: str, name: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
if dtype not in dtype_map.keys():
raise TypeError(
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
f'for {name}.')
return dtype_map[dtype]
def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))
def _meta_from_numpy(data, field, dtype, handle): def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype) data = _maybe_np_slice(data, dtype)
if dtype == 'uint32': interface = data.__array_interface__
c_data = c_array(ctypes.c_uint32, data) assert interface.get('mask', None) is None, 'Masked array is not supported'
_check_call(_LIB.XGDMatrixSetUIntInfo(handle, size = data.shape[0]
c_str(field),
c_array(ctypes.c_uint, data), c_type = _to_data_type(str(data.dtype), field)
c_bst_ulong(len(data)))) ptr = interface['data'][0]
elif dtype == 'float': ptr = ctypes.c_void_p(ptr)
c_data = c_array(ctypes.c_float, data) _check_call(_LIB.XGDMatrixSetDenseInfo(
_check_call(_LIB.XGDMatrixSetFloatInfo(handle, handle,
c_str(field), c_str(field),
c_data, ptr,
c_bst_ulong(len(data)))) c_bst_ulong(size),
else: c_type
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field) ))
def _meta_from_list(data, field, dtype, handle): def _meta_from_list(data, field, dtype, handle):
@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.''' '''Dispatch for meta info.'''
handle = matrix.handle handle = matrix.handle
_validate_meta_shape(data)
if data is None: if data is None:
return return
if _is_list(data): if _is_list(data):

View File

@ -441,6 +441,7 @@ class XGBModel(XGBModelBase):
def fit(self, X, y, sample_weight=None, base_margin=None, def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None, eval_set=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, xgb_model=None, sample_weight_eval_set=None, verbose=True, xgb_model=None, sample_weight_eval_set=None,
feature_weights=None,
callbacks=None): callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init # pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model """Fit gradient boosting model
@ -459,9 +460,6 @@ class XGBModel(XGBModelBase):
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, list of str, or callable, optional eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. doc/parameter.rst.
@ -490,6 +488,13 @@ class XGBModel(XGBModelBase):
xgb_model : str xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost model to be file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`. It is possible to use predefined callbacks by using :ref:`callback_api`.
@ -498,6 +503,7 @@ class XGBModel(XGBModelBase):
.. code-block:: python .. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)] [xgb.callback.reset_learning_rate(custom_rates)]
""" """
self.n_features_in_ = X.shape[1] self.n_features_in_ = X.shape[1]
@ -505,6 +511,7 @@ class XGBModel(XGBModelBase):
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, missing=self.missing,
nthread=self.n_jobs) nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)
evals_result = {} evals_result = {}
@ -759,7 +766,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def fit(self, X, y, sample_weight=None, base_margin=None, def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None): sample_weight_eval_set=None, feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
evals_result = {} evals_result = {}
@ -821,6 +828,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)
self._Booster = train(xgb_options, train_dmatrix, self._Booster = train(xgb_options, train_dmatrix,
self.get_num_boosting_rounds(), self.get_num_boosting_rounds(),
@ -1101,10 +1109,10 @@ class XGBRanker(XGBModel):
raise ValueError("please use XGBRanker for ranking task") raise ValueError("please use XGBRanker for ranking task")
def fit(self, X, y, group, sample_weight=None, base_margin=None, def fit(self, X, y, group, sample_weight=None, base_margin=None,
eval_set=None, eval_set=None, sample_weight_eval_set=None,
sample_weight_eval_set=None, eval_group=None, eval_metric=None, eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None, early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None): feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker """Fit gradient boosting ranker
@ -1170,6 +1178,10 @@ class XGBRanker(XGBModel):
xgb_model : str xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost file name of stored XGBoost model or 'Booster' instance XGBoost
model to be loaded before training (allows training continuation). model to be loaded before training (allows training continuation).
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each List of callback functions that are applied at end of each
iteration. It is possible to use predefined callbacks by using iteration. It is possible to use predefined callbacks by using
@ -1205,6 +1217,7 @@ class XGBRanker(XGBModel):
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight, train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)
train_dmatrix.set_group(group) train_dmatrix.set_group(group)
evals_result = {} evals_result = {}

View File

@ -316,6 +316,17 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, xgboost::bst_ulong size,
int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group, const unsigned* group,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {

View File

@ -9,12 +9,15 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <algorithm>
#include <exception> #include <exception>
#include <functional>
#include <limits> #include <limits>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <numeric>
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
@ -160,6 +163,15 @@ inline void AssertOneAPISupport() {
#endif // XGBOOST_USE_ONEAPI #endif // XGBOOST_USE_ONEAPI
} }
template <typename Idx, typename V, typename Comp = std::less<V>>
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) {
std::vector<Idx> result(array.size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(
result.begin(), result.end(),
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
return result;
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_ #endif // XGBOOST_COMMON_COMMON_H_

38
src/common/random.cc Normal file
View File

@ -0,0 +1,38 @@
/*!
* Copyright 2020 by XGBoost Contributors
* \file random.cc
*/
#include "random.h"
namespace xgboost {
namespace common {
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
float colsample) {
if (colsample == 1.0f) {
return p_features;
}
const auto &features = p_features->HostVector();
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
auto &new_features = *p_new_features;
if (feature_weights_.size() != 0) {
new_features.HostVector() = WeightedSamplingWithoutReplacement(
p_features->HostVector(), feature_weights_, n);
} else {
new_features.Resize(features.size());
std::copy(features.begin(), features.end(),
new_features.HostVector().begin());
std::shuffle(new_features.HostVector().begin(),
new_features.HostVector().end(), rng_);
new_features.Resize(n);
}
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
return p_new_features;
}
} // namespace common
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2020 by Contributors
* \file random.h * \file random.h
* \brief Utility related to random. * \brief Utility related to random.
* \author Tianqi Chen * \author Tianqi Chen
@ -10,14 +10,17 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <algorithm> #include <algorithm>
#include <functional>
#include <vector> #include <vector>
#include <limits> #include <limits>
#include <map> #include <map>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <random> #include <random>
#include <utility>
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "common.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -75,6 +78,38 @@ using GlobalRandomEngine = RandomEngine;
*/ */
GlobalRandomEngine& GlobalRandom(); // NOLINT(*) GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
/*
* Original paper:
* Weighted Random Sampling (2005; Efraimidis, Spirakis)
*
* Blog:
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
*/
template <typename T>
std::vector<T> WeightedSamplingWithoutReplacement(
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
// ES sampling.
CHECK_EQ(array.size(), weights.size());
std::vector<float> keys(weights.size());
std::uniform_real_distribution<float> dist;
auto& rng = GlobalRandom();
for (size_t i = 0; i < array.size(); ++i) {
auto w = std::max(weights.at(i), kRtEps);
auto u = dist(rng);
auto k = std::log(u) / w;
keys[i] = k;
}
auto ind = ArgSort<size_t>(keys, std::greater<>{});
ind.resize(n);
std::vector<T> results(ind.size());
for (size_t k = 0; k < ind.size(); ++k) {
auto idx = ind[k];
results[k] = array[idx];
}
return results;
}
/** /**
* \class ColumnSampler * \class ColumnSampler
* *
@ -82,36 +117,18 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
* colsample_bynode parameters. Should be initialised before tree construction and to * colsample_bynode parameters. Should be initialised before tree construction and to
* reset when tree construction is completed. * reset when tree construction is completed.
*/ */
class ColumnSampler { class ColumnSampler {
std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_; std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_; std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
std::vector<float> feature_weights_;
float colsample_bylevel_{1.0f}; float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f}; float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f}; float colsample_bynode_{1.0f};
GlobalRandomEngine rng_; GlobalRandomEngine rng_;
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
if (colsample == 1.0f) return p_features;
const auto& features = p_features->HostVector();
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
auto& new_features = *p_new_features;
new_features.Resize(features.size());
std::copy(features.begin(), features.end(),
new_features.HostVector().begin());
std::shuffle(new_features.HostVector().begin(),
new_features.HostVector().end(), rng_);
new_features.Resize(n);
std::sort(new_features.HostVector().begin(),
new_features.HostVector().end());
return p_new_features;
}
public: public:
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample);
/** /**
* \brief Column sampler constructor. * \brief Column sampler constructor.
* \note This constructor manually sets the rng seed * \note This constructor manually sets the rng seed
@ -139,8 +156,10 @@ class ColumnSampler {
* \param colsample_bytree * \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0. * \param skip_index_0 (Optional) True to skip index 0.
*/ */
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel, void Init(int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel,
float colsample_bytree, bool skip_index_0 = false) { float colsample_bytree, bool skip_index_0 = false) {
feature_weights_ = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel; colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree; colsample_bytree_ = colsample_bytree;
colsample_bynode_ = colsample_bynode; colsample_bynode_ = colsample_bynode;

View File

@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
} else { } else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
} }
out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths);
return out; return out;
} }
@ -377,6 +380,16 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
labels.resize(num); labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin())); std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else if (!std::strcmp(key, "feature_weights")) {
auto &h_feature_weights = feature_weigths.HostVector();
h_feature_weights.resize(num);
DISPATCH_CONST_PTR(
dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
bool valid =
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) { return w >= 0; });
CHECK(valid) << "Feature weight must be greater than 0.";
} else { } else {
LOG(FATAL) << "Unknown key for MetaInfo: " << key; LOG(FATAL) << "Unknown key for MetaInfo: " << key;
} }
@ -396,6 +409,8 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype,
vec = &this->labels_lower_bound_.HostVector(); vec = &this->labels_lower_bound_.HostVector();
} else if (!std::strcmp(key, "label_upper_bound")) { } else if (!std::strcmp(key, "label_upper_bound")) {
vec = &this->labels_upper_bound_.HostVector(); vec = &this->labels_upper_bound_.HostVector();
} else if (!std::strcmp(key, "feature_weights")) {
vec = &this->feature_weigths.HostVector();
} else { } else {
LOG(FATAL) << "Unknown float field name: " << key; LOG(FATAL) << "Unknown float field name: " << key;
} }
@ -497,6 +512,11 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
auto &h_feature_types = feature_types.HostVector(); auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types); LoadFeatureType(this->feature_type_names, &h_feature_types);
} }
if (!that.feature_weigths.Empty()) {
this->feature_weigths.Resize(that.feature_weigths.Size());
this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx());
this->feature_weigths.Copy(that.feature_weigths);
}
} }
void MetaInfo::Validate(int32_t device) const { void MetaInfo::Validate(int32_t device) const {
@ -538,6 +558,11 @@ void MetaInfo::Validate(int32_t device) const {
check_device(labels_lower_bound_); check_device(labels_lower_bound_);
return; return;
} }
if (feature_weigths.Size() != 0) {
CHECK_EQ(feature_weigths.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
check_device(feature_weigths);
}
if (labels_upper_bound_.Size() != 0) { if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_) CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows."; << "Size of label_upper_bound must equal to number of rows.";

View File

@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
std::partial_sum(out->begin(), out->end(), out->begin()); std::partial_sum(out->begin(), out->end(), out->begin());
} }
namespace {
// thrust::all_of tries to copy lambda function.
struct AllOfOp {
__device__ bool operator()(float w) {
return w >= 0;
}
};
} // anonymous namespace
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface); auto const& j_arr = get<Array>(j_interface);
@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "group") { } else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_); CopyGroupInfoImpl(array_interface, &group_ptr_);
return; return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
return;
} else if (key == "label_upper_bound") {
CopyInfoImpl(array_interface, &labels_upper_bound_);
return;
} else if (key == "feature_weights") {
CopyInfoImpl(array_interface, &feature_weigths);
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
auto valid =
thrust::all_of(thrust::device, d_feature_weights.data(),
d_feature_weights.data() + d_feature_weights.size(),
AllOfOp{});
CHECK(valid) << "Feature weight must be greater than 0.";
return;
} else { } else {
LOG(FATAL) << "Unknown metainfo: " << key; LOG(FATAL) << "Unknown metainfo: " << key;
} }

View File

@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater {
} }
} }
{ {
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode, column_sampler_.Init(fmat.Info().num_col_,
param_.colsample_bylevel, param_.colsample_bytree); fmat.Info().feature_weigths.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree);
} }
{ {
// setup temp space for each thread // setup temp space for each thread

View File

@ -266,8 +266,10 @@ struct GPUHistMakerDevice {
// Note that the column sampler must be passed by value because it is not // Note that the column sampler must be passed by value because it is not
// thread safe // thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) { void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
this->column_sampler.Init(num_columns, param.colsample_bynode, auto const& info = dmat->Info();
param.colsample_bylevel, param.colsample_bytree); this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),

View File

@ -841,11 +841,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
// store a pointer to the tree // store a pointer to the tree
p_last_tree_ = &tree; p_last_tree_ = &tree;
if (data_layout_ == kDenseDataOneBased) { if (data_layout_ == kDenseDataOneBased) {
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel, column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
param_.colsample_bytree, true); param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, true);
} else { } else {
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel, column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
param_.colsample_bytree, false); param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, false);
} }
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data: /* specialized code for dense data:

View File

@ -0,0 +1,13 @@
#include <gtest/gtest.h>
#include "../../../src/common/common.h"
namespace xgboost {
namespace common {
TEST(ArgSort, Basic) {
std::vector<float> inputs {3.0, 2.0, 1.0};
auto ret = ArgSort<bst_feature_t>(inputs);
std::vector<bst_feature_t> sol{2, 1, 0};
ASSERT_EQ(ret, sol);
}
} // namespace common
} // namespace xgboost

View File

@ -8,9 +8,10 @@ namespace common {
TEST(ColumnSampler, Test) { TEST(ColumnSampler, Test) {
int n = 128; int n = 128;
ColumnSampler cs; ColumnSampler cs;
std::vector<float> feature_weights;
// No node sampling // No node sampling
cs.Init(n, 1.0f, 0.5f, 0.5f); cs.Init(n, feature_weights, 1.0f, 0.5f, 0.5f);
auto set0 = cs.GetFeatureSet(0); auto set0 = cs.GetFeatureSet(0);
ASSERT_EQ(set0->Size(), 32); ASSERT_EQ(set0->Size(), 32);
@ -23,7 +24,7 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set2->Size(), 32); ASSERT_EQ(set2->Size(), 32);
// Node sampling // Node sampling
cs.Init(n, 0.5f, 1.0f, 0.5f); cs.Init(n, feature_weights, 0.5f, 1.0f, 0.5f);
auto set3 = cs.GetFeatureSet(0); auto set3 = cs.GetFeatureSet(0);
ASSERT_EQ(set3->Size(), 32); ASSERT_EQ(set3->Size(), 32);
@ -33,19 +34,19 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set4->Size(), 32); ASSERT_EQ(set4->Size(), 32);
// No level or node sampling, should be the same at different depth // No level or node sampling, should be the same at different depth
cs.Init(n, 1.0f, 1.0f, 0.5f); cs.Init(n, feature_weights, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
cs.GetFeatureSet(1)->HostVector()); cs.GetFeatureSet(1)->HostVector());
cs.Init(n, 1.0f, 1.0f, 1.0f); cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set5 = cs.GetFeatureSet(0); auto set5 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->Size(), n); ASSERT_EQ(set5->Size(), n);
cs.Init(n, 1.0f, 1.0f, 1.0f); cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set6 = cs.GetFeatureSet(0); auto set6 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->HostVector(), set6->HostVector()); ASSERT_EQ(set5->HostVector(), set6->HostVector());
// Should always be a minimum of one feature // Should always be a minimum of one feature
cs.Init(n, 1e-16f, 1e-16f, 1e-16f); cs.Init(n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1); ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
} }
@ -56,13 +57,13 @@ TEST(ColumnSampler, ThreadSynchronisation) {
size_t iterations = 10; size_t iterations = 10;
size_t levels = 5; size_t levels = 5;
std::vector<bst_feature_t> reference_result; std::vector<bst_feature_t> reference_result;
bool success = std::vector<float> feature_weights;
true; // Cannot use google test asserts in multithreaded region bool success = true; // Cannot use google test asserts in multithreaded region
#pragma omp parallel num_threads(num_threads) #pragma omp parallel num_threads(num_threads)
{ {
for (auto j = 0ull; j < iterations; j++) { for (auto j = 0ull; j < iterations; j++) {
ColumnSampler cs(j); ColumnSampler cs(j);
cs.Init(n, 0.5f, 0.5f, 0.5f); cs.Init(n, feature_weights, 0.5f, 0.5f, 0.5f);
for (auto level = 0ull; level < levels; level++) { for (auto level = 0ull; level < levels; level++) {
auto result = cs.GetFeatureSet(level)->ConstHostVector(); auto result = cs.GetFeatureSet(level)->ConstHostVector();
#pragma omp single #pragma omp single
@ -76,5 +77,54 @@ TEST(ColumnSampler, ThreadSynchronisation) {
} }
ASSERT_TRUE(success); ASSERT_TRUE(success);
} }
TEST(ColumnSampler, WeightedSampling) {
auto test_basic = [](int first) {
std::vector<float> feature_weights(2);
feature_weights[0] = std::abs(first - 1.0f);
feature_weights[1] = first - 0.0f;
ColumnSampler cs{0};
cs.Init(2, feature_weights, 1.0, 1.0, 0.5);
auto feature_sets = cs.GetFeatureSet(0);
auto const &h_feat_set = feature_sets->HostVector();
ASSERT_EQ(h_feat_set.size(), 1);
ASSERT_EQ(h_feat_set[0], first - 0);
};
test_basic(0);
test_basic(1);
size_t constexpr kCols = 64;
std::vector<float> feature_weights(kCols);
SimpleLCG rng;
SimpleRealUniformDistribution<float> dist(.0f, 12.0f);
std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); });
ColumnSampler cs{0};
cs.Init(kCols, feature_weights, 0.5f, 1.0f, 1.0f);
std::vector<bst_feature_t> features(kCols);
std::iota(features.begin(), features.end(), 0);
std::vector<float> freq(kCols, 0);
for (size_t i = 0; i < 1024; ++i) {
auto fset = cs.GetFeatureSet(0);
ASSERT_EQ(kCols * 0.5, fset->Size());
auto const& h_fset = fset->HostVector();
for (auto f : h_fset) {
freq[f] += 1.0f;
}
}
auto norm = std::accumulate(freq.cbegin(), freq.cend(), .0f);
for (auto& f : freq) {
f /= norm;
}
norm = std::accumulate(feature_weights.cbegin(), feature_weights.cend(), .0f);
for (auto& f : feature_weights) {
f /= norm;
}
for (size_t i = 0; i < feature_weights.size(); ++i) {
EXPECT_NEAR(freq[i], feature_weights[i], 1e-2);
}
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -204,12 +204,11 @@ TEST(GpuHist, EvaluateRootSplit) {
ASSERT_EQ(maker.hist.Data().size(), hist.size()); ASSERT_EQ(maker.hist.Data().size(), hist.size());
thrust::copy(hist.begin(), hist.end(), thrust::copy(hist.begin(), hist.end(),
maker.hist.Data().begin()); maker.hist.Data().begin());
std::vector<float> feature_weights;
maker.column_sampler.Init(kNCols, maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode,
param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree,
param.colsample_bylevel, false);
param.colsample_bytree,
false);
RegTree tree; RegTree tree;
MetaInfo info; MetaInfo info;

View File

@ -16,6 +16,20 @@ class TestDeviceQuantileDMatrix(unittest.TestCase):
match='is not supported for DeviceQuantileDMatrix'): match='is not supported for DeviceQuantileDMatrix'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_feature_weights(self):
import cupy as cp
rng = cp.random.RandomState(1994)
data = rng.randn(5, 5)
m = xgb.DMatrix(data)
feature_weights = rng.uniform(size=5)
m.set_info(feature_weights=feature_weights)
cp.testing.assert_array_equal(
cp.array(m.get_float_info('feature_weights')),
feature_weights.astype(np.float32))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_cupy_init(self): def test_dmatrix_cupy_init(self):
import cupy as cp import cupy as cp

View File

@ -1,12 +1,10 @@
import os import os
import subprocess import subprocess
import sys
import pytest import pytest
import testing as tm import testing as tm
CURRENT_DIR = os.path.dirname(__file__) ROOT_DIR = tm.PROJECT_ROOT
ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
DEMO_DIR = os.path.join(ROOT_DIR, 'demo') DEMO_DIR = os.path.join(ROOT_DIR, 'demo')
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python') PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python')
@ -19,21 +17,27 @@ def test_basic_walkthrough():
os.remove('dump.raw.txt') os.remove('dump.raw.txt')
@pytest.mark.skipif(**tm.no_matplotlib())
def test_custom_multiclass_objective(): def test_custom_multiclass_objective():
script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py') script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py')
cmd = ['python', script, '--plot=0'] cmd = ['python', script, '--plot=0']
subprocess.check_call(cmd) subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_custom_rmsle_objective(): def test_custom_rmsle_objective():
major, minor = sys.version_info[:2]
if minor < 6:
pytest.skip('Skipping RMLSE test due to Python version being too low.')
script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py') script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py')
cmd = ['python', script, '--plot=0'] cmd = ['python', script, '--plot=0']
subprocess.check_call(cmd) subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_feature_weights_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'feature_weights.py')
cmd = ['python', script, '--plot=0']
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_demo(): def test_sklearn_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py') script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py')

View File

@ -99,6 +99,11 @@ class TestDMatrix(unittest.TestCase):
X = rng.randn(100, 100) X = rng.randn(100, 100)
y = rng.randint(low=0, high=3, size=100) y = rng.randint(low=0, high=3, size=100)
d = xgb.DMatrix(X, y) d = xgb.DMatrix(X, y)
np.testing.assert_equal(d.get_label(), y.astype(np.float32))
fw = rng.uniform(size=100).astype(np.float32)
d.set_info(feature_weights=fw)
eval_res_0 = {} eval_res_0 = {}
booster = xgb.train( booster = xgb.train(
{'num_class': 3, 'objective': 'multi:softprob'}, d, {'num_class': 3, 'objective': 'multi:softprob'}, d,
@ -106,19 +111,23 @@ class TestDMatrix(unittest.TestCase):
predt = booster.predict(d) predt = booster.predict(d)
predt = predt.reshape(100 * 3, 1) predt = predt.reshape(100 * 3, 1)
d.set_base_margin(predt) d.set_base_margin(predt)
ridxs = [1, 2, 3, 4, 5, 6] ridxs = [1, 2, 3, 4, 5, 6]
d = d.slice(ridxs) sliced = d.slice(ridxs)
sliced_margin = d.get_float_info('base_margin')
sliced_margin = sliced.get_float_info('base_margin')
assert sliced_margin.shape[0] == len(ridxs) * 3 assert sliced_margin.shape[0] == len(ridxs) * 3
eval_res_1 = {} eval_res_1 = {}
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d, xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced,
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1) num_boost_round=2, evals=[(sliced, 'd')],
evals_result=eval_res_1)
eval_res_0 = eval_res_0['d']['merror'] eval_res_0 = eval_res_0['d']['merror']
eval_res_1 = eval_res_1['d']['merror'] eval_res_1 = eval_res_1['d']['merror']
for i in range(len(eval_res_0)): for i in range(len(eval_res_0)):
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02 assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
@ -196,13 +205,33 @@ class TestDMatrix(unittest.TestCase):
dtrain.get_float_info('base_margin') dtrain.get_float_info('base_margin')
dtrain.get_uint_info('group_ptr') dtrain.get_uint_info('group_ptr')
def test_feature_weights(self):
kRows = 10
kCols = 50
rng = np.random.RandomState(1994)
fw = rng.uniform(size=kCols)
X = rng.randn(kRows, kCols)
m = xgb.DMatrix(X)
m.set_info(feature_weights=fw)
np.testing.assert_allclose(fw, m.get_float_info('feature_weights'))
# Handle empty
m.set_info(feature_weights=np.empty((0, 0)))
assert m.get_float_info('feature_weights').shape[0] == 0
fw -= 1
def assign_weight():
m.set_info(feature_weights=fw)
self.assertRaises(ValueError, assign_weight)
def test_sparse_dmatrix_csr(self): def test_sparse_dmatrix_csr(self):
nrow = 100 nrow = 100
ncol = 1000 ncol = 1000
x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng) x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng)
assert x.indices.max() < ncol - 1 assert x.indices.max() < ncol - 1
x.data[:] = 1 x.data[:] = 1
dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow)) dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
watchlist = [(dtrain, 'train')] watchlist = [(dtrain, 'train')]
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0} param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
@ -215,7 +244,7 @@ class TestDMatrix(unittest.TestCase):
x = rand(nrow, ncol, density=0.0005, format='csc', random_state=rng) x = rand(nrow, ncol, density=0.0005, format='csc', random_state=rng)
assert x.indices.max() < nrow - 1 assert x.indices.max() < nrow - 1
x.data[:] = 1 x.data[:] = 1
dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow)) dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
watchlist = [(dtrain, 'train')] watchlist = [(dtrain, 'train')]
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0} param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}

View File

@ -1,3 +1,5 @@
import collections
import importlib.util
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
from xgboost.sklearn import XGBoostLabelEncoder from xgboost.sklearn import XGBoostLabelEncoder
@ -654,6 +656,7 @@ def test_validation_weights_xgbmodel():
eval_set=[(X_train, y_train), (X_test, y_test)], eval_set=[(X_train, y_train), (X_test, y_test)],
sample_weight_eval_set=[weights_train]) sample_weight_eval_set=[weights_train])
def test_validation_weights_xgbclassifier(): def test_validation_weights_xgbclassifier():
from sklearn.datasets import make_hastie_10_2 from sklearn.datasets import make_hastie_10_2
@ -920,6 +923,64 @@ def test_pandas_input():
np.array([0, 1])) np.array([0, 1]))
def run_feature_weights(increasing):
with TemporaryDirectory() as tmpdir:
kRows = 512
kCols = 64
colsample_bynode = 0.5
reg = xgb.XGBRegressor(tree_method='hist',
colsample_bynode=colsample_bynode)
X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
if increasing:
fw[i] *= float(i)
else:
fw[i] *= float(kCols - i)
reg.fit(X, y, feature_weights=fw)
model_path = os.path.join(tmpdir, 'model.json')
reg.save_model(model_path)
with open(model_path) as fd:
model = json.load(fd)
parser_path = os.path.join(tm.PROJECT_ROOT, 'demo', 'json-model',
'json_parser.py')
spec = importlib.util.spec_from_file_location("JsonParser",
parser_path)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
model = foo.Model(model)
splits = {}
total_nodes = 0
for tree in model.trees:
n_nodes = len(tree.nodes)
total_nodes += n_nodes
for n in range(n_nodes):
if tree.is_leaf(n):
continue
if splits.get(tree.split_index(n), None) is None:
splits[tree.split_index(n)] = 1
else:
splits[tree.split_index(n)] += 1
od = collections.OrderedDict(sorted(splits.items()))
tuples = [(k, v) for k, v in od.items()]
k, v = list(zip(*tuples))
w = np.polyfit(k, v, deg=1)
return w
def test_feature_weights():
poly_increasing = run_feature_weights(True)
poly_decreasing = run_feature_weights(False)
# Approxmated test, this is dependent on the implementation of random
# number generator in std library.
assert poly_increasing[0] > 0.08
assert poly_decreasing[0] < -0.08
class TestBoostFromPrediction(unittest.TestCase): class TestBoostFromPrediction(unittest.TestCase):
def run_boost_from_prediction(self, tree_method): def run_boost_from_prediction(self, tree_method):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer