Add PushCSC for SparsePage. (#4193)

* Add PushCSC for SparsePage.

* Move Push* definitions into cc file.
* Add std:: prefix to `size_t` make clang++ happy.
* Address monitor count == 0.
This commit is contained in:
Jiaming Yuan
2019-03-02 01:58:08 +08:00
committed by GitHub
parent 74009afcac
commit 7ea5675679
8 changed files with 199 additions and 50 deletions

View File

@@ -317,6 +317,95 @@ data::SparsePageFormat::DecideFormat(const std::string& cache_prefix) {
}
}
void SparsePage::Push(const SparsePage &batch) {
auto& data_vec = data.HostVector();
auto& offset_vec = offset.HostVector();
const auto& batch_offset_vec = batch.offset.HostVector();
const auto& batch_data_vec = batch.data.HostVector();
size_t top = offset_vec.back();
data_vec.resize(top + batch.data.Size());
std::memcpy(dmlc::BeginPtr(data_vec) + top,
dmlc::BeginPtr(batch_data_vec),
sizeof(Entry) * batch.data.Size());
size_t begin = offset.Size();
offset_vec.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) {
offset_vec[i + begin] = top + batch_offset_vec[i + 1];
}
}
void SparsePage::Push(const dmlc::RowBlock<uint32_t>& batch) {
auto& data_vec = data.HostVector();
auto& offset_vec = offset.HostVector();
data_vec.reserve(data.Size() + batch.offset[batch.size] - batch.offset[0]);
offset_vec.reserve(offset.Size() + batch.size);
CHECK(batch.index != nullptr);
for (size_t i = 0; i < batch.size; ++i) {
offset_vec.push_back(offset_vec.back() + batch.offset[i + 1] - batch.offset[i]);
}
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
data_vec.emplace_back(index, fvalue);
}
CHECK_EQ(offset_vec.back(), data.Size());
}
void SparsePage::PushCSC(const SparsePage &batch) {
std::vector<xgboost::Entry>& self_data = data.HostVector();
std::vector<size_t>& self_offset = offset.HostVector();
auto const& other_data = batch.data.ConstHostVector();
auto const& other_offset = batch.offset.ConstHostVector();
if (other_data.empty()) {
return;
}
if (!self_data.empty()) {
CHECK_EQ(self_offset.size(), other_offset.size())
<< "self_data.size(): " << this->data.Size() << ", "
<< "other_data.size(): " << other_data.size() << std::flush;
} else {
self_data = other_data;
self_offset = other_offset;
return;
}
std::vector<size_t> offset(other_offset.size());
offset[0] = 0;
std::vector<xgboost::Entry> data(self_data.size() + batch.data.Size());
// n_cols in original csr data matrix, here in csc is n_rows
size_t const n_features = other_offset.size() - 1;
size_t beg = 0;
size_t ptr = 1;
for (size_t i = 0; i < n_features; ++i) {
size_t const self_beg = self_offset.at(i);
size_t const self_length = self_offset.at(i+1) - self_beg;
CHECK_LT(beg, data.size());
std::memcpy(dmlc::BeginPtr(data)+beg,
dmlc::BeginPtr(self_data) + self_beg,
sizeof(Entry) * self_length);
beg += self_length;
size_t const other_beg = other_offset.at(i);
size_t const other_length = other_offset.at(i+1) - other_beg;
CHECK_LT(beg, data.size());
std::memcpy(dmlc::BeginPtr(data)+beg,
dmlc::BeginPtr(other_data) + other_beg,
sizeof(Entry) * other_length);
beg += other_length;
CHECK_LT(ptr, offset.size());
offset.at(ptr) = beg;
ptr++;
}
self_data = std::move(data);
self_offset = std::move(offset);
}
namespace data {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);

View File

@@ -126,7 +126,7 @@ bool SparsePageSource::CacheExist(const std::string& cache_info,
}
void SparsePageSource::CreateRowPage(dmlc::Parser<uint32_t>* src,
const std::string& cache_info) {
const std::string& cache_info) {
const std::string page_type = ".row.page";
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
CHECK_NE(cache_shards.size(), 0U);
@@ -216,7 +216,8 @@ void SparsePageSource::CreateRowPage(dmlc::Parser<uint32_t>* src,
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
info.SaveBinary(fo.get());
}
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
LOG(CONSOLE) << "SparsePageSource::CreateRowPage Finished writing to "
<< name_info;
}
void SparsePageSource::CreatePageFromDMatrix(DMatrix* src,
@@ -246,9 +247,9 @@ void SparsePageSource::CreatePageFromDMatrix(DMatrix* src,
} else if (page_type == ".col.page") {
page->Push(batch.GetTranspose(src->Info().num_col_));
} else if (page_type == ".sorted.col.page") {
auto tmp = batch.GetTranspose(src->Info().num_col_);
tmp.SortRows();
page->Push(tmp);
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
page->PushCSC(tmp);
page->SortRows();
} else {
LOG(FATAL) << "Unknown page type: " << page_type;
}