Feature weights (#5962)

This commit is contained in:
Jiaming Yuan
2020-08-18 19:55:41 +08:00
committed by GitHub
parent a418278064
commit 4d99c58a5f
25 changed files with 509 additions and 104 deletions

View File

@@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater {
}
}
{
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
column_sampler_.Init(fmat.Info().num_col_,
fmat.Info().feature_weigths.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree);
}
{
// setup temp space for each thread

View File

@@ -266,8 +266,10 @@ struct GPUHistMakerDevice {
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
this->column_sampler.Init(num_columns, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree);
auto const& info = dmat->Info();
this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),

View File

@@ -841,11 +841,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
// store a pointer to the tree
p_last_tree_ = &tree;
if (data_layout_ == kDenseDataOneBased) {
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, true);
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, true);
} else {
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, false);
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, false);
}
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data: