Feature weights (#5962)
This commit is contained in:
parent
a418278064
commit
4d99c58a5f
@ -68,6 +68,7 @@
|
||||
#include "../src/learner.cc"
|
||||
#include "../src/logging.cc"
|
||||
#include "../src/common/common.cc"
|
||||
#include "../src/common/random.cc"
|
||||
#include "../src/common/charconv.cc"
|
||||
#include "../src/common/timer.cc"
|
||||
#include "../src/common/quantile.cc"
|
||||
|
||||
49
demo/guide-python/feature_weights.py
Normal file
49
demo/guide-python/feature_weights.py
Normal file
@ -0,0 +1,49 @@
|
||||
'''Using feature weight to change column sampling.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import xgboost
|
||||
from matplotlib import pyplot as plt
|
||||
import argparse
|
||||
|
||||
|
||||
def main(args):
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
kRows = 1000
|
||||
kCols = 10
|
||||
|
||||
X = rng.randn(kRows, kCols)
|
||||
y = rng.randn(kRows)
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
fw[i] *= float(i)
|
||||
|
||||
dtrain = xgboost.DMatrix(X, y)
|
||||
dtrain.set_info(feature_weights=fw)
|
||||
|
||||
bst = xgboost.train({'tree_method': 'hist',
|
||||
'colsample_bynode': 0.5},
|
||||
dtrain, num_boost_round=10,
|
||||
evals=[(dtrain, 'd')])
|
||||
featue_map = bst.get_fscore()
|
||||
# feature zero has 0 weight
|
||||
assert featue_map.get('f0', None) is None
|
||||
assert max(featue_map.values()) == featue_map.get('f9')
|
||||
|
||||
if args.plot:
|
||||
xgboost.plot_importance(bst)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--plot',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Set to 0 to disable plotting the evaluation history.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -94,7 +94,7 @@ class Tree:
|
||||
|
||||
class Model:
|
||||
'''Gradient boosted tree model.'''
|
||||
def __init__(self, m: dict):
|
||||
def __init__(self, model: dict):
|
||||
'''Construct the Model from JSON object.
|
||||
|
||||
parameters
|
||||
|
||||
@ -107,6 +107,10 @@ Parameters for Tree Booster
|
||||
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
|
||||
each split.
|
||||
|
||||
On Python interface, one can set the ``feature_weights`` for DMatrix to define the
|
||||
probability of each feature being selected when using column sampling. There's a
|
||||
similar parameter for ``fit`` method in sklearn interface.
|
||||
|
||||
* ``lambda`` [default=1, alias: ``reg_lambda``]
|
||||
|
||||
- L2 regularization term on weights. Increasing this value will make model more conservative.
|
||||
@ -224,7 +228,7 @@ Parameters for Tree Booster
|
||||
list is a group of indices of features that are allowed to interact with each other.
|
||||
See tutorial for more information
|
||||
|
||||
Additional parameters for ``hist`` and ```gpu_hist`` tree method
|
||||
Additional parameters for ``hist`` and ``gpu_hist`` tree method
|
||||
================================================================
|
||||
|
||||
* ``single_precision_histogram``, [default=``false``]
|
||||
|
||||
@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
bst_ulong *size,
|
||||
const char ***out_features);
|
||||
|
||||
/*!
|
||||
* \brief Set meta info from dense matrix. Valid field names are:
|
||||
*
|
||||
* - label
|
||||
* - weight
|
||||
* - base_margin
|
||||
* - group
|
||||
* - label_lower_bound
|
||||
* - label_upper_bound
|
||||
* - feature_weights
|
||||
*
|
||||
* \param handle An instance of data matrix
|
||||
* \param field Feild name
|
||||
* \param data Pointer to consecutive memory storing data.
|
||||
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
|
||||
* of bytes.)
|
||||
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
|
||||
*
|
||||
* float = 1
|
||||
* double = 2
|
||||
* uint32_t = 3
|
||||
* uint64_t = 4
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void *data, bst_ulong size, int type);
|
||||
|
||||
/*!
|
||||
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
|
||||
* \param handle a instance of data matrix
|
||||
|
||||
@ -88,34 +88,17 @@ class MetaInfo {
|
||||
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
|
||||
*/
|
||||
HostDeviceVector<FeatureType> feature_types;
|
||||
/*
|
||||
* \brief Weight of each feature, used to define the probability of each feature being
|
||||
* selected when using column sampling.
|
||||
*/
|
||||
HostDeviceVector<float> feature_weigths;
|
||||
|
||||
/*! \brief default constructor */
|
||||
MetaInfo() = default;
|
||||
MetaInfo(MetaInfo&& that) = default;
|
||||
MetaInfo& operator=(MetaInfo&& that) = default;
|
||||
MetaInfo& operator=(MetaInfo const& that) {
|
||||
this->num_row_ = that.num_row_;
|
||||
this->num_col_ = that.num_col_;
|
||||
this->num_nonzero_ = that.num_nonzero_;
|
||||
|
||||
this->labels_.Resize(that.labels_.Size());
|
||||
this->labels_.Copy(that.labels_);
|
||||
|
||||
this->group_ptr_ = that.group_ptr_;
|
||||
|
||||
this->weights_.Resize(that.weights_.Size());
|
||||
this->weights_.Copy(that.weights_);
|
||||
|
||||
this->base_margin_.Resize(that.base_margin_.Size());
|
||||
this->base_margin_.Copy(that.base_margin_);
|
||||
|
||||
this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
|
||||
this->labels_lower_bound_.Copy(that.labels_lower_bound_);
|
||||
|
||||
this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
|
||||
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
|
||||
return *this;
|
||||
}
|
||||
MetaInfo& operator=(MetaInfo const& that) = delete;
|
||||
|
||||
/*!
|
||||
* \brief Validate all metainfo.
|
||||
|
||||
@ -455,7 +455,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
feature_types=None,
|
||||
feature_weights=None):
|
||||
'''Set meta info for DMatrix.'''
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
@ -473,6 +474,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.feature_names = feature_names
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
if feature_weights is not None:
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(matrix=self, data=feature_weights,
|
||||
name='feature_weights')
|
||||
|
||||
def get_float_info(self, field):
|
||||
"""Get float property from the DMatrix.
|
||||
|
||||
@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
|
||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||
|
||||
|
||||
def _to_data_type(dtype: str, name: str):
|
||||
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
|
||||
if dtype not in dtype_map.keys():
|
||||
raise TypeError(
|
||||
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
|
||||
f'for {name}.')
|
||||
return dtype_map[dtype]
|
||||
|
||||
|
||||
def _validate_meta_shape(data):
|
||||
if hasattr(data, 'shape'):
|
||||
assert len(data.shape) == 1 or (
|
||||
len(data.shape) == 2 and
|
||||
(data.shape[1] == 0 or data.shape[1] == 1))
|
||||
|
||||
|
||||
def _meta_from_numpy(data, field, dtype, handle):
|
||||
data = _maybe_np_slice(data, dtype)
|
||||
if dtype == 'uint32':
|
||||
c_data = c_array(ctypes.c_uint32, data)
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_uint, data),
|
||||
c_bst_ulong(len(data))))
|
||||
elif dtype == 'float':
|
||||
c_data = c_array(ctypes.c_float, data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
else:
|
||||
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
|
||||
interface = data.__array_interface__
|
||||
assert interface.get('mask', None) is None, 'Masked array is not supported'
|
||||
size = data.shape[0]
|
||||
|
||||
c_type = _to_data_type(str(data.dtype), field)
|
||||
ptr = interface['data'][0]
|
||||
ptr = ctypes.c_void_p(ptr)
|
||||
_check_call(_LIB.XGDMatrixSetDenseInfo(
|
||||
handle,
|
||||
c_str(field),
|
||||
ptr,
|
||||
c_bst_ulong(size),
|
||||
c_type
|
||||
))
|
||||
|
||||
|
||||
def _meta_from_list(data, field, dtype, handle):
|
||||
@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
|
||||
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
||||
'''Dispatch for meta info.'''
|
||||
handle = matrix.handle
|
||||
_validate_meta_shape(data)
|
||||
if data is None:
|
||||
return
|
||||
if _is_list(data):
|
||||
|
||||
@ -441,6 +441,7 @@ class XGBModel(XGBModelBase):
|
||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||
verbose=True, xgb_model=None, sample_weight_eval_set=None,
|
||||
feature_weights=None,
|
||||
callbacks=None):
|
||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||
"""Fit gradient boosting model
|
||||
@ -459,9 +460,6 @@ class XGBModel(XGBModelBase):
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, list of str, or callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst.
|
||||
@ -490,6 +488,13 @@ class XGBModel(XGBModelBase):
|
||||
xgb_model : str
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||
loaded before training (allows training continuation).
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
feature_weights: array_like
|
||||
Weight for each feature, defines the probability of each feature
|
||||
being selected when colsample is being used. All values must be
|
||||
greater than 0, otherwise a `ValueError` is thrown.
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
It is possible to use predefined callbacks by using :ref:`callback_api`.
|
||||
@ -498,6 +503,7 @@ class XGBModel(XGBModelBase):
|
||||
.. code-block:: python
|
||||
|
||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||
|
||||
"""
|
||||
self.n_features_in_ = X.shape[1]
|
||||
|
||||
@ -505,6 +511,7 @@ class XGBModel(XGBModelBase):
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
nthread=self.n_jobs)
|
||||
train_dmatrix.set_info(feature_weights=feature_weights)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
@ -759,7 +766,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, callbacks=None):
|
||||
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
|
||||
evals_result = {}
|
||||
@ -821,6 +828,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix.set_info(feature_weights=feature_weights)
|
||||
|
||||
self._Booster = train(xgb_options, train_dmatrix,
|
||||
self.get_num_boosting_rounds(),
|
||||
@ -1101,10 +1109,10 @@ class XGBRanker(XGBModel):
|
||||
raise ValueError("please use XGBRanker for ranking task")
|
||||
|
||||
def fit(self, X, y, group, sample_weight=None, base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
||||
eval_set=None, sample_weight_eval_set=None,
|
||||
eval_group=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||
callbacks=None):
|
||||
feature_weights=None, callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""Fit gradient boosting ranker
|
||||
|
||||
@ -1170,6 +1178,10 @@ class XGBRanker(XGBModel):
|
||||
xgb_model : str
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost
|
||||
model to be loaded before training (allows training continuation).
|
||||
feature_weights: array_like
|
||||
Weight for each feature, defines the probability of each feature
|
||||
being selected when colsample is being used. All values must be
|
||||
greater than 0, otherwise a `ValueError` is thrown.
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each
|
||||
iteration. It is possible to use predefined callbacks by using
|
||||
@ -1205,6 +1217,7 @@ class XGBRanker(XGBModel):
|
||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix.set_info(feature_weights=feature_weights)
|
||||
train_dmatrix.set_group(group)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
@ -316,6 +316,17 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void *data, xgboost::bst_ulong size,
|
||||
int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
info.SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
|
||||
@ -9,12 +9,15 @@
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/system/cuda/error.h>
|
||||
@ -160,6 +163,15 @@ inline void AssertOneAPISupport() {
|
||||
#endif // XGBOOST_USE_ONEAPI
|
||||
}
|
||||
|
||||
template <typename Idx, typename V, typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) {
|
||||
std::vector<Idx> result(array.size());
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
std::stable_sort(
|
||||
result.begin(), result.end(),
|
||||
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
|
||||
return result;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
38
src/common/random.cc
Normal file
38
src/common/random.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
* \file random.cc
|
||||
*/
|
||||
#include "random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
|
||||
float colsample) {
|
||||
if (colsample == 1.0f) {
|
||||
return p_features;
|
||||
}
|
||||
const auto &features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
auto &new_features = *p_new_features;
|
||||
|
||||
if (feature_weights_.size() != 0) {
|
||||
new_features.HostVector() = WeightedSamplingWithoutReplacement(
|
||||
p_features->HostVector(), feature_weights_, n);
|
||||
} else {
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
}
|
||||
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
|
||||
return p_new_features;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* \file random.h
|
||||
* \brief Utility related to random.
|
||||
* \author Tianqi Chen
|
||||
@ -10,14 +10,17 @@
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -75,6 +78,38 @@ using GlobalRandomEngine = RandomEngine;
|
||||
*/
|
||||
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
|
||||
/*
|
||||
* Original paper:
|
||||
* Weighted Random Sampling (2005; Efraimidis, Spirakis)
|
||||
*
|
||||
* Blog:
|
||||
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
|
||||
// ES sampling.
|
||||
CHECK_EQ(array.size(), weights.size());
|
||||
std::vector<float> keys(weights.size());
|
||||
std::uniform_real_distribution<float> dist;
|
||||
auto& rng = GlobalRandom();
|
||||
for (size_t i = 0; i < array.size(); ++i) {
|
||||
auto w = std::max(weights.at(i), kRtEps);
|
||||
auto u = dist(rng);
|
||||
auto k = std::log(u) / w;
|
||||
keys[i] = k;
|
||||
}
|
||||
auto ind = ArgSort<size_t>(keys, std::greater<>{});
|
||||
ind.resize(n);
|
||||
|
||||
std::vector<T> results(ind.size());
|
||||
for (size_t k = 0; k < ind.size(); ++k) {
|
||||
auto idx = ind[k];
|
||||
results[k] = array[idx];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* \class ColumnSampler
|
||||
*
|
||||
@ -82,36 +117,18 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
* colsample_bynode parameters. Should be initialised before tree construction and to
|
||||
* reset when tree construction is completed.
|
||||
*/
|
||||
|
||||
class ColumnSampler {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
|
||||
std::vector<float> feature_weights_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
|
||||
if (colsample == 1.0f) return p_features;
|
||||
const auto& features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
auto& new_features = *p_new_features;
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
std::sort(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end());
|
||||
|
||||
return p_new_features;
|
||||
}
|
||||
|
||||
public:
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample);
|
||||
/**
|
||||
* \brief Column sampler constructor.
|
||||
* \note This constructor manually sets the rng seed
|
||||
@ -139,8 +156,10 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel,
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel,
|
||||
float colsample_bytree, bool skip_index_0 = false) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
|
||||
@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
} else {
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
|
||||
}
|
||||
|
||||
out.feature_weigths.Resize(this->feature_weigths.Size());
|
||||
out.feature_weigths.Copy(this->feature_weigths);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -377,6 +380,16 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
labels.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||
} else if (!std::strcmp(key, "feature_weights")) {
|
||||
auto &h_feature_weights = feature_weigths.HostVector();
|
||||
h_feature_weights.resize(num);
|
||||
DISPATCH_CONST_PTR(
|
||||
dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
|
||||
bool valid =
|
||||
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
||||
[](float w) { return w >= 0; });
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
|
||||
}
|
||||
@ -396,6 +409,8 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype,
|
||||
vec = &this->labels_lower_bound_.HostVector();
|
||||
} else if (!std::strcmp(key, "label_upper_bound")) {
|
||||
vec = &this->labels_upper_bound_.HostVector();
|
||||
} else if (!std::strcmp(key, "feature_weights")) {
|
||||
vec = &this->feature_weigths.HostVector();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown float field name: " << key;
|
||||
}
|
||||
@ -497,6 +512,11 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
|
||||
auto &h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||
}
|
||||
if (!that.feature_weigths.Empty()) {
|
||||
this->feature_weigths.Resize(that.feature_weigths.Size());
|
||||
this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx());
|
||||
this->feature_weigths.Copy(that.feature_weigths);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
@ -538,6 +558,11 @@ void MetaInfo::Validate(int32_t device) const {
|
||||
check_device(labels_lower_bound_);
|
||||
return;
|
||||
}
|
||||
if (feature_weigths.Size() != 0) {
|
||||
CHECK_EQ(feature_weigths.Size(), num_col_)
|
||||
<< "Size of feature_weights must equal to number of columns.";
|
||||
check_device(feature_weigths);
|
||||
}
|
||||
if (labels_upper_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
||||
<< "Size of label_upper_bound must equal to number of rows.";
|
||||
|
||||
@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
std::partial_sum(out->begin(), out->end(), out->begin());
|
||||
}
|
||||
|
||||
namespace {
|
||||
// thrust::all_of tries to copy lambda function.
|
||||
struct AllOfOp {
|
||||
__device__ bool operator()(float w) {
|
||||
return w >= 0;
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
return;
|
||||
} else if (key == "label_lower_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_lower_bound_);
|
||||
return;
|
||||
} else if (key == "label_upper_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_upper_bound_);
|
||||
return;
|
||||
} else if (key == "feature_weights") {
|
||||
CopyInfoImpl(array_interface, &feature_weigths);
|
||||
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
||||
auto valid =
|
||||
thrust::all_of(thrust::device, d_feature_weights.data(),
|
||||
d_feature_weights.data() + d_feature_weights.size(),
|
||||
AllOfOp{});
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
return;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown metainfo: " << key;
|
||||
}
|
||||
|
||||
@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
{
|
||||
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
column_sampler_.Init(fmat.Info().num_col_,
|
||||
fmat.Info().feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
}
|
||||
{
|
||||
// setup temp space for each thread
|
||||
|
||||
@ -266,8 +266,10 @@ struct GPUHistMakerDevice {
|
||||
// Note that the column sampler must be passed by value because it is not
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
this->column_sampler.Init(num_columns, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
auto const& info = dmat->Info();
|
||||
this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
|
||||
@ -841,11 +841,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
if (data_layout_ == kDenseDataOneBased) {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
} else {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
}
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
/* specialized code for dense data:
|
||||
|
||||
13
tests/cpp/common/test_common.cc
Normal file
13
tests/cpp/common/test_common.cc
Normal file
@ -0,0 +1,13 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../src/common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(ArgSort, Basic) {
|
||||
std::vector<float> inputs {3.0, 2.0, 1.0};
|
||||
auto ret = ArgSort<bst_feature_t>(inputs);
|
||||
std::vector<bst_feature_t> sol{2, 1, 0};
|
||||
ASSERT_EQ(ret, sol);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -8,9 +8,10 @@ namespace common {
|
||||
TEST(ColumnSampler, Test) {
|
||||
int n = 128;
|
||||
ColumnSampler cs;
|
||||
std::vector<float> feature_weights;
|
||||
|
||||
// No node sampling
|
||||
cs.Init(n, 1.0f, 0.5f, 0.5f);
|
||||
cs.Init(n, feature_weights, 1.0f, 0.5f, 0.5f);
|
||||
auto set0 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0->Size(), 32);
|
||||
|
||||
@ -23,7 +24,7 @@ TEST(ColumnSampler, Test) {
|
||||
ASSERT_EQ(set2->Size(), 32);
|
||||
|
||||
// Node sampling
|
||||
cs.Init(n, 0.5f, 1.0f, 0.5f);
|
||||
cs.Init(n, feature_weights, 0.5f, 1.0f, 0.5f);
|
||||
auto set3 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set3->Size(), 32);
|
||||
|
||||
@ -33,19 +34,19 @@ TEST(ColumnSampler, Test) {
|
||||
ASSERT_EQ(set4->Size(), 32);
|
||||
|
||||
// No level or node sampling, should be the same at different depth
|
||||
cs.Init(n, 1.0f, 1.0f, 0.5f);
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 0.5f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
|
||||
cs.GetFeatureSet(1)->HostVector());
|
||||
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
auto set5 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5->Size(), n);
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
auto set6 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5->HostVector(), set6->HostVector());
|
||||
|
||||
// Should always be a minimum of one feature
|
||||
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
|
||||
cs.Init(n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
|
||||
}
|
||||
|
||||
@ -56,13 +57,13 @@ TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
size_t iterations = 10;
|
||||
size_t levels = 5;
|
||||
std::vector<bst_feature_t> reference_result;
|
||||
bool success =
|
||||
true; // Cannot use google test asserts in multithreaded region
|
||||
std::vector<float> feature_weights;
|
||||
bool success = true; // Cannot use google test asserts in multithreaded region
|
||||
#pragma omp parallel num_threads(num_threads)
|
||||
{
|
||||
for (auto j = 0ull; j < iterations; j++) {
|
||||
ColumnSampler cs(j);
|
||||
cs.Init(n, 0.5f, 0.5f, 0.5f);
|
||||
cs.Init(n, feature_weights, 0.5f, 0.5f, 0.5f);
|
||||
for (auto level = 0ull; level < levels; level++) {
|
||||
auto result = cs.GetFeatureSet(level)->ConstHostVector();
|
||||
#pragma omp single
|
||||
@ -76,5 +77,54 @@ TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
}
|
||||
ASSERT_TRUE(success);
|
||||
}
|
||||
|
||||
TEST(ColumnSampler, WeightedSampling) {
|
||||
auto test_basic = [](int first) {
|
||||
std::vector<float> feature_weights(2);
|
||||
feature_weights[0] = std::abs(first - 1.0f);
|
||||
feature_weights[1] = first - 0.0f;
|
||||
ColumnSampler cs{0};
|
||||
cs.Init(2, feature_weights, 1.0, 1.0, 0.5);
|
||||
auto feature_sets = cs.GetFeatureSet(0);
|
||||
auto const &h_feat_set = feature_sets->HostVector();
|
||||
ASSERT_EQ(h_feat_set.size(), 1);
|
||||
ASSERT_EQ(h_feat_set[0], first - 0);
|
||||
};
|
||||
|
||||
test_basic(0);
|
||||
test_basic(1);
|
||||
|
||||
size_t constexpr kCols = 64;
|
||||
std::vector<float> feature_weights(kCols);
|
||||
SimpleLCG rng;
|
||||
SimpleRealUniformDistribution<float> dist(.0f, 12.0f);
|
||||
std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); });
|
||||
ColumnSampler cs{0};
|
||||
cs.Init(kCols, feature_weights, 0.5f, 1.0f, 1.0f);
|
||||
std::vector<bst_feature_t> features(kCols);
|
||||
std::iota(features.begin(), features.end(), 0);
|
||||
std::vector<float> freq(kCols, 0);
|
||||
for (size_t i = 0; i < 1024; ++i) {
|
||||
auto fset = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(kCols * 0.5, fset->Size());
|
||||
auto const& h_fset = fset->HostVector();
|
||||
for (auto f : h_fset) {
|
||||
freq[f] += 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
auto norm = std::accumulate(freq.cbegin(), freq.cend(), .0f);
|
||||
for (auto& f : freq) {
|
||||
f /= norm;
|
||||
}
|
||||
norm = std::accumulate(feature_weights.cbegin(), feature_weights.cend(), .0f);
|
||||
for (auto& f : feature_weights) {
|
||||
f /= norm;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < feature_weights.size(); ++i) {
|
||||
EXPECT_NEAR(freq[i], feature_weights[i], 1e-2);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -204,12 +204,11 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
ASSERT_EQ(maker.hist.Data().size(), hist.size());
|
||||
thrust::copy(hist.begin(), hist.end(),
|
||||
maker.hist.Data().begin());
|
||||
std::vector<float> feature_weights;
|
||||
|
||||
maker.column_sampler.Init(kNCols,
|
||||
param.colsample_bynode,
|
||||
param.colsample_bylevel,
|
||||
param.colsample_bytree,
|
||||
false);
|
||||
maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree,
|
||||
false);
|
||||
|
||||
RegTree tree;
|
||||
MetaInfo info;
|
||||
|
||||
@ -16,6 +16,20 @@ class TestDeviceQuantileDMatrix(unittest.TestCase):
|
||||
match='is not supported for DeviceQuantileDMatrix'):
|
||||
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_feature_weights(self):
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
data = rng.randn(5, 5)
|
||||
m = xgb.DMatrix(data)
|
||||
|
||||
feature_weights = rng.uniform(size=5)
|
||||
m.set_info(feature_weights=feature_weights)
|
||||
|
||||
cp.testing.assert_array_equal(
|
||||
cp.array(m.get_float_info('feature_weights')),
|
||||
feature_weights.astype(np.float32))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dmatrix_cupy_init(self):
|
||||
import cupy as cp
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
|
||||
ROOT_DIR = tm.PROJECT_ROOT
|
||||
DEMO_DIR = os.path.join(ROOT_DIR, 'demo')
|
||||
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python')
|
||||
|
||||
@ -19,21 +17,27 @@ def test_basic_walkthrough():
|
||||
os.remove('dump.raw.txt')
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||
def test_custom_multiclass_objective():
|
||||
script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py')
|
||||
cmd = ['python', script, '--plot=0']
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||
def test_custom_rmsle_objective():
|
||||
major, minor = sys.version_info[:2]
|
||||
if minor < 6:
|
||||
pytest.skip('Skipping RMLSE test due to Python version being too low.')
|
||||
script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py')
|
||||
cmd = ['python', script, '--plot=0']
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||
def test_feature_weights_demo():
|
||||
script = os.path.join(PYTHON_DEMO_DIR, 'feature_weights.py')
|
||||
cmd = ['python', script, '--plot=0']
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn_demo():
|
||||
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py')
|
||||
|
||||
@ -99,6 +99,11 @@ class TestDMatrix(unittest.TestCase):
|
||||
X = rng.randn(100, 100)
|
||||
y = rng.randint(low=0, high=3, size=100)
|
||||
d = xgb.DMatrix(X, y)
|
||||
np.testing.assert_equal(d.get_label(), y.astype(np.float32))
|
||||
|
||||
fw = rng.uniform(size=100).astype(np.float32)
|
||||
d.set_info(feature_weights=fw)
|
||||
|
||||
eval_res_0 = {}
|
||||
booster = xgb.train(
|
||||
{'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
@ -106,19 +111,23 @@ class TestDMatrix(unittest.TestCase):
|
||||
|
||||
predt = booster.predict(d)
|
||||
predt = predt.reshape(100 * 3, 1)
|
||||
|
||||
d.set_base_margin(predt)
|
||||
|
||||
ridxs = [1, 2, 3, 4, 5, 6]
|
||||
d = d.slice(ridxs)
|
||||
sliced_margin = d.get_float_info('base_margin')
|
||||
sliced = d.slice(ridxs)
|
||||
|
||||
sliced_margin = sliced.get_float_info('base_margin')
|
||||
assert sliced_margin.shape[0] == len(ridxs) * 3
|
||||
|
||||
eval_res_1 = {}
|
||||
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
|
||||
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced,
|
||||
num_boost_round=2, evals=[(sliced, 'd')],
|
||||
evals_result=eval_res_1)
|
||||
|
||||
eval_res_0 = eval_res_0['d']['merror']
|
||||
eval_res_1 = eval_res_1['d']['merror']
|
||||
|
||||
for i in range(len(eval_res_0)):
|
||||
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
|
||||
|
||||
@ -196,13 +205,33 @@ class TestDMatrix(unittest.TestCase):
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('group_ptr')
|
||||
|
||||
def test_feature_weights(self):
|
||||
kRows = 10
|
||||
kCols = 50
|
||||
rng = np.random.RandomState(1994)
|
||||
fw = rng.uniform(size=kCols)
|
||||
X = rng.randn(kRows, kCols)
|
||||
m = xgb.DMatrix(X)
|
||||
m.set_info(feature_weights=fw)
|
||||
np.testing.assert_allclose(fw, m.get_float_info('feature_weights'))
|
||||
# Handle empty
|
||||
m.set_info(feature_weights=np.empty((0, 0)))
|
||||
|
||||
assert m.get_float_info('feature_weights').shape[0] == 0
|
||||
|
||||
fw -= 1
|
||||
|
||||
def assign_weight():
|
||||
m.set_info(feature_weights=fw)
|
||||
self.assertRaises(ValueError, assign_weight)
|
||||
|
||||
def test_sparse_dmatrix_csr(self):
|
||||
nrow = 100
|
||||
ncol = 1000
|
||||
x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng)
|
||||
assert x.indices.max() < ncol - 1
|
||||
x.data[:] = 1
|
||||
dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow))
|
||||
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
||||
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
||||
watchlist = [(dtrain, 'train')]
|
||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
||||
@ -215,7 +244,7 @@ class TestDMatrix(unittest.TestCase):
|
||||
x = rand(nrow, ncol, density=0.0005, format='csc', random_state=rng)
|
||||
assert x.indices.max() < nrow - 1
|
||||
x.data[:] = 1
|
||||
dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow))
|
||||
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
||||
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
||||
watchlist = [(dtrain, 'train')]
|
||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import collections
|
||||
import importlib.util
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from xgboost.sklearn import XGBoostLabelEncoder
|
||||
@ -654,6 +656,7 @@ def test_validation_weights_xgbmodel():
|
||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||
sample_weight_eval_set=[weights_train])
|
||||
|
||||
|
||||
def test_validation_weights_xgbclassifier():
|
||||
from sklearn.datasets import make_hastie_10_2
|
||||
|
||||
@ -920,6 +923,64 @@ def test_pandas_input():
|
||||
np.array([0, 1]))
|
||||
|
||||
|
||||
def run_feature_weights(increasing):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
kRows = 512
|
||||
kCols = 64
|
||||
colsample_bynode = 0.5
|
||||
reg = xgb.XGBRegressor(tree_method='hist',
|
||||
colsample_bynode=colsample_bynode)
|
||||
X = rng.randn(kRows, kCols)
|
||||
y = rng.randn(kRows)
|
||||
fw = np.ones(shape=(kCols,))
|
||||
for i in range(kCols):
|
||||
if increasing:
|
||||
fw[i] *= float(i)
|
||||
else:
|
||||
fw[i] *= float(kCols - i)
|
||||
|
||||
reg.fit(X, y, feature_weights=fw)
|
||||
model_path = os.path.join(tmpdir, 'model.json')
|
||||
reg.save_model(model_path)
|
||||
with open(model_path) as fd:
|
||||
model = json.load(fd)
|
||||
|
||||
parser_path = os.path.join(tm.PROJECT_ROOT, 'demo', 'json-model',
|
||||
'json_parser.py')
|
||||
spec = importlib.util.spec_from_file_location("JsonParser",
|
||||
parser_path)
|
||||
foo = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(foo)
|
||||
model = foo.Model(model)
|
||||
splits = {}
|
||||
total_nodes = 0
|
||||
for tree in model.trees:
|
||||
n_nodes = len(tree.nodes)
|
||||
total_nodes += n_nodes
|
||||
for n in range(n_nodes):
|
||||
if tree.is_leaf(n):
|
||||
continue
|
||||
if splits.get(tree.split_index(n), None) is None:
|
||||
splits[tree.split_index(n)] = 1
|
||||
else:
|
||||
splits[tree.split_index(n)] += 1
|
||||
|
||||
od = collections.OrderedDict(sorted(splits.items()))
|
||||
tuples = [(k, v) for k, v in od.items()]
|
||||
k, v = list(zip(*tuples))
|
||||
w = np.polyfit(k, v, deg=1)
|
||||
return w
|
||||
|
||||
|
||||
def test_feature_weights():
|
||||
poly_increasing = run_feature_weights(True)
|
||||
poly_decreasing = run_feature_weights(False)
|
||||
# Approxmated test, this is dependent on the implementation of random
|
||||
# number generator in std library.
|
||||
assert poly_increasing[0] > 0.08
|
||||
assert poly_decreasing[0] < -0.08
|
||||
|
||||
|
||||
class TestBoostFromPrediction(unittest.TestCase):
|
||||
def run_boost_from_prediction(self, tree_method):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user