Remove SimpleCSRSource (#5315)

This commit is contained in:
Rory Mitchell
2020-02-18 16:49:17 +13:00
committed by GitHub
parent 9f77c18b0d
commit b2b2c4e231
18 changed files with 121 additions and 286 deletions

View File

@@ -8,6 +8,7 @@
#include <xgboost/data.h>
#include "../common/random.h"
#include "./simple_dmatrix.h"
#include "../common/math.h"
#include "device_adapter.cuh"
namespace xgboost {
@@ -112,38 +113,36 @@ void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
mat.page_.offset.SetDevice(adapter->DeviceIdx());
mat.page_.data.SetDevice(adapter->DeviceIdx());
sparse_page_.offset.SetDevice(adapter->DeviceIdx());
sparse_page_.data.SetDevice(adapter->DeviceIdx());
// Enforce single batch
CHECK(!adapter->Next());
mat.page_.offset.Resize(adapter->NumRows() + 1);
auto s_offset = mat.page_.offset.DeviceSpan();
sparse_page_.offset.Resize(adapter->NumRows() + 1);
auto s_offset = sparse_page_.offset.DeviceSpan();
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back();
mat.page_.data.Resize(mat.info.num_nonzero_);
info.num_nonzero_ = sparse_page_.offset.HostVector().back();
sparse_page_.data.Resize(info.num_nonzero_);
if (adapter->IsRowMajor()) {
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(),
CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
} else {
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(),
CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
}
// Sync
mat.page_.data.HostVector();
sparse_page_.data.HostVector();
mat.info.num_col_ = adapter->NumColumns();
mat.info.num_row_ = adapter->NumRows();
info.num_col_ = adapter->NumColumns();
info.num_row_ = adapter->NumRows();
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
}
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,