Column sampling at individual nodes (splits). (#3971)
* Column sampling at individual nodes (splits). * Documented colsample_bynode parameter. - also updated documentation for colsample_by* parameters * Updated documentation. * GetFeatureSet() returns shared pointer to std::vector. * Sync sampled columns across multiple processes.
This commit is contained in:
parent
e0a279114e
commit
42bf90eb8f
@ -82,15 +82,22 @@ Parameters for Tree Booster
|
|||||||
- Subsample ratio of the training instances. Setting it to 0.5 means that XGBoost would randomly sample half of the training data prior to growing trees. and this will prevent overfitting. Subsampling will occur once in every boosting iteration.
|
- Subsample ratio of the training instances. Setting it to 0.5 means that XGBoost would randomly sample half of the training data prior to growing trees. and this will prevent overfitting. Subsampling will occur once in every boosting iteration.
|
||||||
- range: (0,1]
|
- range: (0,1]
|
||||||
|
|
||||||
* ``colsample_bytree`` [default=1]
|
* ``colsample_bytree``, ``colsample_bylevel``, ``colsample_bynode`` [default=1]
|
||||||
|
- This is a family of parameters for subsampling of columns.
|
||||||
- Subsample ratio of columns when constructing each tree. Subsampling will occur once in every boosting iteration.
|
- All ``colsample_by*`` parameters have a range of (0, 1], the default value of 1, and
|
||||||
- range: (0,1]
|
specify the fraction of columns to be subsampled.
|
||||||
|
- ``colsample_bytree`` is the subsample ratio of columns when constructing each
|
||||||
* ``colsample_bylevel`` [default=1]
|
tree. Subsampling occurs once for every tree constructed.
|
||||||
|
- ``colsample_bylevel`` is the subsample ratio of columns for each level. Subsampling
|
||||||
- Subsample ratio of columns for each split, in each level. Subsampling will occur each time a new split is made.
|
occurs once for every new depth level reached in a tree. Columns are subsampled from
|
||||||
- range: (0,1]
|
the set of columns chosen for the current tree.
|
||||||
|
- ``colsample_bynode`` is the subsample ratio of columns for each node
|
||||||
|
(split). Subsampling occurs once every time a new split is evaluated. Columns are
|
||||||
|
subsampled from the set of columns chosen for the current level.
|
||||||
|
- ``colsample_by*`` parameters work cumulatively. For instance,
|
||||||
|
the combination ``{'colsample_bytree':0.5, 'colsample_bylevel':0.5,
|
||||||
|
'colsample_bynode':0.5}`` with 64 features will leave 4 features to choose from at
|
||||||
|
each split.
|
||||||
|
|
||||||
* ``lambda`` [default=1, alias: ``reg_lambda``]
|
* ``lambda`` [default=1, alias: ``reg_lambda``]
|
||||||
|
|
||||||
|
|||||||
@ -7,14 +7,15 @@
|
|||||||
#ifndef XGBOOST_COMMON_RANDOM_H_
|
#ifndef XGBOOST_COMMON_RANDOM_H_
|
||||||
#define XGBOOST_COMMON_RANDOM_H_
|
#define XGBOOST_COMMON_RANDOM_H_
|
||||||
|
|
||||||
|
#include <rabit/rabit.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include "host_device_vector.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -75,27 +76,36 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
|||||||
/**
|
/**
|
||||||
* \class ColumnSampler
|
* \class ColumnSampler
|
||||||
*
|
*
|
||||||
* \brief Handles selection of columns due to colsample_bytree and
|
* \brief Handles selection of columns due to colsample_bytree, colsample_bylevel and
|
||||||
* colsample_bylevel parameters. Should be initialised before tree
|
* colsample_bynode parameters. Should be initialised before tree construction and to
|
||||||
* construction and to reset when tree construction is completed.
|
* reset when tree construction is completed.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class ColumnSampler {
|
class ColumnSampler {
|
||||||
HostDeviceVector<int> feature_set_tree_;
|
std::shared_ptr<std::vector<int>> feature_set_tree_;
|
||||||
std::map<int, HostDeviceVector<int>> feature_set_level_;
|
std::map<int, std::shared_ptr<std::vector<int>>> feature_set_level_;
|
||||||
float colsample_bylevel_{1.0f};
|
float colsample_bylevel_{1.0f};
|
||||||
float colsample_bytree_{1.0f};
|
float colsample_bytree_{1.0f};
|
||||||
|
float colsample_bynode_{1.0f};
|
||||||
|
|
||||||
std::vector<int> ColSample(std::vector<int> features, float colsample) const {
|
std::shared_ptr<std::vector<int>> ColSample
|
||||||
if (colsample == 1.0f) return features;
|
(std::shared_ptr<std::vector<int>> p_features, float colsample) const {
|
||||||
|
if (colsample == 1.0f) return p_features;
|
||||||
|
const auto& features = *p_features;
|
||||||
CHECK_GT(features.size(), 0);
|
CHECK_GT(features.size(), 0);
|
||||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||||
|
auto p_new_features = std::make_shared<std::vector<int>>();
|
||||||
|
auto& new_features = *p_new_features;
|
||||||
|
new_features.resize(features.size());
|
||||||
|
std::copy(features.begin(), features.end(), new_features.begin());
|
||||||
|
std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom());
|
||||||
|
new_features.resize(n);
|
||||||
|
std::sort(new_features.begin(), new_features.end());
|
||||||
|
|
||||||
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
|
// ensure that new_features are the same across ranks
|
||||||
features.resize(n);
|
rabit::Broadcast(&new_features, 0);
|
||||||
std::sort(features.begin(), features.end());
|
|
||||||
|
|
||||||
return features;
|
return p_new_features;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -103,44 +113,60 @@ class ColumnSampler {
|
|||||||
* \brief Initialise this object before use.
|
* \brief Initialise this object before use.
|
||||||
*
|
*
|
||||||
* \param num_col
|
* \param num_col
|
||||||
|
* \param colsample_bynode
|
||||||
* \param colsample_bylevel
|
* \param colsample_bylevel
|
||||||
* \param colsample_bytree
|
* \param colsample_bytree
|
||||||
* \param skip_index_0 (Optional) True to skip index 0.
|
* \param skip_index_0 (Optional) True to skip index 0.
|
||||||
*/
|
*/
|
||||||
void Init(int64_t num_col, float colsample_bylevel, float colsample_bytree,
|
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel,
|
||||||
bool skip_index_0 = false) {
|
float colsample_bytree, bool skip_index_0 = false) {
|
||||||
this->colsample_bylevel_ = colsample_bylevel;
|
colsample_bylevel_ = colsample_bylevel;
|
||||||
this->colsample_bytree_ = colsample_bytree;
|
colsample_bytree_ = colsample_bytree;
|
||||||
this->Reset();
|
colsample_bynode_ = colsample_bynode;
|
||||||
|
|
||||||
|
if (feature_set_tree_ == nullptr) {
|
||||||
|
feature_set_tree_ = std::make_shared<std::vector<int>>();
|
||||||
|
}
|
||||||
|
Reset();
|
||||||
|
|
||||||
int begin_idx = skip_index_0 ? 1 : 0;
|
int begin_idx = skip_index_0 ? 1 : 0;
|
||||||
auto& feature_set_h = feature_set_tree_.HostVector();
|
feature_set_tree_->resize(num_col - begin_idx);
|
||||||
feature_set_h.resize(num_col - begin_idx);
|
std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx);
|
||||||
|
|
||||||
std::iota(feature_set_h.begin(), feature_set_h.end(), begin_idx);
|
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||||
feature_set_h = ColSample(feature_set_h, this->colsample_bytree_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Resets this object.
|
* \brief Resets this object.
|
||||||
*/
|
*/
|
||||||
void Reset() {
|
void Reset() {
|
||||||
feature_set_tree_.HostVector().clear();
|
feature_set_tree_->clear();
|
||||||
feature_set_level_.clear();
|
feature_set_level_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<int>& GetFeatureSet(int depth) {
|
/**
|
||||||
if (this->colsample_bylevel_ == 1.0f) {
|
* \brief Samples a feature set.
|
||||||
|
*
|
||||||
|
* \param depth The tree depth of the node at which to sample.
|
||||||
|
* \return The sampled feature set.
|
||||||
|
* \note If colsample_bynode_ < 1.0, this method creates a new feature set each time it
|
||||||
|
* is called. Therefore, it should be called only once per node.
|
||||||
|
*/
|
||||||
|
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
|
||||||
|
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
|
||||||
return feature_set_tree_;
|
return feature_set_tree_;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (feature_set_level_.count(depth) == 0) {
|
if (feature_set_level_.count(depth) == 0) {
|
||||||
// Level sampling, level does not yet exist so generate it
|
// Level sampling, level does not yet exist so generate it
|
||||||
auto& level = feature_set_level_[depth].HostVector();
|
feature_set_level_[depth] = ColSample(feature_set_tree_, colsample_bylevel_);
|
||||||
level = ColSample(feature_set_tree_.HostVector(), this->colsample_bylevel_);
|
|
||||||
}
|
}
|
||||||
// Level sampling
|
if (colsample_bynode_ == 1.0f) {
|
||||||
return feature_set_level_[depth];
|
// Level sampling
|
||||||
|
return feature_set_level_[depth];
|
||||||
|
}
|
||||||
|
// Need to sample for the node individually
|
||||||
|
return ColSample(feature_set_level_[depth], colsample_bynode_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -50,7 +50,9 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
float max_delta_step;
|
float max_delta_step;
|
||||||
// whether we want to do subsample
|
// whether we want to do subsample
|
||||||
float subsample;
|
float subsample;
|
||||||
// whether to subsample columns each split, in each level
|
// whether to subsample columns in each split (node)
|
||||||
|
float colsample_bynode;
|
||||||
|
// whether to subsample columns in each level
|
||||||
float colsample_bylevel;
|
float colsample_bylevel;
|
||||||
// whether to subsample columns during tree construction
|
// whether to subsample columns during tree construction
|
||||||
float colsample_bytree;
|
float colsample_bytree;
|
||||||
@ -149,6 +151,10 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
.set_range(0.0f, 1.0f)
|
.set_range(0.0f, 1.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
.describe("Row subsample ratio of training instance.");
|
.describe("Row subsample ratio of training instance.");
|
||||||
|
DMLC_DECLARE_FIELD(colsample_bynode)
|
||||||
|
.set_range(0.0f, 1.0f)
|
||||||
|
.set_default(1.0f)
|
||||||
|
.describe("Subsample ratio of columns, resample on each node (split).");
|
||||||
DMLC_DECLARE_FIELD(colsample_bylevel)
|
DMLC_DECLARE_FIELD(colsample_bylevel)
|
||||||
.set_range(0.0f, 1.0f)
|
.set_range(0.0f, 1.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
|
|||||||
@ -168,8 +168,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bylevel,
|
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
|
||||||
param_.colsample_bytree);
|
param_.colsample_bylevel, param_.colsample_bytree);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// setup temp space for each thread
|
// setup temp space for each thread
|
||||||
@ -625,7 +625,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
const std::vector<GradientPair> &gpair,
|
const std::vector<GradientPair> &gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
const std::vector<int> &feat_set = column_sampler_.GetFeatureSet(depth).HostVector();
|
auto p_feature_set = column_sampler_.GetFeatureSet(depth);
|
||||||
|
const auto& feat_set = *p_feature_set;
|
||||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||||
this->UpdateSolution(batch, feat_set, gpair, p_fmat);
|
this->UpdateSolution(batch, feat_set, gpair, p_fmat);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -499,6 +499,8 @@ struct DeviceShard {
|
|||||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||||
/*! \brief row offset in SparsePage (the input data). */
|
/*! \brief row offset in SparsePage (the input data). */
|
||||||
thrust::device_vector<size_t> row_ptrs;
|
thrust::device_vector<size_t> row_ptrs;
|
||||||
|
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||||
|
thrust::device_vector<int> feature_set_d;
|
||||||
/*! The row offset for this shard. */
|
/*! The row offset for this shard. */
|
||||||
bst_uint row_begin_idx;
|
bst_uint row_begin_idx;
|
||||||
bst_uint row_end_idx;
|
bst_uint row_end_idx;
|
||||||
@ -579,28 +581,31 @@ struct DeviceShard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DeviceSplitCandidate EvaluateSplit(int nidx,
|
DeviceSplitCandidate EvaluateSplit(int nidx,
|
||||||
const HostDeviceVector<int>& feature_set,
|
const std::vector<int>& feature_set,
|
||||||
ValueConstraint value_constraint) {
|
ValueConstraint value_constraint) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.Size());
|
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.size());
|
||||||
|
feature_set_d.resize(feature_set.size());
|
||||||
|
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
||||||
|
feature_set_d.size());
|
||||||
|
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(),
|
||||||
|
d_features.size_bytes(), cudaMemcpyDefault));
|
||||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||||
feature_set.Reshard(GPUSet::Range(device_id_, 1));
|
|
||||||
|
|
||||||
// One block for each feature
|
// One block for each feature
|
||||||
int constexpr BLOCK_THREADS = 256;
|
int constexpr BLOCK_THREADS = 256;
|
||||||
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
||||||
<<<uint32_t(feature_set.Size()), BLOCK_THREADS, 0>>>(
|
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>
|
||||||
hist.GetNodeHistogram(nidx), feature_set.DeviceSpan(device_id_), node,
|
(hist.GetNodeHistogram(nidx), d_features, node,
|
||||||
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
||||||
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||||
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
||||||
|
|
||||||
dh::safe_cuda(cudaDeviceSynchronize());
|
dh::safe_cuda(cudaDeviceSynchronize());
|
||||||
std::vector<DeviceSplitCandidate> split_candidates(feature_set.Size());
|
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
||||||
dh::safe_cuda(
|
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||||
cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
cudaMemcpyDeviceToHost));
|
||||||
cudaMemcpyDeviceToHost));
|
|
||||||
DeviceSplitCandidate best_split;
|
DeviceSplitCandidate best_split;
|
||||||
for (auto candidate : split_candidates) {
|
for (auto candidate : split_candidates) {
|
||||||
best_split.Update(candidate, param);
|
best_split.Update(candidate, param);
|
||||||
@ -1009,7 +1014,8 @@ class GPUHistMakerSpecialised{
|
|||||||
}
|
}
|
||||||
monitor_.Stop("InitDataOnce", dist_.Devices());
|
monitor_.Stop("InitDataOnce", dist_.Devices());
|
||||||
|
|
||||||
column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
|
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
|
||||||
|
param_.colsample_bylevel, param_.colsample_bytree);
|
||||||
|
|
||||||
// Copy gpair & reset memory
|
// Copy gpair & reset memory
|
||||||
monitor_.Start("InitDataReset", dist_.Devices());
|
monitor_.Start("InitDataReset", dist_.Devices());
|
||||||
@ -1100,7 +1106,7 @@ class GPUHistMakerSpecialised{
|
|||||||
|
|
||||||
DeviceSplitCandidate EvaluateSplit(int nidx, RegTree* p_tree) {
|
DeviceSplitCandidate EvaluateSplit(int nidx, RegTree* p_tree) {
|
||||||
return shards_.front()->EvaluateSplit(
|
return shards_.front()->EvaluateSplit(
|
||||||
nidx, column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)),
|
nidx, *column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)),
|
||||||
node_value_constraints_[nidx]);
|
node_value_constraints_[nidx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -354,11 +354,11 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
|||||||
p_last_fmat_ = &fmat;
|
p_last_fmat_ = &fmat;
|
||||||
// initialize feature index
|
// initialize feature index
|
||||||
if (data_layout_ == kDenseDataOneBased) {
|
if (data_layout_ == kDenseDataOneBased) {
|
||||||
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
|
column_sampler_.Init(info.num_col_, param_.colsample_bynode,
|
||||||
param_.colsample_bytree, true);
|
param_.colsample_bylevel, param_.colsample_bytree, true);
|
||||||
} else {
|
} else {
|
||||||
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
|
column_sampler_.Init(info.num_col_, param_.colsample_bynode,
|
||||||
param_.colsample_bytree, false);
|
param_.colsample_bylevel, param_.colsample_bytree, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||||
@ -400,8 +400,8 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
|
|||||||
const RegTree& tree) {
|
const RegTree& tree) {
|
||||||
// start enumeration
|
// start enumeration
|
||||||
const MetaInfo& info = fmat.Info();
|
const MetaInfo& info = fmat.Info();
|
||||||
const auto& feature_set = column_sampler_.GetFeatureSet(
|
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
||||||
tree.GetDepth(nid)).HostVector();
|
const auto& feature_set = *p_feature_set;
|
||||||
const auto nfeature = static_cast<bst_uint>(feature_set.size());
|
const auto nfeature = static_cast<bst_uint>(feature_set.size());
|
||||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||||
best_split_tloc_.resize(nthread);
|
best_split_tloc_.resize(nthread);
|
||||||
|
|||||||
@ -5,33 +5,45 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
TEST(ColumnSampler, Test) {
|
TEST(ColumnSampler, Test) {
|
||||||
int n = 100;
|
int n = 128;
|
||||||
ColumnSampler cs;
|
ColumnSampler cs;
|
||||||
cs.Init(n, 0.5f, 0.5f);
|
|
||||||
auto &set0 = cs.GetFeatureSet(0).HostVector();
|
|
||||||
ASSERT_EQ(set0.size(), 25);
|
|
||||||
|
|
||||||
auto &set1 = cs.GetFeatureSet(0).HostVector();
|
// No node sampling
|
||||||
|
cs.Init(n, 1.0f, 0.5f, 0.5f);
|
||||||
|
auto set0 = *cs.GetFeatureSet(0);
|
||||||
|
ASSERT_EQ(set0.size(), 32);
|
||||||
|
|
||||||
|
auto set1 = *cs.GetFeatureSet(0);
|
||||||
ASSERT_EQ(set0, set1);
|
ASSERT_EQ(set0, set1);
|
||||||
|
|
||||||
auto &set2 = cs.GetFeatureSet(1).HostVector();
|
auto set2 = *cs.GetFeatureSet(1);
|
||||||
ASSERT_NE(set1, set2);
|
ASSERT_NE(set1, set2);
|
||||||
ASSERT_EQ(set2.size(), 25);
|
ASSERT_EQ(set2.size(), 32);
|
||||||
|
|
||||||
// No level sampling, should be the same at different depth
|
// Node sampling
|
||||||
cs.Init(n, 1.0f, 0.5f);
|
cs.Init(n, 0.5f, 1.0f, 0.5f);
|
||||||
ASSERT_EQ(cs.GetFeatureSet(0).HostVector(), cs.GetFeatureSet(1).HostVector());
|
auto set3 = *cs.GetFeatureSet(0);
|
||||||
|
ASSERT_EQ(set3.size(), 32);
|
||||||
|
|
||||||
cs.Init(n, 1.0f, 1.0f);
|
auto set4 = *cs.GetFeatureSet(0);
|
||||||
auto &set3 = cs.GetFeatureSet(0).HostVector();
|
ASSERT_NE(set3, set4);
|
||||||
ASSERT_EQ(set3.size(), n);
|
ASSERT_EQ(set4.size(), 32);
|
||||||
cs.Init(n, 1.0f, 1.0f);
|
|
||||||
auto &set4 = cs.GetFeatureSet(0).HostVector();
|
// No level or node sampling, should be the same at different depth
|
||||||
ASSERT_EQ(set3, set4);
|
cs.Init(n, 1.0f, 1.0f, 0.5f);
|
||||||
|
ASSERT_EQ(*cs.GetFeatureSet(0), *cs.GetFeatureSet(1));
|
||||||
|
|
||||||
|
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||||
|
auto set5 = *cs.GetFeatureSet(0);
|
||||||
|
ASSERT_EQ(set5.size(), n);
|
||||||
|
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||||
|
auto set6 = *cs.GetFeatureSet(0);
|
||||||
|
ASSERT_EQ(set5, set6);
|
||||||
|
|
||||||
// Should always be a minimum of one feature
|
// Should always be a minimum of one feature
|
||||||
cs.Init(n, 1e-16f, 1e-16f);
|
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
|
||||||
ASSERT_EQ(cs.GetFeatureSet(0).HostVector().size(), 1);
|
ASSERT_EQ(cs.GetFeatureSet(0)->size(), 1);
|
||||||
|
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -227,6 +227,7 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
TrainParam param;
|
TrainParam param;
|
||||||
param.max_depth = 1;
|
param.max_depth = 1;
|
||||||
param.n_gpus = 1;
|
param.n_gpus = 1;
|
||||||
|
param.colsample_bynode = 1;
|
||||||
param.colsample_bylevel = 1;
|
param.colsample_bylevel = 1;
|
||||||
param.colsample_bytree = 1;
|
param.colsample_bytree = 1;
|
||||||
param.min_child_weight = 0.01;
|
param.min_child_weight = 0.01;
|
||||||
@ -284,6 +285,7 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
hist_maker.param_ = param;
|
hist_maker.param_ = param;
|
||||||
hist_maker.shards_.push_back(std::move(shard));
|
hist_maker.shards_.push_back(std::move(shard));
|
||||||
hist_maker.column_sampler_.Init(n_cols,
|
hist_maker.column_sampler_.Init(n_cols,
|
||||||
|
param.colsample_bynode,
|
||||||
param.colsample_bylevel,
|
param.colsample_bylevel,
|
||||||
param.colsample_bytree,
|
param.colsample_bytree,
|
||||||
false);
|
false);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user