Feature weights (#5962)
This commit is contained in:
@@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
{
|
||||
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
column_sampler_.Init(fmat.Info().num_col_,
|
||||
fmat.Info().feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
}
|
||||
{
|
||||
// setup temp space for each thread
|
||||
|
||||
@@ -266,8 +266,10 @@ struct GPUHistMakerDevice {
|
||||
// Note that the column sampler must be passed by value because it is not
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
this->column_sampler.Init(num_columns, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
auto const& info = dmat->Info();
|
||||
this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
|
||||
@@ -841,11 +841,13 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix&
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
if (data_layout_ == kDenseDataOneBased) {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, true);
|
||||
} else {
|
||||
column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, false);
|
||||
}
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
/* specialized code for dense data:
|
||||
|
||||
Reference in New Issue
Block a user