Make `HistCutMatrix::Init' be aware of groups. (#4115)

* Add checks for group size.
* Simple docs.
* Search group index during hist cut matrix initialization.

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2019-02-16 04:39:41 +08:00
committed by GitHub
parent 37ddfd7d6e
commit 754fe8142b
6 changed files with 188 additions and 22 deletions

View File

@@ -474,12 +474,16 @@ class LearnerImpl : public Learner {
void UpdateOneIter(int iter, DMatrix* train) override {
monitor_.Start("UpdateOneIter");
// TODO(trivialfis): Merge the duplicated code with BoostOneIter
CHECK(ModelInitialized())
<< "Always call InitModel or LoadModel before update";
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
}
this->ValidateDMatrix(train);
this->PerformTreeMethodHeuristic(train);
monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_);
monitor_.Stop("PredictRaw");
@@ -493,10 +497,15 @@ class LearnerImpl : public Learner {
void BoostOneIter(int iter, DMatrix* train,
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");
CHECK(ModelInitialized())
<< "Always call InitModel or LoadModel before boost.";
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
}
this->ValidateDMatrix(train);
this->PerformTreeMethodHeuristic(train);
gbm_->DoBoost(train, in_gpair);
monitor_.Stop("BoostOneIter");
}
@@ -711,7 +720,7 @@ class LearnerImpl : public Learner {
mparam_.num_feature = num_feature;
}
CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster?";
<< "0 feature is supplied. Are you using raw Booster interface?";
// setup
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
CHECK(obj_ == nullptr && gbm_ == nullptr);
@@ -736,6 +745,19 @@ class LearnerImpl : public Learner {
gbm_->PredictBatch(data, out_preds, ntree_limit);
}
void ValidateDMatrix(DMatrix* p_fmat) {
MetaInfo const& info = p_fmat->Info();
auto const& weights = info.weights_.HostVector();
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
CHECK(weights.size() == info.group_ptr_.size() - 1)
<< "\n"
<< "weights size: " << weights.size() << ", "
<< "groups size: " << info.group_ptr_.size() -1 << ", "
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
<< "Number of weights should be equal to number of groups in ranking task.";
}
}
// model parameter
LearnerModelParam mparam_;
// training parameter