Enforce correct data shape. (#5191)

* Fix syncing DMatrix columns.
* notes for tree method.
* Enable feature validation for all interfaces except for jvm.
* Better tests for boosting from predictions.
* Disable validation on JVM.
This commit is contained in:
Jiaming Yuan
2020-01-13 15:48:17 +08:00
committed by GitHub
parent 8cbcc53ccb
commit 7b65698187
14 changed files with 108 additions and 60 deletions

View File

@@ -295,32 +295,8 @@ void DMatrix::SaveToLocalFile(const std::string& fname) {
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix) {
if (cache_prefix.length() == 0) {
// FIXME(trivialfis): Currently distcol is broken so we here check for number of rows.
// If we bring back column split this check will break.
bool is_distributed { rabit::IsDistributed() };
if (is_distributed) {
auto world_size = rabit::GetWorldSize();
auto rank = rabit::GetRank();
std::vector<uint64_t> ncols(world_size, 0);
ncols[rank] = source->info.num_col_;
rabit::Allreduce<rabit::op::Sum>(ncols.data(), ncols.size());
auto max_cols = std::max_element(ncols.cbegin(), ncols.cend());
auto max_ind = std::distance(ncols.cbegin(), max_cols);
// FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible.
if (source->info.num_col_ == 0 && source->info.num_row_ == 0) {
LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty.";
source->info.num_col_ = *max_cols;
}
// validate the number of columns across all workers.
for (size_t i = 0; i < ncols.size(); ++i) {
auto v = ncols[i];
CHECK(v == 0 || v == *max_cols)
<< "DMatrix at rank: " << i << " worker "
<< "has different number of columns than rank: " << max_ind << " worker. "
<< "(" << v << " vs. " << *max_cols << ")";
}
}
// Data split mode is fixed to be row right now.
rabit::Allreduce<rabit::op::Max>(&source->info.num_col_, 1);
return new data::SimpleDMatrix(std::move(source));
} else {
#if DMLC_ENABLE_STD_THREAD
@@ -336,6 +312,7 @@ template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size ) {
if (cache_prefix.length() == 0) {
// Data split mode is fixed to be row right now.
return new data::SimpleDMatrix(adapter, missing, nthread);
} else {
#if DMLC_ENABLE_STD_THREAD