Move GHistIndex into DMatrix. (#7064)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "adapter.h"
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@@ -89,6 +90,20 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin));
|
||||
batch_param_ = param;
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_.get()));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
std::vector<uint64_t> qids;
|
||||
|
||||
Reference in New Issue
Block a user