Feature weights (#5962)
This commit is contained in:
@@ -316,6 +316,17 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void *data, xgboost::bst_ulong size,
|
||||
int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
info.SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
|
||||
@@ -9,12 +9,15 @@
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/system/cuda/error.h>
|
||||
@@ -160,6 +163,15 @@ inline void AssertOneAPISupport() {
|
||||
#endif // XGBOOST_USE_ONEAPI
|
||||
}
|
||||
|
||||
template <typename Idx, typename V, typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) {
|
||||
std::vector<Idx> result(array.size());
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
std::stable_sort(
|
||||
result.begin(), result.end(),
|
||||
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
|
||||
return result;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
38
src/common/random.cc
Normal file
38
src/common/random.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
* \file random.cc
|
||||
*/
|
||||
#include "random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
|
||||
float colsample) {
|
||||
if (colsample == 1.0f) {
|
||||
return p_features;
|
||||
}
|
||||
const auto &features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
auto &new_features = *p_new_features;
|
||||
|
||||
if (feature_weights_.size() != 0) {
|
||||
new_features.HostVector() = WeightedSamplingWithoutReplacement(
|
||||
p_features->HostVector(), feature_weights_, n);
|
||||
} else {
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
}
|
||||
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
|
||||
return p_new_features;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* \file random.h
|
||||
* \brief Utility related to random.
|
||||
* \author Tianqi Chen
|
||||
@@ -10,14 +10,17 @@
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@@ -75,6 +78,38 @@ using GlobalRandomEngine = RandomEngine;
|
||||
*/
|
||||
GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
|
||||
/*
|
||||
* Original paper:
|
||||
* Weighted Random Sampling (2005; Efraimidis, Spirakis)
|
||||
*
|
||||
* Blog:
|
||||
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
|
||||
// ES sampling.
|
||||
CHECK_EQ(array.size(), weights.size());
|
||||
std::vector<float> keys(weights.size());
|
||||
std::uniform_real_distribution<float> dist;
|
||||
auto& rng = GlobalRandom();
|
||||
for (size_t i = 0; i < array.size(); ++i) {
|
||||
auto w = std::max(weights.at(i), kRtEps);
|
||||
auto u = dist(rng);
|
||||
auto k = std::log(u) / w;
|
||||
keys[i] = k;
|
||||
}
|
||||
auto ind = ArgSort<size_t>(keys, std::greater<>{});
|
||||
ind.resize(n);
|
||||
|
||||
std::vector<T> results(ind.size());
|
||||
for (size_t k = 0; k < ind.size(); ++k) {
|
||||
auto idx = ind[k];
|
||||
results[k] = array[idx];
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* \class ColumnSampler
|
||||
*
|
||||
@@ -82,36 +117,18 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
* colsample_bynode parameters. Should be initialised before tree construction and to
|
||||
* reset when tree construction is completed.
|
||||
*/
|
||||
|
||||
class ColumnSampler {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
|
||||
std::vector<float> feature_weights_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
|
||||
if (colsample == 1.0f) return p_features;
|
||||
const auto& features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
auto& new_features = *p_new_features;
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(),
|
||||
new_features.HostVector().begin());
|
||||
std::shuffle(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end(), rng_);
|
||||
new_features.Resize(n);
|
||||
std::sort(new_features.HostVector().begin(),
|
||||
new_features.HostVector().end());
|
||||
|
||||
return p_new_features;
|
||||
}
|
||||
|
||||
public:
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample);
|
||||
/**
|
||||
* \brief Column sampler constructor.
|
||||
* \note This constructor manually sets the rng seed
|
||||
@@ -139,8 +156,10 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel,
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel,
|
||||
float colsample_bytree, bool skip_index_0 = false) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
|
||||
@@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
} else {
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
|
||||
}
|
||||
|
||||
out.feature_weigths.Resize(this->feature_weigths.Size());
|
||||
out.feature_weigths.Copy(this->feature_weigths);
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -377,6 +380,16 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
labels.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||
} else if (!std::strcmp(key, "feature_weights")) {
|
||||
auto &h_feature_weights = feature_weigths.HostVector();
|
||||
h_feature_weights.resize(num);
|
||||
DISPATCH_CONST_PTR(
|
||||
dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
|
||||
bool valid =
|
||||
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
|
||||
[](float w) { return w >= 0; });
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
|
||||
}
|
||||
@@ -396,6 +409,8 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype,
|
||||
vec = &this->labels_lower_bound_.HostVector();
|
||||
} else if (!std::strcmp(key, "label_upper_bound")) {
|
||||
vec = &this->labels_upper_bound_.HostVector();
|
||||
} else if (!std::strcmp(key, "feature_weights")) {
|
||||
vec = &this->feature_weigths.HostVector();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown float field name: " << key;
|
||||
}
|
||||
@@ -497,6 +512,11 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
|
||||
auto &h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(this->feature_type_names, &h_feature_types);
|
||||
}
|
||||
if (!that.feature_weigths.Empty()) {
|
||||
this->feature_weigths.Resize(that.feature_weigths.Size());
|
||||
this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx());
|
||||
this->feature_weigths.Copy(that.feature_weigths);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
@@ -538,6 +558,11 @@ void MetaInfo::Validate(int32_t device) const {
|
||||
check_device(labels_lower_bound_);
|
||||
return;
|
||||
}
|
||||
if (feature_weigths.Size() != 0) {
|
||||
CHECK_EQ(feature_weigths.Size(), num_col_)
|
||||
<< "Size of feature_weights must equal to number of columns.";
|
||||
check_device(feature_weigths);
|
||||
}
|
||||
if (labels_upper_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
||||
<< "Size of label_upper_bound must equal to number of rows.";
|
||||
|
||||
@@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
std::partial_sum(out->begin(), out->end(), out->begin());
|
||||
}
|
||||
|
||||
namespace {
|
||||
// thrust::all_of tries to copy lambda function.
|
||||
struct AllOfOp {
|
||||
__device__ bool operator()(float w) {
|
||||
return w >= 0;
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
@@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
return;
|
||||
} else if (key == "label_lower_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_lower_bound_);
|
||||
return;
|
||||
} else if (key == "label_upper_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_upper_bound_);
|
||||
return;
|
||||
} else if (key == "feature_weights") {
|
||||
CopyInfoImpl(array_interface, &feature_weigths);
|
||||
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
||||
auto valid =
|
||||
thrust::all_of(thrust::device, d_feature_weights.data(),
|
||||
d_feature_weights.data() + d_feature_weights.size(),
|
||||
AllOfOp{});
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
return;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown metainfo: " << key;
|
||||
}
|
||||
|
||||
@@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
{
|
||||
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
column_sampler_.Init(fmat.Info().num_col_,
|
||||
fmat.Info().feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
}
|
||||
{
|
||||
// setup temp space for each thread
|
||||
|
||||
@@ -266,8 +266,10 @@ struct GPUHistMakerDevice {
|
||||
// Note that the column sampler must be passed by value because it is not
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
this->column_sampler.Init(num_columns, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
auto const& info = dmat->Info();
|
||||
this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
|
||||
@@ -841,11 +841,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
if (data_layout_ == kDenseDataOneBased) {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
} else {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
}
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
/* specialized code for dense data:
|
||||
|
||||
Reference in New Issue
Block a user