Automatically remove nan from input data when it is sparse. (#2062)

* [DATALoad] Automatically remove Nan when load from sparse matrix

* add log
This commit is contained in:
Tianqi Chen
2017-02-25 08:59:17 -08:00
committed by GitHub
parent 5d093a7f4c
commit fd19b7a188
4 changed files with 41 additions and 20 deletions

View File

@@ -238,22 +238,31 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
API_BEGIN();
data::SimpleCSRSource& mat = *source;
mat.row_ptr_.resize(nindptr);
for (size_t i = 0; i < nindptr; ++i) {
mat.row_ptr_[i] = indptr[i];
}
mat.row_data_.resize(nelem);
for (size_t i = 0; i < nelem; ++i) {
mat.row_data_[i] = RowBatch::Entry(indices[i], data[i]);
mat.info.num_col = std::max(mat.info.num_col,
static_cast<uint64_t>(indices[i] + 1));
mat.row_ptr_.reserve(nindptr);
mat.row_data_.reserve(nelem);
mat.row_ptr_.resize(1);
mat.row_ptr_[0] = 0;
size_t num_column = 0;
for (size_t i = 1; i < nindptr; ++i) {
for (size_t j = indptr[i - 1]; j < indptr[i]; ++j) {
if (!common::CheckNAN(data[j])) {
// automatically skip nan.
mat.row_data_.emplace_back(RowBatch::Entry(indices[j], data[j]));
num_column = std::max(num_column, static_cast<size_t>(indices[j] + 1));
}
}
mat.row_ptr_.push_back(mat.row_data_.size());
}
mat.info.num_col = num_column;
if (num_col > 0) {
CHECK_LE(mat.info.num_col, num_col);
CHECK_LE(mat.info.num_col, num_col)
<< "num_col=" << num_col << " vs " << mat.info.num_col;
mat.info.num_col = num_col;
}
mat.info.num_row = nindptr - 1;
mat.info.num_nonzero = nelem;
LOG(INFO) << "num_row=" << mat.info.num_row;
mat.info.num_nonzero = mat.row_data_.size();
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
@@ -291,7 +300,9 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
for (omp_ulong i = 0; i < static_cast<omp_ulong>(ncol); ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
for (size_t j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.AddBudget(indices[j], tid);
if (!common::CheckNAN(data[j])) {
builder.AddBudget(indices[j], tid);
}
}
}
builder.InitStorage();
@@ -299,9 +310,11 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
for (omp_ulong i = 0; i < static_cast<omp_ulong>(ncol); ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
for (size_t j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.Push(indices[j],
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
tid);
if (!common::CheckNAN(data[j])) {
builder.Push(indices[j],
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
tid);
}
}
}
mat.info.num_row = mat.row_ptr_.size() - 1;