Enforce correct data shape. (#5191)
* Fix syncing DMatrix columns. * notes for tree method. * Enable feature validation for all interfaces except for jvm. * Better tests for boosting from predictions. * Disable validation on JVM.
This commit is contained in:
@@ -71,6 +71,12 @@ class TrainingObserver {
|
||||
auto const& h_vec = vec.HostVector();
|
||||
this->Observe(h_vec, name);
|
||||
}
|
||||
template <typename T>
|
||||
void Observe(HostDeviceVector<T>* vec, std::string name) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
this->Observe(*vec, name);
|
||||
}
|
||||
|
||||
/*\brief Observe objects with `XGBoostParamer' type. */
|
||||
template <typename Parameter,
|
||||
typename std::enable_if<
|
||||
|
||||
@@ -295,32 +295,8 @@ void DMatrix::SaveToLocalFile(const std::string& fname) {
|
||||
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||
const std::string& cache_prefix) {
|
||||
if (cache_prefix.length() == 0) {
|
||||
// FIXME(trivialfis): Currently distcol is broken so we here check for number of rows.
|
||||
// If we bring back column split this check will break.
|
||||
bool is_distributed { rabit::IsDistributed() };
|
||||
if (is_distributed) {
|
||||
auto world_size = rabit::GetWorldSize();
|
||||
auto rank = rabit::GetRank();
|
||||
std::vector<uint64_t> ncols(world_size, 0);
|
||||
ncols[rank] = source->info.num_col_;
|
||||
rabit::Allreduce<rabit::op::Sum>(ncols.data(), ncols.size());
|
||||
auto max_cols = std::max_element(ncols.cbegin(), ncols.cend());
|
||||
auto max_ind = std::distance(ncols.cbegin(), max_cols);
|
||||
// FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible.
|
||||
if (source->info.num_col_ == 0 && source->info.num_row_ == 0) {
|
||||
LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty.";
|
||||
source->info.num_col_ = *max_cols;
|
||||
}
|
||||
|
||||
// validate the number of columns across all workers.
|
||||
for (size_t i = 0; i < ncols.size(); ++i) {
|
||||
auto v = ncols[i];
|
||||
CHECK(v == 0 || v == *max_cols)
|
||||
<< "DMatrix at rank: " << i << " worker "
|
||||
<< "has different number of columns than rank: " << max_ind << " worker. "
|
||||
<< "(" << v << " vs. " << *max_cols << ")";
|
||||
}
|
||||
}
|
||||
// Data split mode is fixed to be row right now.
|
||||
rabit::Allreduce<rabit::op::Max>(&source->info.num_col_, 1);
|
||||
return new data::SimpleDMatrix(std::move(source));
|
||||
} else {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
@@ -336,6 +312,7 @@ template <typename AdapterT>
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size ) {
|
||||
if (cache_prefix.length() == 0) {
|
||||
// Data split mode is fixed to be row right now.
|
||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
||||
} else {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
|
||||
@@ -124,7 +124,6 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
||||
return;
|
||||
}
|
||||
|
||||
tparam_.updater_seq = "grow_histmaker,prune";
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(WARNING) <<
|
||||
"Tree method is automatically selected to be 'approx' "
|
||||
|
||||
@@ -925,6 +925,25 @@ class LearnerImpl : public Learner {
|
||||
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
|
||||
<< "Number of weights should be equal to number of groups in ranking task.";
|
||||
}
|
||||
|
||||
auto const row_based_split = [this]() {
|
||||
return tparam_.dsplit == DataSplitMode::kRow ||
|
||||
tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
bool const valid_features =
|
||||
!row_based_split() ||
|
||||
(learner_model_param_.num_feature == p_fmat->Info().num_col_);
|
||||
std::string const msg {
|
||||
"Number of columns does not match number of features in booster."
|
||||
};
|
||||
if (generic_parameters_.validate_features) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << msg;
|
||||
} else if (!valid_features) {
|
||||
// Remove this and make the equality check fatal once spark can fix all failing tests.
|
||||
LOG(WARNING) << msg << " "
|
||||
<< "Columns: " << p_fmat->Info().num_col_ << " "
|
||||
<< "Features: " << learner_model_param_.num_feature;
|
||||
}
|
||||
}
|
||||
|
||||
// model parameter
|
||||
|
||||
@@ -80,6 +80,10 @@ class ColMaker: public TreeUpdater {
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||
"support distributed training.";
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
Reference in New Issue
Block a user