Dmatrix refactor stage 1 (#3301)

* Use sparse page as singular CSR matrix representation

* Simplify dmatrix methods

* Reduce statefullness of batch iterators

* BREAKING CHANGE: Remove prob_buffer_row parameter. Users are instead recommended to sample their dataset as a preprocessing step before using XGBoost.
This commit is contained in:
Rory Mitchell
2018-06-07 10:25:58 +12:00
committed by GitHub
parent 286dccb8e8
commit a96039141a
47 changed files with 650 additions and 1036 deletions

View File

@@ -238,20 +238,20 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.reserve(nindptr);
mat.row_data_.reserve(nelem);
mat.row_ptr_.resize(1);
mat.row_ptr_[0] = 0;
mat.page_.offset.reserve(nindptr);
mat.page_.data.reserve(nelem);
mat.page_.offset.resize(1);
mat.page_.offset[0] = 0;
size_t num_column = 0;
for (size_t i = 1; i < nindptr; ++i) {
for (size_t j = indptr[i - 1]; j < indptr[i]; ++j) {
if (!common::CheckNAN(data[j])) {
// automatically skip nan.
mat.row_data_.emplace_back(RowBatch::Entry(indices[j], data[j]));
mat.page_.data.emplace_back(Entry(indices[j], data[j]));
num_column = std::max(num_column, static_cast<size_t>(indices[j] + 1));
}
}
mat.row_ptr_.push_back(mat.row_data_.size());
mat.page_.offset.push_back(mat.page_.data.size());
}
mat.info.num_col_ = num_column;
@@ -261,7 +261,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
mat.info.num_col_ = num_col;
}
mat.info.num_row_ = nindptr - 1;
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@@ -293,7 +293,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
// FIXME: User should be able to control number of threads
const int nthread = omp_get_max_threads();
data::SimpleCSRSource& mat = *source;
common::ParallelGroupBuilder<RowBatch::Entry> builder(&mat.row_ptr_, &mat.row_data_);
common::ParallelGroupBuilder<Entry> builder(&mat.page_.offset, &mat.page_.data);
builder.InitBudget(0, nthread);
size_t ncol = nindptr - 1; // NOLINT(*)
#pragma omp parallel for schedule(static)
@@ -312,12 +312,12 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
for (size_t j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
if (!common::CheckNAN(data[j])) {
builder.Push(indices[j],
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
Entry(static_cast<bst_uint>(i), data[j]),
tid);
}
}
}
mat.info.num_row_ = mat.row_ptr_.size() - 1;
mat.info.num_row_ = mat.page_.offset.size() - 1;
if (num_row > 0) {
CHECK_LE(mat.info.num_row_, num_row);
mat.info.num_row_ = num_row;
@@ -351,7 +351,7 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
mat.page_.offset.resize(1+nrow);
bool nan_missing = common::CheckNAN(missing);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
@@ -371,9 +371,9 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
}
}
}
mat.row_ptr_[i+1] = mat.row_ptr_[i] + nelem;
mat.page_.offset[i+1] = mat.page_.offset[i] + nelem;
}
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
data = data0;
for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) {
@@ -382,14 +382,14 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
if (common::CheckNAN(data[j])) {
} else {
if (nan_missing || data[j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] = RowBatch::Entry(j, data[j]);
mat.page_.data[mat.page_.offset[i] + matj] = Entry(j, data[j]);
++matj;
}
}
}
}
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@@ -443,7 +443,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(1+nrow);
mat.page_.offset.resize(1+nrow);
mat.info.num_row_ = nrow;
mat.info.num_col_ = ncol;
@@ -469,7 +469,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
++nelem;
}
}
mat.row_ptr_[i+1] = nelem;
mat.page_.offset[i+1] = nelem;
}
}
// Inform about any NaNs and resize data matrix
@@ -478,8 +478,8 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
}
// do cumulative sum (to avoid otherwise need to copy)
PrefixSum(&mat.row_ptr_[0], mat.row_ptr_.size());
mat.row_data_.resize(mat.row_data_.size() + mat.row_ptr_.back());
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
// Fill data matrix (now that know size, no need for slow push_back())
#pragma omp parallel num_threads(nthread)
@@ -490,15 +490,15 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
if (common::CheckNAN(data[ncol * i + j])) {
} else if (nan_missing || data[ncol * i + j] != missing) {
mat.row_data_[mat.row_ptr_[i] + matj] =
RowBatch::Entry(j, data[ncol * i + j]);
mat.page_.data[mat.page_.offset[i] + matj] =
Entry(j, data[ncol * i + j]);
++matj;
}
}
}
}
mat.info.num_nonzero_ = mat.row_data_.size();
mat.info.num_nonzero_ = mat.page_.data.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@@ -521,18 +521,18 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
ret.info.num_row_ = len;
ret.info.num_col_ = src.info.num_col_;
dmlc::DataIter<RowBatch>* iter = &src;
auto iter = &src;
iter->BeforeFirst();
CHECK(iter->Next());
const RowBatch& batch = iter->Value();
const auto& batch = iter->Value();
for (xgboost::bst_ulong i = 0; i < len; ++i) {
const int ridx = idxset[i];
RowBatch::Inst inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.size);
ret.row_data_.insert(ret.row_data_.end(), inst.data,
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
ret.page_.data.insert(ret.page_.data.end(), inst.data,
inst.data + inst.length);
ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length);
ret.page_.offset.push_back(ret.page_.offset.back() + inst.length);
ret.info.num_nonzero_ += inst.length;
if (src.info.labels_.size() != 0) {