Make `HistCutMatrix::Init' be aware of groups. (#4115)
* Add checks for group size. * Simple docs. * Search group index during hist cut matrix initialization. Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -474,12 +474,16 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
monitor_.Start("UpdateOneIter");
|
||||
|
||||
// TODO(trivialfis): Merge the duplicated code with BoostOneIter
|
||||
CHECK(ModelInitialized())
|
||||
<< "Always call InitModel or LoadModel before update";
|
||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->ValidateDMatrix(train);
|
||||
this->PerformTreeMethodHeuristic(train);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_);
|
||||
monitor_.Stop("PredictRaw");
|
||||
@@ -493,10 +497,15 @@ class LearnerImpl : public Learner {
|
||||
void BoostOneIter(int iter, DMatrix* train,
|
||||
HostDeviceVector<GradientPair>* in_gpair) override {
|
||||
monitor_.Start("BoostOneIter");
|
||||
|
||||
CHECK(ModelInitialized())
|
||||
<< "Always call InitModel or LoadModel before boost.";
|
||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->ValidateDMatrix(train);
|
||||
this->PerformTreeMethodHeuristic(train);
|
||||
|
||||
gbm_->DoBoost(train, in_gpair);
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
@@ -711,7 +720,7 @@ class LearnerImpl : public Learner {
|
||||
mparam_.num_feature = num_feature;
|
||||
}
|
||||
CHECK_NE(mparam_.num_feature, 0)
|
||||
<< "0 feature is supplied. Are you using raw Booster?";
|
||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||
// setup
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
CHECK(obj_ == nullptr && gbm_ == nullptr);
|
||||
@@ -736,6 +745,19 @@ class LearnerImpl : public Learner {
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat) {
|
||||
MetaInfo const& info = p_fmat->Info();
|
||||
auto const& weights = info.weights_.HostVector();
|
||||
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
|
||||
CHECK(weights.size() == info.group_ptr_.size() - 1)
|
||||
<< "\n"
|
||||
<< "weights size: " << weights.size() << ", "
|
||||
<< "groups size: " << info.group_ptr_.size() -1 << ", "
|
||||
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
|
||||
<< "Number of weights should be equal to number of groups in ranking task.";
|
||||
}
|
||||
}
|
||||
|
||||
// model parameter
|
||||
LearnerModelParam mparam_;
|
||||
// training parameter
|
||||
|
||||
Reference in New Issue
Block a user