Remove unnecessary DMatrix methods (#5324)

This commit is contained in:
Rory Mitchell
2020-02-25 12:40:39 +13:00
committed by GitHub
parent 655cf17b60
commit b0ed3f0a66
10 changed files with 43 additions and 72 deletions

View File

@@ -16,21 +16,6 @@ MetaInfo& SimpleDMatrix::Info() { return info; }
const MetaInfo& SimpleDMatrix::Info() const { return info; }
float SimpleDMatrix::GetColDensity(size_t cidx) {
size_t column_size = 0;
// Use whatever version of column batches already exists
if (sorted_column_page_) {
auto batch = this->GetBatches<SortedCSCPage>();
column_size = (*batch.begin())[cidx].size();
} else {
auto batch = this->GetBatches<CSCPage>();
column_size = (*batch.begin())[cidx].size();
}
size_t nmiss = this->Info().num_row_ - column_size;
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
@@ -76,8 +61,6 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
return BatchSet<EllpackPage>(begin_iter);
}
bool SimpleDMatrix::SingleColBlock() const { return true; }
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after

View File

@@ -30,9 +30,7 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;
float GetColDensity(size_t cidx) override;
bool SingleColBlock() const override;
bool SingleColBlock() const override { return true; }
/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;

View File

@@ -58,28 +58,6 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
return BatchSet<EllpackPage>(begin_iter);
}
float SparsePageDMatrix::GetColDensity(size_t cidx) {
// Finds densities if we don't already have them
if (col_density_.empty()) {
std::vector<size_t> column_size(this->Info().num_col_);
for (const auto &batch : this->GetBatches<CSCPage>()) {
for (auto i = 0u; i < batch.Size(); i++) {
column_size[i] += batch[i].size();
}
}
col_density_.resize(column_size.size());
for (auto i = 0u; i < col_density_.size(); i++) {
size_t nmiss = this->Info().num_row_ - column_size[i];
col_density_[i] =
1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
}
return col_density_.at(cidx);
}
bool SparsePageDMatrix::SingleColBlock() const {
return false;
}
} // namespace data
} // namespace xgboost
#endif // DMLC_ENABLE_STD_THREAD

View File

@@ -37,9 +37,7 @@ class SparsePageDMatrix : public DMatrix {
const MetaInfo& Info() const override;
float GetColDensity(size_t cidx) override;
bool SingleColBlock() const override;
bool SingleColBlock() const override { return false; }
private:
BatchSet<SparsePage> GetRowBatches() override;

View File

@@ -61,7 +61,10 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
CHECK(p_fmat->SingleColBlock());
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
if ( IsEmpty() ) { return; }
if (IsEmpty()) {
return;
}
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
// The begin and end indices for the section of each column associated with
// this device

View File

@@ -77,6 +77,24 @@ class ColMaker: public TreeUpdater {
return "grow_colmaker";
}
void LazyGetColumnDensity(DMatrix *dmat) {
// Finds densities if we don't already have them
if (column_densities_.empty()) {
std::vector<size_t> column_size(dmat->Info().num_col_);
for (const auto &batch : dmat->GetBatches<SortedCSCPage>()) {
for (auto i = 0u; i < batch.Size(); i++) {
column_size[i] += batch[i].size();
}
}
column_densities_.resize(column_size.size());
for (auto i = 0u; i < column_densities_.size(); i++) {
size_t nmiss = dmat->Info().num_row_ - column_size[i];
column_densities_[i] =
1.0f - (static_cast<float>(nmiss)) / dmat->Info().num_row_;
}
}
}
void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
@@ -84,6 +102,7 @@ class ColMaker: public TreeUpdater {
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
"support distributed training.";
}
this->LazyGetColumnDensity(dmat);
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
@@ -94,7 +113,7 @@ class ColMaker: public TreeUpdater {
param_,
colmaker_param_,
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
interaction_constraints_);
interaction_constraints_, column_densities_);
builder.Update(gpair->ConstHostVector(), dmat, tree);
}
param_.learning_rate = lr;
@@ -106,6 +125,7 @@ class ColMaker: public TreeUpdater {
ColMakerTrainParam colmaker_param_;
// SplitEvaluator that will be cloned for each Builder
std::unique_ptr<SplitEvaluator> spliteval_;
std::vector<float> column_densities_;
FeatureInteractionConstraintHost interaction_constraints_;
// data structure
@@ -139,11 +159,13 @@ class ColMaker: public TreeUpdater {
explicit Builder(const TrainParam& param,
const ColMakerTrainParam& colmaker_train_param,
std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost _interaction_constraints)
FeatureInteractionConstraintHost _interaction_constraints,
const std::vector<float> &column_densities)
: param_(param), colmaker_train_param_{colmaker_train_param},
nthread_(omp_get_max_threads()),
spliteval_(std::move(spliteval)),
interaction_constraints_{std::move(_interaction_constraints)} {}
interaction_constraints_{std::move(_interaction_constraints)},
column_densities_(column_densities) {}
// update one tree, growing
virtual void Update(const std::vector<GradientPair>& gpair,
DMatrix* p_fmat,
@@ -433,22 +455,14 @@ class ColMaker: public TreeUpdater {
#endif // defined(_OPENMP)
{
std::vector<float> densities(num_features);
CHECK_EQ(feat_set.size(), num_features);
for (bst_omp_uint i = 0; i < num_features; ++i) {
bst_feature_t const fid = feat_set[i];
densities.at(i) = p_fmat->GetColDensity(fid);
}
#pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < num_features; ++i) {
bst_feature_t const fid = feat_set[i];
int32_t const tid = omp_get_thread_num();
auto c = batch[fid];
const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
auto const density = densities[i];
if (colmaker_train_param_.NeedForwardSearch(
param_.default_direction, density, ind)) {
param_.default_direction, column_densities_[fid], ind)) {
this->EnumerateSplit(c.data(), c.data() + c.size(), +1,
fid, gpair, stemp_[tid]);
}
@@ -598,6 +612,7 @@ class ColMaker: public TreeUpdater {
std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost interaction_constraints_;
const std::vector<float> &column_densities_;
};
};
@@ -620,11 +635,12 @@ class DistColMaker : public ColMaker {
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
this->LazyGetColumnDensity(dmat);
Builder builder(
param_,
colmaker_param_,
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
interaction_constraints_);
interaction_constraints_, column_densities_);
// build the tree
builder.Update(gpair->ConstHostVector(), dmat, trees[0]);
//// prune the tree, note that pruner will sync the tree
@@ -637,12 +653,14 @@ class DistColMaker : public ColMaker {
class Builder : public ColMaker::Builder {
public:
explicit Builder(const TrainParam &param,
ColMakerTrainParam const& colmaker_train_param,
ColMakerTrainParam const &colmaker_train_param,
std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost _interaction_constraints)
FeatureInteractionConstraintHost _interaction_constraints,
const std::vector<float> &column_densities)
: ColMaker::Builder(param, colmaker_train_param,
std::move(spliteval),
std::move(_interaction_constraints)) {}
std::move(_interaction_constraints),
column_densities) {}
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
#pragma omp parallel for schedule(static)