Move SimpleDMatrix constructor to .cc file (#5188)
This commit is contained in:
parent
9049c7c653
commit
9559f81377
@ -75,5 +75,96 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
||||||
|
|
||||||
|
template <typename AdapterT>
|
||||||
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||||
|
// Set number of threads but keep old value so we can reset it after
|
||||||
|
const int nthreadmax = omp_get_max_threads();
|
||||||
|
if (nthread <= 0) nthread = nthreadmax;
|
||||||
|
int nthread_original = omp_get_max_threads();
|
||||||
|
omp_set_num_threads(nthread);
|
||||||
|
|
||||||
|
source_.reset(new SimpleCSRSource());
|
||||||
|
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
|
||||||
|
std::vector<uint64_t> qids;
|
||||||
|
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
|
uint64_t last_group_id = default_max;
|
||||||
|
bst_uint group_size = 0;
|
||||||
|
auto& offset_vec = mat.page_.offset.HostVector();
|
||||||
|
auto& data_vec = mat.page_.data.HostVector();
|
||||||
|
uint64_t inferred_num_columns = 0;
|
||||||
|
|
||||||
|
adapter->BeforeFirst();
|
||||||
|
// Iterate over batches of input data
|
||||||
|
while (adapter->Next()) {
|
||||||
|
auto& batch = adapter->Value();
|
||||||
|
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
|
||||||
|
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
||||||
|
// Append meta information if available
|
||||||
|
if (batch.Labels() != nullptr) {
|
||||||
|
auto& labels = mat.info.labels_.HostVector();
|
||||||
|
labels.insert(labels.end(), batch.Labels(),
|
||||||
|
batch.Labels() + batch.Size());
|
||||||
|
}
|
||||||
|
if (batch.Weights() != nullptr) {
|
||||||
|
auto& weights = mat.info.weights_.HostVector();
|
||||||
|
weights.insert(weights.end(), batch.Weights(),
|
||||||
|
batch.Weights() + batch.Size());
|
||||||
|
}
|
||||||
|
if (batch.Qid() != nullptr) {
|
||||||
|
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||||
|
// get group
|
||||||
|
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||||
|
const uint64_t cur_group_id = batch.Qid()[i];
|
||||||
|
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||||
|
mat.info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
last_group_id = cur_group_id;
|
||||||
|
++group_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (last_group_id != default_max) {
|
||||||
|
if (group_size > mat.info.group_ptr_.back()) {
|
||||||
|
mat.info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with empty rows/columns if necessary
|
||||||
|
if (adapter->NumColumns() == kAdapterUnknownSize) {
|
||||||
|
mat.info.num_col_ = inferred_num_columns;
|
||||||
|
} else {
|
||||||
|
mat.info.num_col_ = adapter->NumColumns();
|
||||||
|
}
|
||||||
|
// Synchronise worker columns
|
||||||
|
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);
|
||||||
|
|
||||||
|
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||||
|
mat.info.num_row_ = offset_vec.size() - 1;
|
||||||
|
} else {
|
||||||
|
if (offset_vec.empty()) {
|
||||||
|
offset_vec.emplace_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (offset_vec.size() - 1 < adapter->NumRows()) {
|
||||||
|
offset_vec.emplace_back(offset_vec.back());
|
||||||
|
}
|
||||||
|
mat.info.num_row_ = adapter->NumRows();
|
||||||
|
}
|
||||||
|
mat.info.num_nonzero_ = data_vec.size();
|
||||||
|
omp_set_num_threads(nthread_original);
|
||||||
|
}
|
||||||
|
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
|
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
||||||
|
int nthread);
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -30,82 +30,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
: source_(std::move(source)) {}
|
: source_(std::move(source)) {}
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
|
||||||
// Set number of threads but keep old value so we can reset it after
|
|
||||||
const int nthreadmax = omp_get_max_threads();
|
|
||||||
if (nthread <= 0) nthread = nthreadmax;
|
|
||||||
int nthread_original = omp_get_max_threads();
|
|
||||||
omp_set_num_threads(nthread);
|
|
||||||
|
|
||||||
source_.reset(new SimpleCSRSource());
|
|
||||||
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
|
|
||||||
std::vector<uint64_t> qids;
|
|
||||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
|
||||||
uint64_t last_group_id = default_max;
|
|
||||||
bst_uint group_size = 0;
|
|
||||||
auto& offset_vec = mat.page_.offset.HostVector();
|
|
||||||
auto& data_vec = mat.page_.data.HostVector();
|
|
||||||
uint64_t inferred_num_columns = 0;
|
|
||||||
|
|
||||||
adapter->BeforeFirst();
|
|
||||||
// Iterate over batches of input data
|
|
||||||
while (adapter->Next()) {
|
|
||||||
auto& batch = adapter->Value();
|
|
||||||
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
|
|
||||||
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
|
||||||
// Append meta information if available
|
|
||||||
if (batch.Labels() != nullptr) {
|
|
||||||
auto& labels = mat.info.labels_.HostVector();
|
|
||||||
labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size());
|
|
||||||
}
|
|
||||||
if (batch.Weights() != nullptr) {
|
|
||||||
auto& weights = mat.info.weights_.HostVector();
|
|
||||||
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
|
|
||||||
}
|
|
||||||
if (batch.Qid() != nullptr) {
|
|
||||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
|
||||||
// get group
|
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
|
||||||
const uint64_t cur_group_id = batch.Qid()[i];
|
|
||||||
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
|
||||||
mat.info.group_ptr_.push_back(group_size);
|
|
||||||
}
|
|
||||||
last_group_id = cur_group_id;
|
|
||||||
++group_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (last_group_id != default_max) {
|
|
||||||
if (group_size > mat.info.group_ptr_.back()) {
|
|
||||||
mat.info.group_ptr_.push_back(group_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deal with empty rows/columns if necessary
|
|
||||||
if (adapter->NumColumns() == kAdapterUnknownSize) {
|
|
||||||
mat.info.num_col_ = inferred_num_columns;
|
|
||||||
} else {
|
|
||||||
mat.info.num_col_ = adapter->NumColumns();
|
|
||||||
}
|
|
||||||
// Synchronise worker columns
|
|
||||||
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);
|
|
||||||
|
|
||||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
|
||||||
mat.info.num_row_ = offset_vec.size() - 1;
|
|
||||||
} else {
|
|
||||||
if (offset_vec.empty()) {
|
|
||||||
offset_vec.emplace_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (offset_vec.size() - 1 < adapter->NumRows()) {
|
|
||||||
offset_vec.emplace_back(offset_vec.back());
|
|
||||||
}
|
|
||||||
mat.info.num_row_ = adapter->NumRows();
|
|
||||||
}
|
|
||||||
mat.info.num_nonzero_ = data_vec.size();
|
|
||||||
omp_set_num_threads(nthread_original);
|
|
||||||
}
|
|
||||||
|
|
||||||
MetaInfo& Info() override;
|
MetaInfo& Info() override;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user