Column sampling at individual nodes (splits). (#3971)
* Column sampling at individual nodes (splits). * Documented colsample_bynode parameter. - also updated documentation for colsample_by* parameters * Updated documentation. * GetFeatureSet() returns shared pointer to std::vector. * Sync sampled columns across multiple processes.
This commit is contained in:
committed by
Jiaming Yuan
parent
e0a279114e
commit
42bf90eb8f
@@ -7,14 +7,15 @@
|
||||
#ifndef XGBOOST_COMMON_RANDOM_H_
|
||||
#define XGBOOST_COMMON_RANDOM_H_
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include "host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@@ -75,27 +76,36 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
/**
|
||||
* \class ColumnSampler
|
||||
*
|
||||
* \brief Handles selection of columns due to colsample_bytree and
|
||||
* colsample_bylevel parameters. Should be initialised before tree
|
||||
* construction and to reset when tree construction is completed.
|
||||
* \brief Handles selection of columns due to colsample_bytree, colsample_bylevel and
|
||||
* colsample_bynode parameters. Should be initialised before tree construction and to
|
||||
* reset when tree construction is completed.
|
||||
*/
|
||||
|
||||
class ColumnSampler {
|
||||
HostDeviceVector<int> feature_set_tree_;
|
||||
std::map<int, HostDeviceVector<int>> feature_set_level_;
|
||||
std::shared_ptr<std::vector<int>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<std::vector<int>>> feature_set_level_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
|
||||
std::vector<int> ColSample(std::vector<int> features, float colsample) const {
|
||||
if (colsample == 1.0f) return features;
|
||||
std::shared_ptr<std::vector<int>> ColSample
|
||||
(std::shared_ptr<std::vector<int>> p_features, float colsample) const {
|
||||
if (colsample == 1.0f) return p_features;
|
||||
const auto& features = *p_features;
|
||||
CHECK_GT(features.size(), 0);
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<std::vector<int>>();
|
||||
auto& new_features = *p_new_features;
|
||||
new_features.resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.begin());
|
||||
std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom());
|
||||
new_features.resize(n);
|
||||
std::sort(new_features.begin(), new_features.end());
|
||||
|
||||
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
|
||||
features.resize(n);
|
||||
std::sort(features.begin(), features.end());
|
||||
// ensure that new_features are the same across ranks
|
||||
rabit::Broadcast(&new_features, 0);
|
||||
|
||||
return features;
|
||||
return p_new_features;
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -103,44 +113,60 @@ class ColumnSampler {
|
||||
* \brief Initialise this object before use.
|
||||
*
|
||||
* \param num_col
|
||||
* \param colsample_bynode
|
||||
* \param colsample_bylevel
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, float colsample_bylevel, float colsample_bytree,
|
||||
bool skip_index_0 = false) {
|
||||
this->colsample_bylevel_ = colsample_bylevel;
|
||||
this->colsample_bytree_ = colsample_bytree;
|
||||
this->Reset();
|
||||
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel,
|
||||
float colsample_bytree, bool skip_index_0 = false) {
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
|
||||
if (feature_set_tree_ == nullptr) {
|
||||
feature_set_tree_ = std::make_shared<std::vector<int>>();
|
||||
}
|
||||
Reset();
|
||||
|
||||
int begin_idx = skip_index_0 ? 1 : 0;
|
||||
auto& feature_set_h = feature_set_tree_.HostVector();
|
||||
feature_set_h.resize(num_col - begin_idx);
|
||||
feature_set_tree_->resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx);
|
||||
|
||||
std::iota(feature_set_h.begin(), feature_set_h.end(), begin_idx);
|
||||
feature_set_h = ColSample(feature_set_h, this->colsample_bytree_);
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Resets this object.
|
||||
*/
|
||||
void Reset() {
|
||||
feature_set_tree_.HostVector().clear();
|
||||
feature_set_tree_->clear();
|
||||
feature_set_level_.clear();
|
||||
}
|
||||
|
||||
HostDeviceVector<int>& GetFeatureSet(int depth) {
|
||||
if (this->colsample_bylevel_ == 1.0f) {
|
||||
/**
|
||||
* \brief Samples a feature set.
|
||||
*
|
||||
* \param depth The tree depth of the node at which to sample.
|
||||
* \return The sampled feature set.
|
||||
* \note If colsample_bynode_ < 1.0, this method creates a new feature set each time it
|
||||
* is called. Therefore, it should be called only once per node.
|
||||
*/
|
||||
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
|
||||
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
|
||||
return feature_set_tree_;
|
||||
}
|
||||
|
||||
if (feature_set_level_.count(depth) == 0) {
|
||||
// Level sampling, level does not yet exist so generate it
|
||||
auto& level = feature_set_level_[depth].HostVector();
|
||||
level = ColSample(feature_set_tree_.HostVector(), this->colsample_bylevel_);
|
||||
feature_set_level_[depth] = ColSample(feature_set_tree_, colsample_bylevel_);
|
||||
}
|
||||
// Level sampling
|
||||
return feature_set_level_[depth];
|
||||
if (colsample_bynode_ == 1.0f) {
|
||||
// Level sampling
|
||||
return feature_set_level_[depth];
|
||||
}
|
||||
// Need to sample for the node individually
|
||||
return ColSample(feature_set_level_[depth], colsample_bynode_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user