Make QuantileDMatrix default to sklearn esitmators. (#8220)
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
@@ -14,6 +15,38 @@
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||
bst_bin_t max_bin)
|
||||
: proxy_{proxy}, reset_{reset}, next_{next} {
|
||||
// fetch the first batch
|
||||
auto iter =
|
||||
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
|
||||
iter.Reset();
|
||||
bool valid = iter.Next();
|
||||
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";
|
||||
|
||||
auto d = MakeProxy(proxy_)->DeviceIdx();
|
||||
|
||||
StringView msg{"All batch should be on the same device."};
|
||||
if (batch_param_.gpu_id != Context::kCpuId) {
|
||||
CHECK_EQ(d, batch_param_.gpu_id) << msg;
|
||||
}
|
||||
|
||||
batch_param_ = BatchParam{d, max_bin};
|
||||
// hardcoded parameter.
|
||||
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
|
||||
|
||||
ctx_.UpdateAllowUnknown(
|
||||
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||
if (ctx_.IsCPU()) {
|
||||
this->InitFromCPU(iter_handle, missing, ref);
|
||||
} else {
|
||||
this->InitFromCUDA(iter_handle, missing, ref);
|
||||
}
|
||||
}
|
||||
|
||||
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
|
||||
common::HistogramCuts* p_cuts) {
|
||||
CHECK(ref_);
|
||||
@@ -199,6 +232,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
if (n_batches == 1) {
|
||||
this->info_ = std::move(proxy->Info());
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
this->info_.num_col_ = n_features; // proxy might be empty.
|
||||
CHECK_EQ(proxy->Info().labels.Size(), 0);
|
||||
}
|
||||
}
|
||||
@@ -210,6 +244,10 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
|
||||
}
|
||||
|
||||
if (param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
|
||||
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
|
||||
"of `DMatrix`.";
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
|
||||
Reference in New Issue
Block a user