Fix empty DMatrix with categorical features. (#8739)
This commit is contained in:
@@ -3,15 +3,20 @@
|
||||
*/
|
||||
#include "iterative_dmatrix.h"
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
#include <algorithm> // std::copy
|
||||
#include <cstddef> // std::size_t
|
||||
#include <type_traits> // std::underlying_type_t
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "xgboost/data.h" // FeatureType
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -79,6 +84,27 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
|
||||
<< "Invalid ref DMatrix, different number of features.";
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Synchronize feature type in case of empty DMatrix
|
||||
void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
|
||||
if (!collective::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto& h_ft = *p_h_ft;
|
||||
auto n_ft = h_ft.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_ft, 1);
|
||||
if (!h_ft.empty()) {
|
||||
// Check correct size if this is not an empty DMatrix.
|
||||
CHECK_EQ(h_ft.size(), n_ft);
|
||||
}
|
||||
if (n_ft > 0) {
|
||||
h_ft.resize(n_ft);
|
||||
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
|
||||
collective::Allreduce<collective::Operation::kMax>(ptr, h_ft.size());
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
std::shared_ptr<DMatrix> ref) {
|
||||
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||
@@ -96,13 +122,14 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); });
|
||||
};
|
||||
|
||||
std::vector<size_t> column_sizes;
|
||||
std::vector<std::size_t> column_sizes;
|
||||
auto const is_valid = data::IsValidFunctor{missing};
|
||||
auto nnz_cnt = [&]() {
|
||||
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
||||
size_t n_threads = ctx_.Threads();
|
||||
size_t n_features = column_sizes.size();
|
||||
linalg::Tensor<size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
||||
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
||||
column_sizes_tloc.Data()->Fill(0ul);
|
||||
auto view = column_sizes_tloc.HostView();
|
||||
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
|
||||
auto const& line = value.GetLine(i);
|
||||
@@ -139,7 +166,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
if (n_features == 0) {
|
||||
n_features = num_cols();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
|
||||
column_sizes.resize(n_features);
|
||||
column_sizes.clear();
|
||||
column_sizes.resize(n_features, 0);
|
||||
info_.num_col_ = n_features;
|
||||
} else {
|
||||
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
|
||||
@@ -166,14 +194,18 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
* Generate quantiles
|
||||
*/
|
||||
accumulated_rows = 0;
|
||||
std::vector<FeatureType> h_ft;
|
||||
if (ref) {
|
||||
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
|
||||
h_ft = ref->Info().feature_types.HostVector();
|
||||
} else {
|
||||
size_t i = 0;
|
||||
while (iter.Next()) {
|
||||
if (!p_sketch) {
|
||||
h_ft = proxy->Info().feature_types.ConstHostVector();
|
||||
SyncFeatureType(&h_ft);
|
||||
p_sketch.reset(new common::HostSketchContainer{
|
||||
batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes, false,
|
||||
batch_param_.max_bin, h_ft, column_sizes, false,
|
||||
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
|
||||
}
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
@@ -191,6 +223,9 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
CHECK(p_sketch);
|
||||
p_sketch->MakeCuts(&cuts);
|
||||
}
|
||||
if (!h_ft.empty()) {
|
||||
CHECK_EQ(h_ft.size(), n_features);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate gradient index.
|
||||
@@ -202,8 +237,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
while (iter.Next()) {
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing,
|
||||
proxy->Info().feature_types.ConstHostSpan(),
|
||||
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing, h_ft,
|
||||
batch_param_.sparse_thresh, Info().num_row_);
|
||||
});
|
||||
if (n_batches != 1) {
|
||||
@@ -236,6 +270,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
this->info_.num_col_ = n_features; // proxy might be empty.
|
||||
CHECK_EQ(proxy->Info().labels.Size(), 0);
|
||||
}
|
||||
|
||||
Info().feature_types.HostVector() = h_ft;
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
||||
|
||||
Reference in New Issue
Block a user